diff --git a/source/FLEXR_v0/FLEXR_v0/robots/FLEXR_v0.py b/source/FLEXR_v0/FLEXR_v0/robots/FLEXR_v0.py index 6ba3478..1dd8c19 100644 --- a/source/FLEXR_v0/FLEXR_v0/robots/FLEXR_v0.py +++ b/source/FLEXR_v0/FLEXR_v0/robots/FLEXR_v0.py @@ -7,22 +7,29 @@ import math FLEXR_CONFIG = ArticulationCfg( spawn=sim_utils.UsdFileCfg( - usd_path="/home/hexone/Documents/ftr.usd", + usd_path="/home/hexone/Workplace/ws_ftr/asset/FLEXR_v0.usd", ), init_state=ArticulationCfg.InitialStateCfg( + # joint_pos={ + # "arm_FL": math.radians(45.0), + # "arm_FR": math.radians(-45.0), + # "arm_BL": math.radians(-45.0), + # "arm_BR": math.radians(45.0), + # }, joint_pos={ - "arm_FL": math.radians(45.0), - "arm_FR": math.radians(45.0), - "arm_BL": math.radians(45.0), - "arm_BR": math.radians(45.0), + "arm_FL": 0.0, + "arm_FR": 0.0, + "arm_BL": 0.0, + "arm_BR": 0.0, }, + pos=(0.0, 0.0, 0.25), ), actuators={ # 摆臂执行器组 "arm_acts": ImplicitActuatorCfg( joint_names_expr=["arm_(FL|FR|BL|BR)"], effort_limit_sim=100.0, # 最大扭矩 (N·m) - velocity_limit_sim=100.0, # 最大速度 (rad/s) + velocity_limit_sim=3.0, # 最大速度 (rad/s) stiffness=800.0, # 位置控制的刚度 damping=5.0, # 位置控制的阻尼 friction=0.1 # 可选:关节摩擦 diff --git a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rl_games_ppo_cfg.yaml b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rl_games_ppo_cfg.yaml index 71216e6..c3844a1 100644 --- a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rl_games_ppo_cfg.yaml +++ b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rl_games_ppo_cfg.yaml @@ -30,7 +30,7 @@ params: val: 0 fixed_sigma: True mlp: - units: [32, 32] + units: [32, 32, 32, 32] activation: elu d2rl: False @@ -43,11 +43,11 @@ params: load_path: '' # path to the checkpoint to load config: - name: cartpole_direct + name: flexr_v0_direct env_name: rlgpu device: 'cuda:0' device_name: 'cuda:0' - multi_gpu: False + multi_gpu: True ppo: True mixed_precision: False normalize_input: True diff --git a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rsl_rl_ppo_cfg.py b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rsl_rl_ppo_cfg.py index 24a741d..133d552 100644 --- a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rsl_rl_ppo_cfg.py +++ b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/agents/rsl_rl_ppo_cfg.py @@ -11,14 +11,16 @@ from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg, RslRlPpoActorCriticCfg, R @configclass class PPORunnerCfg(RslRlOnPolicyRunnerCfg): num_steps_per_env = 16 - max_iterations = 150 + max_iterations = 1500 save_interval = 50 - experiment_name = "cartpole_direct" + experiment_name = "flexr_v0_direct" empirical_normalization = False policy = RslRlPpoActorCriticCfg( init_noise_std=1.0, - actor_hidden_dims=[32, 32], - critic_hidden_dims=[32, 32], + # actor_hidden_dims=[256, 128, 64, 32], + # critic_hidden_dims=[256, 128, 64, 32], + actor_hidden_dims=[32, 32, 32], + critic_hidden_dims=[32, 32, 32], activation="elu", ) algorithm = RslRlPpoAlgorithmCfg( diff --git a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env.py b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env.py index 4149192..f90c7e6 100644 --- a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env.py +++ b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env.py @@ -13,15 +13,48 @@ import isaaclab.sim as sim_utils from isaaclab.assets import Articulation from isaaclab.envs import DirectRLEnv from isaaclab.sim.spawners.from_files import GroundPlaneCfg, spawn_ground_plane -from isaaclab.utils.math import sample_uniform - +from isaaclab.utils.math import sample_uniform, euler_xyz_from_quat +from isaaclab.sensors import RayCaster, RayCasterCfg, patterns +from isaaclab.markers import VisualizationMarkers, VisualizationMarkersCfg +from isaaclab.utils.assets import ISAAC_NUCLEUS_DIR from .flexr_v0_env_cfg import FlexrV0EnvCfg +# 日志配置 import logging -# Set the default log level to WARNING -logging.basicConfig(level=logging.WARNING) -# Replace logging.debug statements with debug-level log logging.debugs -logging.debug("Your log message here") +logging.basicConfig(level=logging.DEBUG) + +# 箭头标记 +def define_markers() -> VisualizationMarkers: + """Define markers with various different shapes.""" + marker_cfg = VisualizationMarkersCfg( + prim_path="/Visuals/myMarkers", + markers={ + "forward": sim_utils.UsdFileCfg( + usd_path=f"{ISAAC_NUCLEUS_DIR}/Props/UIElements/arrow_x.usd", + scale=(0.25, 0.25, 0.5), + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(0.0, 1.0, 1.0)), + ), + "command": sim_utils.UsdFileCfg( + usd_path=f"{ISAAC_NUCLEUS_DIR}/Props/UIElements/arrow_x.usd", + scale=(0.25, 0.25, 0.5), + visual_material=sim_utils.PreviewSurfaceCfg(diffuse_color=(1.0, 0.0, 0.0)), + ), + }, + ) + return VisualizationMarkers(cfg=marker_cfg) + +# TODO 注意传感器父框架的配置 +def define_height_sensor() -> RayCaster: + height_sensor_cfg = RayCasterCfg( + prim_path="/World/envs/env_.*/Robot/body", # 明确路径 + update_period=0.02, + offset=RayCasterCfg.OffsetCfg(pos=(0.0, 0.0, 1.0)), # 更合理的初始高度 + attach_yaw_only=True, + pattern_cfg=patterns.GridPatternCfg(resolution=0.1, size=[1.0, 1.0]), # type: ignore + debug_vis=True, + mesh_prim_paths=["/World/ground"], # 确认的实际地面路径 + ) + return RayCaster(cfg=height_sensor_cfg) class FlexrV0Env(DirectRLEnv): cfg: FlexrV0EnvCfg @@ -64,6 +97,12 @@ class FlexrV0Env(DirectRLEnv): # 初始化状态变量 self.joint_pos = self.robot.data.joint_pos self.joint_vel = self.robot.data.joint_vel + + # 中间变量 + # 角速度 + self._last_root_ang_vel = torch.zeros_like(self.robot.data.root_ang_vel_w) + # 角度 + self.orientations = torch.zeros_like(self.robot.data.root_quat_w) def _get_wheel_joint_indices(self, prefix: str, joint_names: list[str]) -> list[int]: """获取指定腿部的所有轮子关节索引(返回整数列表)""" @@ -86,10 +125,21 @@ class FlexrV0Env(DirectRLEnv): light_cfg = sim_utils.DomeLightCfg(intensity=2000.0, color=(0.75, 0.75, 0.75)) light_cfg.func("/World/Light", light_cfg) + # self.visualization_markers = define_markers() + + # add height_scaner + self.height_sensor = define_height_sensor() + + + def _pre_physics_step(self, actions: torch.Tensor) -> None: self.actions = actions.clone() + self.orientations = self.robot.data.root_quat_w def _apply_action(self) -> None: + + # self._debug_print_idx([0]) + # 确保动作有正确形状 [num_envs, 8] if self.actions.dim() == 1: self.actions = self.actions.unsqueeze(0) @@ -102,11 +152,14 @@ class FlexrV0Env(DirectRLEnv): arm_pos_target = torch.zeros_like(self.robot.data.joint_pos[:, self._arm_joint_indices]) arm_pos_target.copy_(arm_actions) - logging.debug(f"Total joints: {len(self.robot.joint_names)}") - logging.debug(f"Arm joint indices: {self._arm_joint_indices}") - logging.debug(f"Wheel joint counts: { {k:len(v) for k,v in self._actuator_joint_indices.items() if k.startswith('wheel')} }") - logging.debug(f"Initial arm positions: {self.robot.data.joint_pos[0, self._arm_joint_indices]}") - logging.debug(f"Arm actions: {arm_actions[0]}") + # logging.debug(f"Total joints: {len(self.robot.joint_names)}") + # logging.debug(f"Arm joint indices: {self._arm_joint_indices}") + # logging.debug(f"Wheel joint counts: { {k:len(v) for k,v in self._actuator_joint_indices.items() if k.startswith('wheel')} }") + # logging.debug(f"Initial arm positions: {self.robot.data.joint_pos[0, self._arm_joint_indices]}") + # logging.debug(f"Arm actions: {arm_actions[0]}") + + # # 暂时对摆臂设置成默认位置,确认轮的作用 + # arm_pos_target = self.robot.data.default_joint_pos[:, self._arm_joint_indices] # 设置目标位置(只针对arm关节) self.robot.set_joint_position_target( @@ -122,15 +175,19 @@ class FlexrV0Env(DirectRLEnv): joint_indices = self._actuator_joint_indices[group] # 全局索引 wheel_vel_target[:, joint_indices] = wheel_actions[:, i].unsqueeze(-1) - self.robot.set_joint_velocity_target( - wheel_vel_target[:, self._all_wheel_joint_indices], # 只选择轮组部分 - joint_ids=self._all_wheel_joint_indices - ) + # print(f"wheel_actions: {wheel_actions}") + # print(f"actions: {self.actions}") + # print(f"Arm Pos Target: {arm_pos_target}") + # print(f"Wheel Vel Target: {wheel_vel_target}") + + # self.robot.set_joint_velocity_target( + # wheel_vel_target[:, self._all_wheel_joint_indices], # 只选择轮组部分 + # joint_ids=self._all_wheel_joint_indices + # ) def _get_observations(self) -> dict: - # 获取摆臂关节的位置和速度 - arm_pos = self.joint_pos[:, self._arm_joint_indices] # [num_envs, 4] + # 获取摆臂关节速度 arm_vel = self.joint_vel[:, self._arm_joint_indices] # [num_envs, 4] # 计算每个轮组的平均速度 @@ -138,24 +195,34 @@ class FlexrV0Env(DirectRLEnv): wheel_FR_vel = self.joint_vel[:, self._actuator_joint_indices["wheel_FR"]].mean(dim=1, keepdim=True) wheel_BL_vel = self.joint_vel[:, self._actuator_joint_indices["wheel_BL"]].mean(dim=1, keepdim=True) wheel_BR_vel = self.joint_vel[:, self._actuator_joint_indices["wheel_BR"]].mean(dim=1, keepdim=True) - - # 组合轮组速度 wheel_vel = torch.cat([wheel_FL_vel, wheel_FR_vel, wheel_BL_vel, wheel_BR_vel], dim=1) # [num_envs, 4] - # 特权信息 - 暂不加入 - base_pos = self.robot.data.root_pos_w - base_quat = self.robot.data.root_quat_w - base_lin_vel = self.robot.data.root_lin_vel_w - base_ang_vel = self.robot.data.root_ang_vel_w - - # - + # 获取车体姿态和运动信息 + base_quat = self.robot.data.root_quat_w # [num_envs, 4] (w,x,y,z) + base_ang_vel = self.robot.data.root_ang_vel_w # [num_envs, 3] (x,y,z) + + # 将四元数转换为欧拉角(使用Isaac Lab的函数) + roll, pitch, _ = euler_xyz_from_quat(base_quat) # 返回弧度值 + + # # 归一化角度到[-1,1]范围(假设±π/2是最大范围) + # norm_pitch = pitch / (math.pi/2) # 归一化到[-1,1] + # norm_roll = roll / (math.pi/2) # 归一化到[-1,1] + norm_pitch = pitch / (math.pi) # 归一化到[-1,1] + norm_roll = roll / (math.pi) # 归一化到[-1,1] + + # 添加噪声(可选) + noise_std = 0.01 # 噪声标准差 + norm_pitch += torch.randn_like(norm_pitch) * noise_std + norm_roll += torch.randn_like(norm_roll) * noise_std + # 组合所有观测 obs = torch.cat( ( - arm_pos, # 摆臂位置 [num_envs, 4] - arm_vel, # 摆臂速度 [num_envs, 4] - wheel_vel, # 轮组平均速度 [num_envs, 4] + arm_vel, # 摆臂速度 [num_envs, 4] + wheel_vel, # 轮组平均速度 [num_envs, 4] + norm_pitch.unsqueeze(-1), # 归一化pitch [num_envs, 1] + norm_roll.unsqueeze(-1), # 归一化roll [num_envs, 1] + base_ang_vel # 三轴角速度 [num_envs, 3] ), dim=-1, ) @@ -164,31 +231,71 @@ class FlexrV0Env(DirectRLEnv): return observations def _get_rewards(self) -> torch.Tensor: - # total_reward = compute_rewards( - # self.cfg.rew_scale_alive, - # self.cfg.rew_scale_terminated, - # self.cfg.rew_scale_pole_pos, - # self.cfg.rew_scale_cart_vel, - # self.cfg.rew_scale_pole_vel, - # self.joint_pos[:, self._pole_dof_idx[0]], - # self.joint_vel[:, self._pole_dof_idx[0]], - # self.joint_pos[:, self._cart_dof_idx[0]], - # self.joint_vel[:, self._cart_dof_idx[0]], - # self.reset_terminated, - # ) - # return total_reward - # 返回一个全为0的float张量 - return torch.zeros_like(self.reset_terminated).float() + # 获取当前角速度 + current_ang_vel = self.robot.data.root_ang_vel_w + + # 计算角加速度 (当前角速度 - 上一帧角速度) / dt + root_ang_acc = (current_ang_vel - self._last_root_ang_vel) / self.step_dt + + # 更新上一帧角速度 + self._last_root_ang_vel[:] = current_ang_vel + + # 目标前进速度 (可以根据需要调整) + target_velocity = torch.ones(self.num_envs, 1, device=self.device) * self.cfg.target_velocity + + # 计算奖励 + total_reward = compute_rewards( + root_ang_acc=root_ang_acc, + last_root_ang_vel=self._last_root_ang_vel, + dt=self.step_dt, + rew_scale_smoothness=self.cfg.rew_scale_smoothness, + rew_scale_alive=self.cfg.rew_scale_alive, + rew_scale_velocity=self.cfg.rew_scale_velocity, + base_lin_vel=self.robot.data.root_lin_vel_w, + target_velocity=target_velocity, + terminated=self.reset_terminated + ) + + # logging.info(f"Computed rewards mean: {total_reward.mean().item()}") + + return total_reward + + # def _get_dones(self) -> tuple[torch.Tensor, torch.Tensor]: + # self.joint_pos = self.robot.data.joint_pos + # self.joint_vel = self.robot.data.joint_vel + + # # time_out = self.episode_length_buf >= self.max_episode_length - 1 + # # out_of_bounds = torch.any(torch.abs(self.joint_pos[:, self._cart_dof_idx]) > self.cfg.max_cart_pos, dim=1) + # # out_of_bounds = out_of_bounds | torch.any(torch.abs(self.joint_pos[:, self._pole_dof_idx]) > math.pi / 2, dim=1) + # # return out_of_bounds, time_out + # return torch.zeros_like(self.reset_terminated), torch.zeros_like(self.reset_terminated) def _get_dones(self) -> tuple[torch.Tensor, torch.Tensor]: - self.joint_pos = self.robot.data.joint_pos - self.joint_vel = self.robot.data.joint_vel + + # 初始化终止标志 + self.reset_terminated = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + self.reset_time_outs = torch.zeros(self.num_envs, dtype=torch.bool, device=self.device) + + # 1. 检查是否超时 + self.reset_time_outs = self.episode_length_buf >= self.max_episode_length + + # 2. 检查机器人是否倾覆 + flipped = self._check_robot_flipped() + self.reset_terminated = flipped + + # # 调试打印高程data + # self.height_sensor._update_outdated_buffers() + # print(f"Height map [{self.height_sensor.data.ray_hits_w.shape}]: {self.height_sensor.data.ray_hits_w}") + # ray_hits_w.shape: [num_envs, 121, 3] 不清楚为什么为 121 - # time_out = self.episode_length_buf >= self.max_episode_length - 1 - # out_of_bounds = torch.any(torch.abs(self.joint_pos[:, self._cart_dof_idx]) > self.cfg.max_cart_pos, dim=1) - # out_of_bounds = out_of_bounds | torch.any(torch.abs(self.joint_pos[:, self._pole_dof_idx]) > math.pi / 2, dim=1) - # return out_of_bounds, time_out - return torch.zeros_like(self.reset_terminated), torch.zeros_like(self.reset_terminated) + # # 3. 检查是否超出地形边界(可选) + # lower = torch.tensor(self.terrain_cfg.map.lower[:2], device=self.device) + # upper = torch.tensor(self.terrain_cfg.map.upper[:2], device=self.device) + # positions_xy = self.positions[:, :2] + # out_of_bounds = ~torch.all((positions_xy > lower) & (positions_xy < upper), dim=1) + # self.reset_terminated = self.reset_terminated | out_of_bounds + + return self.reset_terminated, self.reset_time_outs def _reset_idx(self, env_ids: Sequence[int] | None): if env_ids is None: @@ -196,12 +303,6 @@ class FlexrV0Env(DirectRLEnv): super()._reset_idx(env_ids) # type: ignore joint_pos = self.robot.data.default_joint_pos[env_ids] - # joint_pos[:, self._pole_dof_idx] += sample_uniform( - # self.cfg.initial_pole_angle_range[0] * math.pi, - # self.cfg.initial_pole_angle_range[1] * math.pi, - # joint_pos[:, self._pole_dof_idx].shape, - # joint_pos.device, - # ) joint_vel = self.robot.data.default_joint_vel[env_ids] default_root_state = self.robot.data.default_root_state[env_ids] @@ -214,24 +315,95 @@ class FlexrV0Env(DirectRLEnv): self.robot.write_root_velocity_to_sim(default_root_state[:, 7:], env_ids) self.robot.write_joint_state_to_sim(joint_pos, joint_vel, None, env_ids) + self.height_sensor.reset(env_ids) + + + def _check_robot_flipped(self) -> torch.Tensor: + """改进的倾覆检测方法,正确处理0/360度问题""" + # 获取欧拉角 (弧度) + roll, pitch, _ = euler_xyz_from_quat(self.orientations) + + # 将角度归一化到[-180, 180]度范围 + pitch_deg = torch.rad2deg(pitch) + pitch_deg = torch.where(pitch_deg > 180, pitch_deg - 360, pitch_deg) + pitch_deg = torch.where(pitch_deg < -180, pitch_deg + 360, pitch_deg) + + roll_deg = torch.rad2deg(roll) + roll_deg = torch.where(roll_deg > 180, roll_deg - 360, roll_deg) + roll_deg = torch.where(roll_deg < -180, roll_deg + 360, roll_deg) + + # 计算绝对倾斜角度(不考虑方向) + pitch_abs = torch.abs(pitch_deg) + roll_abs = torch.abs(roll_deg) + + # 检查是否超过阈值 + flipped = (pitch_abs > self.cfg.pitch_threshold) | (roll_abs > self.cfg.roll_threshold) + + if torch.any(flipped): + print(f"Robot flipped at idx: {torch.nonzero(flipped)} [x: {roll_deg[flipped]}, y: {pitch_deg[flipped]}]") + + return flipped + + # def _get_velocities(self, env_ids = 0) -> tuple[torch.Tensor, torch.Tensor]: + # return self.joint_vel[env_ids, :3], self.joint_vel[env_ids, 3:] + + # def _get_positions(self, env_ids = 0) -> tuple[torch.Tensor, torch.Tensor]: + # return self.joint_pos[env_ids, :3], self.joint_pos[env_ids, 3:] + + # def _get_orientations(self, env_ids = 0) -> tuple[torch.Tensor, torch.Tensor]: + # return self.orientations[env_ids, :3], self.orientations[env_ids, 3:] + + def _debug_print_idx(self, env_ids : Sequence[int]): + logging.debug(f"env_ids: {env_ids}") + # logging.debug(f"joint_pos: {self.joint_pos[env_ids]}") + # logging.debug(f"joint_vel: {self.joint_vel[env_ids]}") + # logging.debug(f"orientations: {self.orientations[env_ids]}") + logging.debug(f"euler_angles: {euler_xyz_from_quat(self.orientations[env_ids])}") + @torch.jit.script def compute_rewards( - rew_scale_alive: float, - rew_scale_terminated: float, - rew_scale_pole_pos: float, - rew_scale_cart_vel: float, - rew_scale_pole_vel: float, - pole_pos: torch.Tensor, - pole_vel: torch.Tensor, - cart_pos: torch.Tensor, - cart_vel: torch.Tensor, - reset_terminated: torch.Tensor, -): - rew_alive = rew_scale_alive * (1.0 - reset_terminated.float()) - rew_termination = rew_scale_terminated * reset_terminated.float() - rew_pole_pos = rew_scale_pole_pos * torch.sum(torch.square(pole_pos).unsqueeze(dim=1), dim=-1) - rew_cart_vel = rew_scale_cart_vel * torch.sum(torch.abs(cart_vel).unsqueeze(dim=1), dim=-1) - rew_pole_vel = rew_scale_pole_vel * torch.sum(torch.abs(pole_vel).unsqueeze(dim=1), dim=-1) - total_reward = rew_alive + rew_termination + rew_pole_pos + rew_cart_vel + rew_pole_vel + root_ang_acc: torch.Tensor, # 根部角加速度 [num_envs, 3] (roll, pitch, yaw) + last_root_ang_vel: torch.Tensor, # 上一帧的根部角速度 [num_envs, 3] + dt: float, # 时间步长 + rew_scale_smoothness: float, # 平滑性奖励系数 + rew_scale_alive: float, # 存活奖励系数 + rew_scale_velocity: float, # 速度奖励系数 + base_lin_vel: torch.Tensor, # 基础线速度 [num_envs, 3] + target_velocity: torch.Tensor, # 目标速度 [num_envs, 1] + terminated: torch.Tensor # 终止标志 [num_envs] +) -> torch.Tensor: + """ + 计算包含平顺性奖励的总奖励 + + 参数: + root_ang_acc: 当前角加速度 (通过当前角速度与上一帧角速度计算得到) + last_root_ang_vel: 上一帧的角速度 + dt: 时间步长 + rew_scale_smoothness: 平滑性奖励权重 + rew_scale_alive: 存活奖励权重 + rew_scale_velocity: 速度跟踪奖励权重 + base_lin_vel: 当前线速度 + target_velocity: 目标前进速度 + terminated: 是否终止的标志 + + 返回: + 总奖励值 + """ + # 计算平顺性奖励 - 惩罚pitch和roll的剧烈变化 + # 只取pitch和roll的加速度 (忽略yaw) + pitch_roll_acc = root_ang_acc[:, :2] # [num_envs, 2] + smoothness_penalty = torch.sum(torch.square(pitch_roll_acc), dim=1) # [num_envs] + smoothness_reward = -rew_scale_smoothness * smoothness_penalty * dt + # 速度跟踪奖励 (前进方向x轴) + velocity_error = torch.abs(base_lin_vel[:, 0] - target_velocity[:, 0]) # [num_envs] + velocity_reward = -rew_scale_velocity * velocity_error + # 存活奖励 + alive_reward = rew_scale_alive * (~terminated).float() + # 总奖励 + # total_reward = alive_reward + velocity_reward + smoothness_reward + # total_reward = alive_reward + smoothness_reward + # total_reward = velocity_reward + total_reward = smoothness_reward + return total_reward \ No newline at end of file diff --git a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env_cfg.py b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env_cfg.py index 3c54c20..ba0b335 100644 --- a/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env_cfg.py +++ b/source/FLEXR_v0/FLEXR_v0/tasks/direct/flexr_v0/flexr_v0_env_cfg.py @@ -12,6 +12,7 @@ from isaaclab.envs import DirectRLEnvCfg from isaaclab.scene import InteractiveSceneCfg from isaaclab.sim import SimulationCfg from isaaclab.utils import configclass +import math @configclass @@ -21,7 +22,7 @@ class FlexrV0EnvCfg(DirectRLEnvCfg): episode_length_s = 5.0 # - spaces definition action_space = 8 - observation_space = 12 + observation_space = 13 state_space = 0 # simulation @@ -31,7 +32,17 @@ class FlexrV0EnvCfg(DirectRLEnvCfg): robot_cfg: ArticulationCfg = FLEXR_CONFIG.replace(prim_path="/World/envs/env_.*/Robot") # type: ignore # scene - scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=128, env_spacing=4.0, replicate_physics=True) + scene: InteractiveSceneCfg = InteractiveSceneCfg(num_envs=4096, env_spacing=4.0, replicate_physics=True) + + # 自定义参数 + # 奖励系数 + rew_scale_smoothness: float = 0.1 # 平顺性奖励系数 + rew_scale_alive: float = 1.0 # 存活奖励系数 + rew_scale_velocity: float = 1.0 # 速度跟踪奖励系数 + # 目标前进速度 (m/s) + target_velocity: float = 4.0 # 默认目标速度 + pitch_threshold = 45.0 # 前后倾覆阈值(度) + roll_threshold = 45.0 # 左右倾覆阈值(度) # 待增加的自定义参数 # # custom parameters/scales