Skip to content

Commit 821f83c

Browse files
committed
added transfer learning tutorial
1 parent edbb97c commit 821f83c

File tree

13 files changed

+168
-0
lines changed

13 files changed

+168
-0
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ This is a repository of all the tutorials of [The Python Code](https://www.thepy
2626
- [How to Detect Human Faces in Python using OpenCV](https://www.thepythoncode.com/article/detect-faces-opencv-python). ([code](machine-learning/face_detection))
2727
- [Building a Speech Emotion Recognizer using Scikit-learn](https://www.thepythoncode.com/article/building-a-speech-emotion-recognizer-using-sklearn). ([code](machine-learning/speech-emotion-recognition))
2828
- [How to Make an Image Classifier in Python using Keras](https://www.thepythoncode.com/article/image-classification-keras-python). ([code](machine-learning/image-classifier))
29+
- [How to Use Transfer Learning for Image Classification using Keras in Python](https://www.thepythoncode.com/article/use-transfer-learning-for-image-flower-classification-keras-python). ([code](machine-learning/image-classifier-using-transfer-learning))
2930
- [Top 8 Python Libraries For Data Scientists and Machine Learning Engineers](https://www.thepythoncode.com/article/top-python-libraries-for-data-scientists).
3031

3132
- ### [General Python Topics](https://www.thepythoncode.com/topic/general-python-topics)
@@ -46,3 +47,4 @@ This is a repository of all the tutorials of [The Python Code](https://www.thepy
4647
- [How to Extract Weather Data from Google in Python](https://www.thepythoncode.com/article/extract-weather-data-python). ([code](web-scraping/weather-extractor))
4748
- [How to Download All Images from a Web Page in Python](https://www.thepythoncode.com/article/download-web-page-images-python). ([code](web-scraping/download-images))
4849

50+
For any feedback, please consider pulling requests.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
# [How to Use Transfer Learning for Image Classification using Keras in Python](https://www.thepythoncode.com/article/use-transfer-learning-for-image-flower-classification-keras-python)
2+
To run this:
3+
- `pip3 install -r requirements.txt`
4+
- To train the model (already trained and the optimal weights are in `results` folder):
5+
```
6+
python train.py
7+
```
8+
This will load the flower dataset, construct the `MobileNetV2` model with its weights and starts training.
9+
- 86% accuracy was achieved on 5 classes of flowers which are `daisy`, `dandelion`, `roses`, `sunflowers` and `tulips`.
10+
- To evaluate the model as well as visualizing different flowers and its corresponding predictions:
11+
```
12+
python test.py
13+
```
14+
This will **Outputs:**
15+
```
16+
23/23 [==============================] - 6s 264ms/step
17+
Val loss: 0.5659930361524
18+
Val Accuracy: 0.8166894659134987
19+
```
20+
and **plots**:
21+
![Predicted Flowers](predicted-flowers.png)
22+
Check the [tutorial](https://www.thepythoncode.com/article/use-transfer-learning-for-image-flower-classification-keras-python) for more information.
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
tensorflow
2+
keras
3+
numpy
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from train import load_data, create_model, IMAGE_SHAPE, batch_size, np
2+
import matplotlib.pyplot as plt
3+
# load the data generators
4+
train_generator, validation_generator, class_names = load_data()
5+
# constructs the model
6+
model = create_model(input_shape=IMAGE_SHAPE)
7+
# load the optimal weights
8+
model.load_weights("results/MobileNetV2_finetune_last5_less_lr-loss-0.45-acc-0.86.h5")
9+
10+
validation_steps_per_epoch = np.ceil(validation_generator.samples / batch_size)
11+
# print the validation loss & accuracy
12+
evaluation = model.evaluate_generator(validation_generator, steps=validation_steps_per_epoch, verbose=1)
13+
print("Val loss:", evaluation[0])
14+
print("Val Accuracy:", evaluation[1])
15+
16+
# get a random batch of images
17+
image_batch, label_batch = next(iter(validation_generator))
18+
# turn the original labels into human-readable text
19+
label_batch = [class_names[np.argmax(label_batch[i])] for i in range(batch_size)]
20+
# predict the images on the model
21+
predicted_class_names = model.predict(image_batch)
22+
predicted_ids = [np.argmax(predicted_class_names[i]) for i in range(batch_size)]
23+
# turn the predicted vectors to human readable labels
24+
predicted_class_names = np.array([class_names[id] for id in predicted_ids])
25+
26+
# some nice plotting
27+
plt.figure(figsize=(10,9))
28+
for n in range(30):
29+
plt.subplot(6,5,n+1)
30+
plt.subplots_adjust(hspace = 0.3)
31+
plt.imshow(image_batch[n])
32+
if predicted_class_names[n] == label_batch[n]:
33+
color = "blue"
34+
title = predicted_class_names[n].title()
35+
else:
36+
color = "red"
37+
title = f"{predicted_class_names[n].title()}, correct:{label_batch[n]}"
38+
plt.title(title, color=color)
39+
plt.axis('off')
40+
_ = plt.suptitle("Model predictions (blue: correct, red: incorrect)")
41+
plt.show()
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
2+
import tensorflow as tf
3+
from keras.models import Model
4+
from keras.applications import MobileNetV2, ResNet50, InceptionV3 # try to use them and see which is better
5+
from keras.layers import Dense
6+
from keras.callbacks import ModelCheckpoint, TensorBoard
7+
from keras.utils import get_file
8+
from keras.preprocessing.image import ImageDataGenerator
9+
import os
10+
import pathlib
11+
import numpy as np
12+
13+
batch_size = 32
14+
num_classes = 5
15+
epochs = 10
16+
17+
IMAGE_SHAPE = (224, 224, 3)
18+
19+
20+
def load_data():
21+
"""This function downloads, extracts, loads, normalizes and one-hot encodes Flower Photos dataset"""
22+
# download the dataset and extract it
23+
data_dir = get_file(origin='https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',
24+
fname='flower_photos', untar=True)
25+
data_dir = pathlib.Path(data_dir)
26+
27+
# count how many images are there
28+
image_count = len(list(data_dir.glob('*/*.jpg')))
29+
print("Number of images:", image_count)
30+
31+
# get all classes for this dataset (types of flowers) excluding LICENSE file
32+
CLASS_NAMES = np.array([item.name for item in data_dir.glob('*') if item.name != "LICENSE.txt"])
33+
34+
# roses = list(data_dir.glob('roses/*'))
35+
# 20% validation set 80% training set
36+
image_generator = ImageDataGenerator(rescale=1/255, validation_split=0.2)
37+
38+
# make the training dataset generator
39+
train_data_gen = image_generator.flow_from_directory(directory=str(data_dir), batch_size=batch_size,
40+
classes=list(CLASS_NAMES), target_size=(IMAGE_SHAPE[0], IMAGE_SHAPE[1]),
41+
shuffle=True, subset="training")
42+
# make the validation dataset generator
43+
test_data_gen = image_generator.flow_from_directory(directory=str(data_dir), batch_size=batch_size,
44+
classes=list(CLASS_NAMES), target_size=(IMAGE_SHAPE[0], IMAGE_SHAPE[1]),
45+
shuffle=True, subset="validation")
46+
47+
return train_data_gen, test_data_gen, CLASS_NAMES
48+
49+
50+
def create_model(input_shape):
51+
# load MobileNetV2
52+
model = MobileNetV2(input_shape=input_shape)
53+
# remove the last fully connected layer
54+
model.layers.pop()
55+
# freeze all the weights of the model except the last 4 layers
56+
for layer in model.layers[:-4]:
57+
layer.trainable = False
58+
# construct our own fully connected layer for classification
59+
output = Dense(num_classes, activation="softmax")
60+
# connect that dense layer to the model
61+
output = output(model.layers[-1].output)
62+
63+
model = Model(inputs=model.inputs, outputs=output)
64+
65+
# print the summary of the model architecture
66+
model.summary()
67+
68+
# training the model using rmsprop optimizer
69+
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
70+
return model
71+
72+
73+
if __name__ == "__main__":
74+
75+
# load the data generators
76+
train_generator, validation_generator, class_names = load_data()
77+
78+
# constructs the model
79+
model = create_model(input_shape=IMAGE_SHAPE)
80+
# model name
81+
model_name = "MobileNetV2_finetune_last5"
82+
83+
# some nice callbacks
84+
tensorboard = TensorBoard(log_dir=f"logs/{model_name}")
85+
checkpoint = ModelCheckpoint(f"results/{model_name}" + "-loss-{val_loss:.2f}-acc-{val_acc:.2f}.h5",
86+
save_best_only=True,
87+
verbose=1)
88+
89+
# make sure results folder exist
90+
if not os.path.isdir("results"):
91+
os.mkdir("results")
92+
93+
# count number of steps per epoch
94+
training_steps_per_epoch = np.ceil(train_generator.samples / batch_size)
95+
validation_steps_per_epoch = np.ceil(validation_generator.samples / batch_size)
96+
97+
# train using the generators
98+
model.fit_generator(train_generator, steps_per_epoch=training_steps_per_epoch,
99+
validation_data=validation_generator, validation_steps=validation_steps_per_epoch,
100+
epochs=epochs, verbose=1, callbacks=[tensorboard, checkpoint])

0 commit comments

Comments
 (0)