2025-07-28 12:35:44 +08:00

131 lines
3.7 KiB
Python

import argparse
import os
from glob import glob
import cv2
import FramesViewer.utils as fv_utils
import gymnasium as gym
import mujoco
import numpy as np
from gymnasium.envs.registration import register
from sb3_contrib import TQC
from stable_baselines3 import A2C, PPO, SAC, TD3
register(
id="BDX_env",
entry_point="footsteps_env:BDXEnv",
autoreset=True,
# max_episode_steps=200,
)
def draw_clock(clock):
# clock [a, b]
clock_radius = 100
im = np.zeros((clock_radius * 2, clock_radius * 2, 3), np.uint8)
im = cv2.circle(im, (clock_radius, clock_radius), clock_radius, (255, 255, 255), -1)
im = cv2.line(
im,
(clock_radius, clock_radius),
(
int(clock_radius + clock_radius * clock[0]),
int(clock_radius + clock_radius * clock[1]),
),
(0, 0, 255),
2,
)
cv2.imshow("clock", im)
cv2.waitKey(1)
def draw_frame(pose, i, env):
pose = fv_utils.rotateInSelf(pose, [0, 90, 0])
# env.mujoco_renderer._get_viewer(render_mode="human")
env.mujoco_renderer._get_viewer(render_mode="human").add_marker(
pos=pose[:3, 3],
mat=pose[:3, :3],
size=[0.005, 0.005, 0.1],
type=mujoco.mjtGeom.mjGEOM_ARROW,
rgba=[1, 0, 0, 1],
label=str(i),
)
def test(env, sb3_algo, path_to_model):
if not path_to_model.endswith(".zip"):
models_paths = glob(path_to_model + "/*.zip")
latest_model_id = 0
latest_model_path = None
for model_path in models_paths:
model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
if int(model_id) > latest_model_id:
latest_model_id = int(model_id)
latest_model_path = model_path
if latest_model_path is None:
print("No models found in directory: ", path_to_model)
return
print("Using latest model: ", latest_model_path)
path_to_model = latest_model_path
match sb3_algo:
case "SAC":
model = SAC.load(path_to_model, env=env)
case "TD3":
model = TD3.load(path_to_model, env=env)
case "A2C":
model = A2C.load(path_to_model, env=env)
case "TQC":
model = TQC.load(path_to_model, env=env)
case "PPO":
model = PPO.load(path_to_model, env=env)
case _:
print("Algorithm not found")
return
obs = env.reset()[0]
done = False
extra_steps = 500
while True:
action, _ = model.predict(obs)
obs, _, done, _, _ = env.step(action)
footsteps = env.next_footsteps
base_target_2D = np.mean(
[footsteps[2][:3, 3][:2], footsteps[3][:3, 3][:2]], axis=0
)
base_target_frame = np.eye(4)
base_target_frame[:3, 3][:2] = base_target_2D
draw_frame(base_target_frame, "base target", env)
base_pos_2D = env.data.body("base").xpos[:2]
base_pos_frame = np.eye(4)
base_pos_frame[:3, 3][:2] = base_pos_2D
draw_frame(base_pos_frame, "base pos", env)
# draw_clock(env.get_clock_signal())
for i, footstep in enumerate(footsteps[2:]):
draw_frame(footstep, i, env)
if done:
extra_steps -= 1
if extra_steps < 0:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test model")
parser.add_argument(
"-p",
"--path",
metavar="path_to_model",
help="Path to the model. If directory, will use the latest model.",
)
parser.add_argument("-a", "--algo", default="SAC")
args = parser.parse_args()
gymenv = gym.make("BDX_env", render_mode="human")
test(gymenv, args.algo, path_to_model=args.path)