Skip to content

Commit d33d349

Browse files
author
moneyDboat
committed
init push
1 parent 050c848 commit d33d349

29 files changed

+29726
-0
lines changed

__init__.py

Whitespace-only changes.

__init__.pyc

110 Bytes
Binary file not shown.
12 KB
Binary file not shown.

__pycache__/models.cpython-36.pyc

3.01 KB
Binary file not shown.

__pycache__/util.cpython-36.pyc

963 Bytes
Binary file not shown.

data_manager.py

Lines changed: 460 additions & 0 deletions
Large diffs are not rendered by default.

data_manager.pyc

12.9 KB
Binary file not shown.

models.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
'''
2+
Created on Dec 8, 2015
3+
4+
@author: donghyun
5+
'''
6+
7+
import os
8+
import time
9+
10+
from util import eval_RMSE
11+
import math
12+
import numpy as np
13+
from text_analysis.models import CNN_module
14+
15+
16+
17+
def ConvMF(res_dir, train_user, train_item, valid_user, test_user,
18+
R, CNN_X, vocab_size, init_W=None, give_item_weight=True,
19+
max_iter=50, lambda_u=1, lambda_v=100, dimension=50,
20+
dropout_rate=0.2, emb_dim=200, max_len=300, num_kernel_per_ws=100):
21+
# explicit setting
22+
a = 1
23+
b = 0
24+
25+
num_user = R.shape[0]
26+
num_item = R.shape[1]
27+
PREV_LOSS = 1e-50
28+
if not os.path.exists(res_dir):
29+
os.makedirs(res_dir)
30+
f1 = open(res_dir + '/state.log', 'w')
31+
# state.log record
32+
33+
Train_R_I = train_user[1]
34+
Train_R_J = train_item[1]
35+
Test_R = test_user[1]
36+
Valid_R = valid_user[1]
37+
38+
# 这部分暂时先不管有什么用
39+
if give_item_weight is True:
40+
item_weight = np.array([math.sqrt(len(i))
41+
for i in Train_R_J], dtype=float)
42+
item_weight *= (float(num_item) / item_weight.sum())
43+
else:
44+
item_weight = np.ones(num_item, dtype=float)
45+
46+
pre_val_eval = 1e10
47+
48+
# dimension: latent of dimension for users and items
49+
# emb_dim: Size of latent dimension for word vectors
50+
cnn_module = CNN_module(dimension, vocab_size, dropout_rate,
51+
emb_dim, max_len, num_kernel_per_ws, init_W)
52+
53+
# return the theta of CNN
54+
theta = cnn_module.get_projection_layer(CNN_X)
55+
np.random.seed(133)
56+
U = np.random.uniform(size=(num_user, dimension))
57+
V = theta
58+
59+
endure_count = 5
60+
count = 0
61+
for iteration in range(max_iter):
62+
loss = 0
63+
tic = time.time()
64+
print ("%d iteration\t(patience: %d)" % (iteration, count))
65+
66+
VV = b * (V.T.dot(V)) + lambda_u * np.eye(dimension)
67+
sub_loss = np.zeros(num_user)
68+
69+
for i in range(num_user):
70+
idx_item = train_user[0][i]
71+
V_i = V[idx_item]
72+
R_i = Train_R_I[i]
73+
A = VV + (a - b) * (V_i.T.dot(V_i))
74+
B = (a * V_i * np.tile(R_i, (dimension, 1)).T).sum(0)
75+
76+
U[i] = np.linalg.solve(A, B)
77+
78+
sub_loss[i] = -0.5 * lambda_u * np.dot(U[i], U[i])
79+
80+
loss = loss + np.sum(sub_loss)
81+
82+
sub_loss = np.zeros(num_item)
83+
UU = b * (U.T.dot(U))
84+
for j in range(num_item):
85+
idx_user = train_item[0][j]
86+
U_j = U[idx_user]
87+
R_j = Train_R_J[j]
88+
89+
tmp_A = UU + (a - b) * (U_j.T.dot(U_j))
90+
A = tmp_A + lambda_v * item_weight[j] * np.eye(dimension)
91+
B = (a * U_j * np.tile(R_j, (dimension, 1)).T
92+
).sum(0) + lambda_v * item_weight[j] * theta[j]
93+
V[j] = np.linalg.solve(A, B)
94+
95+
sub_loss[j] = -0.5 * np.square(R_j * a).sum()
96+
sub_loss[j] = sub_loss[j] + a * np.sum((U_j.dot(V[j])) * R_j)
97+
sub_loss[j] = sub_loss[j] - 0.5 * np.dot(V[j].dot(tmp_A), V[j])
98+
99+
loss = loss + np.sum(sub_loss)
100+
seed = np.random.randint(100000)
101+
history = cnn_module.train(CNN_X, V, item_weight, seed)
102+
theta = cnn_module.get_projection_layer(CNN_X)
103+
cnn_loss = history.history['loss'][-1]
104+
105+
loss -= 0.5 * lambda_v * cnn_loss * num_item
106+
107+
tr_eval = eval_RMSE(Train_R_I, U, V, train_user[0])
108+
val_eval = eval_RMSE(Valid_R, U, V, valid_user[0])
109+
te_eval = eval_RMSE(Test_R, U, V, test_user[0])
110+
111+
toc = time.time()
112+
elapsed = toc - tic
113+
114+
converge = abs((loss - PREV_LOSS) / PREV_LOSS)
115+
116+
if val_eval < pre_val_eval:
117+
cnn_module.save_model(res_dir + '/CNN_weights.hdf5')
118+
np.savetxt(res_dir + '/U.dat', U)
119+
np.savetxt(res_dir + '/V.dat', V)
120+
np.savetxt(res_dir + '/theta.dat', theta)
121+
else:
122+
count += 1
123+
124+
pre_val_eval = val_eval
125+
126+
print ("Loss: %.5f Elpased: %.4fs Converge: %.6f Tr: %.5f Val: %.5f Te: %.5f" % (
127+
loss, elapsed, converge, tr_eval, val_eval, te_eval))
128+
f1.write("Loss: %.5f Elpased: %.4fs Converge: %.6f Tr: %.5f Val: %.5f Te: %.5f\n" % (
129+
loss, elapsed, converge, tr_eval, val_eval, te_eval))
130+
131+
if count == endure_count:
132+
break
133+
134+
PREV_LOSS = loss
135+
136+
f1.close()

models.pyc

3.37 KB
Binary file not shown.

run.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
'''
2+
Created on Dec 9, 2015
3+
4+
@author: donghyunz
5+
'''
6+
import argparse
7+
import sys
8+
from data_manager import Data_Factory
9+
10+
parser = argparse.ArgumentParser()
11+
12+
# Option for pre-processing data
13+
parser.add_argument("-c", "--do_preprocess", type=bool,
14+
help="True or False to preprocess raw data for ConvMF (default = False)", default=False)
15+
parser.add_argument("-r", "--raw_rating_data_path", type=str,
16+
help="Path to raw rating data. data format - user id::item id::rating")
17+
parser.add_argument("-i", "--raw_item_document_data_path", type=str,
18+
help="Path to raw item document data. item document consists of multiple text. data format - item id::text1|text2...")
19+
parser.add_argument("-m", "--min_rating", type=int,
20+
help="Users who have less than \"min_rating\" ratings will be removed (default = 1)", default=1)
21+
parser.add_argument("-l", "--max_length_document", type=int,
22+
help="Maximum length of document of each item (default = 300)", default=300)
23+
parser.add_argument("-f", "--max_df", type=float,
24+
help="Threshold to ignore terms that have a document frequency higher than the given value (default = 0.5)", default=0.5)
25+
parser.add_argument("-s", "--vocab_size", type=int,
26+
help="Size of vocabulary (default = 8000)", default=8000)
27+
parser.add_argument("-t", "--split_ratio", type=float,
28+
help="Ratio: 1-ratio, ratio/2 and ratio/2 of the entire dataset (R) will be training, valid and test set, respectively (default = 0.2)", default=0.2)
29+
30+
# Option for pre-processing data and running ConvMF
31+
parser.add_argument("-d", "--data_path", type=str,
32+
help="Path to training, valid and test data sets")
33+
parser.add_argument("-a", "--aux_path", type=str, help="Path to R, D_all sets")
34+
35+
# Option for running ConvMF
36+
parser.add_argument("-o", "--res_dir", type=str,
37+
help="Path to ConvMF's result")
38+
parser.add_argument("-e", "--emb_dim", type=int,
39+
help="Size of latent dimension for word vectors (default: 200)", default=200)
40+
parser.add_argument("-p", "--pretrain_w2v", type=str,
41+
help="Path to pretrain word embedding model to initialize word vectors")
42+
parser.add_argument("-g", "--give_item_weight", type=bool,
43+
help="True or False to give item weight of ConvMF (default = False)", default=True)
44+
parser.add_argument("-k", "--dimension", type=int,
45+
help="Size of latent dimension for users and items (default: 50)", default=50)
46+
parser.add_argument("-u", "--lambda_u", type=float,
47+
help="Value of user regularizer")
48+
parser.add_argument("-v", "--lambda_v", type=float,
49+
help="Value of item regularizer")
50+
parser.add_argument("-n", "--max_iter", type=int,
51+
help="Value of max iteration (default: 200)", default=200)
52+
parser.add_argument("-w", "--num_kernel_per_ws", type=int,
53+
help="Number of kernels per window size for CNN module (default: 100)", default=100)
54+
55+
args = parser.parse_args()
56+
do_preprocess = args.do_preprocess
57+
data_path = args.data_path
58+
aux_path = args.aux_path
59+
if data_path is None:
60+
sys.exit("Argument missing - data_path is required")
61+
if aux_path is None:
62+
sys.exit("Argument missing - aux_path is required")
63+
64+
data_factory = Data_Factory()
65+
66+
if do_preprocess:
67+
path_rating = args.raw_rating_data_path
68+
path_itemtext = args.raw_item_document_data_path
69+
min_rating = args.min_rating
70+
max_length = args.max_length_document
71+
max_df = args.max_df
72+
vocab_size = args.vocab_size
73+
split_ratio = args.split_ratio
74+
75+
print ("=================================Preprocess Option Setting=================================")
76+
print ("\tsaving preprocessed aux path - %s" % aux_path)
77+
print ("\tsaving preprocessed data path - %s" % data_path)
78+
print ("\trating data path - %s" % path_rating)
79+
print ("\tdocument data path - %s" % path_itemtext)
80+
print ("\tmin_rating: %d\n\tmax_length_document: %d\n\tmax_df: %.1f\n\tvocab_size: %d\n\tsplit_ratio: %.1f" \
81+
% (min_rating, max_length, max_df, vocab_size, split_ratio))
82+
print ("===========================================================================================")
83+
84+
R, D_all = data_factory.preprocess(
85+
path_rating, path_itemtext, min_rating, max_length, max_df, vocab_size)
86+
data_factory.save(aux_path, R, D_all)
87+
data_factory.generate_train_valid_test_file_from_R(
88+
data_path, R, split_ratio)
89+
else:
90+
res_dir = args.res_dir
91+
emb_dim = args.emb_dim
92+
pretrain_w2v = args.pretrain_w2v
93+
dimension = args.dimension
94+
lambda_u = args.lambda_u
95+
lambda_v = args.lambda_v
96+
max_iter = args.max_iter
97+
num_kernel_per_ws = args.num_kernel_per_ws
98+
give_item_weight = args.give_item_weight
99+
100+
if res_dir is None:
101+
sys.exit("Argument missing - res_dir is required")
102+
if lambda_u is None:
103+
sys.exit("Argument missing - lambda_u is required")
104+
if lambda_v is None:
105+
sys.exit("Argument missing - lambda_v is required")
106+
107+
print ("===================================ConvMF Option Setting===================================")
108+
print ("\taux path - %s" % aux_path)
109+
print ("\tdata path - %s" % data_path)
110+
print ("\tresult path - %s" % res_dir)
111+
print ("\tpretrained w2v data path - %s" % pretrain_w2v)
112+
print ("\tdimension: %d\n\tlambda_u: %.4f\n\tlambda_v: %.4f\n\tmax_iter: %d\n\tnum_kernel_per_ws: %d" \
113+
% (dimension, lambda_u, lambda_v, max_iter, num_kernel_per_ws))
114+
print ("===========================================================================================")
115+
116+
R, D_all = data_factory.load(aux_path)
117+
CNN_X = D_all['X_sequence']
118+
vocab_size = len(D_all['X_vocab']) + 1
119+
120+
print("\tJay::vocab_size is %d" % vocab_size)
121+
print("\tJay::cnn_x is %d" % len(CNN_X))
122+
123+
from models import ConvMF
124+
125+
if pretrain_w2v is None:
126+
init_W = None
127+
else:
128+
# 生成词向量矩阵
129+
init_W = data_factory.read_pretrained_word2vec(
130+
pretrain_w2v, D_all['X_vocab'], emb_dim)
131+
132+
train_user = data_factory.read_rating(data_path + '/train_user.dat')
133+
train_item = data_factory.read_rating(data_path + '/train_item.dat')
134+
valid_user = data_factory.read_rating(data_path + '/valid_user.dat')
135+
test_user = data_factory.read_rating(data_path + '/test_user.dat')
136+
137+
ConvMF(max_iter=max_iter, res_dir=res_dir,
138+
lambda_u=lambda_u, lambda_v=lambda_v, dimension=dimension, vocab_size=vocab_size, init_W=init_W,
139+
give_item_weight=give_item_weight, CNN_X=CNN_X, emb_dim=emb_dim, num_kernel_per_ws=num_kernel_per_ws,
140+
train_user=train_user, train_item=train_item, valid_user=valid_user, test_user=test_user, R=R)

run_test_ConvMF.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env bash
2+
python ./run.py \
3+
-d ./data/preprocessed/ml-1m/0.2/ \
4+
-a ./data/preprocessed/ml-1m/ \
5+
-o ./test/ml-1m/result/1_100_200 \
6+
-e 200 \
7+
-p ./data/preprocessed/glove/glove.6B.200d.txt \
8+
-u 10 \
9+
-v 100 \
10+
-g True

run_test_preprocess.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#!/usr/bin/env bash
2+
python ./run.py \
3+
-d ./test/ml-1m/0.2/ \
4+
-a ./test/ml-1m/ \
5+
-c True \
6+
-r ./data/movielens/ml-1m_ratings.dat \
7+
-i ./data/movielens/ml_plot.dat \
8+
-m 1

0 commit comments

Comments
 (0)