|
| 1 | +''' |
| 2 | +Created on Dec 9, 2015 |
| 3 | +
|
| 4 | +@author: donghyunz |
| 5 | +''' |
| 6 | +import argparse |
| 7 | +import sys |
| 8 | +from data_manager import Data_Factory |
| 9 | + |
| 10 | +parser = argparse.ArgumentParser() |
| 11 | + |
| 12 | +# Option for pre-processing data |
| 13 | +parser.add_argument("-c", "--do_preprocess", type=bool, |
| 14 | + help="True or False to preprocess raw data for ConvMF (default = False)", default=False) |
| 15 | +parser.add_argument("-r", "--raw_rating_data_path", type=str, |
| 16 | + help="Path to raw rating data. data format - user id::item id::rating") |
| 17 | +parser.add_argument("-i", "--raw_item_document_data_path", type=str, |
| 18 | + help="Path to raw item document data. item document consists of multiple text. data format - item id::text1|text2...") |
| 19 | +parser.add_argument("-m", "--min_rating", type=int, |
| 20 | + help="Users who have less than \"min_rating\" ratings will be removed (default = 1)", default=1) |
| 21 | +parser.add_argument("-l", "--max_length_document", type=int, |
| 22 | + help="Maximum length of document of each item (default = 300)", default=300) |
| 23 | +parser.add_argument("-f", "--max_df", type=float, |
| 24 | + help="Threshold to ignore terms that have a document frequency higher than the given value (default = 0.5)", default=0.5) |
| 25 | +parser.add_argument("-s", "--vocab_size", type=int, |
| 26 | + help="Size of vocabulary (default = 8000)", default=8000) |
| 27 | +parser.add_argument("-t", "--split_ratio", type=float, |
| 28 | + help="Ratio: 1-ratio, ratio/2 and ratio/2 of the entire dataset (R) will be training, valid and test set, respectively (default = 0.2)", default=0.2) |
| 29 | + |
| 30 | +# Option for pre-processing data and running ConvMF |
| 31 | +parser.add_argument("-d", "--data_path", type=str, |
| 32 | + help="Path to training, valid and test data sets") |
| 33 | +parser.add_argument("-a", "--aux_path", type=str, help="Path to R, D_all sets") |
| 34 | + |
| 35 | +# Option for running ConvMF |
| 36 | +parser.add_argument("-o", "--res_dir", type=str, |
| 37 | + help="Path to ConvMF's result") |
| 38 | +parser.add_argument("-e", "--emb_dim", type=int, |
| 39 | + help="Size of latent dimension for word vectors (default: 200)", default=200) |
| 40 | +parser.add_argument("-p", "--pretrain_w2v", type=str, |
| 41 | + help="Path to pretrain word embedding model to initialize word vectors") |
| 42 | +parser.add_argument("-g", "--give_item_weight", type=bool, |
| 43 | + help="True or False to give item weight of ConvMF (default = False)", default=True) |
| 44 | +parser.add_argument("-k", "--dimension", type=int, |
| 45 | + help="Size of latent dimension for users and items (default: 50)", default=50) |
| 46 | +parser.add_argument("-u", "--lambda_u", type=float, |
| 47 | + help="Value of user regularizer") |
| 48 | +parser.add_argument("-v", "--lambda_v", type=float, |
| 49 | + help="Value of item regularizer") |
| 50 | +parser.add_argument("-n", "--max_iter", type=int, |
| 51 | + help="Value of max iteration (default: 200)", default=200) |
| 52 | +parser.add_argument("-w", "--num_kernel_per_ws", type=int, |
| 53 | + help="Number of kernels per window size for CNN module (default: 100)", default=100) |
| 54 | + |
| 55 | +args = parser.parse_args() |
| 56 | +do_preprocess = args.do_preprocess |
| 57 | +data_path = args.data_path |
| 58 | +aux_path = args.aux_path |
| 59 | +if data_path is None: |
| 60 | + sys.exit("Argument missing - data_path is required") |
| 61 | +if aux_path is None: |
| 62 | + sys.exit("Argument missing - aux_path is required") |
| 63 | + |
| 64 | +data_factory = Data_Factory() |
| 65 | + |
| 66 | +if do_preprocess: |
| 67 | + path_rating = args.raw_rating_data_path |
| 68 | + path_itemtext = args.raw_item_document_data_path |
| 69 | + min_rating = args.min_rating |
| 70 | + max_length = args.max_length_document |
| 71 | + max_df = args.max_df |
| 72 | + vocab_size = args.vocab_size |
| 73 | + split_ratio = args.split_ratio |
| 74 | + |
| 75 | + print ("=================================Preprocess Option Setting=================================") |
| 76 | + print ("\tsaving preprocessed aux path - %s" % aux_path) |
| 77 | + print ("\tsaving preprocessed data path - %s" % data_path) |
| 78 | + print ("\trating data path - %s" % path_rating) |
| 79 | + print ("\tdocument data path - %s" % path_itemtext) |
| 80 | + print ("\tmin_rating: %d\n\tmax_length_document: %d\n\tmax_df: %.1f\n\tvocab_size: %d\n\tsplit_ratio: %.1f" \ |
| 81 | + % (min_rating, max_length, max_df, vocab_size, split_ratio)) |
| 82 | + print ("===========================================================================================") |
| 83 | + |
| 84 | + R, D_all = data_factory.preprocess( |
| 85 | + path_rating, path_itemtext, min_rating, max_length, max_df, vocab_size) |
| 86 | + data_factory.save(aux_path, R, D_all) |
| 87 | + data_factory.generate_train_valid_test_file_from_R( |
| 88 | + data_path, R, split_ratio) |
| 89 | +else: |
| 90 | + res_dir = args.res_dir |
| 91 | + emb_dim = args.emb_dim |
| 92 | + pretrain_w2v = args.pretrain_w2v |
| 93 | + dimension = args.dimension |
| 94 | + lambda_u = args.lambda_u |
| 95 | + lambda_v = args.lambda_v |
| 96 | + max_iter = args.max_iter |
| 97 | + num_kernel_per_ws = args.num_kernel_per_ws |
| 98 | + give_item_weight = args.give_item_weight |
| 99 | + |
| 100 | + if res_dir is None: |
| 101 | + sys.exit("Argument missing - res_dir is required") |
| 102 | + if lambda_u is None: |
| 103 | + sys.exit("Argument missing - lambda_u is required") |
| 104 | + if lambda_v is None: |
| 105 | + sys.exit("Argument missing - lambda_v is required") |
| 106 | + |
| 107 | + print ("===================================ConvMF Option Setting===================================") |
| 108 | + print ("\taux path - %s" % aux_path) |
| 109 | + print ("\tdata path - %s" % data_path) |
| 110 | + print ("\tresult path - %s" % res_dir) |
| 111 | + print ("\tpretrained w2v data path - %s" % pretrain_w2v) |
| 112 | + print ("\tdimension: %d\n\tlambda_u: %.4f\n\tlambda_v: %.4f\n\tmax_iter: %d\n\tnum_kernel_per_ws: %d" \ |
| 113 | + % (dimension, lambda_u, lambda_v, max_iter, num_kernel_per_ws)) |
| 114 | + print ("===========================================================================================") |
| 115 | + |
| 116 | + R, D_all = data_factory.load(aux_path) |
| 117 | + CNN_X = D_all['X_sequence'] |
| 118 | + vocab_size = len(D_all['X_vocab']) + 1 |
| 119 | + |
| 120 | + print("\tJay::vocab_size is %d" % vocab_size) |
| 121 | + print("\tJay::cnn_x is %d" % len(CNN_X)) |
| 122 | + |
| 123 | + from models import ConvMF |
| 124 | + |
| 125 | + if pretrain_w2v is None: |
| 126 | + init_W = None |
| 127 | + else: |
| 128 | + # 生成词向量矩阵 |
| 129 | + init_W = data_factory.read_pretrained_word2vec( |
| 130 | + pretrain_w2v, D_all['X_vocab'], emb_dim) |
| 131 | + |
| 132 | + train_user = data_factory.read_rating(data_path + '/train_user.dat') |
| 133 | + train_item = data_factory.read_rating(data_path + '/train_item.dat') |
| 134 | + valid_user = data_factory.read_rating(data_path + '/valid_user.dat') |
| 135 | + test_user = data_factory.read_rating(data_path + '/test_user.dat') |
| 136 | + |
| 137 | + ConvMF(max_iter=max_iter, res_dir=res_dir, |
| 138 | + lambda_u=lambda_u, lambda_v=lambda_v, dimension=dimension, vocab_size=vocab_size, init_W=init_W, |
| 139 | + give_item_weight=give_item_weight, CNN_X=CNN_X, emb_dim=emb_dim, num_kernel_per_ws=num_kernel_per_ws, |
| 140 | + train_user=train_user, train_item=train_item, valid_user=valid_user, test_user=test_user, R=R) |
0 commit comments