3
3
4
4
void PythonInit (int argc, char *argv[]){
5
5
Py_Initialize ();
6
+ PyEval_InitThreads ();
6
7
wchar_t **argw = new wchar_t *[argc];
7
8
for (int i = 0 ; i < argc; i++) argw[i] = Py_DecodeLocale (argv[i], NULL );
8
9
PySys_SetArgv (argc, argw);
9
10
}
10
11
11
- AgentInterfacePtr PPO_agent_with_param (GymEnvironmentPtr env, std::vector<int > actor_size, double actor_lr, \
12
+ AgentInterfacePtr PPO_agent_with_param (std::vector<PytorchEnvironmentPtr> & env, std::vector<int > actor_size, double actor_lr, \
12
13
std::vector<int > critic_size, double critic_lr, double critic_decay, double gamma, double lamda, int steps, int batch_size){
13
- int o_size = env->observationSize ;
14
- int a_size = env->actionSize ;
14
+ int o_size = env[ 0 ] ->observationSize ;
15
+ int a_size = env[ 0 ] ->actionSize ;
15
16
actor_size.insert (actor_size.begin (), o_size); actor_size.insert (actor_size.end (), a_size);
16
17
DeepNetworkPtr actor_network = DeepNetworkPtr (new DeepNetwork (actor_size));
17
18
AdamPtr actor_opt = AdamPtr (new torch::optim::Adam (actor_network->parameters (), actor_lr));
@@ -29,10 +30,10 @@ AgentInterfacePtr PPO_agent_with_param(GymEnvironmentPtr env, std::vector<int> a
29
30
return AgentInterfacePtr (new PPOAgent (env, actor, critic, state_modifier, gamma, lamda, steps, batch_size));
30
31
}
31
32
32
- void train (AgentInterfacePtr agent, int train_step, torch::Device device){
33
+ void train (AgentInterfacePtr agent, int train_step, torch::Device device, bool render ){
33
34
mkdir (" save_model" , 0775 );
34
35
for (int i = 0 ; i < train_step; i++){
35
- agent->train (5 , device, true );
36
+ agent->train (5 , device, render );
36
37
printf (" train fin\n " );
37
38
std::stringstream name;
38
39
name << " ./save_model/" << i;
@@ -45,7 +46,7 @@ void train(AgentInterfacePtr agent, int train_step, torch::Device device){
45
46
}
46
47
}
47
48
48
- void demo (GymEnvironmentPtr env, AgentInterfacePtr agent, torch::Device device){
49
+ void demo (PytorchEnvironmentPtr env, AgentInterfacePtr agent, torch::Device device){
49
50
agent->to (device);
50
51
torch::Tensor state = env->reset ();
51
52
while (1 ){
@@ -62,6 +63,8 @@ static std::string load_model;
62
63
static torch::Device device = torch::kCPU ;
63
64
static int train_step = 0 ;
64
65
static std::string env_type;
66
+ static int render = 0 ;
67
+ static int cpu = 16 ;
65
68
66
69
void ParseArgs (int argc, char *argv[]){
67
70
for (int i = 0 ; i < argc; i++){
@@ -80,16 +83,17 @@ int main(int argc, char *argv[])
80
83
PythonInit (argc, argv);
81
84
ParseArgs (argc, argv);
82
85
83
- GymEnvironmentPtr env = GymEnvironmentPtr (new GymEnvironment (env_type.c_str (), device));
86
+ std::vector<PytorchEnvironmentPtr> env;
87
+ for (int i = 0 ; i < cpu; i++) env.push_back (PytorchEnvironmentPtr (new GymEnvironment (env_type.c_str (), device)));
84
88
AgentInterfacePtr agent = PPO_agent_with_param (env, {128 , 128 }, 1e-4 , {128 , 128 }, 1e-4 , 7e-4 , 0.994 , 0.99 , 4096 , 80 );
85
89
// AgentInterfacePtr agent = Vanila_agent_with_param(env, {128, 128}, 1e-4, {128, 128}, 1e-4, 7e-4, 0.994, 2048, 32);
86
90
if (load_model != " " ){
87
91
tinyxml2::XMLDocument doc;
88
92
if (doc.LoadFile (load_model.c_str ())) return !printf (" %s not exist\n " , load_model.c_str ());
89
93
agent->set_xml (doc.RootElement ());
90
94
}
91
- train (agent, train_step, device);
92
- demo (env, agent, device);
95
+ train (agent, train_step, device, render );
96
+ if (render) demo (env[ 0 ] , agent, device);
93
97
94
98
if (Py_FinalizeEx () < 0 ) {
95
99
return 120 ;
0 commit comments