65 lines
1.7 KiB
Python
Raw Permalink Normal View History

2025-07-28 12:35:44 +08:00
import argparse
import pickle
import pprint
import numpy as np
from gymnasium.envs.registration import register
from imitation.algorithms import density as db
from imitation.data import serialize
from imitation.util import util
from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
parser = argparse.ArgumentParser()
parser.add_argument("-d", "--dataset", type=str, required=True)
args = parser.parse_args()
rng = np.random.default_rng(0)
register(id="BDX_env", entry_point="env:BDXEnv")
env = util.make_vec_env("BDX_env", rng=rng, n_envs=2)
dataset = pickle.load(open(args.dataset, "rb"))
imitation_trainer = PPO(
ActorCriticPolicy, env, learning_rate=3e-4, gamma=0.95, ent_coef=1e-4, n_steps=2048
)
density_trainer = db.DensityAlgorithm(
venv=env,
rng=rng,
demonstrations=dataset,
rl_algo=imitation_trainer,
density_type=db.DensityType.STATE_ACTION_DENSITY,
is_stationary=True,
kernel="gaussian",
kernel_bandwidth=0.4,
standardise_inputs=True,
allow_variable_horizon=True,
)
density_trainer.train()
def print_stats(density_trainer, n_trajectories):
stats = density_trainer.test_policy(n_trajectories=n_trajectories)
print("True reward function stats:")
pprint.pprint(stats)
stats_im = density_trainer.test_policy(
true_reward=False, n_trajectories=n_trajectories
)
print("Imitation reward function stats:")
pprint.pprint(stats_im)
print("Stats before training:")
print_stats(density_trainer, 1)
density_trainer.train_policy(
1000000,
progress_bar=True,
) # Train for 1_000_000 steps to approach expert performance.
print("Stats after training:")
print_stats(density_trainer, 1)
density_trainer.policy.save("density_policy_ppo.zip")