diff --git a/scikits/learn/preprocessing.py b/scikits/learn/preprocessing.py new file mode 100644 index 0000000000000..9e66a2e76d672 --- /dev/null +++ b/scikits/learn/preprocessing.py @@ -0,0 +1,31 @@ +import numpy as np + +from .base import BaseEstimator + +class Scaler(BaseEstimator): + """Object to standardize a dataset along any axis + + It centers the dataset and optionaly scales to + fix the variance to 1. + + """ + def __init__(self, axis=0, with_std=True): + self.axis = axis + self.with_std = with_std + + def fit(self, X, y=None, **params): + self._set_params(**params) + X = np.rollaxis(X, self.axis) + self.mean = X.mean(axis=0) + if self.with_std: + self.std = X.std(axis=0) + return self + + def transform(self, X, y=None, copy=True): + if copy: + X = X.copy() + Xr = np.rollaxis(X, self.axis) + Xr -= self.mean + if self.with_std: + Xr /= self.std + return X diff --git a/scikits/learn/tests/test_preprocessing.py b/scikits/learn/tests/test_preprocessing.py new file mode 100644 index 0000000000000..efc6e1f22ed7c --- /dev/null +++ b/scikits/learn/tests/test_preprocessing.py @@ -0,0 +1,20 @@ +import numpy as np + +from numpy.testing import assert_array_almost_equal + +from scikits.learn.preprocessing import Scaler + +def test_scaler(): + """Test scaling of dataset along all axis + """ + + X = np.random.randn(4, 5) + + scaler = Scaler(axis=1) + X_scaled = scaler.fit(X).transform(X, copy=False) + assert_array_almost_equal(X_scaled.mean(axis=1), 4*[0.0]) + assert_array_almost_equal(X_scaled.std(axis=1), 4*[1.0]) + + scaler = Scaler(axis=0, with_std=False) + X_scaled = scaler.fit(X).transform(X, copy=False) + assert_array_almost_equal(X_scaled.mean(axis=0), 5*[0.0])