Skip to content

Commit 6400d17

Browse files
edited
1 parent fa6480b commit 6400d17

File tree

1 file changed

+10
-7
lines changed
  • 序列预测/PCA去趋势化/dev2

1 file changed

+10
-7
lines changed

序列预测/PCA去趋势化/dev2/main.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -123,15 +123,18 @@ def use_pca(dataset):
123123
train_x = (train_x_raw - x_min) / (x_max - x_min)
124124
train_y = (train_y_raw - y_min) / (y_max - y_min)
125125

126-
# test
127-
test_num = 3
128-
test_x = train_x[0:test_num]
129-
test_y = train_y[0:test_num]
130-
126+
# 构造训练和测试集
127+
total_len = train_x.shape[0]
128+
train_len = int(total_len * 0.75)
129+
test_len = int(total_len *0.25)
130+
131+
test_x = train_x[train_len:]
132+
test_y = train_y[train_len:]
133+
train_x =train_x[0:train_len]
134+
train_y =train_y[0:train_len]
131135

132136
# 放入lstm训练
133137
# lstm的hyper-parameter
134-
135138
hidden_size = 400
136139
layer_num = 1
137140
max_epoch = 5000
@@ -199,7 +202,7 @@ def print_to_console(i, train_y ,train_y_pred):
199202
print_to_console(i,train_y, train_y_pred)
200203
if i % 50 ==0:
201204
print ("test : ")
202-
feed_dict = {x_input: test_x, y_real: test_y, keep_prob: 1.0, batch_size: test_num}
205+
feed_dict = {x_input: test_x, y_real: test_y, keep_prob: 1.0, batch_size: test_len}
203206
test_y_pred = sess.run(y_pred, feed_dict=feed_dict)
204207
print_to_console(i, test_y,test_y_pred)
205208
print ("--- test end ---")

0 commit comments

Comments
 (0)