Skip to content

GridsearchCV.fit throws ValueError when passed a large dataframe that contains an Object column #9483

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
stoddardg opened this issue Aug 2, 2017 · 6 comments · Fixed by #9507

Comments

@stoddardg
Copy link

stoddardg commented Aug 2, 2017

Description

I get the error of ValueError: buffer source array is read-only in the example below whenever I pass a dataframe with around 200K rows and at least one column of dtype Object into GridSearchCV with n_jobs > 1 . The error seems to be caused by passing in a Dataframe that has Object columns into GridsearchCV.fit. My custom class, DataFrame_Encoder, properly encodes the Object rows (by dummy encoding them) when the pipeline executes but this error occurs before it executes. Things work fine if I use a smaller dataset, drop the Object column from the dataframe, or set n_jobs=1.

My minimal example to reproduce the bug is a bit lengthy, so I've also included a notebook with the code and some theories as to what is happening: https://github.com/stoddardg/sklearn_bug_example/blob/master/Bug%20Exploration.ipynb

Steps/Code to Reproduce

Example:

import pandas as pd

from sklearn.model_selection import GridSearchCV
from sklearn.pipeline import Pipeline
from sklearn.datasets import make_classification
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction import DictVectorizer

import numpy as np


class DataFrame_Encoder(BaseEstimator, TransformerMixin):
    
    def __init__(self, categorical_cols_=None,numeric_cols_=None):
        print("__init__ called")
        self.categorical_cols_ = categorical_cols_
        self.numeric_cols_ = numeric_cols_
    
    def fit(self, df, y=None):
        print("Fit called")
        ### df should be a dataframe that is a mix of categorical and numeric columns
        self.vec_ = DictVectorizer(sparse=False)
        temp_data = df[self.categorical_cols_].astype(str)
        self.vec_.fit(temp_data.to_dict('records'))
        self.feature_names_ = list(self.numeric_cols_) + list(self.vec_.feature_names_)
        return self

    def transform(self, df):
        ### df should be a dataframe that is a mix of categorical and numeric columns
        print("Transform called")
        temp_data = df[self.categorical_cols_].astype(str)
        categorical_data = self.vec_.transform(temp_data.to_dict('records'))
        categorical_df = pd.DataFrame(categorical_data, columns=self.vec_.feature_names_, index=df.index)
        new_data = pd.concat([df[self.numeric_cols_], categorical_df],axis=1)
        return new_data

x,y = make_classification(n_samples=200000,n_features=5)

numeric_features = ['x1','x2','x3','x4','x5']
string_features = ['category']

df = pd.DataFrame(data=x,columns=numeric_features)
df['category'] = 'a'

base_clf = RandomForestClassifier(n_jobs=4)
param_grid = {'clf__n_estimators':[10,100]}

pipeline = Pipeline([
        ('feature_encoder',DataFrame_Encoder()),
        ('clf',base_clf)
])
pipeline.set_params(feature_encoder__categorical_cols_=string_features, feature_encoder__numeric_cols_=numeric_features)

clf = GridSearchCV(pipeline, param_grid,cv=5,n_jobs=2,verbose=1)

clf.fit(df,y)
---------------------------------------------------------------------------
Sub-process traceback:
---------------------------------------------------------------------------
ValueError                                         Thu Aug  3 16:29:11 2017
PID: 16736           Python 3.6.2: /volatile/le243287/miniconda3/bin/python
...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/externals/joblib/parallel.py in __call__(self=<sklearn.externals.joblib.parallel.BatchedCalls object>)
    126     def __init__(self, iterator_slice):
    127         self.items = list(iterator_slice)
    128         self._size = len(self.items)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
        self.items = [(<function _fit_and_score>, (Pipeline(memory=None,
     steps=[('feature_enco...None, verbose=0,
            warm_start=False))]),               x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], memmap([0, 0, 1, ..., 1, 1, 1]), {'score': <function _passthrough_scorer>}, memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), array([    0,     1,     2, ..., 40174, 40178, 40184]), 1, {'clf__n_estimators': 10}), {'error_score': 'raise', 'fit_params': {}, 'return_n_test_samples': True, 'return_parameters': False, 'return_times': True, 'return_train_score': True})]
    132 
    133     def __len__(self):
    134         return self._size
    135 

...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/externals/joblib/parallel.py in <listcomp>(.0=<list_iterator object>)
    126     def __init__(self, iterator_slice):
    127         self.items = list(iterator_slice)
    128         self._size = len(self.items)
    129 
    130     def __call__(self):
--> 131         return [func(*args, **kwargs) for func, args, kwargs in self.items]
        func = <function _fit_and_score>
        args = (Pipeline(memory=None,
     steps=[('feature_enco...None, verbose=0,
            warm_start=False))]),               x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], memmap([0, 0, 1, ..., 1, 1, 1]), {'score': <function _passthrough_scorer>}, memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), array([    0,     1,     2, ..., 40174, 40178, 40184]), 1, {'clf__n_estimators': 10})
        kwargs = {'error_score': 'raise', 'fit_params': {}, 'return_n_test_samples': True, 'return_parameters': False, 'return_times': True, 'return_train_score': True}
    132 
    133     def __len__(self):
    134         return self._size
    135 

...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/model_selection/_validation.py in _fit_and_score(estimator=Pipeline(memory=None,
     steps=[('feature_enco...None, verbose=0,
            warm_start=False))]), X=              x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], y=memmap([0, 0, 1, ..., 1, 1, 1]), scorer={'score': <function _passthrough_scorer>}, train=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), test=array([    0,     1,     2, ..., 40174, 40178, 40184]), verbose=1, parameters={'clf__n_estimators': 10}, fit_params={}, return_train_score=True, return_parameters=False, return_n_test_samples=True, return_times=True, error_score='raise')
    422     if parameters is not None:
    423         estimator.set_params(**parameters)
    424 
    425     start_time = time.time()
    426 
--> 427     X_train, y_train = _safe_split(estimator, X, y, train)
        X_train = undefined
        y_train = undefined
        estimator = Pipeline(memory=None,
     steps=[('feature_enco...None, verbose=0,
            warm_start=False))])
        X =               x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns]
        y = memmap([0, 0, 1, ..., 1, 1, 1])
        train = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
    428     X_test, y_test = _safe_split(estimator, X, y, test, train)
    429 
    430     is_multimetric = not callable(scorer)
    431     n_scorers = len(scorer.keys()) if is_multimetric else 1

...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/utils/metaestimators.py in _safe_split(estimator=Pipeline(memory=None,
     steps=[('feature_enco...None, verbose=0,
            warm_start=False))]), X=              x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], y=memmap([0, 0, 1, ..., 1, 1, 1]), indices=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), train_indices=None)
    195         if train_indices is None:
    196             X_subset = X[np.ix_(indices, indices)]
    197         else:
    198             X_subset = X[np.ix_(indices, train_indices)]
    199     else:
--> 200         X_subset = safe_indexing(X, indices)
        X_subset = undefined
        X =               x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns]
        indices = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
    201 
    202     if y is not None:
    203         y_subset = safe_indexing(y, indices)
    204     else:

...........................................................................
/home/le243287/dev/alt-scikit-learn/sklearn/utils/__init__.py in safe_indexing(X=              x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], indices=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]))
    148         except ValueError:
    149             # Cython typed memoryviews internally used in pandas do not support
    150             # readonly buffers.
    151             warnings.warn("Copying input dataframe for slicing.",
    152                           DataConversionWarning)
--> 153             return X.copy().iloc[indices]
        X.copy.iloc = undefined
        indices = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
    154     elif hasattr(X, "shape"):
    155         if hasattr(X, 'take') and (hasattr(indices, 'dtype') and
    156                                    indices.dtype.kind == 'i'):
    157             # This is often substantially faster than X[indices]

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in __getitem__(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]))
   1323             except (KeyError, IndexError):
   1324                 pass
   1325             return self._getitem_tuple(key)
   1326         else:
   1327             key = com._apply_if_callable(key, self.obj)
-> 1328             return self._getitem_axis(key, axis=0)
        self._getitem_axis = <bound method _iLocIndexer._getitem_axis of <pandas.core.indexing._iLocIndexer object>>
        key = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
   1329 
   1330     def _is_scalar_access(self, key):
   1331         raise NotImplementedError()
   1332 

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in _getitem_axis(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=0)
   1733             self._has_valid_type(key, axis)
   1734             return self._getbool_axis(key, axis=axis)
   1735 
   1736         # a list of integers
   1737         elif is_list_like_indexer(key):
-> 1738             return self._get_list_axis(key, axis=axis)
        self._get_list_axis = <bound method _iLocIndexer._get_list_axis of <pandas.core.indexing._iLocIndexer object>>
        key = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
        axis = 0
   1739 
   1740         # a single integer
   1741         else:
   1742             key = self._convert_scalar_indexer(key, axis)

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/indexing.py in _get_list_axis(self=<pandas.core.indexing._iLocIndexer object>, key=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=0)
   1710         Returns
   1711         -------
   1712         Series object
   1713         """
   1714         try:
-> 1715             return self.obj.take(key, axis=axis, convert=False)
        self.obj.take = <bound method NDFrame.take of               x1  ...24  0.815972        a

[200000 rows x 6 columns]>
        key = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
        axis = 0
   1716         except IndexError:
   1717             # re-raise with different error message
   1718             raise IndexError("positional indexers are out-of-bounds")
   1719 

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/generic.py in take(self=              x1        x2        x3        x4  ...124  0.815972        a

[200000 rows x 6 columns], indices=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=0, convert=False, is_copy=True, **kwargs={})
   1923         """
   1924         nv.validate_take(tuple(), kwargs)
   1925         self._consolidate_inplace()
   1926         new_data = self._data.take(indices,
   1927                                    axis=self._get_block_manager_axis(axis),
-> 1928                                    convert=True, verify=True)
        convert = False
   1929         result = self._constructor(new_data).__finalize__(self)
   1930 
   1931         # maybe set copy if we didn't actually change the index
   1932         if is_copy:

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in take(self=BlockManager
Items: Index(['x1', 'x2', 'x3', 'x4...tBlock: slice(5, 6, 1), 1 x 200000, dtype: object, indexer=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=1, verify=True, convert=True)
   4006                 raise Exception('Indices must be nonzero and less than '
   4007                                 'the axis length')
   4008 
   4009         new_labels = self.axes[axis].take(indexer)
   4010         return self.reindex_indexer(new_axis=new_labels, indexer=indexer,
-> 4011                                     axis=axis, allow_dups=True)
        axis = 1
   4012 
   4013     def merge(self, other, lsuffix='', rsuffix=''):
   4014         if not self._is_indexed_like(other):
   4015             raise AssertionError('Must have same axes to merge managers')

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in reindex_indexer(self=BlockManager
Items: Index(['x1', 'x2', 'x3', 'x4...tBlock: slice(5, 6, 1), 1 x 200000, dtype: object, new_axis=Int64Index([ 39843,  39844,  39846,  39848,  398...199999],
           dtype='int64', length=160000), indexer=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=1, fill_value=None, allow_dups=True, copy=True)
   3892             new_blocks = self._slice_take_blocks_ax0(indexer,
   3893                                                      fill_tuple=(fill_value,))
   3894         else:
   3895             new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
   3896                 fill_value if fill_value is not None else blk.fill_value,))
-> 3897                 for blk in self.blocks]
        self.blocks = (FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64, ObjectBlock: slice(5, 6, 1), 1 x 200000, dtype: object)
   3898 
   3899         new_axes = list(self.axes)
   3900         new_axes[axis] = new_axis
   3901         return self.__class__(new_blocks, new_axes)

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in <listcomp>(.0=<tuple_iterator object>)
   3892             new_blocks = self._slice_take_blocks_ax0(indexer,
   3893                                                      fill_tuple=(fill_value,))
   3894         else:
   3895             new_blocks = [blk.take_nd(indexer, axis=axis, fill_tuple=(
   3896                 fill_value if fill_value is not None else blk.fill_value,))
-> 3897                 for blk in self.blocks]
        blk = FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64
   3898 
   3899         new_axes = list(self.axes)
   3900         new_axes[axis] = new_axis
   3901         return self.__class__(new_blocks, new_axes)

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/internals.py in take_nd(self=FloatBlock: slice(0, 5, 1), 5 x 200000, dtype: float64, indexer=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=1, new_mgr_locs=None, fill_tuple=(nan,))
   1041             new_values = algos.take_nd(values, indexer, axis=axis,
   1042                                        allow_fill=False)
   1043         else:
   1044             fill_value = fill_tuple[0]
   1045             new_values = algos.take_nd(values, indexer, axis=axis,
-> 1046                                        allow_fill=True, fill_value=fill_value)
        fill_value = nan
   1047 
   1048         if new_mgr_locs is None:
   1049             if axis == 0:
   1050                 slc = lib.indexer_as_slice(indexer)

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/core/algorithms.py in take_nd(arr=memmap([[ 1.85430272,  0.02363887, -0.44955668, ... 0.22950348,
          0.80573119,  0.81597234]]), indexer=memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999]), axis=1, out=array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
    ...0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]]), fill_value=nan, mask_info=None, allow_fill=True)
   1466         else:
   1467             out = np.empty(out_shape, dtype=dtype)
   1468 
   1469     func = _get_take_nd_function(arr.ndim, arr.dtype, out.dtype, axis=axis,
   1470                                  mask_info=mask_info)
-> 1471     func(arr, indexer, out, fill_value)
        func = <built-in function take_2d_axis1_float64_float64>
        arr = memmap([[ 1.85430272,  0.02363887, -0.44955668, ... 0.22950348,
          0.80573119,  0.81597234]])
        indexer = memmap([ 39843,  39844,  39846, ..., 199997, 199998, 199999])
        out = array([[ 0.,  0.,  0., ...,  0.,  0.,  0.],
    ...0.],
       [ 0.,  0.,  0., ...,  0.,  0.,  0.]])
        fill_value = nan
   1472 
   1473     if flip_order:
   1474         out = out.T
   1475     return out

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in pandas._libs.algos.take_2d_axis1_float64_float64 (pandas/_libs/algos.c:111160)()

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in View.MemoryView.memoryview_cwrapper (pandas/_libs/algos.c:124730)()

...........................................................................
/volatile/le243287/miniconda3/lib/python3.6/site-packages/pandas/_libs/algos.cpython-36m-x86_64-linux-gnu.so in View.MemoryView.memoryview.__cinit__ (pandas/_libs/algos.c:120965)()

ValueError: buffer source array is read-only
___________________________________________________________________________

Expected Results

No error is thrown.

Actual Results

I get an incredibly long error message (viewable in the notebook) but the punchline is:

ValueError: buffer source array is read-only

Versions

Darwin-15.6.0-x86_64-i386-64bit
Python 3.6.1 |Continuum Analytics, Inc.| (default, May 11 2017, 13:04:09)
[GCC 4.2.1 Compatible Apple LLVM 6.0 (clang-600.0.57)]
NumPy 1.13.1
SciPy 0.19.1
Scikit-Learn 0.18.2

@jnothman
Copy link
Member

jnothman commented Aug 2, 2017 via email

@stoddardg
Copy link
Author

Thanks for the reply. The weird thing about the issue is that it doesn't happen when there are only numeric columns in the dataframe. Other people experienced this bug in prior versions with numeric-only dataframes but those issues seemed to have been fixed [1][2] (or at least as far as I can by the issue status). That's what caused me to think that the issue might be specifically with Object columns.

The work around, for now, is to separate the encoding step from the Pipeline and encode things before GridSearchCV.fit is called. I don't think this is ideal because it breaks the really nice ability of Pipeline to encapsulate everything related to fitting a model.

I forgot to link some issues that this seems to be related to:

[1] #4772
[2] pandas-dev/pandas#9928 (comment)

There seemed to be a few more when I went down rabbit hole one day but I've lost those links.

@lesteve
Copy link
Member

lesteve commented Aug 3, 2017

I have edited your snippet to add the missing imports (and also added part of the traceback in a "details" section) and I can reproduce the problem.

I think @jnothman's assesment of the problem is pretty accurate. This needs more investigation to figure out whether there is a work-around or a fix.

@lesteve
Copy link
Member

lesteve commented Aug 7, 2017

@stoddardg so I looked a little bit more at it and a work-around I found for your particular case is to build the DataFrame in one-go like this:

my_dict = {name: arr for name, arr in zip(numeric_features, x.T)}
my_dict['category'] = 'a'
df = pd.DataFrame(my_dict)

Don't ask me exactly why yet because I have not fully understood the problem ... I'll try to clarify a bit what I have found.

@lesteve
Copy link
Member

lesteve commented Aug 7, 2017

A minimal snippet showing that the problem is related to pandas:

import numpy as np
import pandas as pd

df = pd.DataFrame({'first': np.ones(100, dtype='float64')})
indices = np.array([1, 3, 6])
indices.flags.writeable = False
df.iloc[indices]

This seems like a variation of pandas-dev/pandas#10043 (fixed by pandas-dev/pandas#10070). The difference here is that the indices is read-only not the numpy arrays inside the DataFrame.

@lesteve
Copy link
Member

lesteve commented Aug 7, 2017

I opened an issue on pandas-dev/pandas#17192.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants