Skip to content

[MRG] Example Bag-of-Visual-Words #6509

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
108 changes: 108 additions & 0 deletions examples/bovw/plot_bovw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""
===================
Bag of Visual Words
===================

An illustration of a Bag of Visual Words (BoVW) approach using textons for
image recognition [1]_. The proposed solution is a naive and simplified
solution using a limited number of patches and words and do not rely on any
other well known computer vision features. It aims at illustrating the BoVW
under a limited processing time rather than achieving high classification
performance.

References
----------
.. [1] Varma, Manik, and Andrew Zisserman. "A statistical approach to material
classification using image patch exemplars." IEEE transactions on
pattern analysis and machine intelligence 31.11 (2009): 2032-2047.

"""

# Author: Guillaume Lemaitre <g.lemaitre58@gmail.com>
# License: BSD

import numpy as np
from scipy.misc import imread

from sklearn.feature_extraction.image import extract_patches_2d
from sklearn.model_selection import StratifiedKFold
from sklearn.cluster import MiniBatchKMeans
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import confusion_matrix

from tudarmstadt import fetch_tu_darmstadt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be moved to sklearn/datasets



print(__doc__)


def image_extraction(path_image, rng, patch_size=(9, 9),
max_patches=100000):
""" Function to extract a couple of patches from an image and apply PCA """

# Read the current image
im = imread(path_image)
# Extract patches
patch = extract_patches_2d(im, patch_size=patch_size,
max_patches=max_patches, random_state=rng)

return patch.reshape((patch.shape[0], np.prod(patch_size) * len(im.shape)))

# Define the parameters in use afterwards
patch_size = (9, 9)
max_patches = 100
n_jobs = 1
n_components = 9
max_patches_classify = 2000
nb_words = 50
rng = 42

# Load the data
png_files, labels = fetch_tu_darmstadt()

# Extract the data and project them
patch_arr = [image_extraction(path_im, rng, patch_size, max_patches_classify)
for path_im in png_files]

print('Extracted patches for image classification')

# Apply a stratified K-fold classification in which we will learn
# a dictionary
skf = StratifiedKFold(n_splits=5, random_state=rng)

# Get the training and testing index from the first fold
train_idx, test_idx = skf.split(patch_arr, labels).next()

# Build the codebook
# Define the number of words to create the codebook
vq = MiniBatchKMeans(n_clusters=nb_words, verbose=False, init='random',
batch_size=10 * nb_words, compute_labels=False,
reassignment_ratio=0.0, random_state=rng, n_init=3)
# Stack the training example
stack_training = np.vstack([patch_arr[t] for t in train_idx])
# Find the centroids
vq.fit(stack_training)

print('Codebook learnt')

# Build the training and testing data
train_data = np.array([np.histogram(vq.predict(patch_arr[tr_im]),
bins=range(nb_words),
density=True)
for tr_im in train_idx])
train_data = np.vstack(train_data[:, 0])
train_label = labels[train_idx]

test_data = np.array([np.histogram(vq.predict(patch_arr[te_im]),
bins=range(nb_words),
density=True)
for te_im in test_idx])
test_data = np.vstack(test_data[:, 0])
test_label = labels[test_idx]

# Classification using Random Forest
rf = RandomForestClassifier(n_estimators=10, random_state=rng, n_jobs=n_jobs)
pred = rf.fit(train_data, train_label).predict(test_data)

print('Classification performed - the confusion matrix obtained is:')
print(confusion_matrix(test_label, pred))
121 changes: 121 additions & 0 deletions examples/bovw/tudarmstadt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""TU Darmstadt dataset.
The original database was available from
http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz
"""

import os
from os.path import join, exists
try:
# Python 2
from urllib2 import HTTPError
from urllib2 import urlopen
except ImportError:
# Python 3+
from urllib.error import HTTPError
from urllib.request import urlopen

import numpy as np
import tarfile

from sklearn.datasets import get_data_home

DATA_URL = "http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz"
TARGET_FILENAME = "tud.pkz"

# Grab the module-level docstring to use as a description of the
# dataset
MODULE_DOCS = __doc__


def fetch_tu_darmstadt(data_home=None):
"""Loader for the TU Darmstadt dataset.

Read more in the :ref:`User Guide <datasets>`.


Parameters
----------
data_home : optional, default: None
Specify another download and cache folder for the datasets. By default
all scikit learn data is stored in '~/scikit_learn_data' subfolders.

Returns
-------
images_list : list
Python list with the path of each image to consider during the
classification.

labels : array-like, shape (n_images, )
An array with the different label corresponding to the categories.
0: motorbikes - 1: cars - 2: cows.

Notes
------
The dataset is composed of 124 motorbikes images, 100 cars, and 112 cows.

Examples
--------
Load the 'tu-darmstadt' dataset:

>>> from tudarmstadt import fetch_tu_darmstadt
>>> import tempfile
>>> test_data_home = tempfile.mkdtemp()
>>> im_list, labels = fetch_tu_darmstadt(data_home=test_data_home)
"""

# check if the data has been already downloaded
data_home = get_data_home(data_home=data_home)
data_home = join(data_home, 'tu_darmstadt')
if not exists(data_home):
os.makedirs(data_home)

# dataset tar file
filename = join(data_home, 'tud.tar.gz')

# if the file does not exist, download it
if not exists(filename):
try:
db_url = urlopen(DATA_URL)
with open(filename, 'wb') as f:
f.write(db_url.read())
db_url.close()
except HTTPError as e:
if e.code == 404:
e.msg = 'TU Darmstadt dataset not found.'
raise
# Try to extract the complete archieve
try:
tarfile.open(filename, "r:gz").extractall(path=data_home)
except:
os.remove(filename)
raise

# the file 'motorbikes023' is a gray scale image and need to be removed
file_removal = [
join(data_home,
'TUDarmstadt/PNGImages/motorbike-testset/motorbikes023.png'),
join(data_home,
'TUDarmstadt/Annotations/motorbike-testset/motorbikes023.txt'),
]
for f in file_removal:
os.remove(f)

# list the different images
data_path = join(data_home, 'TUDarmstadt/PNGImages')
images_list = [os.path.join(root, name)
for root, dirs, files in os.walk(data_path)
for name in files
if name.endswith((".png"))]

# create the label array
labels = []
for imf in images_list:
if 'motorbike' in imf:
labels.append(0)
elif 'cars' in imf:
labels.append(1)
elif 'cows' in imf:
labels.append(2)

# Return these information
return images_list, np.array(labels)