65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
|
import argparse
|
||
|
import pickle
|
||
|
import pprint
|
||
|
|
||
|
import numpy as np
|
||
|
from gymnasium.envs.registration import register
|
||
|
from imitation.algorithms import density as db
|
||
|
from imitation.data import serialize
|
||
|
from imitation.util import util
|
||
|
from stable_baselines3 import PPO
|
||
|
from stable_baselines3.common.policies import ActorCriticPolicy
|
||
|
|
||
|
parser = argparse.ArgumentParser()
|
||
|
parser.add_argument("-d", "--dataset", type=str, required=True)
|
||
|
args = parser.parse_args()
|
||
|
|
||
|
rng = np.random.default_rng(0)
|
||
|
|
||
|
register(id="BDX_env", entry_point="env:BDXEnv")
|
||
|
env = util.make_vec_env("BDX_env", rng=rng, n_envs=2)
|
||
|
|
||
|
dataset = pickle.load(open(args.dataset, "rb"))
|
||
|
|
||
|
imitation_trainer = PPO(
|
||
|
ActorCriticPolicy, env, learning_rate=3e-4, gamma=0.95, ent_coef=1e-4, n_steps=2048
|
||
|
)
|
||
|
density_trainer = db.DensityAlgorithm(
|
||
|
venv=env,
|
||
|
rng=rng,
|
||
|
demonstrations=dataset,
|
||
|
rl_algo=imitation_trainer,
|
||
|
density_type=db.DensityType.STATE_ACTION_DENSITY,
|
||
|
is_stationary=True,
|
||
|
kernel="gaussian",
|
||
|
kernel_bandwidth=0.4,
|
||
|
standardise_inputs=True,
|
||
|
allow_variable_horizon=True,
|
||
|
)
|
||
|
density_trainer.train()
|
||
|
|
||
|
|
||
|
def print_stats(density_trainer, n_trajectories):
|
||
|
stats = density_trainer.test_policy(n_trajectories=n_trajectories)
|
||
|
print("True reward function stats:")
|
||
|
pprint.pprint(stats)
|
||
|
stats_im = density_trainer.test_policy(
|
||
|
true_reward=False, n_trajectories=n_trajectories
|
||
|
)
|
||
|
print("Imitation reward function stats:")
|
||
|
pprint.pprint(stats_im)
|
||
|
|
||
|
|
||
|
print("Stats before training:")
|
||
|
print_stats(density_trainer, 1)
|
||
|
|
||
|
density_trainer.train_policy(
|
||
|
1000000,
|
||
|
progress_bar=True,
|
||
|
) # Train for 1_000_000 steps to approach expert performance.
|
||
|
|
||
|
print("Stats after training:")
|
||
|
print_stats(density_trainer, 1)
|
||
|
|
||
|
density_trainer.policy.save("density_policy_ppo.zip")
|