Skip to content

Commit d9d65a6

Browse files
committed
ENH improve benchmark on nmf
1 parent 726c8d9 commit d9d65a6

File tree

1 file changed

+152
-155
lines changed

1 file changed

+152
-155
lines changed

benchmarks/bench_plot_nmf.py

Lines changed: 152 additions & 155 deletions
Original file line numberDiff line numberDiff line change
@@ -1,168 +1,165 @@
11
"""
22
Benchmarks of Non-Negative Matrix Factorization
33
"""
4+
# Author : Tom Dupre la Tour <tom.dupre-la-tour@m4x.org>
5+
# License: BSD 3 clause
46

57
from __future__ import print_function
6-
7-
from collections import defaultdict
8-
import gc
98
from time import time
9+
import sys
1010

1111
import six
1212

1313
import numpy as np
14-
from scipy.linalg import norm
15-
16-
from sklearn.decomposition.nmf import NMF, _initialize_nmf
17-
from sklearn.datasets.samples_generator import make_low_rank_matrix
18-
from sklearn.externals.six.moves import xrange
19-
20-
21-
def alt_nnmf(V, r, max_iter=1000, tol=1e-3, init='random'):
22-
"""
23-
A, S = nnmf(X, r, tol=1e-3, R=None)
24-
25-
Implement Lee & Seung's algorithm
26-
27-
Parameters
28-
----------
29-
V : 2-ndarray, [n_samples, n_features]
30-
input matrix
31-
r : integer
32-
number of latent features
33-
max_iter : integer, optional
34-
maximum number of iterations (default: 1000)
35-
tol : double
36-
tolerance threshold for early exit (when the update factor is within
37-
tol of 1., the function exits)
38-
init : string
39-
Method used to initialize the procedure.
40-
41-
Returns
42-
-------
43-
A : 2-ndarray, [n_samples, r]
44-
Component part of the factorization
45-
46-
S : 2-ndarray, [r, n_features]
47-
Data part of the factorization
48-
Reference
49-
---------
50-
"Algorithms for Non-negative Matrix Factorization"
51-
by Daniel D Lee, Sebastian H Seung
52-
(available at http://citeseer.ist.psu.edu/lee01algorithms.html)
53-
"""
54-
# Nomenclature in the function follows Lee & Seung
55-
eps = 1e-5
56-
n, m = V.shape
57-
W, H = _initialize_nmf(V, r, init, random_state=0)
58-
59-
for i in xrange(max_iter):
60-
updateH = np.dot(W.T, V) / (np.dot(np.dot(W.T, W), H) + eps)
61-
H *= updateH
62-
updateW = np.dot(V, H.T) / (np.dot(W, np.dot(H, H.T)) + eps)
63-
W *= updateW
64-
if i % 10 == 0:
65-
max_update = max(updateW.max(), updateH.max())
66-
if abs(1. - max_update) < tol:
67-
break
68-
return W, H
69-
70-
71-
def report(error, time):
72-
print("Frobenius loss: %.5f" % error)
73-
print("Took: %.2fs" % time)
74-
print()
75-
76-
77-
def benchmark(samples_range, features_range, rank=50, tolerance=1e-5):
78-
timeset = defaultdict(lambda: [])
79-
err = defaultdict(lambda: [])
80-
81-
for n_samples in samples_range:
82-
for n_features in features_range:
83-
print("%2d samples, %2d features" % (n_samples, n_features))
84-
print('=======================')
85-
X = np.abs(make_low_rank_matrix(n_samples, n_features,
86-
effective_rank=rank, tail_strength=0.2))
87-
88-
gc.collect()
89-
print("benchmarking nndsvd-nmf: ")
90-
tstart = time()
91-
m = NMF(n_components=30, tol=tolerance, init='nndsvd').fit(X)
92-
tend = time() - tstart
93-
timeset['nndsvd-nmf'].append(tend)
94-
err['nndsvd-nmf'].append(m.reconstruction_err_)
95-
report(m.reconstruction_err_, tend)
96-
97-
gc.collect()
98-
print("benchmarking nndsvda-nmf: ")
99-
tstart = time()
100-
m = NMF(n_components=30, init='nndsvda',
101-
tol=tolerance).fit(X)
102-
tend = time() - tstart
103-
timeset['nndsvda-nmf'].append(tend)
104-
err['nndsvda-nmf'].append(m.reconstruction_err_)
105-
report(m.reconstruction_err_, tend)
106-
107-
gc.collect()
108-
print("benchmarking nndsvdar-nmf: ")
109-
tstart = time()
110-
m = NMF(n_components=30, init='nndsvdar',
111-
tol=tolerance).fit(X)
112-
tend = time() - tstart
113-
timeset['nndsvdar-nmf'].append(tend)
114-
err['nndsvdar-nmf'].append(m.reconstruction_err_)
115-
report(m.reconstruction_err_, tend)
116-
117-
gc.collect()
118-
print("benchmarking random-nmf")
119-
tstart = time()
120-
m = NMF(n_components=30, init='random', max_iter=1000,
121-
tol=tolerance).fit(X)
122-
tend = time() - tstart
123-
timeset['random-nmf'].append(tend)
124-
err['random-nmf'].append(m.reconstruction_err_)
125-
report(m.reconstruction_err_, tend)
126-
127-
gc.collect()
128-
print("benchmarking alt-random-nmf")
129-
tstart = time()
130-
W, H = alt_nnmf(X, r=30, init='random', tol=tolerance)
131-
tend = time() - tstart
132-
timeset['alt-random-nmf'].append(tend)
133-
err['alt-random-nmf'].append(np.linalg.norm(X - np.dot(W, H)))
134-
report(norm(X - np.dot(W, H)), tend)
135-
136-
return timeset, err
14+
import matplotlib.pyplot as plt
15+
import pandas
16+
17+
from sklearn.utils.testing import ignore_warnings
18+
from sklearn.feature_extraction.text import TfidfVectorizer
19+
from sklearn.decomposition.nmf import NMF
20+
from sklearn.decomposition.nmf import _initialize_nmf
21+
from sklearn.decomposition.nmf import _beta_divergence
22+
from sklearn.externals.joblib import Memory
23+
from sklearn.exceptions import ConvergenceWarning
24+
25+
mem = Memory(cachedir='.', verbose=0)
26+
27+
28+
def plot_results(results_df, plot_name):
29+
if results_df is None:
30+
return None
31+
32+
plt.figure(figsize=(16, 6))
33+
colors = 'bgr'
34+
markers = 'ovs'
35+
ax = plt.subplot(1, 3, 1)
36+
for i, init in enumerate(np.unique(results_df['init'])):
37+
plt.subplot(1, 3, i + 1, sharex=ax, sharey=ax)
38+
for j, method in enumerate(np.unique(results_df['method'])):
39+
mask = np.logical_and(results_df['init'] == init,
40+
results_df['method'] == method)
41+
selected_items = results_df[mask]
42+
43+
plt.plot(selected_items['time'], selected_items['loss'],
44+
color=colors[j % len(colors)], ls='-',
45+
marker=markers[j % len(markers)],
46+
label=method)
47+
48+
plt.legend(loc=0, fontsize='x-small')
49+
plt.xlabel("Time (s)")
50+
plt.ylabel("loss")
51+
plt.title("%s" % init)
52+
plt.suptitle(plot_name, fontsize=16)
53+
54+
55+
# The deprecated projected-gradient solver raises a UserWarning as convergence
56+
# is not reached; the coordinate-descent solver raises a ConvergenceWarning.
57+
@ignore_warnings(category=(ConvergenceWarning, UserWarning,
58+
DeprecationWarning))
59+
# use joblib to cache the results.
60+
# X_shape is specified in arguments for avoiding hashing X
61+
@mem.cache(ignore=['X', 'W0', 'H0'])
62+
def bench_one(name, X, W0, H0, X_shape, clf_type, clf_params, init,
63+
n_components, random_state):
64+
W = W0.copy()
65+
H = H0.copy()
66+
67+
clf = clf_type(**clf_params)
68+
st = time()
69+
W = clf.fit_transform(X, W=W, H=H)
70+
end = time()
71+
H = clf.components_
72+
73+
this_loss = _beta_divergence(X, W, H, 2.0, True)
74+
duration = end - st
75+
return this_loss, duration
76+
77+
78+
def run_bench(X, clfs, plot_name, n_components, tol, alpha, l1_ratio):
79+
start = time()
80+
results = []
81+
for name, clf_type, iter_range, clf_params in clfs:
82+
print("Training %s:" % name)
83+
for rs, init in enumerate(('nndsvd', 'nndsvdar', 'random')):
84+
print(" %s %s: " % (init, " " * (8 - len(init))), end="")
85+
W, H = _initialize_nmf(X, n_components, init, 1e-6, rs)
86+
87+
for max_iter in iter_range:
88+
clf_params['alpha'] = alpha
89+
clf_params['l1_ratio'] = l1_ratio
90+
clf_params['max_iter'] = max_iter
91+
clf_params['tol'] = tol
92+
clf_params['random_state'] = rs
93+
clf_params['init'] = 'custom'
94+
clf_params['n_components'] = n_components
95+
96+
this_loss, duration = bench_one(name, X, W, H, X.shape,
97+
clf_type, clf_params,
98+
init, n_components, rs)
99+
100+
init_name = "init='%s'" % init
101+
results.append((name, this_loss, duration, init_name))
102+
# print("loss: %.6f, time: %.3f sec" % (this_loss, duration))
103+
print(".", end="")
104+
sys.stdout.flush()
105+
print(" ")
106+
107+
# Use a panda dataframe to organize the results
108+
results_df = pandas.DataFrame(results,
109+
columns="method loss time init".split())
110+
print("Total time = %0.3f sec\n" % (time() - start))
111+
112+
# plot the results
113+
plot_results(results_df, plot_name)
114+
return results_df
115+
116+
117+
def load_20news():
118+
print("Loading 20 newsgroups dataset")
119+
print("-----------------------------")
120+
from sklearn.datasets import fetch_20newsgroups
121+
dataset = fetch_20newsgroups(shuffle=True, random_state=1,
122+
remove=('headers', 'footers', 'quotes'))
123+
vectorizer = TfidfVectorizer(max_df=0.95, min_df=2, stop_words='english')
124+
tfidf = vectorizer.fit_transform(dataset.data)
125+
return tfidf
126+
127+
128+
def load_faces():
129+
print("Loading Olivetti face dataset")
130+
print("-----------------------------")
131+
from sklearn.datasets import fetch_olivetti_faces
132+
faces = fetch_olivetti_faces(shuffle=True)
133+
return faces.data
134+
135+
136+
def build_clfs(cd_iters, mu_iters):
137+
clfs = [("Coordinate Descent", NMF, cd_iters, {'solver': 'cd'}),
138+
("Multiplicative Update", NMF, mu_iters, {'solver': 'mu'}),
139+
]
140+
return clfs
137141

138142

139143
if __name__ == '__main__':
140-
from mpl_toolkits.mplot3d import axes3d # register the 3d projection
141-
axes3d
142-
import matplotlib.pyplot as plt
143-
144-
samples_range = np.linspace(50, 500, 3).astype(np.int)
145-
features_range = np.linspace(50, 500, 3).astype(np.int)
146-
timeset, err = benchmark(samples_range, features_range)
147-
148-
for i, results in enumerate((timeset, err)):
149-
fig = plt.figure('scikit-learn Non-Negative Matrix Factorization'
150-
'benchmark results')
151-
ax = fig.gca(projection='3d')
152-
for c, (label, timings) in zip('rbgcm', sorted(six.iteritems(results))):
153-
X, Y = np.meshgrid(samples_range, features_range)
154-
Z = np.asarray(timings).reshape(samples_range.shape[0],
155-
features_range.shape[0])
156-
# plot the actual surface
157-
ax.plot_surface(X, Y, Z, rstride=8, cstride=8, alpha=0.3,
158-
color=c)
159-
# dummy point plot to stick the legend to since surface plot do not
160-
# support legends (yet?)
161-
ax.plot([1], [1], [1], color=c, label=label)
162-
163-
ax.set_xlabel('n_samples')
164-
ax.set_ylabel('n_features')
165-
zlabel = 'Time (s)' if i == 0 else 'reconstruction error'
166-
ax.set_zlabel(zlabel)
167-
ax.legend()
168-
plt.show()
144+
alpha = 0.
145+
l1_ratio = 0.5
146+
n_components = 10
147+
tol = 1e-15
148+
149+
# first benchmark on 20 newsgroup dataset: sparse, shape(11314, 39116)
150+
plot_name = "20 Newsgroups sparse dataset"
151+
cd_iters = np.arange(1, 30)
152+
mu_iters = np.arange(1, 30)
153+
clfs = build_clfs(cd_iters, mu_iters)
154+
X_20news = load_20news()
155+
run_bench(X_20news, clfs, plot_name, n_components, tol, alpha, l1_ratio)
156+
157+
# second benchmark on Olivetti faces dataset: dense, shape(400, 4096)
158+
plot_name = "Olivetti Faces dense dataset"
159+
cd_iters = np.arange(1, 30)
160+
mu_iters = np.arange(1, 30)
161+
clfs = build_clfs(cd_iters, mu_iters)
162+
X_faces = load_faces()
163+
run_bench(X_faces, clfs, plot_name, n_components, tol, alpha, l1_ratio,)
164+
165+
plt.show()

0 commit comments

Comments
 (0)