Skip to content

DOC improve iris example #26973

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 28 commits into from
Sep 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
7c728e0
typo
Aug 1, 2023
b289185
add link to the example
Aug 1, 2023
c6b6026
adapt wording: pca dimension instead of direction
eguenther Aug 1, 2023
ea18244
adapt PCA section
eguenther Aug 1, 2023
edec869
adapt format and correct a typo
eguenther Aug 1, 2023
2e0d8ed
fix docstring
eguenther Aug 9, 2023
0f2e5dd
add explanations to the plots
eguenther Aug 9, 2023
465ddf0
Merge branch 'main' into doc_improve_iris_example
eguenther Aug 9, 2023
d096b26
remove empty line in docstring (as suggested by guillaume)
eguenther Aug 12, 2023
f81572b
adapt divider lines to fit text length (as suggested by Guillaume)
eguenther Aug 12, 2023
a3d2bc9
move matplotlib import into the cell where it is used the first time …
eguenther Aug 12, 2023
3f8731e
move PCA import into the cell where it is used the first time (as sug…
eguenther Aug 12, 2023
d752fe6
clean up code for better readability. First plot does not need X and …
eguenther Aug 12, 2023
7b3dc5e
remove typos: the 't' key on my keyboard is stuck :-/
eguenther Aug 12, 2023
259a009
adapt text for PCA (intro and plot description), as suggested by Guil…
eguenther Aug 12, 2023
cfe9790
adapt title for first cell
eguenther Aug 12, 2023
cd3f95b
Merge branch 'main' into doc_improve_iris_example
eguenther Aug 12, 2023
326bfd5
adjust scatter plot to Guillaumes suggestion
eguenther Aug 12, 2023
f0db197
Merge branch 'doc_improve_iris_example' of github.com:eguenther/sciki…
eguenther Aug 12, 2023
1ff67e0
improve text for the scatter plot: now that we have added a legend, w…
eguenther Aug 12, 2023
db80506
trying to fix the failed checks: add empty line before next cell
eguenther Aug 12, 2023
1116fd5
Merge branch 'main' into doc_improve_iris_example
eguenther Aug 15, 2023
3546484
add new cell for comments
eguenther Sep 5, 2023
d83f0e5
remove the part where the plot size was specified, we don't need it
eguenther Sep 5, 2023
57b5be9
add noqa because import is not done at the top of the file!
eguenther Sep 5, 2023
c2acef8
Merge branch 'doc_improve_iris_example' of github.com:eguenther/sciki…
eguenther Sep 5, 2023
5764d95
Merge branch 'main' into doc_improve_iris_example
eguenther Sep 5, 2023
474e8cc
MAINT remove unecessary check
glemaitre Sep 7, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 46 additions & 32 deletions examples/datasets/plot_iris_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
=========================================================
================
The Iris Dataset
=========================================================
================
This data sets consists of 3 different types of irises'
(Setosa, Versicolour, and Virginica) petal and sepal
length, stored in a 150x4 numpy.ndarray
Expand All @@ -19,37 +19,47 @@
# Modified for documentation by Jaques Grobler
# License: BSD 3 clause

import matplotlib.pyplot as plt

# unused but required import for doing 3d projections with matplotlib < 3.2
import mpl_toolkits.mplot3d # noqa: F401

# %%
# Loading the iris dataset
# ------------------------
from sklearn import datasets
from sklearn.decomposition import PCA

# import some data to play with
iris = datasets.load_iris()
X = iris.data[:, :2] # we only take the first two features.
y = iris.target

x_min, x_max = X[:, 0].min() - 0.5, X[:, 0].max() + 0.5
y_min, y_max = X[:, 1].min() - 0.5, X[:, 1].max() + 0.5

plt.figure(2, figsize=(8, 6))
plt.clf()
# %%
# Scatter Plot of the Iris dataset
# --------------------------------
import matplotlib.pyplot as plt

_, ax = plt.subplots()
scatter = ax.scatter(iris.data[:, 0], iris.data[:, 1], c=iris.target)
ax.set(xlabel=iris.feature_names[0], ylabel=iris.feature_names[1])
_ = ax.legend(
scatter.legend_elements()[0], iris.target_names, loc="lower right", title="Classes"
)

# Plot the training points
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.Set1, edgecolor="k")
plt.xlabel("Sepal length")
plt.ylabel("Sepal width")
# %%
# Each point in the scatter plot refers to one of the 150 iris flowers
# in the dataset, with the color indicating their respective type
# (Setosa, Versicolour, and Virginica).
# You can already see a pattern regarding the Setosa type, which is
# easily identifiable based on its short and wide sepal. Only
# considering these 2 dimensions, sepal width and length, there's still
# overlap between the Versicolor and Virginica types.

# %%
# Plot a PCA representation
# -------------------------
# Let's apply a Principal Component Analysis (PCA) to the iris dataset
# and then plot the irises across the first three PCA dimensions.
# This will allow us to better differentiate between the three types!

plt.xlim(x_min, x_max)
plt.ylim(y_min, y_max)
plt.xticks(())
plt.yticks(())
# unused but required import for doing 3d projections with matplotlib < 3.2
import mpl_toolkits.mplot3d # noqa: F401

from sklearn.decomposition import PCA

# To getter a better understanding of interaction of the dimensions
# plot the first three PCA dimensions
fig = plt.figure(1, figsize=(8, 6))
ax = fig.add_subplot(111, projection="3d", elev=-150, azim=110)

Expand All @@ -58,18 +68,22 @@
X_reduced[:, 0],
X_reduced[:, 1],
X_reduced[:, 2],
c=y,
cmap=plt.cm.Set1,
edgecolor="k",
c=iris.target,
s=40,
)

ax.set_title("First three PCA directions")
ax.set_xlabel("1st eigenvector")
ax.set_title("First three PCA dimensions")
ax.set_xlabel("1st Eigenvector")
ax.xaxis.set_ticklabels([])
ax.set_ylabel("2nd eigenvector")
ax.set_ylabel("2nd Eigenvector")
ax.yaxis.set_ticklabels([])
ax.set_zlabel("3rd eigenvector")
ax.set_zlabel("3rd Eigenvector")
ax.zaxis.set_ticklabels([])

plt.show()

# %%
# PCA will create 3 new features that are a linear combination of the
# 4 original features. In addition, this transform maximizes the variance.
# With this transformation, we see that we can identify each species using
# only the first feature (i.e. first eigenvalues).
3 changes: 3 additions & 0 deletions sklearn/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,9 @@ def load_iris(*, return_X_y=False, as_frame=False):
array([0, 0, 1])
>>> list(data.target_names)
['setosa', 'versicolor', 'virginica']

See :ref:`sphx_glr_auto_examples_datasets_plot_iris_dataset.py` for a more
detailed example of how to work with the iris dataset.
"""
data_file_name = "iris.csv"
data, target, target_names, fdescr = load_csv_data(
Expand Down