42 lines
1.1 KiB
Python
42 lines
1.1 KiB
Python
import argparse
|
|
import pickle
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium.envs.registration import register
|
|
from imitation.algorithms import bc
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
|
|
# Check this out https://imitation.readthedocs.io/en/latest/algorithms/bc.html
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-d", "--dataset", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
|
|
dataset = pickle.load(open(args.dataset, "rb"))
|
|
|
|
register(id="BDX_env", entry_point="env_humanoid:BDXEnv")
|
|
|
|
env = gym.make("BDX_env", render_mode=None)
|
|
|
|
rng = np.random.default_rng(0)
|
|
|
|
bc_trainer = bc.BC(
|
|
observation_space=env.observation_space,
|
|
action_space=env.action_space,
|
|
demonstrations=dataset,
|
|
rng=rng,
|
|
device="cpu",
|
|
policy=PPO(
|
|
"MlpPolicy", env, policy_kwargs=dict(net_arch=[400, 300])
|
|
).policy, # not working with SAC for some reason
|
|
)
|
|
bc_trainer.train(n_epochs=10)
|
|
|
|
bc_trainer.policy.save("bc_policy_ppo.zip")
|
|
|
|
# reward, _ = evaluate_policy(bc_trainer.policy, env, 1)
|
|
# print(reward)
|