Skip to content

Commit d19d2f6

Browse files
committed
Added CNN network for ddpg
1 parent f26da43 commit d19d2f6

15 files changed

+255
-26
lines changed

Algorithms/ddpg/core.py

+152
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
import torch.nn as nn
33
import torch
44

5+
##########################################################################################################
6+
#MLP ACTOR-CRITIC##
7+
##########################################################################################################
8+
59
def mlp(sizes, activation, output_activation=nn.Identity):
610
'''
711
Create a multi-layer perceptron model from input sizes and activations
@@ -87,6 +91,154 @@ def __init__(self, observation_space, action_space, hidden_sizes=(256, 256), act
8791
self.pi = MLPActor(obs_dim, act_dim, hidden_sizes, activation, act_limit).to(device)
8892
self.q = MLPCritic(obs_dim, act_dim, hidden_sizes, activation).to(device)
8993

94+
def act(self, obs):
95+
with torch.no_grad():
96+
return self.pi(obs).cpu().numpy()
97+
98+
99+
##########################################################################################################
100+
#CNN ACTOR-CRITIC##
101+
##########################################################################################################
102+
103+
def cnn(in_channels, conv_layer_sizes, activation, batchnorm=True):
104+
'''
105+
Create a Convolutional Neural Network with given number of cnn layers
106+
Each convolutional layer has kernel_size=2, and stride=2, which effectively
107+
halves the spatial dimensions and doubles the channel size.
108+
Args:
109+
con_layer_sizes (list): list of 3-tuples consisting of
110+
(output_channel, kernel_size, stride)
111+
in_channels (int): incoming number of channels
112+
num_layers (int): number of convolutional layers needed
113+
activation (nn.Module.Activation): Activation function after each conv layer
114+
batchnorm (bool): If true, add a batchnorm2d layer after activation layer
115+
Returns:
116+
nn.Sequential module for the CNN
117+
'''
118+
layers = []
119+
channels = in_channels
120+
for i in range(len(conv_layer_sizes)):
121+
out_channel, kernel, stride = conv_layer_sizes[i]
122+
layers += [nn.Conv2d(in_channels, out_channel, kernel, stride),
123+
activation()]
124+
if batchnorm:
125+
layers += [nn.BatchNorm2d(out_channel)]
126+
127+
in_channels = out_channel
128+
129+
return nn.Sequential(*layers)
130+
131+
132+
133+
class CNNActor(nn.Module):
134+
def __init__(self, obs_dim, act_dim, conv_layer_sizes, hidden_sizes, activation, act_limit):
135+
'''
136+
A Convolutional Neural Net for the Actor network
137+
Network Architecture: (input) -> CNN -> MLP -> (output)
138+
Assume input is in the shape: (128, 128, 3)
139+
Args:
140+
obs_dim (tuple): observation dimension of the environment in the form of (H, W, C)
141+
act_dim (int): action dimension of the environment
142+
conv_layer_sizes (list): list of 3-tuples consisting of (output_channel, kernel_size, stride)
143+
that describes the cnn architecture
144+
hidden_sizes (list): list of number of neurons in each layer of MLP after output from CNN
145+
activation (nn.modules.activation): Activation function for each layer of MLP
146+
act_limit (float): the greatest magnitude possible for the action in the environment
147+
'''
148+
super().__init__()
149+
150+
self.pi_cnn = cnn(obs_dim[2], conv_layer_sizes, nn.ReLU, batchnorm=True)
151+
self.start_dim = self.calc_shape(obs_dim, self.pi_cnn)
152+
mlp_sizes = [self.start_dim] + list(hidden_sizes) + [act_dim]
153+
self.pi_mlp = mlp(mlp_sizes, activation, output_activation=nn.Tanh)
154+
self.act_limit = act_limit
155+
156+
def calc_shape(self, obs_dim, pi_cnn):
157+
'''
158+
Function to determine the shape of the data after the conv layers
159+
to determine how many neurons for the MLP.
160+
'''
161+
H, W, C = obs_dim
162+
dummy_input = torch.randn(1, C, H, W)
163+
with torch.no_grad():
164+
cnn_out = pi_cnn(dummy_input)
165+
shape = cnn_out.view(-1, ).shape[0]
166+
return shape
167+
168+
def forward(self, obs):
169+
'''
170+
Forward propagation for actor network
171+
Args:
172+
obs (Tensor [n, obs_dim]): batch of observation from environment
173+
Return:
174+
output of actor network * act_limit
175+
'''
176+
obs = self.pi_cnn(obs)
177+
obs = obs.view(-1, self.start_dim)
178+
obs = self.pi_mlp(obs)
179+
return obs*self.act_limit
180+
181+
class CNNCritic(nn.Module):
182+
def __init__(self, obs_dim, act_dim, conv_layer_sizes, hidden_sizes, activation):
183+
'''
184+
A Convolutional Neural Net for the Critic network
185+
Args:
186+
obs_dim (tuple): observation dimension of the environment in the form of (H, W, C)
187+
act_dim (int): action dimension of the environment
188+
conv_layer_sizes (list): list of 3-tuples consisting of (output_channel, kernel_size, stride)
189+
that describes the cnn architecture
190+
hidden_sizes (list): list of number of neurons in each layer of MLP
191+
activation (nn.modules.activation): Activation function for each layer of MLP
192+
'''
193+
super().__init__()
194+
self.q_cnn = cnn(obs_dim[2], conv_layer_sizes, nn.ReLU, batchnorm=True)
195+
self.start_dim = self.calc_shape(obs_dim, self.q_cnn)
196+
self.q_mlp = mlp([self.start_dim + act_dim] + list(hidden_sizes) + [1], activation)
197+
198+
def calc_shape(self, obs_dim, pi_cnn):
199+
'''
200+
Function to determine the shape of the data after the conv layers
201+
to determine how many neurons for the MLP.
202+
'''
203+
H, W, C = obs_dim
204+
dummy_input = torch.randn(1, C, H, W)
205+
with torch.no_grad():
206+
cnn_out = pi_cnn(dummy_input)
207+
shape = cnn_out.view(-1, ).shape[0]
208+
return shape
209+
210+
def forward(self, obs, act):
211+
'''
212+
Forward propagation for critic network
213+
Args:
214+
obs (Tensor [n, obs_dim]): batch of observation from environment
215+
act (Tensor [n, act_dim]): batch of actions taken by actor
216+
'''
217+
obs = self.q_cnn(obs)
218+
obs = obs.view(-1, self.start_dim)
219+
q = self.q_mlp(torch.cat([obs, act], dim=-1))
220+
return torch.squeeze(q, -1) # ensure q has the right shape
221+
222+
class CNNActorCritic(nn.Module):
223+
def __init__(self, observation_space, action_space, conv_layer_sizes, hidden_sizes=(256, 256), activation=nn.ReLU, device='cpu'):
224+
'''
225+
A Multi-Layer Perceptron for the Actor_Critic network
226+
Args:
227+
observation_space (gym.spaces): observation space of the environment
228+
act_space (gym.spaces): action space of the environment
229+
hidden_sizes (tuple): list of number of neurons in each layer of MLP
230+
activation (nn.modules.activation): Activation function for each layer of MLP
231+
device (str): whether to use cpu or gpu to run the model
232+
'''
233+
super().__init__()
234+
obs_dim = observation_space.shape
235+
act_dim = action_space.shape[0]
236+
act_limit = action_space.high[0]
237+
238+
# Create Actor and Critic networks
239+
self.pi = CNNActor(obs_dim, act_dim, conv_layer_sizes, hidden_sizes, activation, act_limit).to(device)
240+
self.q = CNNCritic(obs_dim, act_dim, conv_layer_sizes, hidden_sizes, activation).to(device)
241+
90242
def act(self, obs):
91243
with torch.no_grad():
92244
return self.pi(obs).cpu().numpy()

Algorithms/ddpg/ddpg.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def __init__(self, env_fn, save_dir, actor_critic=MLPActorCritic, ac_kwargs=dict
9393
self.gamma = gamma
9494
self.tau = tau
9595
self.act_noise = act_noise
96-
self.obs_dim = self.env.observation_space.shape[0]
96+
# self.obs_dim = self.env.observation_space.shape[0]
9797
self.act_dim = self.env.action_space.shape[0]
9898
self.num_test_episodes = num_test_episodes
9999
self.max_ep_len = self.env.spec.max_episode_steps if self.env.spec.max_episode_steps is not None else max_ep_len
@@ -267,7 +267,7 @@ def load_weights(self, best=True, load_buffer=True):
267267

268268
env_pkl_path = os.path.join(self.save_dir, "env.pickle")
269269
if os.path.isfile(env_pkl_path):
270-
self.env = Normalize_Observation.load(env_pkl_path)
270+
self.env = self.env.__class__.load(env_pkl_path)
271271
print("Environment loaded")
272272

273273
print('checkpoint loaded at {}'.format(checkpoint_path))

Algorithms/ddpg/ddpg_config_cnn.json

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"ac_kwargs": {
3+
"hidden_sizes": [512, 256],
4+
"conv_layer_sizes": [[16, 5, 2],
5+
[32, 5, 2],
6+
[64, 5, 2],
7+
[64, 3, 1]]
8+
},
9+
"replay_size": 1e6,
10+
"gamma": 0.99,
11+
"tau": 0.995,
12+
"pi_lr": 1e-3,
13+
"q_lr": 1e-3,
14+
"batch_size": 100,
15+
"start_steps": 1000,
16+
"update_after": 1000,
17+
"update_every": 50,
18+
"act_noise": 0.1,
19+
"max_ep_len": 1000,
20+
"save_freq": 1
21+
}

Algorithms/ppo/ppo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def load_weights(self, best=True):
198198

199199
env_pkl_path = os.path.join(self.save_dir, "env.pickle")
200200
if os.path.isfile(env_pkl_path):
201-
self.env = Normalize_Observation.load(env_pkl_path)
201+
self.env = self.env.__class__.load(env_pkl_path)
202202
print("Environment loaded")
203203
print('checkpoint loaded at {}'.format(checkpoint_path))
204204
else:
File renamed without changes.

Algorithms/td3/td3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def load_weights(self, best=True, load_buffer=True):
289289

290290
env_pkl_path = os.path.join(self.save_dir, "env.pickle")
291291
if os.path.isfile(env_pkl_path):
292-
self.env = Normalize_Observation.load(env_pkl_path)
292+
self.env = self.env.__class__.load(env_pkl_path)
293293
print("Environment loaded")
294294

295295
print('checkpoint loaded at {}'.format(checkpoint_path))
File renamed without changes.

Algorithms/trpo/trpo.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -296,7 +296,7 @@ def load_weights(self, best=True):
296296

297297
env_pkl_path = os.path.join(self.save_dir, "env.pickle")
298298
if os.path.isfile(env_pkl_path):
299-
self.env = Normalize_Observation.load(env_pkl_path)
299+
self.env = self.env.__class__.load(env_pkl_path)
300300
print("Environment loaded")
301301

302302
print('checkpoint loaded at {}'.format(checkpoint_path))

Wrappers/normalize_observation.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import pickle
44
from typing import Tuple
55

6-
K = n = Ex = Ex2 = 0.0
76
class Running_Stat:
87
'''
98
Class to store variables required to compute 1st and 2nd order statistics

Wrappers/rlbench_wrapper.py

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import gym
2+
import numpy as np
3+
import pickle
4+
from typing import Tuple
5+
6+
class RLBench_Wrapper(gym.ObservationWrapper):
7+
'''
8+
Observation Wrapper for the RLBench environment to only output 1 of the
9+
camera views during training/testing instead of a dictionary of all camera views
10+
'''
11+
def __init__(self, env, view):
12+
'''
13+
Args:
14+
view (str): Dictionary key to specify which camera view to use.
15+
RLBench observation comes in a dictionary of
16+
['state', 'left_shoulder_rgb', 'right_shoulder_rgb', 'wrist_rgb', 'front_rgb']
17+
'''
18+
super(RLBench_Wrapper, self).__init__(env)
19+
self.view = view
20+
self.observation_space = self.observation_space[view]
21+
22+
def reset(self, **kwargs):
23+
observation = self.env.reset(**kwargs)
24+
return self.observation(observation)
25+
26+
def observation(self, observation):
27+
return observation[self.view]
28+
29+
def save(self, fname):
30+
with open(fname, 'wb') as f:
31+
pickle.dump(self, f)
32+
33+
@classmethod
34+
def load(cls, filename):
35+
with open(filename, 'rb') as f:
36+
return pickle.load(f)

Wrappers/serialize_env.py

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import gym
2+
import numpy as np
3+
import pickle
4+
from typing import Tuple
5+
6+
class Serialize_Env(gym.ObservationWrapper):
7+
'''
8+
Simple wrapper to add the save and load functionality
9+
'''
10+
def __init__(self, env, training=True):
11+
super(Serialize_Env, self).__init__(env)
12+
13+
def save(self, fname):
14+
with open(fname, 'wb') as f:
15+
pickle.dump(self, f)
16+
17+
@classmethod
18+
def load(cls, filename):
19+
with open(filename, 'rb') as f:
20+
return pickle.load(f)

test.py

+4-8
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@ def parse_arguments():
6767

6868
def main():
6969
args = parse_arguments()
70+
71+
save_dir = os.path.join("Model_Weights", args.env, args.agent.lower())
72+
config_path = os.path.join(save_dir, args.agent.lower() + "_config.json")
73+
7074
if args.agent.lower() == 'random':
7175
save_dir = os.path.join("Model_Weights", args.env) if args.gif else None
7276
if not os.path.isdir(save_dir):
@@ -77,8 +81,6 @@ def main():
7781

7882
elif args.agent.lower() == 'ddpg':
7983
from Algorithms.ddpg.ddpg import DDPG
80-
save_dir = os.path.join("Model_Weights", args.env, "ddpg")
81-
config_path = os.path.join(save_dir, "ddpg_config.json")
8284
logger_kwargs = {
8385
"output_dir": save_dir
8486
}
@@ -89,8 +91,6 @@ def main():
8991
model.load_weights(load_buffer=False)
9092
elif args.agent.lower() == 'td3':
9193
from Algorithms.td3.td3 import TD3
92-
save_dir = os.path.join("Model_Weights", args.env, "td3")
93-
config_path = os.path.join(save_dir, "td3_config.json")
9494
logger_kwargs = {
9595
"output_dir": save_dir
9696
}
@@ -101,8 +101,6 @@ def main():
101101
model.load_weights(load_buffer=False)
102102
elif args.agent.lower() == 'trpo':
103103
from Algorithms.trpo.trpo import TRPO
104-
save_dir = os.path.join("Model_Weights", args.env, "trpo")
105-
config_path = os.path.join(save_dir, "trpo_config.json")
106104
logger_kwargs = {
107105
"output_dir": save_dir
108106
}
@@ -113,8 +111,6 @@ def main():
113111
model.load_weights()
114112
elif args.agent.lower() == 'ppo':
115113
from Algorithms.ppo.ppo import PPO
116-
save_dir = os.path.join("Model_Weights", args.env, "ppo")
117-
config_path = os.path.join(save_dir, "ppo_config.json")
118114
logger_kwargs = {
119115
"output_dir": save_dir
120116
}

0 commit comments

Comments
 (0)