Skip to content

Commit df5242e

Browse files
committed
fix bug on text generator code
1 parent 37824b6 commit df5242e

File tree

2 files changed

+305
-2
lines changed

2 files changed

+305
-2
lines changed
Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"source": [
7+
"import tensorflow as tf\r\n",
8+
"import numpy as np\r\n",
9+
"import os\r\n",
10+
"import pickle\r\n",
11+
"\r\n",
12+
"SEQUENCE_LENGTH = 50\r\n",
13+
"EMBEDDING_DIM = 200\r\n",
14+
"BATCH_SIZE = 128\r\n",
15+
"FILE_PATH = \"data/python_code.py\"\r\n",
16+
"BASENAME = os.path.basename(FILE_PATH) + \"-lower\"\r\n",
17+
"\r\n",
18+
"text = open(FILE_PATH).read()\r\n",
19+
"# comment this if you want to use uppercase letters\r\n",
20+
"text = text.lower()\r\n",
21+
"n_chars = len(text)\r\n",
22+
"vocab = ''.join(sorted(set(text)))\r\n",
23+
"print(\"vocab:\", vocab)\r\n",
24+
"n_unique_chars = len(vocab)\r\n",
25+
"print(\"Number of characters:\", n_chars)\r\n",
26+
"print(\"Number of unique characters:\", n_unique_chars)"
27+
],
28+
"outputs": [],
29+
"metadata": {}
30+
},
31+
{
32+
"cell_type": "code",
33+
"execution_count": null,
34+
"source": [
35+
"# dictionary that converts characters to integers\r\n",
36+
"char2int = {c: i for i, c in enumerate(vocab)}\r\n",
37+
"# dictionary that converts integers to characters\r\n",
38+
"int2char = {i: c for i, c in enumerate(vocab)}\r\n",
39+
"\r\n",
40+
"# save these dictionaries for later generation\r\n",
41+
"pickle.dump(char2int, open(f\"{BASENAME}-char2int.pickle\", \"wb\"))\r\n",
42+
"pickle.dump(int2char, open(f\"{BASENAME}-int2char.pickle\", \"wb\"))"
43+
],
44+
"outputs": [],
45+
"metadata": {}
46+
},
47+
{
48+
"cell_type": "code",
49+
"execution_count": null,
50+
"source": [
51+
"encoded_text = np.array([char2int[c] for c in text])"
52+
],
53+
"outputs": [],
54+
"metadata": {}
55+
},
56+
{
57+
"cell_type": "code",
58+
"execution_count": null,
59+
"source": [
60+
"char_dataset = tf.data.Dataset.from_tensor_slices(encoded_text)\r\n",
61+
"for element in char_dataset.take(5):\r\n",
62+
" print(element.numpy())"
63+
],
64+
"outputs": [],
65+
"metadata": {}
66+
},
67+
{
68+
"cell_type": "code",
69+
"execution_count": null,
70+
"source": [
71+
"for element in char_dataset.batch(SEQUENCE_LENGTH+1).shuffle(1024).take(2):\r\n",
72+
" print(''.join([int2char[c] for c in element.numpy()]))"
73+
],
74+
"outputs": [],
75+
"metadata": {}
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": null,
80+
"source": [
81+
"#help(tf.one_hot)\r\n",
82+
"#help(char_dataset.window)\r\n",
83+
"windows = char_dataset.window(SEQUENCE_LENGTH+1, shift=1, drop_remainder=True)\r\n",
84+
"sequences = windows.flat_map(lambda window: window.batch(SEQUENCE_LENGTH+1))\r\n",
85+
"dataset = sequences.map(lambda x: (x[:-1], x[-1]))\r\n",
86+
"for input_, target in dataset.take(10):\r\n",
87+
" print(input_.numpy().shape)\r\n",
88+
" print(target.numpy().shape)\r\n",
89+
" print(''.join([int2char[c] for c in input_.numpy()]), int2char[target.numpy()])\r\n",
90+
" print(\"=\"*50)"
91+
],
92+
"outputs": [],
93+
"metadata": {}
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"source": [
99+
"sequences2 = char_dataset.batch(2*SEQUENCE_LENGTH+1, drop_remainder=True)\r\n",
100+
"\r\n",
101+
"def split_sample(sample):\r\n",
102+
" ds = tf.data.Dataset.from_tensors((sample[:SEQUENCE_LENGTH], sample[SEQUENCE_LENGTH]))\r\n",
103+
" for i in range(1, (len(sample)-1) // 2):\r\n",
104+
" input_ = sample[i:i+SEQUENCE_LENGTH]\r\n",
105+
" target = sample[i+SEQUENCE_LENGTH]\r\n",
106+
" other_ds = tf.data.Dataset.from_tensors((input_, target))\r\n",
107+
" ds = ds.concatenate(other_ds)\r\n",
108+
" return ds\r\n",
109+
"\r\n",
110+
"\r\n",
111+
"dataset2 = sequences2.flat_map(split_sample)\r\n",
112+
"for element in dataset2.take(10):\r\n",
113+
" print(element[0].shape, element[1].shape)\r\n",
114+
" print(''.join([int2char[c] for c in element[0].numpy()]), int2char[element[1].numpy()])"
115+
],
116+
"outputs": [],
117+
"metadata": {
118+
"tags": [
119+
"outputPrepend",
120+
"outputPrepend",
121+
"outputPrepend",
122+
"outputPrepend"
123+
]
124+
}
125+
},
126+
{
127+
"cell_type": "code",
128+
"execution_count": null,
129+
"source": [
130+
"for element1, element2 in zip(dataset.take(5), dataset2.take(5)):\r\n",
131+
" print(element1[0].numpy() == element2[0].numpy())\r\n",
132+
" "
133+
],
134+
"outputs": [],
135+
"metadata": {}
136+
},
137+
{
138+
"cell_type": "code",
139+
"execution_count": null,
140+
"source": [
141+
"def one_hot_samples(input_, target):\r\n",
142+
" return tf.one_hot(input_, len(vocab)), tf.one_hot(target, len(vocab))\r\n",
143+
"# return input_, tf.one_hot(target, len(vocab))\r\n",
144+
"\r\n",
145+
"dataset = dataset.map(one_hot_samples)\r\n",
146+
"dataset2 = dataset2.map(one_hot_samples)\r\n",
147+
"for element in dataset.take(10):\r\n",
148+
" print(element[0].shape, element[1].shape)"
149+
],
150+
"outputs": [],
151+
"metadata": {}
152+
},
153+
{
154+
"cell_type": "code",
155+
"execution_count": null,
156+
"source": [
157+
"ds = dataset.shuffle(1024).batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(1).repeat()\r\n",
158+
"ds2 = dataset2.shuffle(1024).batch(BATCH_SIZE, drop_remainder=True).cache().prefetch(1).repeat()"
159+
],
160+
"outputs": [],
161+
"metadata": {}
162+
},
163+
{
164+
"cell_type": "code",
165+
"execution_count": null,
166+
"source": [
167+
"def create_model(vocab_size, embedding_dim, rnn_units, batch_size):\r\n",
168+
" model = tf.keras.Sequential()\r\n",
169+
" # model.add(tf.keras.layers.Embedding(vocab_size, embedding_dim, input_shape=(SEQUENCE_LENGTH,)))\r\n",
170+
" model.add(tf.keras.layers.LSTM(rnn_units, input_shape=(SEQUENCE_LENGTH, len(vocab)), return_sequences=True))\r\n",
171+
" model.add(tf.keras.layers.Dropout(0.3))\r\n",
172+
" model.add(tf.keras.layers.LSTM(rnn_units)),\r\n",
173+
" model.add(tf.keras.layers.Dropout(0.3))\r\n",
174+
" model.add(tf.keras.layers.Dense(vocab_size, activation=\"softmax\"))\r\n",
175+
" return model"
176+
],
177+
"outputs": [],
178+
"metadata": {}
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": null,
183+
"source": [
184+
"model = create_model(len(vocab), embedding_dim=EMBEDDING_DIM, rnn_units=128, batch_size=BATCH_SIZE)\r\n",
185+
"model.summary()\r\n",
186+
"model.compile(optimizer=\"adam\", loss=\"categorical_crossentropy\", metrics=[\"accuracy\"])"
187+
],
188+
"outputs": [],
189+
"metadata": {}
190+
},
191+
{
192+
"cell_type": "code",
193+
"execution_count": null,
194+
"source": [
195+
"EPOCHS = 5\r\n",
196+
"history = model.fit(ds2, steps_per_epoch=(len(encoded_text) - SEQUENCE_LENGTH ) // BATCH_SIZE, epochs=EPOCHS)"
197+
],
198+
"outputs": [],
199+
"metadata": {}
200+
},
201+
{
202+
"cell_type": "code",
203+
"execution_count": null,
204+
"source": [
205+
"# save the model\r\n",
206+
"model_path = f\"results/{BASENAME}-{SEQUENCE_LENGTH}-NOEMBEDDING-moredata.h5\"\r\n",
207+
"model.save(model_path)\r\n",
208+
"# model.load_weights(model_path)"
209+
],
210+
"outputs": [],
211+
"metadata": {}
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": null,
216+
"source": [
217+
"seed = \"\"\"You can be a\"\"\".lower()\r\n",
218+
"s = seed\r\n",
219+
"# generate 400 characters\r\n",
220+
"generated = \"\"\r\n",
221+
"for i in range(200):\r\n",
222+
" # make the input sequence\r\n",
223+
" X = np.zeros((1, SEQUENCE_LENGTH, len(vocab)))\r\n",
224+
" # X = np.zeros((1, SEQUENCE_LENGTH))\r\n",
225+
" for t, char in enumerate(seed):\r\n",
226+
" X[0, (SEQUENCE_LENGTH - len(seed)) + t, char2int[char]] = 1\r\n",
227+
" # predict the next character\r\n",
228+
" predicted = model.predict(X, verbose=0)[0]\r\n",
229+
" # print(predicted)\r\n",
230+
" # converting the vector to an integer\r\n",
231+
" next_index = np.argmax(predicted)\r\n",
232+
"# next_index = np.squeeze(np.round(predicted))\r\n",
233+
" # converting the integer to a character\r\n",
234+
"# print(next_index)\r\n",
235+
" next_char = int2char[next_index]\r\n",
236+
" # add the character to results\r\n",
237+
" generated += next_char\r\n",
238+
" # shift seed and the predicted character\r\n",
239+
" seed = seed[1:] + next_char\r\n",
240+
"\r\n",
241+
"print(\"Generated text:\")\r\n",
242+
"print(s + generated)"
243+
],
244+
"outputs": [],
245+
"metadata": {}
246+
},
247+
{
248+
"cell_type": "code",
249+
"execution_count": null,
250+
"source": [
251+
"char2int\r\n"
252+
],
253+
"outputs": [],
254+
"metadata": {}
255+
},
256+
{
257+
"cell_type": "code",
258+
"execution_count": null,
259+
"source": [],
260+
"outputs": [],
261+
"metadata": {}
262+
},
263+
{
264+
"cell_type": "code",
265+
"execution_count": null,
266+
"source": [],
267+
"outputs": [],
268+
"metadata": {}
269+
}
270+
],
271+
"metadata": {
272+
"file_extension": ".py",
273+
"kernelspec": {
274+
"name": "python3",
275+
"display_name": "Python 3.8.7 64-bit"
276+
},
277+
"language_info": {
278+
"codemirror_mode": {
279+
"name": "ipython",
280+
"version": 3
281+
},
282+
"file_extension": ".py",
283+
"mimetype": "text/x-python",
284+
"name": "python",
285+
"nbconvert_exporter": "python",
286+
"pygments_lexer": "ipython3",
287+
"version": "3.8.7"
288+
},
289+
"mimetype": "text/x-python",
290+
"name": "python",
291+
"npconvert_exporter": "python",
292+
"pygments_lexer": "ipython3",
293+
"version": 3,
294+
"interpreter": {
295+
"hash": "777490da48e046e3b512f0b24bf037db286a787493a11bf82a9e0f2cbf21bb67"
296+
}
297+
},
298+
"nbformat": 4,
299+
"nbformat_minor": 4
300+
}

machine-learning/nlp/text-generator/train.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,10 @@ def one_hot_samples(input_, target):
113113
Dense(n_unique_chars, activation="softmax"),
114114
])
115115

116-
model.load_weights(f"results/{BASENAME}-{sequence_length}.h5")
116+
# define the model path
117+
model_weights_path = f"results/{BASENAME}-{sequence_length}.h5"
118+
# if os.path.isfile(model_weights_path):
119+
# model.load_weights(model_weights_path)
117120

118121
model.summary()
119122
model.compile(loss="categorical_crossentropy", optimizer="adam", metrics=["accuracy"])
@@ -126,4 +129,4 @@ def one_hot_samples(input_, target):
126129
# train the model
127130
model.fit(ds, steps_per_epoch=(len(encoded_text) - sequence_length) // BATCH_SIZE, epochs=EPOCHS)
128131
# save the model
129-
model.save(f"results/{BASENAME}-{sequence_length}.h5")
132+
model.save(model_weights_path)

0 commit comments

Comments
 (0)