Skip to content
1 change: 1 addition & 0 deletions doc/modules/classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Nearest Neighbors
:template: class.rst

neighbors.Neighbors
neighbors.NeighborsBarycenter
ball_tree.BallTree

.. autosummary::
Expand Down
15 changes: 12 additions & 3 deletions doc/modules/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,20 @@ the k nearest neighbors of a point is assigned to this point.
Regression
==========

Nearest neighbor regression is not (yet) implemented, yet it should be
straightforward using the BallTree class.
The :class:`NeighborsBarycenter` estimator implements a nearest-neighbors
regression method using barycenter weighting of the targets of the
k-neighbors.

.. figure:: ../auto_examples/images/plot_neighbors_regression.png
:target: ../auto_examples/plot_neighbors_regression.html
:align: center
:scale: 75


.. currentmodule:: scikits.learn.ball_tree
.. topic:: Examples:

* :ref:`example_plot_neighbors_regression.py`: an example of regression
using nearest neighbor.

Efficient implementation: the ball tree
==========================================
Expand Down
6 changes: 3 additions & 3 deletions examples/plot_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
# avoid this ugly slicing by using a two-dim dataset
Y = iris.target

h=.02 # step size in the mesh
h = .02 # step size in the mesh

# we create an instance of SVM and fit out data. We do not scale our
# data since we want to plot the support vectors
Expand All @@ -27,8 +27,8 @@

# Plot the decision boundary. For that, we will asign a color to each
# point in the mesh [x_min, m_max]x[y_min, y_max].
x_min, x_max = X[:,0].min()-1, X[:,0].max()+1
y_min, y_max = X[:,1].min()-1, X[:,1].max()+1
x_min, x_max = X[:,0].min()-1, X[:,0].max() + 1
y_min, y_max = X[:,1].min()-1, X[:,1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

Expand Down
44 changes: 44 additions & 0 deletions examples/plot_neighbors_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""
==============================
k-Nearest Neighbors regression
==============================

Demonstrate the resolution of a regression problem
using a k-Nearest Neighbor and the interpolation of the
target using barycenter computation.

"""
print __doc__

###############################################################################
# Generate sample data
import numpy as np

np.random.seed(0)
X = np.sort(5*np.random.rand(40, 1), axis=0)
T = np.linspace(0, 5, 500)
y = np.sin(X).ravel()

# Add noise to targets
y[::5] += 1*(0.5 - np.random.rand(8))

###############################################################################
# Fit regression model

from scikits.learn import neighbors

knn_barycenter = neighbors.NeighborsBarycenter(n_neighbors=5)
y_ = knn_barycenter.fit(X, y).predict(T)

###############################################################################
# look at the results
import pylab as pl
pl.scatter(X, y, c='k', label='data')
pl.hold('on')
pl.plot(T, y_, c='g', label='k-NN prediction')
pl.xlabel('data')
pl.ylabel('target')
pl.legend()
pl.title('k-NN Regression')
pl.show()

6 changes: 2 additions & 4 deletions examples/svm/plot_weighted_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

# we create 20 points
np.random.seed(0)
X = np.r_[np.random.randn(10, 2) + [1,1], np.random.randn(10, 2)]
X = np.r_[np.random.randn(10, 2) + [1, 1], np.random.randn(10, 2)]
Y = [1]*10 + [-1]*10
sample_weight = 100 * np.abs(np.random.randn(20))
# and assign a bigger weight to the last 10 samples
Expand All @@ -27,16 +27,14 @@
# get the separating hyperplane
xx, yy = np.meshgrid(np.linspace(-4, 5, 500), np.linspace(-4, 5, 500))


# plot the line, the points, and the nearest vectors to the plane
Z = clf.decision_function(np.c_[xx.ravel(), yy.ravel()])
Z = Z.reshape(xx.shape)

# plot the line, the points, and the nearest vectors to the plane
pl.set_cmap(pl.cm.bone)
pl.contourf(xx, yy, Z, alpha=0.75)
pl.scatter(X[:,0], X[:,1], c=Y, s=sample_weight, alpha=0.9)
pl.scatter(X[:, 0], X[:, 1], c=Y, s=sample_weight, alpha=0.9)

pl.axis('tight')
pl.show()

Loading