ThomasSimonini HF staff commited on
Commit
b10dd53
1 Parent(s): e1e46b3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +154 -1
README.md CHANGED
@@ -4,4 +4,157 @@ tags:
4
  - reinforcement-learning
5
  - stable-baselines3
6
  ---
7
- # TODO: Fill this model card
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - reinforcement-learning
5
  - stable-baselines3
6
  ---
7
+ # PPO Agent playing PongNoFrameskip-v4
8
+ This is a trained model of a PPO agent playing PongNoFrameskip-v4 using the stable-baselines3 library.
9
+
10
+ <video src="https://huggingface.co/ThomasSimonini/ppo-SpaceInvadersNoFrameskip-v4/resolve/main/output.mp4" controls autoplay loop></video>
11
+
12
+ # Usage (with Stable-baselines3)
13
+
14
+
15
+ ## Evaluation Results
16
+ Mean_reward = 21.00 +/- 0.0
17
+
18
+ ## Watch your agent interacts
19
+ - You need to use `gym==0.19` since it **includes Atari Roms**.
20
+ - The Actor Space is 6 since we use only **legit actions**.
21
+
22
+ ```python
23
+ # Install these libraries (don't forget to restart the runtime after installing the librairies)
24
+ !pip install stable-baselines3[extra]
25
+ !pip install huggingface_sb3
26
+ !pip install huggingface_hub
27
+ !pip install pickle5
28
+
29
+ # Import the libraries
30
+ import os
31
+
32
+ import gym
33
+
34
+ from stable_baselines3 import PPO
35
+ from stable_baselines3.common.vec_env import VecNormalize
36
+
37
+ from stable_baselines3.common.env_util import make_atari_env
38
+ from stable_baselines3.common.vec_env import VecFrameStack
39
+ from stable_baselines3 import PPO
40
+ from stable_baselines3.common.callbacks import CheckpointCallback
41
+
42
+
43
+ from huggingface_sb3 import load_from_hub, push_to_hub
44
+ import gym
45
+ from stable_baselines3.common.vec_env import VecVideoRecorder, DummyVecEnv
46
+
47
+
48
+ from stable_baselines3.common.evaluation import evaluate_policy
49
+
50
+ # Load the model
51
+ checkpoint = load_from_hub("ThomasSimonini/ppo-PongNoFrameskip-v4", "ppo-PongNoFrameskip-v4.zip")
52
+
53
+ # Because we using 3.7 on Colab and this agent was trained with 3.8 to avoid Pickle errors:
54
+ custom_objects = {
55
+ "learning_rate": 0.0,
56
+ "lr_schedule": lambda _: 0.0,
57
+ "clip_range": lambda _: 0.0,
58
+ }
59
+
60
+ model= PPO.load(checkpoint, custom_objects=custom_objects)
61
+
62
+ ## Evaluate the agent
63
+ env = make_atari_env('PongNoFrameskip-v4', n_envs=1)
64
+ env = VecFrameStack(env, n_stack=4)
65
+
66
+ mean_reward, std_reward = evaluate_policy(model, env, n_eval_episodes=10)
67
+ print(f"mean_reward={mean_reward:.2f} +/- {std_reward}")
68
+
69
+ ## Generate a video of your agent performing with Colab
70
+ !pip install gym pyvirtualdisplay > /dev/null 2>&1
71
+ !apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
72
+ !pip install colabgymrender==1.0.2
73
+
74
+ observation = env.reset()
75
+ terminal = False
76
+ while not terminal:
77
+ action, _state = model.predict(observation)
78
+ observation, reward, terminal, info = env.step(action)
79
+ env.play()
80
+ ```
81
+
82
+
83
+ ## Training Code
84
+ - You need to use `gym==0.19` since it **includes Atari Roms**.
85
+ - The Actor Space is 6 since we use only **legit actions**.
86
+
87
+ ```python
88
+ import wandb
89
+ import gym
90
+
91
+ from stable_baselines3 import PPO
92
+ from stable_baselines3.common.env_util import make_atari_env
93
+ from stable_baselines3.common.vec_env import VecFrameStack, VecVideoRecorder
94
+ from stable_baselines3.common.callbacks import CheckpointCallback
95
+
96
+ from wandb.integration.sb3 import WandbCallback
97
+
98
+ from huggingface_sb3 import load_from_hub, push_to_hub
99
+
100
+ config = {
101
+ "env_name": "PongNoFrameskip-v4",
102
+ "num_envs": 8,
103
+ "total_timesteps": int(10e6),
104
+ "seed": 4089164106,
105
+ }
106
+
107
+ run = wandb.init(
108
+ project="HFxSB3",
109
+ config = config,
110
+ sync_tensorboard = True, # Auto-upload sb3's tensorboard metrics
111
+ monitor_gym = True, # Auto-upload the videos of agents playing the game
112
+ save_code = True, # Save the code to W&B
113
+ )
114
+
115
+ # There already exists an environment generator
116
+ # that will make and wrap atari environments correctly.
117
+ # Here we are also multi-worker training (n_envs=8 => 8 environments)
118
+ env = make_atari_env(config["env_name"], n_envs=config["num_envs"], seed=config["seed"]) #PongNoFrameskip-v4
119
+
120
+ print("ENV ACTION SPACE: ", env.action_space.n)
121
+
122
+ # Frame-stacking with 4 frames
123
+ env = VecFrameStack(env, n_stack=4)
124
+ # Video recorder
125
+ env = VecVideoRecorder(env, "videos", record_video_trigger=lambda x: x % 100000 == 0, video_length=2000)
126
+
127
+ # https://github.com/DLR-RM/rl-trained-agents/blob/10a9c31e806820d59b20d8b85ca67090338ea912/ppo/PongNoFrameskip-v4_1/PongNoFrameskip-v4/config.yml
128
+ model = PPO(policy = "CnnPolicy",
129
+ env = env,
130
+ batch_size = 256,
131
+ clip_range = 0.1,
132
+ ent_coef = 0.01,
133
+ gae_lambda = 0.9,
134
+ gamma = 0.99,
135
+ learning_rate = 2.5e-4,
136
+ max_grad_norm = 0.5,
137
+ n_epochs = 4,
138
+ n_steps = 128,
139
+ vf_coef = 0.5,
140
+ tensorboard_log = f"runs",
141
+ verbose=1,
142
+ )
143
+
144
+ model.learn(
145
+ total_timesteps = config["total_timesteps"],
146
+ callback = [
147
+ WandbCallback(
148
+ gradient_save_freq = 1000,
149
+ model_save_path = f"models/{run.id}",
150
+ ),
151
+ CheckpointCallback(save_freq=10000, save_path='./pong',
152
+ name_prefix=config["env_name"]),
153
+ ]
154
+ )
155
+
156
+ model.save("ppo-PongNoFrameskip-v4.zip")
157
+ push_to_hub(repo_id="ThomasSimonini/ppo-PongNoFrameskip-v4",
158
+ filename="ppo-PongNoFrameskip-v4.zip",
159
+ commit_message="Added Pong trained agent")
160
+ ```