2025-07-28 12:35:44 +08:00

84 lines
2.2 KiB
Python

import argparse
import os
from glob import glob
import gymnasium as gym
from gymnasium.envs.registration import register
from sb3_contrib import TQC
from stable_baselines3 import A2C, PPO, SAC, TD3
register(
id="BDX_env",
entry_point="env_humanoid:BDXEnv",
autoreset=True,
# max_episode_steps=200,
)
def test(env, sb3_algo, path_to_model):
if not path_to_model.endswith(".zip"):
models_paths = glob(path_to_model + "/*.zip")
latest_model_id = 0
latest_model_path = None
for model_path in models_paths:
model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
if int(model_id) > latest_model_id:
latest_model_id = int(model_id)
latest_model_path = model_path
if latest_model_path is None:
print("No models found in directory: ", path_to_model)
return
print("Using latest model: ", latest_model_path)
path_to_model = latest_model_path
match sb3_algo:
case "SAC":
model = SAC.load(path_to_model, env=env)
case "TD3":
model = TD3.load(path_to_model, env=env)
case "A2C":
model = A2C.load(path_to_model, env=env)
case "TQC":
model = TQC.load(path_to_model, env=env)
case "PPO":
model = PPO("MlpPolicy", env)
model.policy.load(path_to_model)
# model = PPO.load(path_to_model, env=env)
case _:
print("Algorithm not found")
return
obs = env.reset()[0]
done = False
extra_steps = 500
while True:
action, _ = model.predict(obs)
obs, _, done, _, _ = env.step(action)
if done:
extra_steps -= 1
if extra_steps < 0:
break
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Test model")
parser.add_argument(
"-p",
"--path",
metavar="path_to_model",
help="Path to the model. If directory, will use the latest model.",
)
parser.add_argument("-a", "--algo", default="SAC")
args = parser.parse_args()
gymenv = gym.make("BDX_env", render_mode="human")
test(gymenv, args.algo, path_to_model=args.path)