1
+ # Copyright 2023 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Creates a dataset to train a machine learning model."""
16
+
17
+ from datetime import datetime , timedelta
18
+ from typing import Dict , Iterable , List , Optional , NamedTuple , Tuple
19
+ import io
20
+ import logging
21
+ import random
22
+ import requests
23
+ import uuid
24
+
25
+ import ee
26
+ from google .api_core import retry , exceptions
27
+ import google .auth
28
+ import numpy as np
29
+ from numpy .lib .recfunctions import structured_to_unstructured
30
+
31
+ INPUTS = {
32
+ 'USGS/SRTMGL1_003' : ["elevation" ],
33
+ 'GRIDMET/DROUGHT' : ["psdi" ],
34
+ 'ECMWF/ERA5/DAILY' : [
35
+ 'mean_2m_air_temperature' ,
36
+ 'total_precipitation' ,
37
+ 'u_component_of_wind_10m' ,
38
+ 'v_component_of_wind_10m' ],
39
+ 'IDAHO_EPSCOR/GRIDMET' : [
40
+ 'pr' ,
41
+ 'sph' ,
42
+ 'th' ,
43
+ 'tmmn' ,
44
+ 'tmmx' ,
45
+ 'vs' ,
46
+ 'erc' ],
47
+ 'CIESIN/GPWv411/GPW_Population_Density' : ['population_density' ],
48
+ 'MODIS/006/MOD14A1' : ['FireMask' ]
49
+ }
50
+
51
+ LABELS = {
52
+ 'MODIS/006/MOD14A1' : ['FireMask' ],
53
+ }
54
+
55
+ SCALE = 5000
56
+ WINDOW = timedelta (days = 1 )
57
+
58
+ START_DATE = datetime (2019 , 1 , 1 )
59
+ END_DATE = datetime (2020 , 1 , 1 )
60
+
61
+
62
+ class Bounds (NamedTuple ):
63
+ west : float
64
+ south : float
65
+ east : float
66
+ north : float
67
+
68
+
69
+ class Point (NamedTuple ):
70
+ lat : float
71
+ lon : float
72
+
73
+
74
+ class Example (NamedTuple ):
75
+ inputs : np .ndarray
76
+ labels : np .ndarray
77
+
78
+
79
+ def ee_init () -> None :
80
+ """Authenticate and initialize Earth Engine with the default credentials."""
81
+ # Use the Earth Engine High Volume endpoint.
82
+ # https://developers.google.com/earth-engine/cloud/highvolume
83
+ credentials , project = google .auth .default ()
84
+ ee .Initialize (
85
+ credentials ,
86
+ project = project ,
87
+ opt_url = "https://earthengine-highvolume.googleapis.com" ,
88
+ )
89
+
90
+ @retry .Retry (deadline = 60 * 20 ) # seconds
91
+ def ee_fetch (url : str ) -> bytes :
92
+ # If we get "429: Too Many Requests" errors, it's safe to retry the request.
93
+ # The Retry library only works with `google.api_core` exceptions.
94
+ response = requests .get (url )
95
+ if response .status_code == 429 :
96
+ raise exceptions .TooManyRequests (response .text )
97
+
98
+ # Still raise any other exceptions to make sure we got valid data.
99
+ response .raise_for_status ()
100
+ return response .content
101
+
102
+
103
+ def get_image (
104
+ date : datetime , bands_schema : Dict [str , List [str ]], window : timedelta
105
+ ) -> ee .Image :
106
+ ee_init ()
107
+ # if elevation dataset is part of bands_schema, deal with it separately
108
+ if 'USGS/SRTMGL1_003' in bands_schema :
109
+ elevation = ee .Image ('USGS/SRTMGL1_003' ).select (bands_schema ['USGS/SRTMGL1_003' ])
110
+ bands_schema .pop ("USGS/SRTMGL1_003" )
111
+ else :
112
+ elevation = None
113
+
114
+ # if population dataset is part of bands_schema, deal with it separately
115
+ if 'CIESIN/GPWv411/GPW_Population_Density' in bands_schema :
116
+ population = [
117
+ ee .ImageCollection ('CIESIN/GPWv411/GPW_Population_Density' )
118
+ .filterDate (date .isoformat (), (date + window ).isoformat ())
119
+ .select (bands_schema ['CIESIN/GPWv411/GPW_Population_Density' ])
120
+ .median ()
121
+ ]
122
+ bands_schema .pop ("CIESIN/GPWv411/GPW_Population_Density" )
123
+ else :
124
+ population = None
125
+
126
+ images = [
127
+ ee .ImageCollection (collection )
128
+ .filterDate (date .isoformat (), (date + window ).isoformat ())
129
+ .select (bands )
130
+ .mosaic ()
131
+ for collection , bands in bands_schema .items ()
132
+ ]
133
+ # add elevation to list
134
+ if elevation :
135
+ images .append (elevation )
136
+ # add population to list
137
+ if population :
138
+ images .append (population )
139
+ return ee .Image (images )
140
+
141
+ def get_input_image (date : datetime ) -> ee .Image :
142
+ return get_image (date , INPUTS , WINDOW )
143
+
144
+
145
+ def get_label_image (date : datetime ) -> ee .Image :
146
+ return get_image (date , LABELS , WINDOW )
147
+
148
+
149
+ def sample_labels (
150
+ date : datetime , num_points : int , bounds : Bounds
151
+ ) -> Iterable [Tuple [datetime , Point ]]:
152
+ image = get_label_image (date )
153
+ for lat , lon in sample_points (image , num_points , bounds , SCALE ):
154
+ yield (date , Point (lat , lon ))
155
+
156
+
157
+ def sample_points (
158
+ image : ee .Image , num_points : int , bounds : Bounds , scale : int
159
+ ) -> np .ndarray :
160
+ def get_coordinates (point : ee .Feature ) -> ee .Feature :
161
+ coords = point .geometry ().coordinates ()
162
+ return ee .Feature (None , {"lat" : coords .get (1 ), "lon" : coords .get (0 )})
163
+
164
+ points = image .int ().stratifiedSample (
165
+ num_points ,
166
+ region = ee .Geometry .Rectangle (bounds ),
167
+ scale = scale ,
168
+ geometries = True ,
169
+ )
170
+ url = points .map (get_coordinates ).getDownloadURL ("CSV" , ["lat" , "lon" ])
171
+ return np .genfromtxt (io .BytesIO (ee_fetch (url )), delimiter = "," , skip_header = 1 )
172
+
173
+
174
+ def get_input_sequence (
175
+ date : datetime , point : Point , patch_size : int , num_days : int
176
+ ) -> np .ndarray :
177
+ dates = [date + timedelta (days = d ) for d in range (1 - num_days , 1 )]
178
+ images = [get_input_image (d ) for d in dates ]
179
+ return get_patch_sequence (images , point , patch_size , SCALE )
180
+
181
+
182
+ def get_label_sequence (
183
+ date : datetime , point : Point , patch_size : int , num_days : int
184
+ ) -> np .ndarray :
185
+ dates = [date + timedelta (days = d ) for d in range (1 , num_days + 1 )]
186
+ images = [get_label_image (d ) for d in dates ]
187
+ return get_patch_sequence (images , point , patch_size , SCALE )
188
+
189
+
190
+ def get_training_example (
191
+ date : datetime , point : Point , patch_size : int = 64 , num_days : int = 2
192
+ ) -> Example :
193
+ ee_init ()
194
+ return Example (
195
+ get_input_sequence (date , point , patch_size , num_days + 1 ),
196
+ get_label_sequence (date , point , patch_size , num_days ),
197
+ )
198
+
199
+ def try_get_training_example (
200
+ date : datetime , point : Point , patch_size : int = 64 , num_hours : int = 2
201
+ ) -> Iterable [Example ]:
202
+ try :
203
+ yield get_training_example (date , point , patch_size , num_hours )
204
+ except Exception as e :
205
+ logging .exception (e )
206
+
207
+ def get_patch_sequence (
208
+ image_sequence : List [ee .Image ], point : Point , patch_size : int , scale : int
209
+ ) -> np .ndarray :
210
+ def unpack (arr : np .ndarray , i : int ) -> np .ndarray :
211
+ names = [x for x in arr .dtype .names if x .startswith (f"{ i } _" )]
212
+ return np .moveaxis (structured_to_unstructured (arr [names ]), - 1 , 0 )
213
+
214
+ point = ee .Geometry .Point ([point .lon , point .lat ])
215
+ image = ee .ImageCollection (image_sequence ).toBands ()
216
+ url = image .getDownloadURL (
217
+ {
218
+ "region" : point .buffer (scale * patch_size / 2 , 1 ).bounds (1 ),
219
+ "dimensions" : [patch_size , patch_size ],
220
+ "format" : "NPY" ,
221
+ }
222
+ )
223
+ flat_seq = np .load (io .BytesIO (ee_fetch (url )), allow_pickle = True )
224
+ return np .stack ([unpack (flat_seq , i ) for i , _ in enumerate (image_sequence )], axis = 1 )
225
+
226
+ def write_npz_file (example : Example , file_prefix : str ) -> str :
227
+ from apache_beam .io .filesystems import FileSystems
228
+
229
+ filename = FileSystems .join (file_prefix , f"{ uuid .uuid4 ()} .npz" )
230
+ with FileSystems .create (filename ) as f :
231
+ np .savez_compressed (f , inputs = example .inputs , labels = example .labels )
232
+ return filename
233
+
234
+
235
+ def run (
236
+ output_path : str ,
237
+ num_dates : int ,
238
+ num_points : int ,
239
+ bounds : Bounds ,
240
+ patch_size : int ,
241
+ max_requests : int ,
242
+ beam_args : Optional [List [str ]] = None ,
243
+ ) -> None :
244
+ import apache_beam as beam
245
+ from apache_beam .options .pipeline_options import PipelineOptions
246
+
247
+ random_dates = [
248
+ START_DATE + (END_DATE - START_DATE ) * random .random () for _ in range (num_dates )
249
+ ]
250
+
251
+ beam_options = PipelineOptions (
252
+ beam_args ,
253
+ save_main_session = True ,
254
+ requirements_file = "requirements.txt" ,
255
+ max_num_workers = max_requests ,
256
+ )
257
+ with beam .Pipeline (options = beam_options ) as pipeline :
258
+ (
259
+ pipeline
260
+ | "Random dates" >> beam .Create (random_dates )
261
+ | "Sample labels" >> beam .FlatMap (sample_labels , num_points , bounds )
262
+ | "Reshuffle" >> beam .Reshuffle ()
263
+ | "Get example" >> beam .FlatMapTuple (try_get_training_example , patch_size )
264
+ | "Write NPZ files" >> beam .Map (write_npz_file , output_path )
265
+ | "Log files" >> beam .Map (logging .info )
266
+ )
267
+
268
+ if __name__ == "__main__" :
269
+ import argparse
270
+
271
+ logging .getLogger ().setLevel (logging .INFO )
272
+
273
+ parser = argparse .ArgumentParser ()
274
+ parser .add_argument ("--output-path" , required = True )
275
+ parser .add_argument ("--num-dates" , type = int , default = 20 )
276
+ parser .add_argument ("--num-points" , type = int , default = 10 )
277
+ parser .add_argument ("--west" , type = float , default = - 125.3 )
278
+ parser .add_argument ("--south" , type = float , default = 27.4 )
279
+ parser .add_argument ("--east" , type = float , default = - 66.5 )
280
+ parser .add_argument ("--north" , type = float , default = 49.1 )
281
+ parser .add_argument ("--patch-size" , type = int , default = 64 )
282
+ parser .add_argument ("--max-requests" , type = int , default = 20 )
283
+ args , beam_args = parser .parse_known_args ()
284
+
285
+ run (
286
+ output_path = args .output_path ,
287
+ num_dates = args .num_dates ,
288
+ num_points = args .num_points ,
289
+ bounds = Bounds (args .west , args .south , args .east , args .north ),
290
+ patch_size = args .patch_size ,
291
+ max_requests = args .max_requests ,
292
+ beam_args = beam_args ,
293
+ )
0 commit comments