Skip to content

Commit a333ab5

Browse files
committed
feat(ppai): adds wildfires overview and datasets notebook
1 parent cfcf91b commit a333ab5

File tree

3 files changed

+1562
-0
lines changed

3 files changed

+1562
-0
lines changed
Lines changed: 293 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,293 @@
1+
# Copyright 2023 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Creates a dataset to train a machine learning model."""
16+
17+
from datetime import datetime, timedelta
18+
from typing import Dict, Iterable, List, Optional, NamedTuple, Tuple
19+
import io
20+
import logging
21+
import random
22+
import requests
23+
import uuid
24+
25+
import ee
26+
from google.api_core import retry, exceptions
27+
import google.auth
28+
import numpy as np
29+
from numpy.lib.recfunctions import structured_to_unstructured
30+
31+
INPUTS = {
32+
'USGS/SRTMGL1_003': ["elevation"],
33+
'GRIDMET/DROUGHT': ["psdi"],
34+
'ECMWF/ERA5/DAILY': [
35+
'mean_2m_air_temperature',
36+
'total_precipitation',
37+
'u_component_of_wind_10m',
38+
'v_component_of_wind_10m'],
39+
'IDAHO_EPSCOR/GRIDMET': [
40+
'pr',
41+
'sph',
42+
'th',
43+
'tmmn',
44+
'tmmx',
45+
'vs',
46+
'erc'],
47+
'CIESIN/GPWv411/GPW_Population_Density': ['population_density'],
48+
'MODIS/006/MOD14A1': ['FireMask']
49+
}
50+
51+
LABELS = {
52+
'MODIS/006/MOD14A1': ['FireMask'],
53+
}
54+
55+
SCALE = 5000
56+
WINDOW = timedelta(days=1)
57+
58+
START_DATE = datetime(2019, 1, 1)
59+
END_DATE = datetime(2020, 1, 1)
60+
61+
62+
class Bounds(NamedTuple):
63+
west: float
64+
south: float
65+
east: float
66+
north: float
67+
68+
69+
class Point(NamedTuple):
70+
lat: float
71+
lon: float
72+
73+
74+
class Example(NamedTuple):
75+
inputs: np.ndarray
76+
labels: np.ndarray
77+
78+
79+
def ee_init() -> None:
80+
"""Authenticate and initialize Earth Engine with the default credentials."""
81+
# Use the Earth Engine High Volume endpoint.
82+
# https://developers.google.com/earth-engine/cloud/highvolume
83+
credentials, project = google.auth.default()
84+
ee.Initialize(
85+
credentials,
86+
project=project,
87+
opt_url="https://earthengine-highvolume.googleapis.com",
88+
)
89+
90+
@retry.Retry(deadline=60 * 20) # seconds
91+
def ee_fetch(url: str) -> bytes:
92+
# If we get "429: Too Many Requests" errors, it's safe to retry the request.
93+
# The Retry library only works with `google.api_core` exceptions.
94+
response = requests.get(url)
95+
if response.status_code == 429:
96+
raise exceptions.TooManyRequests(response.text)
97+
98+
# Still raise any other exceptions to make sure we got valid data.
99+
response.raise_for_status()
100+
return response.content
101+
102+
103+
def get_image(
104+
date: datetime, bands_schema: Dict[str, List[str]], window: timedelta
105+
) -> ee.Image:
106+
ee_init()
107+
# if elevation dataset is part of bands_schema, deal with it separately
108+
if 'USGS/SRTMGL1_003' in bands_schema:
109+
elevation = ee.Image('USGS/SRTMGL1_003').select(bands_schema['USGS/SRTMGL1_003'])
110+
bands_schema.pop("USGS/SRTMGL1_003")
111+
else:
112+
elevation = None
113+
114+
# if population dataset is part of bands_schema, deal with it separately
115+
if 'CIESIN/GPWv411/GPW_Population_Density' in bands_schema:
116+
population = [
117+
ee.ImageCollection('CIESIN/GPWv411/GPW_Population_Density')
118+
.filterDate(date.isoformat(), (date + window).isoformat())
119+
.select(bands_schema['CIESIN/GPWv411/GPW_Population_Density'])
120+
.median()
121+
]
122+
bands_schema.pop("CIESIN/GPWv411/GPW_Population_Density")
123+
else:
124+
population = None
125+
126+
images = [
127+
ee.ImageCollection(collection)
128+
.filterDate(date.isoformat(), (date + window).isoformat())
129+
.select(bands)
130+
.mosaic()
131+
for collection, bands in bands_schema.items()
132+
]
133+
# add elevation to list
134+
if elevation:
135+
images.append(elevation)
136+
# add population to list
137+
if population:
138+
images.append(population)
139+
return ee.Image(images)
140+
141+
def get_input_image(date: datetime) -> ee.Image:
142+
return get_image(date, INPUTS, WINDOW)
143+
144+
145+
def get_label_image(date: datetime) -> ee.Image:
146+
return get_image(date, LABELS, WINDOW)
147+
148+
149+
def sample_labels(
150+
date: datetime, num_points: int, bounds: Bounds
151+
) -> Iterable[Tuple[datetime, Point]]:
152+
image = get_label_image(date)
153+
for lat, lon in sample_points(image, num_points, bounds, SCALE):
154+
yield (date, Point(lat, lon))
155+
156+
157+
def sample_points(
158+
image: ee.Image, num_points: int, bounds: Bounds, scale: int
159+
) -> np.ndarray:
160+
def get_coordinates(point: ee.Feature) -> ee.Feature:
161+
coords = point.geometry().coordinates()
162+
return ee.Feature(None, {"lat": coords.get(1), "lon": coords.get(0)})
163+
164+
points = image.int().stratifiedSample(
165+
num_points,
166+
region=ee.Geometry.Rectangle(bounds),
167+
scale=scale,
168+
geometries=True,
169+
)
170+
url = points.map(get_coordinates).getDownloadURL("CSV", ["lat", "lon"])
171+
return np.genfromtxt(io.BytesIO(ee_fetch(url)), delimiter=",", skip_header=1)
172+
173+
174+
def get_input_sequence(
175+
date: datetime, point: Point, patch_size: int, num_days: int
176+
) -> np.ndarray:
177+
dates = [date + timedelta(days=d) for d in range(1 - num_days, 1)]
178+
images = [get_input_image(d) for d in dates]
179+
return get_patch_sequence(images, point, patch_size, SCALE)
180+
181+
182+
def get_label_sequence(
183+
date: datetime, point: Point, patch_size: int, num_days: int
184+
) -> np.ndarray:
185+
dates = [date + timedelta(days=d) for d in range(1, num_days + 1)]
186+
images = [get_label_image(d) for d in dates]
187+
return get_patch_sequence(images, point, patch_size, SCALE)
188+
189+
190+
def get_training_example(
191+
date: datetime, point: Point, patch_size: int = 64, num_days: int = 2
192+
) -> Example:
193+
ee_init()
194+
return Example(
195+
get_input_sequence(date, point, patch_size, num_days + 1),
196+
get_label_sequence(date, point, patch_size, num_days),
197+
)
198+
199+
def try_get_training_example(
200+
date: datetime, point: Point, patch_size: int = 64, num_hours: int = 2
201+
) -> Iterable[Example]:
202+
try:
203+
yield get_training_example(date, point, patch_size, num_hours)
204+
except Exception as e:
205+
logging.exception(e)
206+
207+
def get_patch_sequence(
208+
image_sequence: List[ee.Image], point: Point, patch_size: int, scale: int
209+
) -> np.ndarray:
210+
def unpack(arr: np.ndarray, i: int) -> np.ndarray:
211+
names = [x for x in arr.dtype.names if x.startswith(f"{i}_")]
212+
return np.moveaxis(structured_to_unstructured(arr[names]), -1, 0)
213+
214+
point = ee.Geometry.Point([point.lon, point.lat])
215+
image = ee.ImageCollection(image_sequence).toBands()
216+
url = image.getDownloadURL(
217+
{
218+
"region": point.buffer(scale * patch_size / 2, 1).bounds(1),
219+
"dimensions": [patch_size, patch_size],
220+
"format": "NPY",
221+
}
222+
)
223+
flat_seq = np.load(io.BytesIO(ee_fetch(url)), allow_pickle=True)
224+
return np.stack([unpack(flat_seq, i) for i, _ in enumerate(image_sequence)], axis=1)
225+
226+
def write_npz_file(example: Example, file_prefix: str) -> str:
227+
from apache_beam.io.filesystems import FileSystems
228+
229+
filename = FileSystems.join(file_prefix, f"{uuid.uuid4()}.npz")
230+
with FileSystems.create(filename) as f:
231+
np.savez_compressed(f, inputs=example.inputs, labels=example.labels)
232+
return filename
233+
234+
235+
def run(
236+
output_path: str,
237+
num_dates: int,
238+
num_points: int,
239+
bounds: Bounds,
240+
patch_size: int,
241+
max_requests: int,
242+
beam_args: Optional[List[str]] = None,
243+
) -> None:
244+
import apache_beam as beam
245+
from apache_beam.options.pipeline_options import PipelineOptions
246+
247+
random_dates = [
248+
START_DATE + (END_DATE - START_DATE) * random.random() for _ in range(num_dates)
249+
]
250+
251+
beam_options = PipelineOptions(
252+
beam_args,
253+
save_main_session=True,
254+
requirements_file="requirements.txt",
255+
max_num_workers=max_requests,
256+
)
257+
with beam.Pipeline(options=beam_options) as pipeline:
258+
(
259+
pipeline
260+
| "Random dates" >> beam.Create(random_dates)
261+
| "Sample labels" >> beam.FlatMap(sample_labels, num_points, bounds)
262+
| "Reshuffle" >> beam.Reshuffle()
263+
| "Get example" >> beam.FlatMapTuple(try_get_training_example, patch_size)
264+
| "Write NPZ files" >> beam.Map(write_npz_file, output_path)
265+
| "Log files" >> beam.Map(logging.info)
266+
)
267+
268+
if __name__ == "__main__":
269+
import argparse
270+
271+
logging.getLogger().setLevel(logging.INFO)
272+
273+
parser = argparse.ArgumentParser()
274+
parser.add_argument("--output-path", required=True)
275+
parser.add_argument("--num-dates", type=int, default=20)
276+
parser.add_argument("--num-points", type=int, default=10)
277+
parser.add_argument("--west", type=float, default=-125.3)
278+
parser.add_argument("--south", type=float, default=27.4)
279+
parser.add_argument("--east", type=float, default=-66.5)
280+
parser.add_argument("--north", type=float, default=49.1)
281+
parser.add_argument("--patch-size", type=int, default=64)
282+
parser.add_argument("--max-requests", type=int, default=20)
283+
args, beam_args = parser.parse_known_args()
284+
285+
run(
286+
output_path=args.output_path,
287+
num_dates=args.num_dates,
288+
num_points=args.num_points,
289+
bounds=Bounds(args.west, args.south, args.east, args.north),
290+
patch_size=args.patch_size,
291+
max_requests=args.max_requests,
292+
beam_args=beam_args,
293+
)

0 commit comments

Comments
 (0)