二つのリンクをつなげたハンドのリーチングを学習させました。
この時の設定やプログラムの忘備録です。
以下、作成したゲームオブジェクトです。
リーてぃんぐのTargetはシンプルです。cube を作っただけで、Rigid Bodyは作っていません。Box Collidorはついています(はじめからついていたんだっけ?)。
ハンドは、中心の土台のBaseと、Link1、Link2、そして青の先端部 Endeffectorからなります。
Link1は、HingeでBaseにつながっています。
Link2は、HingeでLink1につながっています。
Endeffectorは、FixedJointでLink2につながっています。
Link1のHingeの設定です。
Use Spring にチェックを入れ、Springの設定をしています。
プログラムから、Target Positionを動かすことでハンドを動かします。
Link2も同様です。
Agentの設定をしているのは、Endeffectorです。
Script は、Roller Agent です(エージェント名が初めにやったもののままでした)。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
]using System.Collections.Generic; using UnityEngine; using Unity.MLAgents; using Unity.MLAgents.Sensors; using Unity.MLAgents.Actuators; public class RollerAgent : Agent { public Transform TF_Target; public GameObject GO_Link1, GO_Link2; public float speed; private Rigidbody RB_Link1, RB_Link2, RB_Endeffector; private HingeJoint Hinge1, Hinge2; private Vector3 init_pos1, init_pos2, init_posE; private Quaternion init_rot1, init_rot2, init_rotE; private float angle1, angle2; public override void Initialize() { RB_Link1 = GO_Link1.GetComponent<Rigidbody>(); init_pos1 = RB_Link1.transform.localPosition; init_rot1 = RB_Link1.transform.rotation; RB_Link2 = GO_Link2.GetComponent<Rigidbody>(); init_pos2 = RB_Link2.transform.localPosition; init_rot2 = RB_Link2.transform.rotation; RB_Endeffector = this.GetComponent<Rigidbody>(); init_posE = RB_Endeffector.transform.localPosition; init_rotE = RB_Endeffector.transform.rotation; Hinge1 = GO_Link1.GetComponent<HingeJoint>(); Hinge2 = GO_Link2.GetComponent<HingeJoint>(); angle1 = 0; angle2 = 0; } public override void OnEpisodeBegin() { if (TF_Target.localPosition.y < 0) { RB_Link1.transform.localPosition = init_pos1; RB_Link2.transform.localPosition = init_pos2; RB_Endeffector.transform.localPosition = init_posE; RB_Link1.velocity = Vector3.zero; RB_Link1.angularVelocity = Vector3.zero; RB_Link2.velocity = Vector3.zero; RB_Link2.angularVelocity = Vector3.zero; RB_Endeffector.velocity = Vector3.zero; RB_Endeffector.angularVelocity = Vector3.zero; } while(true){ TF_Target.localPosition = new Vector3( Random.value * 8 - 4, 0.5f, Random.value * 8 - 4); float distance = Vector3.Distance( TF_Target.localPosition, Vector3.zero); if (distance < 4.00f){ break; } } } public override void CollectObservations(VectorSensor sensor) { sensor.AddObservation(TF_Target.localPosition); // sensor.AddObservation(RB_Endeffector.velocity); sensor.AddObservation(Hinge1.spring.targetPosition / 180); sensor.AddObservation(Hinge2.spring.targetPosition / 180); } public override void OnActionReceived(ActionBuffers actionBuffers) { float a1 = actionBuffers.ContinuousActions[0]; float a2 = actionBuffers.ContinuousActions[1]; angle1 = speed * a1; angle2 = speed * a2; if(angle1 < -180){ angle1 += 360; } if(angle1 > 180){ angle1 -= 360; } if(angle2 < -180){ angle2 += 360; } if(angle2 > 180){ angle2 -= 360; } JointSpring hs1 = Hinge1.spring; hs1.targetPosition = angle1; Hinge1.spring = hs1; JointSpring hs2 = Hinge2.spring; hs2.targetPosition = angle2; Hinge2.spring = hs2; float distanceToTarget = Vector3.Distance( TF_Target.localPosition, RB_Endeffector.transform.localPosition); if (distanceToTarget < 1.42f) { AddReward(1.0f); EndEpisode(); } if (this.transform.localPosition.y < 0) { EndEpisode(); } } public override void Heuristic(in ActionBuffers actionsOut) { float a1, a2; a1 = 0; a2 = 0; var continuousActionsOut = actionsOut.ContinuousActions; if (Input.GetKey(KeyCode.Alpha1)){ a1 = -0.5f; } if (Input.GetKey(KeyCode.Alpha2)){ a1 = 0.5f; } if (Input.GetKey(KeyCode.Alpha3)){ a2 = -0.5f; } if (Input.GetKey(KeyCode.Alpha4)){ a2 = 0.5f; } continuousActionsOut[0] = a1; continuousActionsOut[1] = a2; } } |
Agentの観測は、Targetの座標と、Hinge1、Hinge2の角度です(Target Position)。
設定ファイルは、
ML-Agents/config/config.yaml に保存しました。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
behaviors: BallRoller: trainer_type: ppo hyperparameters: batch_size: 10 buffer_size: 100 learning_rate: 3.0e-4 beta: 5.0e-4 epsilon: 0.2 lambd: 0.99 num_epoch: 3 learning_rate_schedule: linear beta_schedule: constant epsilon_schedule: linear network_settings: normalize: false hidden_units: 128 num_layers: 2 reward_signals: extrinsic: gamma: 0.99 strength: 1.0 max_steps: 500000 time_horizon: 64 summary_freq: 10000 |
動作を確認してから、
全体をプレファブにして、16個並べました。
PPOを使って学習させたら、すぐに学習できました(10分くらいだったような)。
resumeを使うと、追加で学習させることができます。
コメントを残す