-
-
Notifications
You must be signed in to change notification settings - Fork 26.2k
Description
Describe the workflow you want to enable
I aim to enable a more efficient training workflow for Multilayer Perceptrons (MLPs) in scikit-learn by optimizing the performance of the SGDOptimizer
and AdamOptimizer
classes. Currently, these optimizers use list comprehensions in their _get_updates
methods to compute parameter updates, which can be computationally expensive for large neural networks with many parameters (e.g., hidden layers with thousands of neurons). The proposed vectorized operations will allow users to train larger MLPs faster, particularly on datasets requiring extensive iterations, without altering the existing API or user experience. Additionally, the optimization will address redundant computations in SGDOptimizer
when using Nesterov’s momentum, further improving training speed. This enhancement will benefit users working on deep learning tasks within scikit-learn, such as image classification or regression with complex models, by reducing training time and improving scalability.
Describe your proposed solution
To optimize the performance of SGDOptimizer
and AdamOptimizer
, I propose the following changes to sklearn/neural_network/_stochastic_optimizers.py
:
- Vectorized Operations:
- Replace list comprehensions in
_get_updates
with in-place NumPy vectorized operations. This will leverage NumPy’s optimized C-based implementation, reducing Python loop overhead. For example, inAdamOptimizer._get_updates
, the current list comprehension for updating first and second moments can be replaced with a single vectorized operation across all parameters. - Example implementation for
AdamOptimizer._get_updates
:
- Replace list comprehensions in
def _get_updates(self, grads: List[np.ndarray]) -> List[np.ndarray]:
self.t += 1
lr_t = self.learning_rate_init * np.sqrt(1 - self.beta_2**self.t) / (1 - self.beta_1**self.t)
for m, v, grad in zip(self.ms, self.vs, grads):
np.multiply(self.beta_1, m, out=m)
np.multiply(1 - self.beta_1, grad, out=grad)
np.add(m, grad, out=m)
np.multiply(self.beta_2, v, out=v)
np.multiply(1 - self.beta_2, grad**2, out=grad)
np.add(v, grad, out=v)
updates = [-lr_t * m / (np.sqrt(v) + self.epsilon) for m, v in zip(self.ms, self.vs)]
return updates
- Optimize Nesterov’s Momentum in SGDOptimizer:
Streamline SGDOptimizer._get_updates to avoid redundant computations when nesterov=True. Currently, updates are computed twice (once for velocity updates and again for Nesterov correction). The proposal reuses the updated velocities to compute the final update in a single pass.
Example implementation
def _get_updates(self, grads: List[np.ndarray]) -> List[np.ndarray]:
self.velocities = [
self.momentum * velocity - self.learning_rate * grad
for velocity, grad in zip(self.velocities, grads)
]
if self.nesterov:
return [
self.momentum * velocity - self.learning_rate * grad
for velocity, grad in zip(self.velocities, grads)
]
return self.velocities
Describe alternatives you've considered, if relevant
Switching to a Compiled Backend (e.g., Numba or Cython):
Using Numba to JIT-compile the _get_updates methods or Cython to convert them into C code could significantly speed up execution. However, this would introduce additional dependencies and complexity, potentially conflicting with scikit-learn’s lightweight, NumPy-only design philosophy. Vectorized NumPy operations are preferred as they leverage existing optimizations without external tools.
Parallelization with Multithreading:
Implementing parallel updates across parameter arrays using Python’s multiprocessing or threading could improve performance on multi-core systems. However, this approach would add overhead for synchronization and is less practical for the relatively small-scale computations in MLP training, where NumPy’s vectorization already provides efficient single-threaded performance.
Additional context
Maintainability: The proposed changes use existing NumPy functions, ensuring compatibility with the current codebase and ease of maintenance.
Next Steps: I am willing to contribute a pull request with the implementation, updated tests, and benchmark results if this feature request is approved. Feedback on specific datasets or test cases for validation would be appreciated.