Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions machine_learning/mlp_activation_comparison.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import matplotlib.pyplot as plt
import numpy as np
from sklearn.datasets import make_moons
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier


# Compare different activation functions in MLPClassifier
def compare_activations():

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please provide return type hint for the function: compare_activations. If the function does not return a value, please provide the type hint as: def function() -> None:

As there is no test file in this pull request nor any test function or class in the file machine_learning/mlp_activation_comparison.py, please provide doctest for the function compare_activations

X, y = make_moons(n_samples=200, noise=0.25, random_state=3)

Check failure on line 10 in machine_learning/mlp_activation_comparison.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N806)

machine_learning/mlp_activation_comparison.py:10:5: N806 Variable `X` in function should be lowercase
X_train, X_test, y_train, y_test = train_test_split(

Check failure on line 11 in machine_learning/mlp_activation_comparison.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N806)

machine_learning/mlp_activation_comparison.py:11:14: N806 Variable `X_test` in function should be lowercase

Check failure on line 11 in machine_learning/mlp_activation_comparison.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N806)

machine_learning/mlp_activation_comparison.py:11:5: N806 Variable `X_train` in function should be lowercase

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable and function names should follow the snake_case naming convention. Please update the following name accordingly: X_train

Variable and function names should follow the snake_case naming convention. Please update the following name accordingly: X_test

X, y, stratify=y, random_state=42
)

activations = ["identity", "logistic", "tanh", "relu"]

for activation in activations:
mlp = MLPClassifier(
hidden_layer_sizes=[50],
max_iter=1000,
activation=activation,
random_state=0,
)
mlp.fit(X_train, y_train)

print(
f"Activation: {activation}, "
f"Train Accuracy: {mlp.score(X_train, y_train):.2f}, "
f"Test Accuracy: {mlp.score(X_test, y_test):.2f}"
)

# Decision boundary
x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5
xx, yy = np.meshgrid(
np.linspace(x_min, x_max, 200),
np.linspace(y_min, y_max, 200),
)
Z = mlp.predict(np.c_[xx.ravel(), yy.ravel()])

Check failure on line 39 in machine_learning/mlp_activation_comparison.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N806)

machine_learning/mlp_activation_comparison.py:39:9: N806 Variable `Z` in function should be lowercase
Z = Z.reshape(xx.shape)

Check failure on line 40 in machine_learning/mlp_activation_comparison.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (N806)

machine_learning/mlp_activation_comparison.py:40:9: N806 Variable `Z` in function should be lowercase

plt.contourf(xx, yy, Z, alpha=0.3)
plt.scatter(X_train[:, 0], X_train[:, 1], c=y_train, marker="o", label="Train")
plt.scatter(X_test[:, 0], X_test[:, 1], c=y_test, marker="s", label="Test")
plt.title(f"Activation: {activation}")
plt.legend()
plt.show()


if __name__ == "__main__":
compare_activations()
Loading