|
| 1 | +# to use CPU uncomment below code |
| 2 | +# import os |
| 3 | +# os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" # see issue #152 |
| 4 | +# os.environ["CUDA_VISIBLE_DEVICES"] = "-1" |
| 5 | + |
| 6 | +# import tensorflow as tf |
| 7 | + |
| 8 | +# config = tf.ConfigProto(intra_op_parallelism_threads=5, |
| 9 | +# inter_op_parallelism_threads=5, |
| 10 | +# allow_soft_placement=True, |
| 11 | +# device_count = {'CPU' : 1, |
| 12 | +# 'GPU' : 0} |
| 13 | +# ) |
| 14 | + |
| 15 | + |
| 16 | +from keras.preprocessing.text import Tokenizer |
| 17 | +from keras.preprocessing.sequence import pad_sequences |
| 18 | +from keras.utils import to_categorical |
| 19 | +from keras.callbacks import ModelCheckpoint, TensorBoard |
| 20 | +from sklearn.model_selection import train_test_split |
| 21 | +import time |
| 22 | +import numpy as np |
| 23 | +import pickle |
| 24 | + |
| 25 | +from utils import get_embedding_vectors, get_model, SEQUENCE_LENGTH, EMBEDDING_SIZE, TEST_SIZE |
| 26 | +from utils import BATCH_SIZE, EPOCHS, int2label, label2int |
| 27 | + |
| 28 | + |
| 29 | +def load_data(): |
| 30 | + """ |
| 31 | + Loads SMS Spam Collection dataset |
| 32 | + """ |
| 33 | + texts, labels = [], [] |
| 34 | + with open("data/SMSSpamCollection") as f: |
| 35 | + for line in f: |
| 36 | + split = line.split() |
| 37 | + labels.append(split[0].strip()) |
| 38 | + texts.append(' '.join(split[1:]).strip()) |
| 39 | + return texts, labels |
| 40 | + |
| 41 | + |
| 42 | +# load the data |
| 43 | +X, y = load_data() |
| 44 | + |
| 45 | +# Text tokenization |
| 46 | +# vectorizing text, turning each text into sequence of integers |
| 47 | +tokenizer = Tokenizer() |
| 48 | +tokenizer.fit_on_texts(X) |
| 49 | +# lets dump it to a file, so we can use it in testing |
| 50 | +pickle.dump(tokenizer, open("results/tokenizer.pickle", "wb")) |
| 51 | + |
| 52 | +# convert to sequence of integers |
| 53 | +X = tokenizer.texts_to_sequences(X) |
| 54 | +print(X[0]) |
| 55 | +# convert to numpy arrays |
| 56 | +X = np.array(X) |
| 57 | +y = np.array(y) |
| 58 | +# pad sequences at the beginning of each sequence with 0's |
| 59 | +# for example if SEQUENCE_LENGTH=4: |
| 60 | +# [[5, 3, 2], [5, 1, 2, 3], [3, 4]] |
| 61 | +# will be transformed to: |
| 62 | +# [[0, 5, 3, 2], [5, 1, 2, 3], [0, 0, 3, 4]] |
| 63 | +X = pad_sequences(X, maxlen=SEQUENCE_LENGTH) |
| 64 | +print(X[0]) |
| 65 | +# One Hot encoding labels |
| 66 | +# [spam, ham, spam, ham, ham] will be converted to: |
| 67 | +# [1, 0, 1, 0, 1] and then to: |
| 68 | +# [[0, 1], [1, 0], [0, 1], [1, 0], [0, 1]] |
| 69 | + |
| 70 | +y = [ label2int[label] for label in y ] |
| 71 | +y = to_categorical(y) |
| 72 | + |
| 73 | +print(y[0]) |
| 74 | + |
| 75 | +# split and shuffle |
| 76 | +X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=TEST_SIZE, random_state=7) |
| 77 | + |
| 78 | +# constructs the model with 128 LSTM units |
| 79 | +model = get_model(tokenizer=tokenizer, lstm_units=128) |
| 80 | + |
| 81 | +# initialize our ModelCheckpoint and TensorBoard callbacks |
| 82 | +# model checkpoint for saving best weights |
| 83 | +model_checkpoint = ModelCheckpoint("results/spam_classifier_{val_loss:.2f}", save_best_only=True, |
| 84 | + verbose=1) |
| 85 | +# for better visualization |
| 86 | +tensorboard = TensorBoard(f"logs/spam_classifier_{time.time()}") |
| 87 | +# print our data shapes |
| 88 | +print("X_train.shape:", X_train.shape) |
| 89 | +print("X_test.shape:", X_test.shape) |
| 90 | +print("y_train.shape:", y_train.shape) |
| 91 | +print("y_test.shape:", y_test.shape) |
| 92 | +# train the model |
| 93 | +model.fit(X_train, y_train, validation_data=(X_test, y_test), |
| 94 | + batch_size=BATCH_SIZE, epochs=EPOCHS, |
| 95 | + callbacks=[tensorboard, model_checkpoint], |
| 96 | + verbose=1) |
| 97 | + |
| 98 | +# get the loss and metrics |
| 99 | +result = model.evaluate(X_test, y_test) |
| 100 | +# extract those |
| 101 | +loss = result[0] |
| 102 | +accuracy = result[1] |
| 103 | +precision = result[2] |
| 104 | +recall = result[3] |
| 105 | + |
| 106 | +print(f"[+] Accuracy: {accuracy*100:.2f}%") |
| 107 | +print(f"[+] Precision: {precision*100:.2f}%") |
| 108 | +print(f"[+] Recall: {recall*100:.2f}%") |
| 109 | + |
| 110 | + |
| 111 | + |
0 commit comments