超详细的Unity小白的 ML-Agents(Release 22)学习记录2:3D平衡球案例
超详细的Unity小白的 ML-Agents(Release 22)学习记录2:3D平衡球案例
0. 资料来源
官方Github文档:https://github.com/Unity-Technologies/ml-agents/blob/develop/docs/Getting-Started.md
1. 基础准备
① 完成 Unity (2023.2.13f1) 版本的安装 与 ML Agents 和 ML Agents Extensions 包的安装,Ubuntu系统的可参考我之前发表的博客:
https://blog.csdn.net/CDSN985144132/article/details/145768687?sharetype=blogdetail&sharerId=145768687&sharerefer=PC&sharesource=CDSN985144132&spm=1011.2480.3001.8118
② 安装支持 C# 的代码编辑软件,并连接 Unity Editor
Windows 系统安装 Visual Studio 用于修改 Unity 内的 C# 脚本 (推荐 Windows 系统方案)。
Ubuntu 可采用 VS Code (安装 .NET框架 以编写C#,比较麻烦,且 Unity Editor 界面字体小,官方确认Ubuntu 无UI Scaling功能)。
Visual Studio 与 Unity Editor 的连接方法参考以下博客:
https://blog.csdn.net/xks18232047575/article/details/134545094
③ Unity Editor、Ml-agents 、TensorBoard 的基本操作可参考我的上一篇教程:
https://blog.csdn.net/CDSN985144132/article/details/145779969
2. 打开项目文件
① 在Unity Hub 中:
Projects → Open → Add Projects from disk → 选择 ml-agents/Projects 根目录,进入 Unity Editor
② 在 Unity Editor 中:
Project → Assets → ML-Agents → Examples → 3D Ball → Scenes
3. Unity环境核心概念
3.1 环境构成要素
① 场景 Scene :包含Agent及其交互对象的虚拟空间
② 游戏对象 GameObject :所有实体的基础容器,通过Inspector窗口可查看组件构成
③ 组件系统:通过添加不同组件(如Transform、Rigidbody等)定义对象行为
3.2 多 Agent 设计
① 3D Balance Ball 场景包含 12 个相同 Agent 立方体
② 并行训练优势:
数据采集效率提升 12 倍
增强策略泛化能力
加速收敛(单个 Agent 平均500万步→12 个 Agen t约42万步)
4. Agent 核心概念
4.1 角色定义
① Agent是环境中执行观察和动作的实体。
② 在 3D Balance Ball 环境中,12个名为 Agent 的 GameObject 上挂载了 Agent 组件。
4.2 关键属性
① Behavior Parameters(行为参数):
每个Agent必须配置,决定Agent如何做出决策。
② Max Step(最大步数):
定义Agent单次训练周期(episode)的最大步数(3D Balance Ball中为5000步),超时后环境重置。
4.3 观察空间(Vector Observation Space)
① 空间维度: 8维向量,即包含以下 8 项信息。
② 包含的观测信息:
Agent立方体的旋转分量(x, z轴),共2项。
球体相对于Agent的位置(x, y, z坐标),共 3 项。
球体相对于Agent的速度(x, y, z分量),共 3 项。
4.4 动作空间(Actions)
① 动作类型:连续型动作(Continuous Actions)。
② 动作维度:2 维向量。
③ 动作功能:控制 Agent 绕 x ,z 轴旋转,以保持球体平衡。
5. 理解 Ball3DAgent.cs 和 Ball3DHardAgent.cs 代码(已添加注释):
在Unity Editor 的 Projects 窗口找到 Assets\Example\3DBall\Scripts,可看到这两个代码,双击打开可编辑。
※ 以下添加注释的版本仅供大家学习理解,正式运行如报错,可采用官方版本。
5.1 Ball3DAgent.cs (已添加注释):
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;
public class Ball3DAgent : Agent
{
[Header("Specific to Ball3D")] //[Header("3D平衡球专用参数")]
public GameObject ball; // 被平衡的球体对象
[Tooltip("Whether to use vector observation. This option should be checked " +
"in 3DBall scene, and unchecked in Visual3DBall scene. ")]//[Tooltip("是否使用向量观测。在3DBall场景中应勾选,在Visual3DBall场景中取消勾选")]
public bool useVecObs; // 切换观测模式(向量/视觉)
Rigidbody m_BallRb; // 球体的刚体组件
EnvironmentParameters m_ResetParams; // 环境重置参数
// 初始化方法
public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>(); // 获取球体刚体
m_ResetParams = Academy.Instance.EnvironmentParameters; // 连接学院参数
SetResetParameters(); // 初始参数设置
}
// 收集观测数据(Agent的"眼睛")
public override void CollectObservations(VectorSensor sensor)
{
if (useVecObs)
{
// 观测值1:Agent自身Z轴旋转角度(范围[-1,1])
sensor.AddObservation(gameObject.transform.rotation.z);
// 观测值2:Agent自身X轴旋转角度(范围[-1,1])
sensor.AddObservation(gameObject.transform.rotation.x);
// 观测值3-5:球体相对于Agent的位置(三维向量)
sensor.AddObservation(ball.transform.position - gameObject.transform.position);
// 观测值6-8:球体的速度(三维向量)
sensor.AddObservation(m_BallRb.velocity);
}
}
// 执行动作并计算奖励(Agent的"大脑决策")
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// 解析连续动作(范围[-2, 2])
var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);
// Z轴旋转控制(限制在±25度范围内)
if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
(gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
{
gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
}
// X轴旋转控制(限制在±25度范围内)
if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
(gameObject.transform.rotation.x > -0.25f && actionX < 0f))
{
gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
}
// 失败条件检测:
// 1. 球体低于Agent 2个单位
// 2. 球体在X/Z方向偏移超过3个单位
if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
{
SetReward(-1f); // 惩罚:给予-1奖励
EndEpisode(); // 结束当前训练回合
}
else
{
SetReward(0.1f); // 持续奖励:每步+0.1奖励
}
}
// 回合开始时重置环境
public override void OnEpisodeBegin()
{
// 重置Agent旋转(添加随机初始倾斜)
gameObject.transform.rotation = Quaternion.identity;
gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
// 重置球体状态
m_BallRb.velocity = Vector3.zero; // 清除速度
ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f))
+ gameObject.transform.position; // 随机初始位置
SetResetParameters(); // 应用环境参数重置
}
// 人工控制模式(用于测试)
public override void Heuristic(in ActionBuffers actionsOut)
{
var continuousActionsOut = actionsOut.ContinuousActions;
// 键盘映射:左右箭头控制Z轴,上下箭头控制X轴
continuousActionsOut[0] = -Input.GetAxis("Horizontal"); // 左右→Z轴旋转
continuousActionsOut[1] = Input.GetAxis("Vertical"); // 上下→X轴旋转
}
// 设置球体物理属性
void SetBall()
{
// 从学院参数获取质量(默认1.0)
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
// 获取缩放比例(默认1.0)
var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
ball.transform.localScale = new Vector3(scale, scale, scale);
}
// 重置环境参数
void SetResetParameters()
{
SetBall(); // 每次重置时更新球体属性
}
}
5.2 Ball3DHardAgent.cs (已添加注释):
using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors.Reflection;
// 3D平衡球困难模式智能体控制脚本
// 核心功能:通过旋转平台保持球体平衡,球体掉落或偏移时重置训练周期
public class Ball3DHardAgent : Agent
{
[Header("平衡球配置")]
public GameObject ball; // 需要平衡的球体对象
private Rigidbody m_BallRb; // 球体的刚体物理组件
private EnvironmentParameters m_ResetParams; // 环境动态参数控制器
// 初始化智能体(ML-Agents框架生命周期方法)
public override void Initialize()
{
m_BallRb = ball.GetComponent<Rigidbody>();
m_ResetParams = Academy.Instance.EnvironmentParameters;
SetResetParameters(); // 应用环境参数初始化
}
/* 观测空间定义
-----------------------------------------------------------
观测值1:平台旋转角度(Z轴和X轴弧度值)
numStackedObservations=9 表示堆叠最近9帧数据*/
[Observable(numStackedObservations: 9)]
Vector2 Rotation => new Vector2(
transform.rotation.z, // Z轴旋转(左右倾斜)
transform.rotation.x // X轴旋转(前后倾斜)
);
// 观测值2:球体相对于平台的坐标偏移量(XYZ三维向量)
[Observable(numStackedObservations: 9)]
Vector3 PositionDelta => ball.transform.position - transform.position;
// 动作处理核心方法(ML-Agents每物理帧调用)
public override void OnActionReceived(ActionBuffers actionBuffers)
{
// 解析连续动作空间的两个维度(范围[-1,1])
var actions = actionBuffers.ContinuousActions;
float rotateZ = 2f * Mathf.Clamp(actions[0], -1f, 1f); // 放大到[-2°,2°]
float rotateX = 2f * Mathf.Clamp(actions[1], -1f, 1f);
// 安全旋转控制(Z轴左右倾斜)
if ((transform.rotation.z < 0.25f && rotateZ > 0f) ||
(transform.rotation.z > -0.25f && rotateZ < 0f))
{
transform.Rotate(0, 0, 1, rotateZ); // 绕Z轴旋转
}
// 安全旋转控制(X轴前后倾斜)
if ((transform.rotation.x < 0.25f && rotateX > 0f) ||
(transform.rotation.x > -0.25f && rotateX < 0f))
{
transform.Rotate(1, 0, 0, rotateX); // 绕X轴旋转
}
/* 终止条件判断
---------------------------------------------------
球体位置检测:下落超过2米或水平偏移超过3米*/
Vector3 ballPos = ball.transform.position;
bool isFalling = (ballPos.y - transform.position.y) < -2f;
bool isOutOfRangeX = Mathf.Abs(ballPos.x - transform.position.x) > 3f;
bool isOutOfRangeZ = Mathf.Abs(ballPos.z - transform.position.z) > 3f;
if (isFalling || isOutOfRangeX || isOutOfRangeZ)
{
SetReward(-1f); // 失败惩罚
EndEpisode(); // 终止当前训练周期
}
else
{
SetReward(0.1f); // 持续平衡奖励
}
}
// 训练周期重置方法(每次重新开始时调用)
public override void OnEpisodeBegin()
{
// 随机初始化平台倾斜角度(X/Z轴±10°范围内)
transform.rotation = Quaternion.identity;
transform.Rotate(Random.Range(-10f, 10f), 0, 0); // X轴倾斜
transform.Rotate(0, 0, Random.Range(-10f, 10f)); // Z轴倾斜
// 重置球体状态
m_BallRb.velocity = Vector3.zero; // 清除运动惯性
ball.transform.position = transform.position + new Vector3(
Random.Range(-1.5f, 1.5f), // 水平随机位置
4f, // 初始高度4米
Random.Range(-1.5f, 1.5f)
);
}
// 动态调整球体物理参数(质量、尺寸)
void SetBall()
{
m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f); // 质量参数
float scale = m_ResetParams.GetWithDefault("scale", 1.0f); // 尺寸参数
ball.transform.localScale = Vector3.one * scale;
}
// 环境参数应用入口
void SetResetParameters() => SetBall();
}
5.3 Ball3DAgent.cs 和 Ball3DHardAgent.cs 的区别 (个人小总结,不一定准确)
① 观测值收集方式:
Ball3DAgent: 使用传统手动添加观测的方式,通过 CollectObservations 方法显式收集。
sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
// ...其他观测
通过 useVecObs 开关控制是否启用向量观测,适合需要动态切换观测模式的场景。
Ball3DHardAgent: 使用 反射传感器(Reflection Sensors),通过 [Observable] 属性标记自动收集观测。
[Observable(numStackedObservations: 9)]
Vector2 Rotation { get { /* 平台旋转 */ } }
[Observable(numStackedObservations: 9)]
Vector3 PositionDelta { get { /* 球位置差 */ } }
numStackedObservations:9 表示堆叠过去 9 帧的观测数据,为模型提供时序信息。
② 观测数据类型:
Ball3DAgent:
包含 4 个观测值:平台旋转(z/x)、球相对位置、球速度。
Ball3DHardAgent:
仅包含 2 个观测属性(但通过堆叠扩展为 9 帧历史)。
③ 手动控制支持:
Ball3DAgent:
实现 Heuristic 方法,支持通过键盘输入(Horizontal/Vertical 轴)手动控制。
Ball3DHardAgent:
未实现 Heuristic 方法,无法直接通过键盘控制,仅适用于自动训练。
④ 总结:
Ball3DAgent:
更灵活,支持手动控制、动态参数重置和观测模式切换,适合基础训练和调试。
Ball3DHardAgent:
通过堆叠时序观测简化了代码,但可能因缺少速度和动态参数重置增加训练难度,适合复杂环境下的时序依赖性训练。
6. 后续训练过程详见之前发布的博客:
https://blog.csdn.net/CDSN985144132/article/details/145779969
7. 训练超参数位置:
① ml-agents的存放根目录\ml-agents\config\ppo
② 找到对应的 yaml 文件位置,开始训练。
更多推荐
所有评论(0)