Skip to content

Draft implementation of MiniBatchNMF #13326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 115 additions & 0 deletions sklearn/decomposition/benchmark_nmf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
from time import time

from scipy import sparse
import pandas as pd

from sklearn.decomposition.nmf import _beta_divergence
from sklearn.feature_extraction.text import CountVectorizer, HashingVectorizer

from nmf import NMF
from nmf_original import NMFOriginal

import matplotlib.pyplot as plt
from dirty_cat.datasets import fetch_traffic_violations

dataset = 'traffic_violations'

try:
X = sparse.load_npz('X.npz')
except FileNotFoundError:
if dataset == 'wiki':
df = pd.read_csv('/home/pcerda/parietal/online_nmf/scikit-learn/' +
'enwiki_1000000_first_paragraphs.csv')
cats = df['0'].astype(str)
counter = HashingVectorizer(analyzer='word', ngram_range=(1, 1),
n_features=2**12, norm=None,
alternate_sign=False)
elif dataset == 'traffic_violations':
data = fetch_traffic_violations()
df = pd.read_csv(data['path'])
cats = df['Model'].astype(str).values
counter = CountVectorizer(analyzer='char', ngram_range=(3, 3))
X = counter.fit_transform(cats)
# sparse.save_npz('X.npz', X)

n_test = 10000
n_train = 50000

X_test = X[:n_test, :]
X = X[n_test:n_train + n_test, :]

n_components = 10

print(X.shape)

time_nmf = []
kl_nmf = []
time_nmf2 = []
kl_nmf2 = []

fig, ax = plt.subplots()
# plt.yscale('log')
fontsize = 16
beta_loss = 'kullback-leibler'

max_iter_nmf = [1, 5, 10, 30, 50, 100]
max_iter_minibatch_nmf = [1, 5, 10, 20, 30, 40]

nmf2 = NMF(
n_components=n_components, beta_loss=beta_loss, batch_size=1000,
solver='mu', max_iter=1, random_state=10, tol=0)

for i, max_iter in enumerate(zip(max_iter_nmf, max_iter_minibatch_nmf)):
nmf = NMFOriginal(n_components=n_components, beta_loss=beta_loss,
solver='mu', max_iter=max_iter[0], random_state=10,
tol=0)
t0 = time()
nmf.fit(X)
W = nmf.transform(X_test)
tf = time() - t0
time_nmf.append(tf)
print('Time NMF: %.1fs.' % tf)
kldiv = _beta_divergence(X_test, W, nmf.components_,
nmf.beta_loss) / X_test.shape[0]
kl_nmf.append(kldiv)
print('KL-div NMF: %.2f' % kldiv)
del W

t0 = time()
# nmf2 = NMF(
# n_components=n_components, beta_loss=beta_loss, batch_size=1000,
# solver='mu', max_iter=max_iter[1], random_state=10, tol=0)
nmf2.partial_fit(X)
W = nmf2.transform(X_test)
tf = time() - t0
time_nmf2.append(tf)
print('Time MiniBatchNMF: %.1fs.' % tf)
kldiv = _beta_divergence(X_test, W, nmf2.components_,
nmf2.beta_loss) / X_test.shape[0]
kl_nmf2.append(kldiv)
print('KL-div MiniBatchNMF: %.2f' % kldiv)
del W

if i > 0:
plt.plot(time_nmf, kl_nmf, 'r', marker='o')
plt.plot(time_nmf2, kl_nmf2, 'b', marker='o')
plt.pause(.01)
if i == 1:
plt.legend(labels=['NMF', 'Online NMF'], fontsize=fontsize)


plt.tick_params(axis='both', which='major', labelsize=fontsize-2)
plt.xlabel('Time (seconds)', fontsize=fontsize)
plt.ylabel(beta_loss, fontsize=fontsize)

if dataset == 'traffic_violations':
title = 'Traffic Violations; Column: Model'
elif dataset == 'wiki':
title = 'Wikipedia articles (first paragraph)'
ax.set_title(title, fontsize=fontsize+4)

figname = 'benchmark_nmf_%s.pdf' % dataset
print('Saving: ' + figname)
plt.savefig(figname,
transparent=False, bbox_inches='tight', pad_inches=0)
plt.show()
280 changes: 280 additions & 0 deletions sklearn/decomposition/minibatch_nmf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,280 @@
import numpy as np
from scipy import sparse

from sklearn.utils import check_random_state
from sklearn.utils.extmath import row_norms, safe_sparse_dot, randomized_svd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils import gen_batches
# from sklearn.utils import check_array

from sklearn.cluster.k_means_ import _k_init
from sklearn.decomposition.nmf import _special_sparse_dot
from sklearn.decomposition.nmf import norm


class MiniBatchNMF(BaseEstimator, TransformerMixin):
"""
Mini batch non-negative matrix factorization by minimizing the
Kullback-Leibler divergence.

Parameters
----------

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't an empty line here against convention ?

n_components: int, default=10
Number of topics of the matrix Factorization.

batch_size: int, default=100

r: float, default=1
Weight parameter for the update of the W matrix

tol: float, default=1E-3
Tolerance for the convergence of the matrix W

mix_iter: int, default=2

max_iter: int, default=10

ngram_range: tuple, default=(2, 4)

init: str, default 'k-means++'
Initialization method of the W matrix.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps add 2-3 words on what the W matrix is in the NMF formulation


random_state: default=None

Attributes
----------

References
----------
"""

def __init__(self, n_components=10, batch_size=512,
r=.001, init='k-means++',
tol=1E-4, min_iter=2, max_iter=5, ngram_range=(2, 4),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tol does not match the default mentioned in docstring

add_words=False, random_state=None,
rescale_W=True, max_iter_e_step=20):

self.n_components = n_components
self.r = r
self.batch_size = batch_size
self.tol = tol
self.max_iter = max_iter
self.min_iter = min_iter
self.init = init
self.add_words = add_words
self.random_state = check_random_state(random_state)
self.rescale_W = rescale_W
self.max_iter_e_step = max_iter_e_step

def _rescale_W(self, W, A, B):
s = W.sum(axis=1, keepdims=True)
np.divide(W, s, out=W, where=(s != 0))
np.divide(A, s, out=A, where=(s != 0))
return W, A, B

def _rescale_H(self, V, H):
epsilon = 1e-10 # in case of a document having length=0
H *= np.maximum(epsilon, V.sum(axis=1).A)
H /= H.sum(axis=1, keepdims=True)
return H

def _e_step(self, Vt, W, Ht,
tol=1E-3, max_iter=20):
if self.rescale_W:
W_WT1 = W
else:
WT1 = np.sum(W, axis=1)
W_WT1 = W / WT1[:, np.newaxis]
squared_tol = tol**2
squared_norm = 1
for iter in range(max_iter):
if squared_norm <= squared_tol:
break
Ht_W = _special_sparse_dot(Ht, W, Vt)
Ht_W_data = Ht_W.data
Vt_data = Vt.data
np.divide(Vt_data, Ht_W_data, out=Ht_W_data,
where=(Ht_W_data != 0))
Ht_out = Ht * safe_sparse_dot(Ht_W, W_WT1.T)
squared_norm = np.linalg.norm(
Ht_out - Ht) / (np.linalg.norm(Ht) + 1E-10)
Ht[:] = Ht_out
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the [:] to make sure that the same Ht object is used in each iteration ?

return Ht

def _m_step(self, Vt, W, A, B, Ht, iter):
Ht_W = _special_sparse_dot(Ht, W, Vt)
Ht_W_data = Ht_W.data
np.divide(Vt.data, Ht_W_data, out=Ht_W_data, where=(Ht_W_data != 0))
self.rho_ = self.r ** (1 / iter)
# self.rho_ = .98
A *= self.rho_
A += W * safe_sparse_dot(Ht.T, Ht_W)
B *= self.rho_
B += Ht.sum(axis=0).reshape(-1, 1)
np.divide(A, B, out=W, where=(W != 0))
if self.rescale_W:
W, A, B = self._rescale_W(W, A, B)
return W, A, B

def _get_H(self, X):
H_out = np.empty((len(X), self.n_components))
for x, h_out in zip(X, H_out):
h_out[:] = self.H_dict[x]
return H_out

def _init_vars(self, V):
if self.init == 'k-means++':
W = _k_init(
V, self.n_components, row_norms(V, squared=True),
random_state=self.random_state,
n_local_trials=None) + .1
W /= W.sum(axis=1, keepdims=True)
H = np.ones((V.shape[0], self.n_components))
H = self._rescale_H(V, H)
elif self.init == 'random':
W = self.random_state.gamma(
shape=1, scale=1,
size=(self.n_components, self.n_features_))
W /= W.sum(axis=1, keepdims=True)
H = np.ones((V.shape[0], self.n_components))
H = self._rescale_H(V, H)
elif self.init == 'nndsvd':
eps = 1e-6
U, S, V = randomized_svd(V, self.n_components,
random_state=self.random_state)
H, W = np.zeros(U.shape), np.zeros(V.shape)

# The leading singular triplet is non-negative
# so it can be used as is for initialization.
H[:, 0] = np.sqrt(S[0]) * np.abs(U[:, 0])
W[0, :] = np.sqrt(S[0]) * np.abs(V[0, :])

for j in range(1, self.n_components):
x, y = U[:, j], V[j, :]

# extract positive and negative parts of column vectors
x_p, y_p = np.maximum(x, 0), np.maximum(y, 0)
x_n, y_n = np.abs(np.minimum(x, 0)), np.abs(np.minimum(y, 0))

# and their norms
x_p_nrm, y_p_nrm = norm(x_p), norm(y_p)
x_n_nrm, y_n_nrm = norm(x_n), norm(y_n)

m_p, m_n = x_p_nrm * y_p_nrm, x_n_nrm * y_n_nrm

# choose update
if m_p > m_n:
u = x_p / x_p_nrm
v = y_p / y_p_nrm
sigma = m_p
else:
u = x_n / x_n_nrm
v = y_n / y_n_nrm
sigma = m_n

lbd = np.sqrt(S[j] * sigma)
H[:, j] = lbd * u
W[j, :] = lbd * v

W[W < eps] = 0
H[H < eps] = 0
H = np.ones((V.shape[0], self.n_components))
H = self._rescale_H(V, H)
else:
raise AttributeError(
'Initialization method %s does not exist.' % self.init)
A = W.copy()
B = np.ones((self.n_components, self.n_features_))
return H, W, A, B

def fit(self, X, y=None):
"""Fit the NMF to X.

Parameters
----------
X : string array-like, shape [n_samples, n_features]
The data to determine the categories of each feature
Returns
-------
self
"""
n_samples, self.n_features_ = X.shape

if sparse.issparse(X):
H, self.W_, self.A_, self.B_ = self._init_vars(X)
# self.rho_ = self.r**(self.batch_size / n_samples)
# else:
# not implemented yet
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We'll have to implement for dense.


n_batch = (n_samples - 1) // self.batch_size + 1
self.iter = 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable should probably be named "self.n_iter_"


for iter in range(self.max_iter):
for i, slice in enumerate(gen_batches(n=n_samples,
batch_size=self.batch_size)):
if i == n_batch-1:
W_last = self.W_
H[slice] = self._e_step(X[slice], self.W_, H[slice],
max_iter=self.max_iter_e_step)
self.W_, self.A_, self.B_ = self._m_step(
X[slice], self.W_, self.A_, self.B_, H[slice], self.iter)
self.iter += 1
if i == n_batch-1:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PEP8 ? My feel is to add spaces around "-"

W_change = np.linalg.norm(
self.W_ - W_last) / np.linalg.norm(W_last)
if (W_change < self.tol) and (iter >= self.min_iter - 1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"iter" should be renamed to "n_iter" as "iter" is a function in Python.

break
return self

def partial_fit(self, X, y=None):
if hasattr(self, 'iter'):
assert X.shape[1] == self.n_features_
n_samples, _ = X.shape

if sparse.issparse(X):
H = np.ones((n_samples, self.n_components))
H = self._rescale_H(X, H)
# else:
# not implemented yet
else:
n_samples, self.n_features_ = X.shape

if sparse.issparse(X):
# H = np.ones((n_samples, self.n_components))
# H = self._rescale_H(X, H)
H, self.W_, self.A_, self.B_ = self._init_vars(X)
self.iter = 1
# self.rho = self.r**(self.batch_size / n_samples)
# else:
# not implemented yet

for slice in gen_batches(n=n_samples, batch_size=self.batch_size):
H[slice] = self._e_step(X[slice], self.W_, H[slice],
max_iter=self.max_iter_e_step)
self.W_, self.A_, self.B_ = self._m_step(
X[slice], self.W_, self.A_, self.B_, H[slice], self.iter)
self.iter += 1

def transform(self, X):
"""Transform X using the trained matrix W.

Parameters
----------
X : array-like (str), shape [n_samples,]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_feature missing in [] bracket ?

The data to encode.

Returns
-------
X_new : 2-d array, shape [n_samples, n_components]
Transformed input.
"""
assert X.shape[1] == self.n_features_
n_samples, _ = X.shape

H = np.ones((n_samples, self.n_components))
H = self._rescale_H(X, H)

for slice in gen_batches(n=n_samples, batch_size=self.batch_size):
H[slice] = self._e_step(X[slice], self.W_, H[slice], max_iter=50)
return H
Loading