84 lines
2.2 KiB
Python
84 lines
2.2 KiB
Python
|
import argparse
|
||
|
import os
|
||
|
from glob import glob
|
||
|
|
||
|
import gymnasium as gym
|
||
|
from gymnasium.envs.registration import register
|
||
|
from sb3_contrib import TQC
|
||
|
from stable_baselines3 import A2C, PPO, SAC, TD3
|
||
|
|
||
|
register(
|
||
|
id="BDX_env",
|
||
|
entry_point="env_humanoid:BDXEnv",
|
||
|
autoreset=True,
|
||
|
# max_episode_steps=200,
|
||
|
)
|
||
|
|
||
|
|
||
|
def test(env, sb3_algo, path_to_model):
|
||
|
|
||
|
if not path_to_model.endswith(".zip"):
|
||
|
models_paths = glob(path_to_model + "/*.zip")
|
||
|
latest_model_id = 0
|
||
|
latest_model_path = None
|
||
|
for model_path in models_paths:
|
||
|
model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
|
||
|
if int(model_id) > latest_model_id:
|
||
|
latest_model_id = int(model_id)
|
||
|
latest_model_path = model_path
|
||
|
|
||
|
if latest_model_path is None:
|
||
|
print("No models found in directory: ", path_to_model)
|
||
|
return
|
||
|
|
||
|
print("Using latest model: ", latest_model_path)
|
||
|
|
||
|
path_to_model = latest_model_path
|
||
|
|
||
|
match sb3_algo:
|
||
|
case "SAC":
|
||
|
model = SAC.load(path_to_model, env=env)
|
||
|
case "TD3":
|
||
|
model = TD3.load(path_to_model, env=env)
|
||
|
case "A2C":
|
||
|
model = A2C.load(path_to_model, env=env)
|
||
|
case "TQC":
|
||
|
model = TQC.load(path_to_model, env=env)
|
||
|
case "PPO":
|
||
|
model = PPO("MlpPolicy", env)
|
||
|
model.policy.load(path_to_model)
|
||
|
# model = PPO.load(path_to_model, env=env)
|
||
|
|
||
|
case _:
|
||
|
print("Algorithm not found")
|
||
|
return
|
||
|
|
||
|
obs = env.reset()[0]
|
||
|
done = False
|
||
|
extra_steps = 500
|
||
|
while True:
|
||
|
action, _ = model.predict(obs)
|
||
|
obs, _, done, _, _ = env.step(action)
|
||
|
|
||
|
if done:
|
||
|
extra_steps -= 1
|
||
|
|
||
|
if extra_steps < 0:
|
||
|
break
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
|
||
|
parser = argparse.ArgumentParser(description="Test model")
|
||
|
parser.add_argument(
|
||
|
"-p",
|
||
|
"--path",
|
||
|
metavar="path_to_model",
|
||
|
help="Path to the model. If directory, will use the latest model.",
|
||
|
)
|
||
|
parser.add_argument("-a", "--algo", default="SAC")
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
gymenv = gym.make("BDX_env", render_mode="human")
|
||
|
test(gymenv, args.algo, path_to_model=args.path)
|