|
| 1 | +#!/usr/bin/env python |
| 2 | +# coding: utf8 |
| 3 | +"""Define a text classification model using PyTorch, and wrap it with Thinc's |
| 4 | +PytorchWrapper class, so it can be used in spaCy's TextCategorizer component. |
| 5 | +
|
| 6 | +The model is added to spacy.pipeline, and predictions are available via |
| 7 | +`doc.cats`. For more details, see the documentation: |
| 8 | +
|
| 9 | +* Deep learning: https://alpha.spacy.io/usage/deep-learning |
| 10 | +* Text classification: https://alpha.spacy.io/usage/text-classification |
| 11 | +
|
| 12 | +Developed for: spaCy 2.0.0a19 |
| 13 | +Last updated for: spaCy 2.0.0a19 |
| 14 | +""" |
| 15 | +from __future__ import unicode_literals, print_function |
| 16 | +import plac |
| 17 | +import random |
| 18 | +from pathlib import Path |
| 19 | +import thinc.extra.datasets |
| 20 | +import thinc.extra.wrappers |
| 21 | + |
| 22 | +import spacy |
| 23 | +from spacy.gold import GoldParse, minibatch |
| 24 | +from spacy.util import compounding |
| 25 | + |
| 26 | + |
| 27 | +@plac.annotations( |
| 28 | + model=("Model name. Defaults to blank 'en' model.", "option", "m", str), |
| 29 | + output_dir=("Optional output directory", "option", "o", Path), |
| 30 | + n_texts=("Number of texts to train from", "option", "t", int), |
| 31 | + n_iter=("Number of training iterations", "option", "n", int)) |
| 32 | +def main(model=None, output_dir=None, n_iter=20, n_texts=2000): |
| 33 | + if model is not None: |
| 34 | + nlp = spacy.load(model) # load existing spaCy model |
| 35 | + print("Loaded model '%s'" % model) |
| 36 | + else: |
| 37 | + nlp = spacy.blank('en') # create blank Language class |
| 38 | + print("Created blank 'en' model") |
| 39 | + |
| 40 | + # Create the PyTorch neural network model, and wrap it with Thinc. This |
| 41 | + # gives it the API that spaCy expects. |
| 42 | + pt_model = create_model() |
| 43 | + textcat = thinc.extra.wrappers.PytorchWrapper(pt_model) |
| 44 | + nlp.add_pipe(textcat, last=True) |
| 45 | + |
| 46 | + # add label to text classifier |
| 47 | + textcat.add_label('POSITIVE') |
| 48 | + |
| 49 | + # load the IMBD dataset |
| 50 | + print("Loading IMDB data...") |
| 51 | + (train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts) |
| 52 | + print("Using %d training examples" % n_texts) |
| 53 | + train_docs = [nlp.tokenizer(text) for text in train_texts] |
| 54 | + train_gold = [GoldParse(doc, cats=cats) for doc, cats in |
| 55 | + zip(train_docs, train_cats)] |
| 56 | + train_data = list(zip(train_docs, train_gold)) |
| 57 | + |
| 58 | + # get names of other pipes to disable them during training |
| 59 | + other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat'] |
| 60 | + with nlp.disable_pipes(*other_pipes): # only train textcat |
| 61 | + optimizer = nlp.begin_training() |
| 62 | + print("Training the model...") |
| 63 | + print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F')) |
| 64 | + for i in range(n_iter): |
| 65 | + losses = {} |
| 66 | + # batch up the examples using spaCy's minibatch |
| 67 | + batches = minibatch(train_data, size=compounding(4., 32., 1.001)) |
| 68 | + for batch in batches: |
| 69 | + docs, golds = zip(*batch) |
| 70 | + nlp.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses) |
| 71 | + with textcat.model.use_params(optimizer.averages): |
| 72 | + # evaluate on the dev data split off in load_data() |
| 73 | + scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats) |
| 74 | + print('{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table |
| 75 | + .format(losses['textcat'], scores['textcat_p'], |
| 76 | + scores['textcat_r'], scores['textcat_f'])) |
| 77 | + |
| 78 | + # test the trained model |
| 79 | + test_text = "This movie sucked" |
| 80 | + doc = nlp(test_text) |
| 81 | + print(test_text, doc.cats) |
| 82 | + |
| 83 | + if output_dir is not None: |
| 84 | + output_dir = Path(output_dir) |
| 85 | + if not output_dir.exists(): |
| 86 | + output_dir.mkdir() |
| 87 | + nlp.to_disk(output_dir) |
| 88 | + print("Saved model to", output_dir) |
| 89 | + |
| 90 | + # test the saved model |
| 91 | + print("Loading from", output_dir) |
| 92 | + nlp2 = spacy.load(output_dir) |
| 93 | + doc2 = nlp2(test_text) |
| 94 | + print(test_text, doc2.cats) |
| 95 | + |
| 96 | + |
| 97 | +def load_data(limit=0, split=0.8): |
| 98 | + """Load data from the IMDB dataset.""" |
| 99 | + # Partition off part of the train data for evaluation |
| 100 | + train_data, _ = thinc.extra.datasets.imdb() |
| 101 | + random.shuffle(train_data) |
| 102 | + train_data = train_data[-limit:] |
| 103 | + texts, labels = zip(*train_data) |
| 104 | + cats = [{'POSITIVE': bool(y)} for y in labels] |
| 105 | + split = int(len(train_data) * split) |
| 106 | + return (texts[:split], cats[:split]), (texts[split:], cats[split:]) |
| 107 | + |
| 108 | + |
| 109 | +def evaluate(tokenizer, textcat, texts, cats): |
| 110 | + docs = (tokenizer(text) for text in texts) |
| 111 | + tp = 1e-8 # True positives |
| 112 | + fp = 1e-8 # False positives |
| 113 | + fn = 1e-8 # False negatives |
| 114 | + tn = 1e-8 # True negatives |
| 115 | + for i, doc in enumerate(textcat.pipe(docs)): |
| 116 | + gold = cats[i] |
| 117 | + for label, score in doc.cats.items(): |
| 118 | + if label not in gold: |
| 119 | + continue |
| 120 | + if score >= 0.5 and gold[label] >= 0.5: |
| 121 | + tp += 1. |
| 122 | + elif score >= 0.5 and gold[label] < 0.5: |
| 123 | + fp += 1. |
| 124 | + elif score < 0.5 and gold[label] < 0.5: |
| 125 | + tn += 1 |
| 126 | + elif score < 0.5 and gold[label] >= 0.5: |
| 127 | + fn += 1 |
| 128 | + precision = tp / (tp + fp) |
| 129 | + recall = tp / (tp + fn) |
| 130 | + f_score = 2 * (precision * recall) / (precision + recall) |
| 131 | + return {'textcat_p': precision, 'textcat_r': recall, 'textcat_f': f_score} |
| 132 | + |
| 133 | + |
| 134 | +if __name__ == '__main__': |
| 135 | + plac.call(main) |
0 commit comments