Skip to content

Commit b1a8b1b

Browse files
committed
2 parents d90e521 + 10e49cb commit b1a8b1b

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

A.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,16 @@
33

44
void PythonInit(int argc, char *argv[]){
55
Py_Initialize();
6+
PyEval_InitThreads();
67
wchar_t **argw = new wchar_t*[argc];
78
for(int i = 0; i < argc; i++) argw[i] = Py_DecodeLocale(argv[i], NULL);
89
PySys_SetArgv(argc, argw);
910
}
1011

11-
AgentInterfacePtr PPO_agent_with_param(GymEnvironmentPtr env, std::vector<int> actor_size, double actor_lr, \
12+
AgentInterfacePtr PPO_agent_with_param(std::vector<PytorchEnvironmentPtr> &env, std::vector<int> actor_size, double actor_lr, \
1213
std::vector<int> critic_size, double critic_lr, double critic_decay, double gamma, double lamda, int steps, int batch_size){
13-
int o_size = env->observationSize;
14-
int a_size = env->actionSize;
14+
int o_size = env[0]->observationSize;
15+
int a_size = env[0]->actionSize;
1516
actor_size.insert(actor_size.begin(), o_size); actor_size.insert(actor_size.end(), a_size);
1617
DeepNetworkPtr actor_network = DeepNetworkPtr(new DeepNetwork(actor_size));
1718
AdamPtr actor_opt = AdamPtr(new torch::optim::Adam(actor_network->parameters(), actor_lr));
@@ -29,10 +30,10 @@ AgentInterfacePtr PPO_agent_with_param(GymEnvironmentPtr env, std::vector<int> a
2930
return AgentInterfacePtr(new PPOAgent(env, actor, critic, state_modifier, gamma, lamda, steps, batch_size));
3031
}
3132

32-
void train(AgentInterfacePtr agent, int train_step, torch::Device device){
33+
void train(AgentInterfacePtr agent, int train_step, torch::Device device, bool render){
3334
mkdir("save_model", 0775);
3435
for(int i = 0; i < train_step; i++){
35-
agent->train(5, device, true);
36+
agent->train(5, device, render);
3637
printf("train fin\n");
3738
std::stringstream name;
3839
name << "./save_model/" << i;
@@ -45,7 +46,7 @@ void train(AgentInterfacePtr agent, int train_step, torch::Device device){
4546
}
4647
}
4748

48-
void demo(GymEnvironmentPtr env, AgentInterfacePtr agent, torch::Device device){
49+
void demo(PytorchEnvironmentPtr env, AgentInterfacePtr agent, torch::Device device){
4950
agent->to(device);
5051
torch::Tensor state = env->reset();
5152
while(1){
@@ -62,6 +63,8 @@ static std::string load_model;
6263
static torch::Device device = torch::kCPU;
6364
static int train_step = 0;
6465
static std::string env_type;
66+
static int render = 0;
67+
static int cpu = 16;
6568

6669
void ParseArgs(int argc, char *argv[]){
6770
for(int i = 0; i < argc; i++){
@@ -80,16 +83,17 @@ int main(int argc, char *argv[])
8083
PythonInit(argc, argv);
8184
ParseArgs(argc, argv);
8285

83-
GymEnvironmentPtr env = GymEnvironmentPtr(new GymEnvironment(env_type.c_str(), device));
86+
std::vector<PytorchEnvironmentPtr> env;
87+
for(int i = 0; i < cpu; i++) env.push_back(PytorchEnvironmentPtr(new GymEnvironment(env_type.c_str(), device)));
8488
AgentInterfacePtr agent = PPO_agent_with_param(env, {128, 128}, 1e-4, {128, 128}, 1e-4, 7e-4, 0.994, 0.99, 4096, 80);
8589
//AgentInterfacePtr agent = Vanila_agent_with_param(env, {128, 128}, 1e-4, {128, 128}, 1e-4, 7e-4, 0.994, 2048, 32);
8690
if(load_model != ""){
8791
tinyxml2::XMLDocument doc;
8892
if(doc.LoadFile(load_model.c_str())) return !printf("%s not exist\n", load_model.c_str());
8993
agent->set_xml(doc.RootElement());
9094
}
91-
train(agent, train_step, device);
92-
demo(env, agent, device);
95+
train(agent, train_step, device, render);
96+
if(render) demo(env[0], agent, device);
9397

9498
if (Py_FinalizeEx() < 0) {
9599
return 120;

CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,14 @@ add_definitions(-DTIXML_USE_STL)
77
add_compile_options(-g)
88
find_package(Torch REQUIRED)
99
find_package(PythonLibs REQUIRED)
10+
find_package(OpenMP)
1011
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
1112
include_directories(${PYTHON_INCLUDE_DIRS})
1213
include_directories("./include")
1314
include_directories("./tinyxml2")
1415

1516
file(GLOB srcs A.cpp cpp/*/*.cpp tinyxml2/*.cpp)
1617
add_executable(A ${srcs})
17-
target_link_libraries(A "${TORCH_LIBRARIES}" ${PYTHON_LIBRARIES})
18+
target_link_libraries(A "${TORCH_LIBRARIES}" ${PYTHON_LIBRARIES} ${OpenMP_CXX_FLAGS})
1819
set_property(TARGET A PROPERTY CXX_STANDARD 14)
1920

cpp/Environment/GymEnvironment.cpp

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
PyObject* GymEnvironment::pModule;
66
PyObject* GymEnvironment::pMake;
77

8+
static std::mutex mtx;
9+
810
GymEnvironment::GymEnvironment(const char* name, torch::Device device) :
911
PytorchEnvironment(device), PyWrapper(pModule == NULL? init(name) : PyObject_CallFunctionObjArgs(pMake, PyString(name).obj, NULL)){
1012
pyReset = PyObject_GetAttrString(obj, "reset");
@@ -27,20 +29,27 @@ PyObject* GymEnvironment::init(const char* name){
2729

2830
torch::Tensor GymEnvironment::reset(){
2931
steps = 0;
30-
return PySequenceToTensor(PyObject_CallFunctionObjArgs(pyReset, NULL), true).to(device);
32+
mtx.lock();
33+
torch::Tensor tmp = PySequenceToTensor(PyObject_CallFunctionObjArgs(pyReset, NULL), true).to(device);
34+
mtx.unlock();
35+
return tmp;
3136
}
3237

3338
void GymEnvironment::step(const torch::Tensor &action, torch::Tensor &next_state, double &reward, int &done, int &tl){
39+
mtx.lock();
3440
PyObject* tuple = PyObject_CallFunctionObjArgs(pyStep, PyArray(action.to(torch::kCPU)).obj, NULL);
3541
if(tuple == NULL) PyErr_Print();
3642
next_state = PySequenceToTensor(PyTuple_GetItem(tuple, 0)).to(device);
3743
reward = PyFloat_AsDouble(PyTuple_GetItem(tuple, 1));
3844
done = (PyTuple_GetItem(tuple, 2) == Py_True);
3945
tl = steps >= 2000;
4046
Py_DECREF(tuple);
47+
mtx.unlock();
4148
if(done) next_state = reset();
4249
}
4350

4451
void GymEnvironment::render(){
52+
mtx.lock();
4553
if(PyObject_CallFunctionObjArgs(pyRender, NULL) == NULL) PyErr_Print();
46-
}
54+
mtx.unlock();
55+
}

include/Environment/GymEnvironment.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ class GymEnvironment : public PyWrapper, public PytorchEnvironment{
1010
public:
1111
GymEnvironment(const char* name, torch::Device device);
1212
PyObject* pyReset, *pyStep, *pyRender;
13-
int steps, observationSize, actionSize;
13+
int steps;
1414

1515
virtual torch::Tensor reset();
1616
virtual void step(const torch::Tensor &action, torch::Tensor &next_state, double &reward, int &done, int &tl);

0 commit comments

Comments
 (0)