119 lines
3.0 KiB
Python
119 lines
3.0 KiB
Python
import argparse
|
|
import time
|
|
from glob import glob
|
|
|
|
import gymnasium as gym
|
|
import mujoco
|
|
import mujoco.viewer
|
|
import numpy as np
|
|
from gymnasium.envs.registration import register
|
|
from stable_baselines3 import PPO, SAC
|
|
|
|
from mini_bdx.utils.mujoco_utils import check_contact
|
|
|
|
|
|
def get_observation(data, left_contact, right_contact):
|
|
|
|
position = (
|
|
data.qpos.flat.copy()
|
|
) # position/rotation of trunk + position of all joints
|
|
velocity = (
|
|
data.qvel.flat.copy()
|
|
) # positional/angular velocity of trunk + of all joints
|
|
|
|
obs = np.concatenate(
|
|
[
|
|
position,
|
|
velocity,
|
|
[left_contact, right_contact],
|
|
]
|
|
)
|
|
# print("OBS SIZE", len(obs))
|
|
return obs
|
|
|
|
|
|
def key_callback(keycode):
|
|
pass
|
|
|
|
|
|
def get_model_from_dir(path_to_model):
|
|
|
|
if not path_to_model.endswith(".zip"):
|
|
models_paths = glob(path_to_model + "/*.zip")
|
|
latest_model_id = 0
|
|
latest_model_path = None
|
|
for model_path in models_paths:
|
|
model_id = model_path.split("/")[-1][: -len(".zip")].split("_")[-1]
|
|
if int(model_id) > latest_model_id:
|
|
latest_model_id = int(model_id)
|
|
latest_model_path = model_path
|
|
|
|
if latest_model_path is None:
|
|
print("No models found in directory: ", path_to_model)
|
|
return
|
|
else:
|
|
latest_model_path = path_to_model
|
|
|
|
return latest_model_path
|
|
|
|
|
|
def get_feet_contact(data, model):
|
|
right_contact = check_contact(data, model, "foot_module", "floor")
|
|
left_contact = check_contact(data, model, "foot_module_2", "floor")
|
|
return right_contact, left_contact
|
|
|
|
|
|
def play(env, path_to_model):
|
|
model_path = get_model_from_dir(path_to_model)
|
|
|
|
model = mujoco.MjModel.from_xml_path("../../mini_bdx/robots/bdx/scene.xml")
|
|
data = mujoco.MjData(model)
|
|
|
|
left_contact = False
|
|
right_contact = False
|
|
|
|
viewer = mujoco.viewer.launch_passive(model, data, key_callback=key_callback)
|
|
|
|
# nn_model = SAC.load(model_path, env)
|
|
|
|
nn_model = PPO("MlpPolicy", env)
|
|
nn_model.policy.load(model_path)
|
|
|
|
try:
|
|
while True:
|
|
|
|
right_contact, left_contact = get_feet_contact(data, model)
|
|
obs = get_observation(
|
|
data,
|
|
left_contact,
|
|
right_contact,
|
|
)
|
|
action, _ = nn_model.predict(obs)
|
|
data.ctrl[:] = action
|
|
|
|
mujoco.mj_step(model, data)
|
|
viewer.sync()
|
|
time.sleep(model.opt.timestep / 2.5)
|
|
|
|
except KeyboardInterrupt:
|
|
viewer.close()
|
|
|
|
viewer.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
parser = argparse.ArgumentParser(description="Test model")
|
|
parser.add_argument(
|
|
"-p",
|
|
"--path",
|
|
metavar="path_to_model",
|
|
help="Path to the model. If directory, will use the latest model.",
|
|
)
|
|
parser.add_argument("-a", "--algo", default="SAC")
|
|
args = parser.parse_args()
|
|
|
|
register(id="BDX_env", entry_point="env_humanoid:BDXEnv")
|
|
env = gym.make("BDX_env", render_mode=None)
|
|
play(env, path_to_model=args.path)
|