Skip to content

Commit c0d2b96

Browse files
committed
Add PyTorch example
1 parent 6b08aef commit c0d2b96

File tree

1 file changed

+135
-0
lines changed

1 file changed

+135
-0
lines changed
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
#!/usr/bin/env python
2+
# coding: utf8
3+
"""Define a text classification model using PyTorch, and wrap it with Thinc's
4+
PytorchWrapper class, so it can be used in spaCy's TextCategorizer component.
5+
6+
The model is added to spacy.pipeline, and predictions are available via
7+
`doc.cats`. For more details, see the documentation:
8+
9+
* Deep learning: https://alpha.spacy.io/usage/deep-learning
10+
* Text classification: https://alpha.spacy.io/usage/text-classification
11+
12+
Developed for: spaCy 2.0.0a19
13+
Last updated for: spaCy 2.0.0a19
14+
"""
15+
from __future__ import unicode_literals, print_function
16+
import plac
17+
import random
18+
from pathlib import Path
19+
import thinc.extra.datasets
20+
import thinc.extra.wrappers
21+
22+
import spacy
23+
from spacy.gold import GoldParse, minibatch
24+
from spacy.util import compounding
25+
26+
27+
@plac.annotations(
28+
model=("Model name. Defaults to blank 'en' model.", "option", "m", str),
29+
output_dir=("Optional output directory", "option", "o", Path),
30+
n_texts=("Number of texts to train from", "option", "t", int),
31+
n_iter=("Number of training iterations", "option", "n", int))
32+
def main(model=None, output_dir=None, n_iter=20, n_texts=2000):
33+
if model is not None:
34+
nlp = spacy.load(model) # load existing spaCy model
35+
print("Loaded model '%s'" % model)
36+
else:
37+
nlp = spacy.blank('en') # create blank Language class
38+
print("Created blank 'en' model")
39+
40+
# Create the PyTorch neural network model, and wrap it with Thinc. This
41+
# gives it the API that spaCy expects.
42+
pt_model = create_model()
43+
textcat = thinc.extra.wrappers.PytorchWrapper(pt_model)
44+
nlp.add_pipe(textcat, last=True)
45+
46+
# add label to text classifier
47+
textcat.add_label('POSITIVE')
48+
49+
# load the IMBD dataset
50+
print("Loading IMDB data...")
51+
(train_texts, train_cats), (dev_texts, dev_cats) = load_data(limit=n_texts)
52+
print("Using %d training examples" % n_texts)
53+
train_docs = [nlp.tokenizer(text) for text in train_texts]
54+
train_gold = [GoldParse(doc, cats=cats) for doc, cats in
55+
zip(train_docs, train_cats)]
56+
train_data = list(zip(train_docs, train_gold))
57+
58+
# get names of other pipes to disable them during training
59+
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != 'textcat']
60+
with nlp.disable_pipes(*other_pipes): # only train textcat
61+
optimizer = nlp.begin_training()
62+
print("Training the model...")
63+
print('{:^5}\t{:^5}\t{:^5}\t{:^5}'.format('LOSS', 'P', 'R', 'F'))
64+
for i in range(n_iter):
65+
losses = {}
66+
# batch up the examples using spaCy's minibatch
67+
batches = minibatch(train_data, size=compounding(4., 32., 1.001))
68+
for batch in batches:
69+
docs, golds = zip(*batch)
70+
nlp.update(docs, golds, sgd=optimizer, drop=0.2, losses=losses)
71+
with textcat.model.use_params(optimizer.averages):
72+
# evaluate on the dev data split off in load_data()
73+
scores = evaluate(nlp.tokenizer, textcat, dev_texts, dev_cats)
74+
print('{0:.3f}\t{1:.3f}\t{2:.3f}\t{3:.3f}' # print a simple table
75+
.format(losses['textcat'], scores['textcat_p'],
76+
scores['textcat_r'], scores['textcat_f']))
77+
78+
# test the trained model
79+
test_text = "This movie sucked"
80+
doc = nlp(test_text)
81+
print(test_text, doc.cats)
82+
83+
if output_dir is not None:
84+
output_dir = Path(output_dir)
85+
if not output_dir.exists():
86+
output_dir.mkdir()
87+
nlp.to_disk(output_dir)
88+
print("Saved model to", output_dir)
89+
90+
# test the saved model
91+
print("Loading from", output_dir)
92+
nlp2 = spacy.load(output_dir)
93+
doc2 = nlp2(test_text)
94+
print(test_text, doc2.cats)
95+
96+
97+
def load_data(limit=0, split=0.8):
98+
"""Load data from the IMDB dataset."""
99+
# Partition off part of the train data for evaluation
100+
train_data, _ = thinc.extra.datasets.imdb()
101+
random.shuffle(train_data)
102+
train_data = train_data[-limit:]
103+
texts, labels = zip(*train_data)
104+
cats = [{'POSITIVE': bool(y)} for y in labels]
105+
split = int(len(train_data) * split)
106+
return (texts[:split], cats[:split]), (texts[split:], cats[split:])
107+
108+
109+
def evaluate(tokenizer, textcat, texts, cats):
110+
docs = (tokenizer(text) for text in texts)
111+
tp = 1e-8 # True positives
112+
fp = 1e-8 # False positives
113+
fn = 1e-8 # False negatives
114+
tn = 1e-8 # True negatives
115+
for i, doc in enumerate(textcat.pipe(docs)):
116+
gold = cats[i]
117+
for label, score in doc.cats.items():
118+
if label not in gold:
119+
continue
120+
if score >= 0.5 and gold[label] >= 0.5:
121+
tp += 1.
122+
elif score >= 0.5 and gold[label] < 0.5:
123+
fp += 1.
124+
elif score < 0.5 and gold[label] < 0.5:
125+
tn += 1
126+
elif score < 0.5 and gold[label] >= 0.5:
127+
fn += 1
128+
precision = tp / (tp + fp)
129+
recall = tp / (tp + fn)
130+
f_score = 2 * (precision * recall) / (precision + recall)
131+
return {'textcat_p': precision, 'textcat_r': recall, 'textcat_f': f_score}
132+
133+
134+
if __name__ == '__main__':
135+
plac.call(main)

0 commit comments

Comments
 (0)