二つのリンクをつなげたハンドのリーチングを学習させました。

この時の設定やプログラムの忘備録です。

以下、作成したゲームオブジェクトです。

リーてぃんぐのTargetはシンプルです。cube を作っただけで、Rigid Bodyは作っていません。Box Collidorはついています(はじめからついていたんだっけ?)。

ハンドは、中心の土台のBaseと、Link1、Link2、そして青の先端部 Endeffectorからなります。

Link1は、HingeでBaseにつながっています。
Link2は、HingeでLink1につながっています。
Endeffectorは、FixedJointでLink2につながっています。

Link1のHingeの設定です。

Use Spring にチェックを入れ、Springの設定をしています。
プログラムから、Target Positionを動かすことでハンドを動かします。

Link2も同様です。

Agentの設定をしているのは、Endeffectorです。

Script は、Roller Agent です(エージェント名が初めにやったもののままでした)。

]using System.Collections.Generic;
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Sensors;
using Unity.MLAgents.Actuators;
 
 
public class RollerAgent : Agent
{
    public Transform TF_Target;
    public GameObject GO_Link1, GO_Link2;
    public float speed;

    private Rigidbody RB_Link1, RB_Link2, RB_Endeffector;
    private HingeJoint Hinge1, Hinge2;
    private Vector3 init_pos1, init_pos2, init_posE;
    private Quaternion init_rot1, init_rot2, init_rotE;
    private float angle1, angle2;
 
    public override void Initialize()
    {
        RB_Link1 = GO_Link1.GetComponent<Rigidbody>();
        init_pos1 = RB_Link1.transform.localPosition;
        init_rot1 = RB_Link1.transform.rotation;

        RB_Link2 = GO_Link2.GetComponent<Rigidbody>();
        init_pos2 = RB_Link2.transform.localPosition;
        init_rot2 = RB_Link2.transform.rotation;

        RB_Endeffector = this.GetComponent<Rigidbody>();
        init_posE = RB_Endeffector.transform.localPosition;
        init_rotE = RB_Endeffector.transform.rotation;

        Hinge1 = GO_Link1.GetComponent<HingeJoint>();
        Hinge2 = GO_Link2.GetComponent<HingeJoint>();

        angle1 = 0;
        angle2 = 0;
        
    }
 
    public override void OnEpisodeBegin()
    {
        if (TF_Target.localPosition.y < 0)
        {
            RB_Link1.transform.localPosition = init_pos1;
            RB_Link2.transform.localPosition = init_pos2;
            RB_Endeffector.transform.localPosition = init_posE;

            RB_Link1.velocity = Vector3.zero;
            RB_Link1.angularVelocity = Vector3.zero;
            RB_Link2.velocity = Vector3.zero;
            RB_Link2.angularVelocity = Vector3.zero;
            RB_Endeffector.velocity = Vector3.zero;
            RB_Endeffector.angularVelocity = Vector3.zero;

        }

        while(true){
            TF_Target.localPosition = new Vector3(
                Random.value * 8 - 4, 0.5f, Random.value * 8 - 4);
            float distance = Vector3.Distance(
                TF_Target.localPosition, Vector3.zero);
            if (distance < 4.00f){
                break;
            }
        }
    }
 
    public override void CollectObservations(VectorSensor sensor)
    {
        sensor.AddObservation(TF_Target.localPosition);
        // sensor.AddObservation(RB_Endeffector.velocity);
        sensor.AddObservation(Hinge1.spring.targetPosition / 180);
        sensor.AddObservation(Hinge2.spring.targetPosition / 180);
    }
 
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        float a1 = actionBuffers.ContinuousActions[0];
        float a2 = actionBuffers.ContinuousActions[1];

        angle1 = speed * a1;
        angle2 = speed * a2;

        if(angle1 < -180){
            angle1 += 360;
        }
        if(angle1 > 180){
            angle1 -= 360;
        }
        if(angle2 < -180){
            angle2 += 360;
        }
        if(angle2 > 180){
            angle2 -= 360;
        }
        JointSpring hs1 = Hinge1.spring;
        hs1.targetPosition = angle1;
        Hinge1.spring = hs1;

        JointSpring hs2 = Hinge2.spring;
        hs2.targetPosition = angle2;
        Hinge2.spring = hs2;

        float distanceToTarget = Vector3.Distance(
            TF_Target.localPosition, 
            RB_Endeffector.transform.localPosition);
        if (distanceToTarget < 1.42f)
        {
            AddReward(1.0f);
            EndEpisode();
        }
        if (this.transform.localPosition.y < 0)
        {
            EndEpisode();
        }
    }
 
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        float a1, a2;
        a1 = 0;
        a2 = 0;
        var continuousActionsOut = actionsOut.ContinuousActions;
        if (Input.GetKey(KeyCode.Alpha1)){
            a1 = -0.5f;
        }
        if (Input.GetKey(KeyCode.Alpha2)){
            a1 = 0.5f;
        }
        if (Input.GetKey(KeyCode.Alpha3)){
            a2 = -0.5f;
        }
        if (Input.GetKey(KeyCode.Alpha4)){
            a2 = 0.5f;
        }
        continuousActionsOut[0] = a1;
        continuousActionsOut[1] = a2;
    }
}

Agentの観測は、Targetの座標と、Hinge1、Hinge2の角度です(Target Position)。

設定ファイルは、
ML-Agents/config/config.yaml に保存しました。

behaviors:
  BallRoller:

    trainer_type: ppo
    hyperparameters:
      batch_size: 10
      buffer_size: 100
      learning_rate: 3.0e-4
      beta: 5.0e-4
      epsilon: 0.2
      lambd: 0.99
      num_epoch: 3
      learning_rate_schedule: linear
      beta_schedule: constant
      epsilon_schedule: linear
    network_settings:
      normalize: false
      hidden_units: 128
      num_layers: 2
    reward_signals:
      extrinsic:
        gamma: 0.99
        strength: 1.0
    max_steps: 500000
    time_horizon: 64
    summary_freq: 10000

動作を確認してから、
全体をプレファブにして、16個並べました。

[code]
mlagents-learn .\config\config.yaml –run-id=multi
[/code]
PPOを使って学習させたら、すぐに学習できました(10分くらいだったような)。

[code]
mlagents-learn .\config\config.yaml –run-id=multi –resume
[/code]

resumeを使うと、追加で学習させることができます。