0. 资料来源

官方Github文档:https://github.com/Unity-Technologies/ml-agents/blob/develop/docs/Getting-Started.md

请添加图片描述

1. 基础准备

① 完成 Unity (2023.2.13f1) 版本的安装 与 ML Agents 和 ML Agents Extensions 包的安装,Ubuntu系统的可参考我之前发表的博客:
https://blog.csdn.net/CDSN985144132/article/details/145768687?sharetype=blogdetail&sharerId=145768687&sharerefer=PC&sharesource=CDSN985144132&spm=1011.2480.3001.8118

② 安装支持 C# 的代码编辑软件,并连接 Unity Editor
Windows 系统安装 Visual Studio 用于修改 Unity 内的 C# 脚本 (推荐 Windows 系统方案)。
Ubuntu 可采用 VS Code (安装 .NET框架 以编写C#,比较麻烦,且 Unity Editor 界面字体小,官方确认Ubuntu 无UI Scaling功能)。

Visual Studio 与 Unity Editor 的连接方法参考以下博客:
https://blog.csdn.net/xks18232047575/article/details/134545094

③ Unity Editor、Ml-agents 、TensorBoard 的基本操作可参考我的上一篇教程:
https://blog.csdn.net/CDSN985144132/article/details/145779969

2. 打开项目文件

① 在Unity Hub 中:
Projects → Open → Add Projects from disk → 选择 ml-agents/Projects 根目录,进入 Unity Editor

请添加图片描述② 在 Unity Editor 中:
Project → Assets → ML-Agents → Examples → 3D Ball → Scenes

请添加图片描述

3. Unity环境核心概念

3.1 环境构成要素

① 场景 Scene :包含Agent及其交互对象的虚拟空间
② 游戏对象 GameObject :所有实体的基础容器,通过Inspector窗口可查看组件构成
③ 组件系统:通过添加不同组件(如Transform、Rigidbody等)定义对象行为

3.2 多 Agent 设计

① 3D Balance Ball 场景包含 12 个相同 Agent 立方体
② 并行训练优势:
数据采集效率提升 12 倍
增强策略泛化能力
加速收敛(单个 Agent 平均500万步→12 个 Agen t约42万步)

4. Agent 核心概念

4.1 角色定义

① Agent是环境中执行观察和动作的实体。

② 在 3D Balance Ball 环境中,12个名为 Agent 的 GameObject 上挂载了 Agent 组件。

4.2 关键属性

Behavior Parameters(行为参数):
每个Agent必须配置,决定Agent如何做出决策。

Max Step(最大步数):
定义Agent单次训练周期(episode)的最大步数(3D Balance Ball中为5000步),超时后环境重置。

请添加图片描述

4.3 观察空间(Vector Observation Space)

空间维度: 8维向量,即包含以下 8 项信息。

包含的观测信息:

Agent立方体的旋转分量(x, z轴),共2项。

球体相对于Agent的位置(x, y, z坐标),共 3 项。

球体相对于Agent的速度(x, y, z分量),共 3 项。

4.4 动作空间(Actions)

① 动作类型:连续型动作(Continuous Actions)。

② 动作维度:2 维向量。

③ 动作功能:控制 Agent 绕 x ,z 轴旋转,以保持球体平衡。

请添加图片描述

5. 理解 Ball3DAgent.cs 和 Ball3DHardAgent.cs 代码(已添加注释):

在Unity Editor 的 Projects 窗口找到 Assets\Example\3DBall\Scripts,可看到这两个代码,双击打开可编辑。
※ 以下添加注释的版本仅供大家学习理解,正式运行如报错,可采用官方版本。

请添加图片描述

5.1 Ball3DAgent.cs (已添加注释):

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors;
using Random = UnityEngine.Random;

public class Ball3DAgent : Agent
{
    [Header("Specific to Ball3D")] //[Header("3D平衡球专用参数")]
    public GameObject ball;  // 被平衡的球体对象
    [Tooltip("Whether to use vector observation. This option should be checked " +
    "in 3DBall scene, and unchecked in Visual3DBall scene. ")]//[Tooltip("是否使用向量观测。在3DBall场景中应勾选,在Visual3DBall场景中取消勾选")]
    public bool useVecObs;    // 切换观测模式(向量/视觉)
    
    Rigidbody m_BallRb;       // 球体的刚体组件
    EnvironmentParameters m_ResetParams;  // 环境重置参数

    // 初始化方法
    public override void Initialize()
    {
        m_BallRb = ball.GetComponent<Rigidbody>();  // 获取球体刚体
        m_ResetParams = Academy.Instance.EnvironmentParameters;  // 连接学院参数
        SetResetParameters();  // 初始参数设置
    }

    // 收集观测数据(Agent的"眼睛")
    public override void CollectObservations(VectorSensor sensor)
    {
        if (useVecObs)
        {
            // 观测值1:Agent自身Z轴旋转角度(范围[-1,1])
            sensor.AddObservation(gameObject.transform.rotation.z);
            
            // 观测值2:Agent自身X轴旋转角度(范围[-1,1])
            sensor.AddObservation(gameObject.transform.rotation.x);
            
            // 观测值3-5:球体相对于Agent的位置(三维向量)
            sensor.AddObservation(ball.transform.position - gameObject.transform.position);
            
            // 观测值6-8:球体的速度(三维向量)
            sensor.AddObservation(m_BallRb.velocity);
        }
    }

    // 执行动作并计算奖励(Agent的"大脑决策")
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // 解析连续动作(范围[-2, 2])
        var actionZ = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[0], -1f, 1f);
        var actionX = 2f * Mathf.Clamp(actionBuffers.ContinuousActions[1], -1f, 1f);

        // Z轴旋转控制(限制在±25度范围内)
        if ((gameObject.transform.rotation.z < 0.25f && actionZ > 0f) ||
            (gameObject.transform.rotation.z > -0.25f && actionZ < 0f))
        {
            gameObject.transform.Rotate(new Vector3(0, 0, 1), actionZ);
        }

        // X轴旋转控制(限制在±25度范围内)
        if ((gameObject.transform.rotation.x < 0.25f && actionX > 0f) ||
            (gameObject.transform.rotation.x > -0.25f && actionX < 0f))
        {
            gameObject.transform.Rotate(new Vector3(1, 0, 0), actionX);
        }

        // 失败条件检测:
        // 1. 球体低于Agent 2个单位
        // 2. 球体在X/Z方向偏移超过3个单位
        if ((ball.transform.position.y - gameObject.transform.position.y) < -2f ||
            Mathf.Abs(ball.transform.position.x - gameObject.transform.position.x) > 3f ||
            Mathf.Abs(ball.transform.position.z - gameObject.transform.position.z) > 3f)
        {
            SetReward(-1f);  // 惩罚:给予-1奖励
            EndEpisode();    // 结束当前训练回合
        }
        else
        {
            SetReward(0.1f); // 持续奖励:每步+0.1奖励
        }
    }

    // 回合开始时重置环境
    public override void OnEpisodeBegin()
    {
        // 重置Agent旋转(添加随机初始倾斜)
        gameObject.transform.rotation = Quaternion.identity;
        gameObject.transform.Rotate(new Vector3(1, 0, 0), Random.Range(-10f, 10f));
        gameObject.transform.Rotate(new Vector3(0, 0, 1), Random.Range(-10f, 10f));
        
        // 重置球体状态
        m_BallRb.velocity = Vector3.zero;  // 清除速度
        ball.transform.position = new Vector3(Random.Range(-1.5f, 1.5f), 4f, Random.Range(-1.5f, 1.5f)) 
            + gameObject.transform.position;  // 随机初始位置
        
        SetResetParameters();  // 应用环境参数重置
    }

    // 人工控制模式(用于测试)
    public override void Heuristic(in ActionBuffers actionsOut)
    {
        var continuousActionsOut = actionsOut.ContinuousActions;
        // 键盘映射:左右箭头控制Z轴,上下箭头控制X轴
        continuousActionsOut[0] = -Input.GetAxis("Horizontal");  // 左右→Z轴旋转
        continuousActionsOut[1] = Input.GetAxis("Vertical");     // 上下→X轴旋转
    }

    // 设置球体物理属性
    void SetBall()
    {
        // 从学院参数获取质量(默认1.0)
        m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);
        
        // 获取缩放比例(默认1.0)
        var scale = m_ResetParams.GetWithDefault("scale", 1.0f);
        ball.transform.localScale = new Vector3(scale, scale, scale);
    }

    // 重置环境参数
    void SetResetParameters()
    {
        SetBall();  // 每次重置时更新球体属性
    }
}

5.2 Ball3DHardAgent.cs (已添加注释):

using UnityEngine;
using Unity.MLAgents;
using Unity.MLAgents.Actuators;
using Unity.MLAgents.Sensors.Reflection;

// 3D平衡球困难模式智能体控制脚本
// 核心功能:通过旋转平台保持球体平衡,球体掉落或偏移时重置训练周期
public class Ball3DHardAgent : Agent
{
    [Header("平衡球配置")]
    public GameObject ball;          // 需要平衡的球体对象
    private Rigidbody m_BallRb;       // 球体的刚体物理组件
    private EnvironmentParameters m_ResetParams;  // 环境动态参数控制器

    // 初始化智能体(ML-Agents框架生命周期方法)
    public override void Initialize()
    {
        m_BallRb = ball.GetComponent<Rigidbody>();
        m_ResetParams = Academy.Instance.EnvironmentParameters;
        SetResetParameters();  // 应用环境参数初始化
    }

    /* 观测空间定义
    -----------------------------------------------------------
    观测值1:平台旋转角度(Z轴和X轴弧度值)
    numStackedObservations=9 表示堆叠最近9帧数据*/
    [Observable(numStackedObservations: 9)]
    Vector2 Rotation => new Vector2(
        transform.rotation.z,   // Z轴旋转(左右倾斜) 
        transform.rotation.x     // X轴旋转(前后倾斜)
    );

    // 观测值2:球体相对于平台的坐标偏移量(XYZ三维向量)
    [Observable(numStackedObservations: 9)]
    Vector3 PositionDelta => ball.transform.position - transform.position;

    // 动作处理核心方法(ML-Agents每物理帧调用)
    public override void OnActionReceived(ActionBuffers actionBuffers)
    {
        // 解析连续动作空间的两个维度(范围[-1,1])
        var actions = actionBuffers.ContinuousActions;
        float rotateZ = 2f * Mathf.Clamp(actions[0], -1f, 1f);  // 放大到[-2°,2°]
        float rotateX = 2f * Mathf.Clamp(actions[1], -1f, 1f);

        // 安全旋转控制(Z轴左右倾斜)
        if ((transform.rotation.z < 0.25f && rotateZ > 0f) || 
           (transform.rotation.z > -0.25f && rotateZ < 0f))
        {
            transform.Rotate(0, 0, 1, rotateZ);  // 绕Z轴旋转
        }

        // 安全旋转控制(X轴前后倾斜)
        if ((transform.rotation.x < 0.25f && rotateX > 0f) || 
           (transform.rotation.x > -0.25f && rotateX < 0f))
        {
            transform.Rotate(1, 0, 0, rotateX);  // 绕X轴旋转
        }

        /* 终止条件判断
        ---------------------------------------------------
        球体位置检测:下落超过2米或水平偏移超过3米*/
        Vector3 ballPos = ball.transform.position;
        bool isFalling = (ballPos.y - transform.position.y) < -2f;
        bool isOutOfRangeX = Mathf.Abs(ballPos.x - transform.position.x) > 3f;
        bool isOutOfRangeZ = Mathf.Abs(ballPos.z - transform.position.z) > 3f;

        if (isFalling || isOutOfRangeX || isOutOfRangeZ)
        {
            SetReward(-1f);   // 失败惩罚
            EndEpisode();     // 终止当前训练周期
        }
        else
        {
            SetReward(0.1f);  // 持续平衡奖励
        }
    }

    // 训练周期重置方法(每次重新开始时调用)
    public override void OnEpisodeBegin()
    {
        // 随机初始化平台倾斜角度(X/Z轴±10°范围内)
        transform.rotation = Quaternion.identity;
        transform.Rotate(Random.Range(-10f, 10f), 0, 0);  // X轴倾斜
        transform.Rotate(0, 0, Random.Range(-10f, 10f));  // Z轴倾斜

        // 重置球体状态
        m_BallRb.velocity = Vector3.zero;  // 清除运动惯性
        ball.transform.position = transform.position + new Vector3(
            Random.Range(-1.5f, 1.5f),  // 水平随机位置
            4f,                         // 初始高度4米
            Random.Range(-1.5f, 1.5f)
        );
    }

    // 动态调整球体物理参数(质量、尺寸)
    void SetBall()
    {
        m_BallRb.mass = m_ResetParams.GetWithDefault("mass", 1.0f);  // 质量参数
        float scale = m_ResetParams.GetWithDefault("scale", 1.0f);  // 尺寸参数
        ball.transform.localScale = Vector3.one * scale;
    }

    // 环境参数应用入口
    void SetResetParameters() => SetBall();
}

5.3 Ball3DAgent.cs 和 Ball3DHardAgent.cs 的区别 个人小总结,不一定准确

观测值收集方式:

Ball3DAgent: 使用传统手动添加观测的方式,通过 CollectObservations 方法显式收集。

sensor.AddObservation(gameObject.transform.rotation.z);
sensor.AddObservation(gameObject.transform.rotation.x);
// ...其他观测

通过 useVecObs 开关控制是否启用向量观测,适合需要动态切换观测模式的场景。

Ball3DHardAgent: 使用 反射传感器(Reflection Sensors),通过 [Observable] 属性标记自动收集观测。

[Observable(numStackedObservations: 9)]
Vector2 Rotation { get { /* 平台旋转 */ } }

[Observable(numStackedObservations: 9)]
Vector3 PositionDelta { get { /* 球位置差 */ } }

numStackedObservations:9 表示堆叠过去 9 帧的观测数据,为模型提供时序信息。

观测数据类型:

Ball3DAgent:
包含 4 个观测值:平台旋转(z/x)、球相对位置、球速度。

Ball3DHardAgent:
仅包含 2 个观测属性(但通过堆叠扩展为 9 帧历史)。

手动控制支持:

Ball3DAgent:
实现 Heuristic 方法,支持通过键盘输入(Horizontal/Vertical 轴)手动控制。

Ball3DHardAgent:
未实现 Heuristic 方法,无法直接通过键盘控制,仅适用于自动训练。

总结:

Ball3DAgent:
更灵活,支持手动控制、动态参数重置和观测模式切换,适合基础训练和调试。

Ball3DHardAgent:
通过堆叠时序观测简化了代码,但可能因缺少速度和动态参数重置增加训练难度,适合复杂环境下的时序依赖性训练。

6. 后续训练过程详见之前发布的博客:

https://blog.csdn.net/CDSN985144132/article/details/145779969

7. 训练超参数位置:

① ml-agents的存放根目录\ml-agents\config\ppo
② 找到对应的 yaml 文件位置,开始训练。

请添加图片描述

Logo

分享前沿Unity技术干货和开发经验,精彩的Unity活动和社区相关信息

更多推荐