二つのリンクをつなげたハンドのリーチングを学習させました。
この時の設定やプログラムの忘備録です。
以下、作成したゲームオブジェクトです。
リーてぃんぐのTargetはシンプルです。cube を作っただけで、Rigid Bodyは作っていません。Box Collidorはついています(はじめからついていたんだっけ?)。
ハンドは、中心の土台のBaseと、Link1、Link2、そして青の先端部 Endeffectorからなります。
Link1は、HingeでBaseにつながっています。
Link2は、HingeでLink1につながっています。
Endeffectorは、FixedJointでLink2につながっています。
Link1のHingeの設定です。
Use Spring にチェックを入れ、Springの設定をしています。
プログラムから、Target Positionを動かすことでハンドを動かします。
Link2も同様です。
Agentの設定をしているのは、Endeffectorです。
Script は、Roller Agent です(エージェント名が初めにやったもののままでした)。
]using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; public class RollerAgent : Agent { public Transform TF_Target; public GameObject GO_Link1, GO_Link2; public float speed; private Rigidbody RB_Link1, RB_Link2, RB_Endeffector; private HingeJoint Hinge1, Hinge2; private Vector3 init_pos1, init_pos2, init_posE; private Quaternion init_rot1, init_rot2, init_rotE; private float angle1, angle2; public override void Initialize() { RB_Link1 = GO_Link1.GetComponent<Rigidbody>(); init_pos1 = RB_Link1.transform.localPosition; init_rot1 = RB_Link1.transform.rotation; RB_Link2 = GO_Link2.GetComponent<Rigidbody>(); init_pos2 = RB_Link2.transform.localPosition; init_rot2 = RB_Link2.transform.rotation; RB_Endeffector = this.GetComponent<Rigidbody>(); init_posE = RB_Endeffector.transform.localPosition; init_rotE = RB_Endeffector.transform.rotation; Hinge1 = GO_Link1.GetComponent<HingeJoint>(); Hinge2 = GO_Link2.GetComponent<HingeJoint>(); angle1 = 0; angle2 = 0; } public override void OnEpisodeBegin() { if (TF_Target.localPosition.y < 0) { RB_Link1.transform.localPosition = init_pos1; RB_Link2.transform.localPosition = init_pos2; RB_Endeffector.transform.localPosition = init_posE; RB_Link1.velocity = Vector3.zero; RB_Link1.angularVelocity = Vector3.zero; RB_Link2.velocity = Vector3.zero; RB_Link2.angularVelocity = Vector3.zero; RB_Endeffector.velocity = Vector3.zero; RB_Endeffector.angularVelocity = Vector3.zero; } while(true){ TF_Target.localPosition = new Vector3( Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); float distance = Vector3.Distance( TF_Target.localPosition, Vector3.zero); if (distance < 4.00f){ break; } } } public override void CollectObservations(VectorSensor sensor) { sensor.AddObservation(TF_Target.localPosition); // sensor.AddObservation(RB_Endeffector.velocity); sensor.AddObservation(Hinge1.spring.targetPosition / 180); sensor.AddObservation(Hinge2.spring.targetPosition / 180); } public override void OnActionReceived(ActionBuffers actionBuffers) { float a1 = actionBuffers.ContinuousActions[0]; float a2 = actionBuffers.ContinuousActions[1]; angle1 = speed * a1; angle2 = speed * a2; if(angle1 < -180){ angle1 += 360; } if(angle1 > 180){ angle1 -= 360; } if(angle2 < -180){ angle2 += 360; } if(angle2 > 180){ angle2 -= 360; } JointSpring hs1 = Hinge1.spring; hs1.targetPosition = angle1; Hinge1.spring = hs1; JointSpring hs2 = Hinge2.spring; hs2.targetPosition = angle2; Hinge2.spring = hs2; float distanceToTarget = Vector3.Distance( TF_Target.localPosition, RB_Endeffector.transform.localPosition); if (distanceToTarget < 1.42f) { AddReward(1.0f); EndEpisode(); } if (this.transform.localPosition.y < 0) { EndEpisode(); } } public override void Heuristic(in ActionBuffers actionsOut) { float a1, a2; a1 = 0; a2 = 0; var continuousActionsOut = actionsOut.ContinuousActions; if (Input.GetKey(KeyCode.Alpha1)){ a1 = -0.5f; } if (Input.GetKey(KeyCode.Alpha2)){ a1 = 0.5f; } if (Input.GetKey(KeyCode.Alpha3)){ a2 = -0.5f; } if (Input.GetKey(KeyCode.Alpha4)){ a2 = 0.5f; } continuousActionsOut[0] = a1; continuousActionsOut[1] = a2; } }
Agentの観測は、Targetの座標と、Hinge1、Hinge2の角度です(Target Position)。
設定ファイルは、
ML-Agents/config/config.yaml に保存しました。
behaviors: BallRoller: trainer_type: ppo hyperparameters: batch_size: 10 buffer_size: 100 learning_rate: 3.0e-4 beta: 5.0e-4 epsilon: 0.2 lambd: 0.99 num_epoch: 3 learning_rate_schedule: linear beta_schedule: constant epsilon_schedule: linear network_settings: normalize: false hidden_units: 128 num_layers: 2 reward_signals: extrinsic: gamma: 0.99 strength: 1.0 max_steps: 500000 time_horizon: 64 summary_freq: 10000
動作を確認してから、
全体をプレファブにして、16個並べました。
[code]
mlagents-learn .\config\config.yaml –run-id=multi
[/code]
PPOを使って学習させたら、すぐに学習できました(10分くらいだったような)。
[code]
mlagents-learn .\config\config.yaml –run-id=multi –resume
[/code]
resumeを使うと、追加で学習させることができます。
コメントを残す