diff --git a/examples/bovw/plot_bovw.py b/examples/bovw/plot_bovw.py new file mode 100644 index 0000000000000..e19cad8d2bc13 --- /dev/null +++ b/examples/bovw/plot_bovw.py @@ -0,0 +1,108 @@ +""" +=================== +Bag of Visual Words +=================== + +An illustration of a Bag of Visual Words (BoVW) approach using textons for +image recognition [1]_. The proposed solution is a naive and simplified +solution using a limited number of patches and words and do not rely on any +other well known computer vision features. It aims at illustrating the BoVW +under a limited processing time rather than achieving high classification +performance. + +References +---------- +.. [1] Varma, Manik, and Andrew Zisserman. "A statistical approach to material + classification using image patch exemplars." IEEE transactions on + pattern analysis and machine intelligence 31.11 (2009): 2032-2047. + +""" + +# Author: Guillaume Lemaitre +# License: BSD + +import numpy as np +from scipy.misc import imread + +from sklearn.feature_extraction.image import extract_patches_2d +from sklearn.model_selection import StratifiedKFold +from sklearn.cluster import MiniBatchKMeans +from sklearn.ensemble import RandomForestClassifier +from sklearn.metrics import confusion_matrix + +from tudarmstadt import fetch_tu_darmstadt + + +print(__doc__) + + +def image_extraction(path_image, rng, patch_size=(9, 9), + max_patches=100000): + """ Function to extract a couple of patches from an image and apply PCA """ + + # Read the current image + im = imread(path_image) + # Extract patches + patch = extract_patches_2d(im, patch_size=patch_size, + max_patches=max_patches, random_state=rng) + + return patch.reshape((patch.shape[0], np.prod(patch_size) * len(im.shape))) + +# Define the parameters in use afterwards +patch_size = (9, 9) +max_patches = 100 +n_jobs = 1 +n_components = 9 +max_patches_classify = 2000 +nb_words = 50 +rng = 42 + +# Load the data +png_files, labels = fetch_tu_darmstadt() + +# Extract the data and project them +patch_arr = [image_extraction(path_im, rng, patch_size, max_patches_classify) + for path_im in png_files] + +print('Extracted patches for image classification') + +# Apply a stratified K-fold classification in which we will learn +# a dictionary +skf = StratifiedKFold(n_splits=5, random_state=rng) + +# Get the training and testing index from the first fold +train_idx, test_idx = skf.split(patch_arr, labels).next() + +# Build the codebook +# Define the number of words to create the codebook +vq = MiniBatchKMeans(n_clusters=nb_words, verbose=False, init='random', + batch_size=10 * nb_words, compute_labels=False, + reassignment_ratio=0.0, random_state=rng, n_init=3) +# Stack the training example +stack_training = np.vstack([patch_arr[t] for t in train_idx]) +# Find the centroids +vq.fit(stack_training) + +print('Codebook learnt') + +# Build the training and testing data +train_data = np.array([np.histogram(vq.predict(patch_arr[tr_im]), + bins=range(nb_words), + density=True) + for tr_im in train_idx]) +train_data = np.vstack(train_data[:, 0]) +train_label = labels[train_idx] + +test_data = np.array([np.histogram(vq.predict(patch_arr[te_im]), + bins=range(nb_words), + density=True) + for te_im in test_idx]) +test_data = np.vstack(test_data[:, 0]) +test_label = labels[test_idx] + +# Classification using Random Forest +rf = RandomForestClassifier(n_estimators=10, random_state=rng, n_jobs=n_jobs) +pred = rf.fit(train_data, train_label).predict(test_data) + +print('Classification performed - the confusion matrix obtained is:') +print(confusion_matrix(test_label, pred)) diff --git a/examples/bovw/tudarmstadt.py b/examples/bovw/tudarmstadt.py new file mode 100644 index 0000000000000..93780781db776 --- /dev/null +++ b/examples/bovw/tudarmstadt.py @@ -0,0 +1,121 @@ +"""TU Darmstadt dataset. +The original database was available from + http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz +""" + +import os +from os.path import join, exists +try: + # Python 2 + from urllib2 import HTTPError + from urllib2 import urlopen +except ImportError: + # Python 3+ + from urllib.error import HTTPError + from urllib.request import urlopen + +import numpy as np +import tarfile + +from sklearn.datasets import get_data_home + +DATA_URL = "http://host.robots.ox.ac.uk/pascal/VOC/download/tud.tar.gz" +TARGET_FILENAME = "tud.pkz" + +# Grab the module-level docstring to use as a description of the +# dataset +MODULE_DOCS = __doc__ + + +def fetch_tu_darmstadt(data_home=None): + """Loader for the TU Darmstadt dataset. + + Read more in the :ref:`User Guide `. + + + Parameters + ---------- + data_home : optional, default: None + Specify another download and cache folder for the datasets. By default + all scikit learn data is stored in '~/scikit_learn_data' subfolders. + + Returns + ------- + images_list : list + Python list with the path of each image to consider during the + classification. + + labels : array-like, shape (n_images, ) + An array with the different label corresponding to the categories. + 0: motorbikes - 1: cars - 2: cows. + + Notes + ------ + The dataset is composed of 124 motorbikes images, 100 cars, and 112 cows. + + Examples + -------- + Load the 'tu-darmstadt' dataset: + + >>> from tudarmstadt import fetch_tu_darmstadt + >>> import tempfile + >>> test_data_home = tempfile.mkdtemp() + >>> im_list, labels = fetch_tu_darmstadt(data_home=test_data_home) + """ + + # check if the data has been already downloaded + data_home = get_data_home(data_home=data_home) + data_home = join(data_home, 'tu_darmstadt') + if not exists(data_home): + os.makedirs(data_home) + + # dataset tar file + filename = join(data_home, 'tud.tar.gz') + + # if the file does not exist, download it + if not exists(filename): + try: + db_url = urlopen(DATA_URL) + with open(filename, 'wb') as f: + f.write(db_url.read()) + db_url.close() + except HTTPError as e: + if e.code == 404: + e.msg = 'TU Darmstadt dataset not found.' + raise + # Try to extract the complete archieve + try: + tarfile.open(filename, "r:gz").extractall(path=data_home) + except: + os.remove(filename) + raise + + # the file 'motorbikes023' is a gray scale image and need to be removed + file_removal = [ + join(data_home, + 'TUDarmstadt/PNGImages/motorbike-testset/motorbikes023.png'), + join(data_home, + 'TUDarmstadt/Annotations/motorbike-testset/motorbikes023.txt'), + ] + for f in file_removal: + os.remove(f) + + # list the different images + data_path = join(data_home, 'TUDarmstadt/PNGImages') + images_list = [os.path.join(root, name) + for root, dirs, files in os.walk(data_path) + for name in files + if name.endswith((".png"))] + + # create the label array + labels = [] + for imf in images_list: + if 'motorbike' in imf: + labels.append(0) + elif 'cars' in imf: + labels.append(1) + elif 'cows' in imf: + labels.append(2) + + # Return these information + return images_list, np.array(labels)