|
1 | 1 | """
|
2 | 2 | Benchmarks of Non-Negative Matrix Factorization
|
3 | 3 | """
|
| 4 | +# Author : Tom Dupre la Tour <tom.dupre-la-tour@m4x.org> |
| 5 | +# License: BSD 3 clause |
4 | 6 |
|
5 | 7 | from __future__ import print_function
|
6 |
| - |
7 |
| -from collections import defaultdict |
8 |
| -import gc |
9 | 8 | from time import time
|
| 9 | +import sys |
10 | 10 |
|
11 | 11 | import six
|
12 | 12 |
|
13 | 13 | import numpy as np
|
14 |
| -from scipy.linalg import norm |
15 |
| - |
16 |
| -from sklearn.decomposition.nmf import NMF, _initialize_nmf |
17 |
| -from sklearn.datasets.samples_generator import make_low_rank_matrix |
18 |
| -from sklearn.externals.six.moves import xrange |
19 |
| - |
20 |
| - |
21 |
| -def alt_nnmf(V, r, max_iter=1000, tol=1e-3, init='random'): |
22 |
| - """ |
23 |
| - A, S = nnmf(X, r, tol=1e-3, R=None) |
24 |
| -
|
25 |
| - Implement Lee & Seung's algorithm |
26 |
| -
|
27 |
| - Parameters |
28 |
| - ---------- |
29 |
| - V : 2-ndarray, [n_samples, n_features] |
30 |
| - input matrix |
31 |
| - r : integer |
32 |
| - number of latent features |
33 |
| - max_iter : integer, optional |
34 |
| - maximum number of iterations (default: 1000) |
35 |
| - tol : double |
36 |
| - tolerance threshold for early exit (when the update factor is within |
37 |
| - tol of 1., the function exits) |
38 |
| - init : string |
39 |
| - Method used to initialize the procedure. |
40 |
| -
|
41 |
| - Returns |
42 |
| - ------- |
43 |
| - A : 2-ndarray, [n_samples, r] |
44 |
| - Component part of the factorization |
45 |
| -
|
46 |
| - S : 2-ndarray, [r, n_features] |
47 |
| - Data part of the factorization |
48 |
| - Reference |
49 |
| - --------- |
50 |
| - "Algorithms for Non-negative Matrix Factorization" |
51 |
| - by Daniel D Lee, Sebastian H Seung |
52 |
| - (available at http://citeseer.ist.psu.edu/lee01algorithms.html) |
53 |
| - """ |
54 |
| - # Nomenclature in the function follows Lee & Seung |
55 |
| - eps = 1e-5 |
56 |
| - n, m = V.shape |
57 |
| - W, H = _initialize_nmf(V, r, init, random_state=0) |
58 |
| - |
59 |
| - for i in xrange(max_iter): |
60 |
| - updateH = np.dot(W.T, V) / (np.dot(np.dot(W.T, W), H) + eps) |
61 |
| - H *= updateH |
62 |
| - updateW = np.dot(V, H.T) / (np.dot(W, np.dot(H, H.T)) + eps) |
63 |
| - W *= updateW |
64 |
| - if i % 10 == 0: |
65 |
| - max_update = max(updateW.max(), updateH.max()) |
66 |
| - if abs(1. - max_update) < tol: |
67 |
| - break |
68 |
| - return W, H |
69 |
| - |
70 |
| - |
71 |
| -def report(error, time): |
72 |
| - print("Frobenius loss: %.5f" % error) |
73 |
| - print("Took: %.2fs" % time) |
74 |
| - print() |
75 |
| - |
76 |
| - |
77 |
| -def benchmark(samples_range, features_range, rank=50, tolerance=1e-5): |
78 |
| - timeset = defaultdict(lambda: []) |
79 |
| - err = defaultdict(lambda: []) |
80 |
| - |
81 |
| - for n_samples in samples_range: |
82 |
| - for n_features in features_range: |
83 |
| - print("%2d samples, %2d features" % (n_samples, n_features)) |
84 |
| - print('=======================') |
85 |
| - X = np.abs(make_low_rank_matrix(n_samples, n_features, |
86 |
| - effective_rank=rank, tail_strength=0.2)) |
87 |
| - |
88 |
| - gc.collect() |
89 |
| - print("benchmarking nndsvd-nmf: ") |
90 |
| - tstart = time() |
91 |
| - m = NMF(n_components=30, tol=tolerance, init='nndsvd').fit(X) |
92 |
| - tend = time() - tstart |
93 |
| - timeset['nndsvd-nmf'].append(tend) |
94 |
| - err['nndsvd-nmf'].append(m.reconstruction_err_) |
95 |
| - report(m.reconstruction_err_, tend) |
96 |
| - |
97 |
| - gc.collect() |
98 |
| - print("benchmarking nndsvda-nmf: ") |
99 |
| - tstart = time() |
100 |
| - m = NMF(n_components=30, init='nndsvda', |
101 |
| - tol=tolerance).fit(X) |
102 |
| - tend = time() - tstart |
103 |
| - timeset['nndsvda-nmf'].append(tend) |
104 |
| - err['nndsvda-nmf'].append(m.reconstruction_err_) |
105 |
| - report(m.reconstruction_err_, tend) |
106 |
| - |
107 |
| - gc.collect() |
108 |
| - print("benchmarking nndsvdar-nmf: ") |
109 |
| - tstart = time() |
110 |
| - m = NMF(n_components=30, init='nndsvdar', |
111 |
| - tol=tolerance).fit(X) |
112 |
| - tend = time() - tstart |
113 |
| - timeset['nndsvdar-nmf'].append(tend) |
114 |
| - err['nndsvdar-nmf'].append(m.reconstruction_err_) |
115 |
| - report(m.reconstruction_err_, tend) |
116 |
| - |
117 |
| - gc.collect() |
118 |
| - print("benchmarking random-nmf") |
119 |
| - tstart = time() |
120 |
| - m = NMF(n_components=30, init='random', max_iter=1000, |
121 |
| - tol=tolerance).fit(X) |
122 |
| - tend = time() - tstart |
123 |
| - timeset['random-nmf'].append(tend) |
124 |
| - err['random-nmf'].append(m.reconstruction_err_) |
125 |
| - report(m.reconstruction_err_, tend) |
126 |
| - |
127 |
| - gc.collect() |
128 |
| - print("benchmarking alt-random-nmf") |
129 |
| - tstart = time() |
130 |
| - W, H = alt_nnmf(X, r=30, init='random', tol=tolerance) |
131 |
| - tend = time() - tstart |
132 |
| - timeset['alt-random-nmf'].append(tend) |
133 |
| - err['alt-random-nmf'].append(np.linalg.norm(X - np.dot(W, H))) |
134 |
| - report(norm(X - np.dot(W, H)), tend) |
135 |
| - |
136 |
| - return timeset, err |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +import pandas |
| 16 | + |
| 17 | +from sklearn.utils.testing import ignore_warnings |
| 18 | +from sklearn.feature_extraction.text import TfidfVectorizer |
| 19 | +from sklearn.decomposition.nmf import NMF |
| 20 | +from sklearn.decomposition.nmf import _initialize_nmf |
| 21 | +from sklearn.decomposition.nmf import _beta_divergence |
| 22 | +from sklearn.externals.joblib import Memory |
| 23 | +from sklearn.exceptions import ConvergenceWarning |
| 24 | + |
| 25 | +mem = Memory(cachedir='.', verbose=0) |
| 26 | + |
| 27 | + |
| 28 | +def plot_results(results_df, plot_name): |
| 29 | + if results_df is None: |
| 30 | + return None |
| 31 | + |
| 32 | + plt.figure(figsize=(16, 6)) |
| 33 | + colors = 'bgr' |
| 34 | + markers = 'ovs' |
| 35 | + ax = plt.subplot(1, 3, 1) |
| 36 | + for i, init in enumerate(np.unique(results_df['init'])): |
| 37 | + plt.subplot(1, 3, i + 1, sharex=ax, sharey=ax) |
| 38 | + for j, method in enumerate(np.unique(results_df['method'])): |
| 39 | + mask = np.logical_and(results_df['init'] == init, |
| 40 | + results_df['method'] == method) |
| 41 | + selected_items = results_df[mask] |
| 42 | + |
| 43 | + plt.plot(selected_items['time'], selected_items['loss'], |
| 44 | + color=colors[j % len(colors)], ls='-', |
| 45 | + marker=markers[j % len(markers)], |
| 46 | + label=method) |
| 47 | + |
| 48 | + plt.legend(loc=0, fontsize='x-small') |
| 49 | + plt.xlabel("Time (s)") |
| 50 | + plt.ylabel("loss") |
| 51 | + plt.title("%s" % init) |
| 52 | + plt.suptitle(plot_name, fontsize=16) |
| 53 | + |
| 54 | + |
| 55 | +# The deprecated projected-gradient solver raises a UserWarning as convergence |
| 56 | +# is not reached; the coordinate-descent solver raises a ConvergenceWarning. |
| 57 | +@ignore_warnings(category=(ConvergenceWarning, UserWarning, |
| 58 | + DeprecationWarning)) |
| 59 | +# use joblib to cache the results. |
| 60 | +# X_shape is specified in arguments for avoiding hashing X |
| 61 | +@mem.cache(ignore=['X', 'W0', 'H0']) |
| 62 | +def bench_one(name, X, W0, H0, X_shape, clf_type, clf_params, init, |
| 63 | + n_components, random_state): |
| 64 | + W = W0.copy() |
| 65 | + H = H0.copy() |
| 66 | + |
| 67 | + clf = clf_type(**clf_params) |
| 68 | + st = time() |
| 69 | + W = clf.fit_transform(X, W=W, H=H) |
| 70 | + end = time() |
| 71 | + H = clf.components_ |
| 72 | + |
| 73 | + this_loss = _beta_divergence(X, W, H, 2.0, True) |
| 74 | + duration = end - st |
| 75 | + return this_loss, duration |
| 76 | + |
| 77 | + |
| 78 | +def run_bench(X, clfs, plot_name, n_components, tol, alpha, l1_ratio): |
| 79 | + start = time() |
| 80 | + results = [] |
| 81 | + for name, clf_type, iter_range, clf_params in clfs: |
| 82 | + print("Training %s:" % name) |
| 83 | + for rs, init in enumerate(('nndsvd', 'nndsvdar', 'random')): |
| 84 | + print(" %s %s: " % (init, " " * (8 - len(init))), end="") |
| 85 | + W, H = _initialize_nmf(X, n_components, init, 1e-6, rs) |
| 86 | + |
| 87 | + for max_iter in iter_range: |
| 88 | + clf_params['alpha'] = alpha |
| 89 | + clf_params['l1_ratio'] = l1_ratio |
| 90 | + clf_params['max_iter'] = max_iter |
| 91 | + clf_params['tol'] = tol |
| 92 | + clf_params['random_state'] = rs |
| 93 | + clf_params['init'] = 'custom' |
| 94 | + clf_params['n_components'] = n_components |
| 95 | + |
| 96 | + this_loss, duration = bench_one(name, X, W, H, X.shape, |
| 97 | + clf_type, clf_params, |
| 98 | + init, n_components, rs) |
| 99 | + |
| 100 | + init_name = "init='%s'" % init |
| 101 | + results.append((name, this_loss, duration, init_name)) |
| 102 | + # print("loss: %.6f, time: %.3f sec" % (this_loss, duration)) |
| 103 | + print(".", end="") |
| 104 | + sys.stdout.flush() |
| 105 | + print(" ") |
| 106 | + |
| 107 | + # Use a panda dataframe to organize the results |
| 108 | + results_df = pandas.DataFrame(results, |
| 109 | + columns="method loss time init".split()) |
| 110 | + print("Total time = %0.3f sec\n" % (time() - start)) |
| 111 | + |
| 112 | + # plot the results |
| 113 | + plot_results(results_df, plot_name) |
| 114 | + return results_df |
| 115 | + |
| 116 | + |
| 117 | +def load_20news(): |
| 118 | + print("Loading 20 newsgroups dataset") |
| 119 | + print("-----------------------------") |
| 120 | + from sklearn.datasets import fetch_20newsgroups |
| 121 | + dataset = fetch_20newsgroups(shuffle=True, random_state=1, |
| 122 | + remove=('headers', 'footers', 'quotes')) |
| 123 | + vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english') |
| 124 | + tfidf = vectorizer.fit_transform(dataset.data) |
| 125 | + return tfidf |
| 126 | + |
| 127 | + |
| 128 | +def load_faces(): |
| 129 | + print("Loading Olivetti face dataset") |
| 130 | + print("-----------------------------") |
| 131 | + from sklearn.datasets import fetch_olivetti_faces |
| 132 | + faces = fetch_olivetti_faces(shuffle=True) |
| 133 | + return faces.data |
| 134 | + |
| 135 | + |
| 136 | +def build_clfs(cd_iters, mu_iters): |
| 137 | + clfs = [("Coordinate Descent", NMF, cd_iters, {'solver': 'cd'}), |
| 138 | + ("Multiplicative Update", NMF, mu_iters, {'solver': 'mu'}), |
| 139 | + ] |
| 140 | + return clfs |
137 | 141 |
|
138 | 142 |
|
139 | 143 | if __name__ == '__main__':
|
140 |
| - from mpl_toolkits.mplot3d import axes3d # register the 3d projection |
141 |
| - axes3d |
142 |
| - import matplotlib.pyplot as plt |
143 |
| - |
144 |
| - samples_range = np.linspace(50, 500, 3).astype(np.int) |
145 |
| - features_range = np.linspace(50, 500, 3).astype(np.int) |
146 |
| - timeset, err = benchmark(samples_range, features_range) |
147 |
| - |
148 |
| - for i, results in enumerate((timeset, err)): |
149 |
| - fig = plt.figure('scikit-learn Non-Negative Matrix Factorization' |
150 |
| - 'benchmark results') |
151 |
| - ax = fig.gca(projection='3d') |
152 |
| - for c, (label, timings) in zip('rbgcm', sorted(six.iteritems(results))): |
153 |
| - X, Y = np.meshgrid(samples_range, features_range) |
154 |
| - Z = np.asarray(timings).reshape(samples_range.shape[0], |
155 |
| - features_range.shape[0]) |
156 |
| - # plot the actual surface |
157 |
| - ax.plot_surface(X, Y, Z, rstride=8, cstride=8, alpha=0.3, |
158 |
| - color=c) |
159 |
| - # dummy point plot to stick the legend to since surface plot do not |
160 |
| - # support legends (yet?) |
161 |
| - ax.plot([1], [1], [1], color=c, label=label) |
162 |
| - |
163 |
| - ax.set_xlabel('n_samples') |
164 |
| - ax.set_ylabel('n_features') |
165 |
| - zlabel = 'Time (s)' if i == 0 else 'reconstruction error' |
166 |
| - ax.set_zlabel(zlabel) |
167 |
| - ax.legend() |
168 |
| - plt.show() |
| 144 | + alpha = 0. |
| 145 | + l1_ratio = 0.5 |
| 146 | + n_components = 10 |
| 147 | + tol = 1e-15 |
| 148 | + |
| 149 | + # first benchmark on 20 newsgroup dataset: sparse, shape(11314, 39116) |
| 150 | + plot_name = "20 Newsgroups sparse dataset" |
| 151 | + cd_iters = np.arange(1, 30) |
| 152 | + mu_iters = np.arange(1, 30) |
| 153 | + clfs = build_clfs(cd_iters, mu_iters) |
| 154 | + X_20news = load_20news() |
| 155 | + run_bench(X_20news, clfs, plot_name, n_components, tol, alpha, l1_ratio) |
| 156 | + |
| 157 | + # second benchmark on Olivetti faces dataset: dense, shape(400, 4096) |
| 158 | + plot_name = "Olivetti Faces dense dataset" |
| 159 | + cd_iters = np.arange(1, 30) |
| 160 | + mu_iters = np.arange(1, 30) |
| 161 | + clfs = build_clfs(cd_iters, mu_iters) |
| 162 | + X_faces = load_faces() |
| 163 | + run_bench(X_faces, clfs, plot_name, n_components, tol, alpha, l1_ratio,) |
| 164 | + |
| 165 | + plt.show() |
0 commit comments