Skip to content

Commit 78f61f0

Browse files
author
moneyDboat
committed
fix a preprogress bug
1 parent 2308add commit 78f61f0

16 files changed

+28033
-28019
lines changed
255 Bytes
Binary file not shown.

__pycache__/models.cpython-36.pyc

20 Bytes
Binary file not shown.

__pycache__/util.cpython-36.pyc

-8 Bytes
Binary file not shown.

data_manager.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,10 +64,15 @@ def read_rating(self, path):
6464

6565
# 生成词向量矩阵
6666
def read_pretrained_word2vec(self, path, vocab, dim):
67+
parent_path = '/'.join(path.split('/')[:-1]) + '/'
68+
if os.path.isfile(parent_path + 'preW.all'):
69+
print('Load pretrained_word2vec from preW.all')
70+
W = pickle.load(open(parent_path + 'preW.all', 'rb'))
71+
return W
6772
if os.path.isfile(path):
6873
raw_word2vec = open(path, 'r')
6974
else:
70-
print ("Path (word2vec) is wrong!")
75+
print("Path (word2vec) is wrong!")
7176
sys.exit()
7277

7378
word2vec_dic = {}
@@ -79,7 +84,7 @@ def read_pretrained_word2vec(self, path, vocab, dim):
7984
_word = tmp[0]
8085
_vec = np.array(tmp[1:], dtype=float)
8186
if _vec.shape[0] != dim:
82-
print ("Mismatch the dimension of pre-trained word vector with word embedding dimension!")
87+
print("Mismatch the dimension of pre-trained word vector with word embedding dimension!")
8388
sys.exit()
8489
word2vec_dic[_word] = _vec
8590
mean = mean + _vec
@@ -96,7 +101,10 @@ def read_pretrained_word2vec(self, path, vocab, dim):
96101
else:
97102
W[i + 1] = np.random.normal(mean, 0.1, size=dim)
98103

99-
print ("%d words exist in the given pretrained model" % count)
104+
print("%d words exist in the given pretrained model" % count)
105+
print('Saving preW.all file.')
106+
pickle.dump(W, open(parent_path +'preW.all', 'wb'))
107+
print('Done')
100108

101109
return W
102110

@@ -108,8 +116,10 @@ def split_data(self, ratio, R):
108116
np.random.shuffle(user_rating)
109117
train.append((i, user_rating[0]))
110118

119+
# "*train" to open a list
111120
remain_item = set(range(R.shape[1])) - set(list(zip(*train))[1])
112121

122+
# to make sure that training set contains at least a rating on every user and item
113123
for j in remain_item:
114124
item_rating = R.tocsc().T[j].nonzero()[1]
115125
np.random.shuffle(item_rating)
@@ -135,10 +145,10 @@ def split_data(self, ratio, R):
135145
trainset_u_idx = set(trainset_u_idx)
136146
trainset_i_idx = set(trainset_i_idx)
137147
if len(trainset_u_idx) != R.shape[0] or len(trainset_i_idx) != R.shape[1]:
138-
print ("Fatal error in split function. Check your data again or contact authors")
148+
print("Fatal error in split function. Check your data again or contact authors")
139149
sys.exit()
140150

141-
print ("Finish constructing training set and test set")
151+
print("Finish constructing training set and test set")
142152
return train, valid, test
143153

144154
def generate_train_valid_test_file_from_R(self, path, R, ratio):
@@ -153,7 +163,7 @@ def generate_train_valid_test_file_from_R(self, path, R, ratio):
153163
- ratio: (1-ratio), ratio/2 and ratio/2 of the entire dataset (R) will be training, valid and test set, respectively
154164
'''
155165
train, valid, test = self.split_data(ratio, R)
156-
print ("Save training set and test set to %s..." % path)
166+
print("Save training set and test set to %s..." % path)
157167
if not os.path.exists(path):
158168
os.makedirs(path)
159169

@@ -248,7 +258,7 @@ def generate_train_valid_test_file_from_R(self, path, R, ratio):
248258
formatted_item_test = []
249259

250260
for j in range(R.shape[1]):
251-
if i in item_ratings_train:
261+
if j in item_ratings_train:
252262
formatted = [str(len(item_ratings_train[j]))]
253263
formatted.extend(["%d:%.1f" % (i, R_lil[i, j])
254264
for i in sorted(item_ratings_train[j])])
@@ -279,9 +289,9 @@ def generate_train_valid_test_file_from_R(self, path, R, ratio):
279289
f_train_item.close()
280290
f_valid_item.close()
281291
f_test_item.close()
282-
print ("\ttrain_item.dat, valid_item.dat, test_item.dat files are generated.")
292+
print("\ttrain_item.dat, valid_item.dat, test_item.dat files are generated.")
283293

284-
print ("Done!")
294+
print("Done!")
285295

286296
def generate_CTRCDLformat_content_file_from_D_all(self, path, D_all):
287297
'''
@@ -378,6 +388,7 @@ def preprocess(self, path_rating, path_itemtext, min_rating,
378388
item = []
379389
rating = []
380390

391+
# convert to CSR format to represent sparse matrix
381392
for line in all_line:
382393
tmp = line.split('::')
383394
u = tmp[0]
@@ -433,7 +444,7 @@ def preprocess(self, path_rating, path_itemtext, min_rating,
433444

434445
# Make vocabulary by document
435446
vectorizer = TfidfVectorizer(max_df=_max_df, stop_words={
436-
'english'}, max_features=_vocab_size)
447+
'english'}, max_features=_vocab_size)
437448
Raw_X = [map_idtoplot[i] for i in range(R.shape[1])]
438449
vectorizer.fit(Raw_X)
439450
vocab = vectorizer.vocabulary_

models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def ConvMF(res_dir, train_user, train_item, valid_user, test_user,
3535
Test_R = test_user[1]
3636
Valid_R = valid_user[1]
3737

38-
# 这部分暂时先不管有什么用
38+
# 这部分暂时不知道有什么用
3939
if give_item_weight is True:
4040
item_weight = np.array([math.sqrt(len(i))
4141
for i in Train_R_J], dtype=float)

run.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
'''
66
import argparse
77
import sys
8+
import os
89
from data_manager import Data_Factory
910

1011
parser = argparse.ArgumentParser()

test/ml-1m/0.2/test_item.dat

Lines changed: 3270 additions & 3270 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)