1
+ import numpy as np
2
+ from faker import Faker
3
+ import random
4
+ from tqdm import tqdm
5
+ from babel .dates import format_date
6
+ from keras .utils import to_categorical
7
+ import keras .backend as K
8
+ import matplotlib .pyplot as plt
9
+
10
+ fake = Faker ()
11
+ fake .seed (12345 )
12
+ random .seed (12345 )
13
+
14
+ # Define format of the data we would like to generate
15
+ FORMATS = ['short' ,
16
+ 'medium' ,
17
+ 'long' ,
18
+ 'full' ,
19
+ 'full' ,
20
+ 'full' ,
21
+ 'full' ,
22
+ 'full' ,
23
+ 'full' ,
24
+ 'full' ,
25
+ 'full' ,
26
+ 'full' ,
27
+ 'full' ,
28
+ 'd MMM YYY' ,
29
+ 'd MMMM YYY' ,
30
+ 'dd MMM YYY' ,
31
+ 'd MMM, YYY' ,
32
+ 'd MMMM, YYY' ,
33
+ 'dd, MMM YYY' ,
34
+ 'd MM YY' ,
35
+ 'd MMMM YYY' ,
36
+ 'MMMM d YYY' ,
37
+ 'MMMM d, YYY' ,
38
+ 'dd.MM.YY' ]
39
+
40
+ # change this if you want it to work with another language
41
+ LOCALES = ['en_US' ]
42
+
43
+ def load_date ():
44
+ """
45
+ Loads some fake dates
46
+ :returns: tuple containing human readable string, machine readable string, and date object
47
+ """
48
+ dt = fake .date_object ()
49
+
50
+ try :
51
+ human_readable = format_date (dt , format = random .choice (FORMATS ), locale = 'en_US' ) # locale=random.choice(LOCALES))
52
+ human_readable = human_readable .lower ()
53
+ human_readable = human_readable .replace (',' ,'' )
54
+ machine_readable = dt .isoformat ()
55
+
56
+ except AttributeError as e :
57
+ return None , None , None
58
+
59
+ return human_readable , machine_readable , dt
60
+
61
+ def load_dataset (m ):
62
+ """
63
+ Loads a dataset with m examples and vocabularies
64
+ :m: the number of examples to generate
65
+ """
66
+
67
+ human_vocab = set ()
68
+ machine_vocab = set ()
69
+ dataset = []
70
+ Tx = 30
71
+
72
+
73
+ for i in tqdm (range (m )):
74
+ h , m , _ = load_date ()
75
+ if h is not None :
76
+ dataset .append ((h , m ))
77
+ human_vocab .update (tuple (h ))
78
+ machine_vocab .update (tuple (m ))
79
+
80
+ human = dict (zip (sorted (human_vocab ) + ['<unk>' , '<pad>' ],
81
+ list (range (len (human_vocab ) + 2 ))))
82
+ inv_machine = dict (enumerate (sorted (machine_vocab )))
83
+ machine = {v :k for k ,v in inv_machine .items ()}
84
+
85
+ return dataset , human , machine , inv_machine
86
+
87
+ def preprocess_data (dataset , human_vocab , machine_vocab , Tx , Ty ):
88
+
89
+ X , Y = zip (* dataset )
90
+
91
+ X = np .array ([string_to_int (i , Tx , human_vocab ) for i in X ])
92
+ Y = [string_to_int (t , Ty , machine_vocab ) for t in Y ]
93
+
94
+ Xoh = np .array (list (map (lambda x : to_categorical (x , num_classes = len (human_vocab )), X )))
95
+ Yoh = np .array (list (map (lambda x : to_categorical (x , num_classes = len (machine_vocab )), Y )))
96
+
97
+ return X , np .array (Y ), Xoh , Yoh
98
+
99
+ def string_to_int (string , length , vocab ):
100
+ """
101
+ Converts all strings in the vocabulary into a list of integers representing the positions of the
102
+ input string's characters in the "vocab"
103
+
104
+ Arguments:
105
+ string -- input string, e.g. 'Wed 10 Jul 2007'
106
+ length -- the number of time steps you'd like, determines if the output will be padded or cut
107
+ vocab -- vocabulary, dictionary used to index every character of your "string"
108
+
109
+ Returns:
110
+ rep -- list of integers (or '<unk>') (size = length) representing the position of the string's character in the vocabulary
111
+ """
112
+
113
+ #make lower to standardize
114
+ string = string .lower ()
115
+ string = string .replace (',' ,'' )
116
+
117
+ if len (string ) > length :
118
+ string = string [:length ]
119
+
120
+ rep = list (map (lambda x : vocab .get (x , '<unk>' ), string ))
121
+
122
+ if len (string ) < length :
123
+ rep += [vocab ['<pad>' ]] * (length - len (string ))
124
+
125
+ #print (rep)
126
+ return rep
127
+
128
+
129
+ def int_to_string (ints , inv_vocab ):
130
+ """
131
+ Output a machine readable list of characters based on a list of indexes in the machine's vocabulary
132
+
133
+ Arguments:
134
+ ints -- list of integers representing indexes in the machine's vocabulary
135
+ inv_vocab -- dictionary mapping machine readable indexes to machine readable characters
136
+
137
+ Returns:
138
+ l -- list of characters corresponding to the indexes of ints thanks to the inv_vocab mapping
139
+ """
140
+
141
+ l = [inv_vocab [i ] for i in ints ]
142
+ return l
143
+
144
+
145
+ EXAMPLES = ['3 May 1979' , '5 Apr 09' , '20th February 2016' , 'Wed 10 Jul 2007' ]
146
+
147
+ def run_example (model , input_vocabulary , inv_output_vocabulary , text ):
148
+ encoded = string_to_int (text , TIME_STEPS , input_vocabulary )
149
+ prediction = model .predict (np .array ([encoded ]))
150
+ prediction = np .argmax (prediction [0 ], axis = - 1 )
151
+ return int_to_string (prediction , inv_output_vocabulary )
152
+
153
+ def run_examples (model , input_vocabulary , inv_output_vocabulary , examples = EXAMPLES ):
154
+ predicted = []
155
+ for example in examples :
156
+ predicted .append ('' .join (run_example (model , input_vocabulary , inv_output_vocabulary , example )))
157
+ print ('input:' , example )
158
+ print ('output:' , predicted [- 1 ])
159
+ return predicted
160
+
161
+
162
+ def softmax (x , axis = 1 ):
163
+ """Softmax activation function.
164
+ # Arguments
165
+ x : Tensor.
166
+ axis: Integer, axis along which the softmax normalization is applied.
167
+ # Returns
168
+ Tensor, output of softmax transformation.
169
+ # Raises
170
+ ValueError: In case `dim(x) == 1`.
171
+ """
172
+ ndim = K .ndim (x )
173
+ if ndim == 2 :
174
+ return K .softmax (x )
175
+ elif ndim > 2 :
176
+ e = K .exp (x - K .max (x , axis = axis , keepdims = True ))
177
+ s = K .sum (e , axis = axis , keepdims = True )
178
+ return e / s
179
+ else :
180
+ raise ValueError ('Cannot apply softmax to a tensor that is 1D' )
181
+
182
+
183
+ def plot_attention_map (model , input_vocabulary , inv_output_vocabulary , text , n_s = 128 , num = 6 , Tx = 30 , Ty = 10 ):
184
+ """
185
+ Plot the attention map.
186
+
187
+ """
188
+ attention_map = np .zeros ((10 , 30 ))
189
+ Ty , Tx = attention_map .shape
190
+
191
+ s0 = np .zeros ((1 , n_s ))
192
+ c0 = np .zeros ((1 , n_s ))
193
+ layer = model .layers [num ]
194
+
195
+ encoded = np .array (string_to_int (text , Tx , input_vocabulary )).reshape ((1 , 30 ))
196
+ encoded = np .array (list (map (lambda x : to_categorical (x , num_classes = len (input_vocabulary )), encoded )))
197
+
198
+ f = K .function (model .inputs , [layer .get_output_at (t ) for t in range (Ty )])
199
+ r = f ([encoded , s0 , c0 ])
200
+
201
+ for t in range (Ty ):
202
+ for t_prime in range (Tx ):
203
+ attention_map [t ][t_prime ] = r [t ][0 ,t_prime ,0 ]
204
+
205
+ # Normalize attention map
206
+ # row_max = attention_map.max(axis=1)
207
+ # attention_map = attention_map / row_max[:, None]
208
+
209
+ prediction = model .predict ([encoded , s0 , c0 ])
210
+
211
+ predicted_text = []
212
+ for i in range (len (prediction )):
213
+ predicted_text .append (int (np .argmax (prediction [i ], axis = 1 )))
214
+
215
+ predicted_text = list (predicted_text )
216
+ predicted_text = int_to_string (predicted_text , inv_output_vocabulary )
217
+ text_ = list (text )
218
+
219
+ # get the lengths of the string
220
+ input_length = len (text )
221
+ output_length = Ty
222
+
223
+ # Plot the attention_map
224
+ plt .clf ()
225
+ f = plt .figure (figsize = (8 , 8.5 ))
226
+ ax = f .add_subplot (1 , 1 , 1 )
227
+
228
+ # add image
229
+ i = ax .imshow (attention_map , interpolation = 'nearest' , cmap = 'Blues' )
230
+
231
+ # add colorbar
232
+ cbaxes = f .add_axes ([0.2 , 0 , 0.6 , 0.03 ])
233
+ cbar = f .colorbar (i , cax = cbaxes , orientation = 'horizontal' )
234
+ cbar .ax .set_xlabel ('Alpha value (Probability output of the "softmax")' , labelpad = 2 )
235
+
236
+ # add labels
237
+ ax .set_yticks (range (output_length ))
238
+ ax .set_yticklabels (predicted_text [:output_length ])
239
+
240
+ ax .set_xticks (range (input_length ))
241
+ ax .set_xticklabels (text_ [:input_length ], rotation = 45 )
242
+
243
+ ax .set_xlabel ('Input Sequence' )
244
+ ax .set_ylabel ('Output Sequence' )
245
+
246
+ # add grid and legend
247
+ ax .grid ()
248
+
249
+ #f.show()
250
+
251
+ return attention_map
0 commit comments