90 lines
2.3 KiB
Python
90 lines
2.3 KiB
Python
import argparse
|
|
import pickle
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from gymnasium.envs.registration import register
|
|
from imitation.algorithms.adversarial.gail import GAIL
|
|
from imitation.data.wrappers import RolloutInfoWrapper
|
|
from imitation.rewards.reward_nets import BasicRewardNet
|
|
from imitation.util.networks import RunningNorm
|
|
from imitation.util.util import make_vec_env
|
|
from stable_baselines3 import PPO
|
|
from stable_baselines3.common.evaluation import evaluate_policy
|
|
from stable_baselines3.ppo import MlpPolicy
|
|
|
|
# Check this out https://imitation.readthedocs.io/en/latest/algorithms/bc.html
|
|
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("-d", "--dataset", type=str, required=True)
|
|
args = parser.parse_args()
|
|
|
|
|
|
dataset = pickle.load(open(args.dataset, "rb"))
|
|
|
|
register(id="BDX_env", entry_point="env:BDXEnv")
|
|
|
|
SEED = 42
|
|
rng = np.random.default_rng(SEED)
|
|
# env = gym.make("BDX_env", render_mode=None)
|
|
env = make_vec_env(
|
|
"BDX_env",
|
|
rng=rng,
|
|
n_envs=8,
|
|
post_wrappers=[lambda env, _: RolloutInfoWrapper(env)], # to compute rollouts
|
|
)
|
|
|
|
|
|
learner = PPO(
|
|
env=env,
|
|
policy=MlpPolicy,
|
|
batch_size=64,
|
|
ent_coef=0.0,
|
|
learning_rate=0.0004,
|
|
gamma=0.95,
|
|
n_epochs=5,
|
|
seed=SEED,
|
|
tensorboard_log="logs",
|
|
)
|
|
reward_net = BasicRewardNet(
|
|
observation_space=env.observation_space,
|
|
action_space=env.action_space,
|
|
normalize_input_layer=RunningNorm,
|
|
)
|
|
gail_trainer = GAIL(
|
|
demonstrations=dataset,
|
|
demo_batch_size=1024,
|
|
gen_replay_buffer_capacity=512,
|
|
n_disc_updates_per_round=8,
|
|
venv=env,
|
|
gen_algo=learner,
|
|
reward_net=reward_net,
|
|
allow_variable_horizon=True,
|
|
)
|
|
|
|
print("evaluate the learner before training")
|
|
env.seed(SEED)
|
|
learner_rewards_before_training, _ = evaluate_policy(
|
|
learner,
|
|
env,
|
|
100,
|
|
return_episode_rewards=True,
|
|
)
|
|
|
|
print("train the learner and evaluate again")
|
|
gail_trainer.train(500000) # Train for 800_000 steps to match expert.
|
|
|
|
env.seed(SEED)
|
|
learner_rewards_after_training, _ = evaluate_policy(
|
|
learner,
|
|
env,
|
|
100,
|
|
return_episode_rewards=True,
|
|
)
|
|
|
|
print("mean episode reward before training:", np.mean(learner_rewards_before_training))
|
|
print("mean episode reward after training:", np.mean(learner_rewards_after_training))
|
|
|
|
print("Save the policy")
|
|
gail_trainer.policy.save("gail_policy_ppo.zip")
|