Skip to content

Commit e37a84c

Browse files
David Cavazosnicain
andauthored
dataflow tests: simplify wait function + fix landsat prime flaky test (GoogleCloudPlatform#7443)
* dataflow tests: simplify wait function to only accept job id * make timeout into a constant * make timeout into a constant * remove unused import * move reshuffle before GPU operation * manually fuse functions for the GPU * add default values to function * avoid multiprocessing since it cannot be pickled Co-authored-by: nicain <[email protected]>
1 parent 5c3322f commit e37a84c

File tree

3 files changed

+67
-87
lines changed

3 files changed

+67
-87
lines changed

dataflow/conftest.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -360,23 +360,17 @@ def dataflow_jobs_get(job_id: str, project: str = PROJECT) -> Dict[str, Any]:
360360

361361
@staticmethod
362362
def dataflow_jobs_wait(
363-
job_id: str = None,
364-
job_name: str = None,
363+
job_id: str,
365364
project: str = PROJECT,
366365
region: str = REGION,
367366
target_states: Set[str] = {"JOB_STATE_DONE"},
368-
list_page_size: int = LIST_PAGE_SIZE,
369367
timeout_sec: str = TIMEOUT_SEC,
370368
poll_interval_sec: int = POLL_INTERVAL_SEC,
371369
) -> Optional[str]:
372370
"""For a list of all the valid states:
373371
https://cloud.google.com/dataflow/docs/reference/rest/v1b3/projects.jobs#Job.JobState
374372
"""
375373

376-
assert job_id or job_name, "required to pass either a job_id or a job_name"
377-
if not job_id:
378-
job_id = Utils.dataflow_job_id(job_name, project, list_page_size)
379-
380374
finish_states = {
381375
"JOB_STATE_DONE",
382376
"JOB_STATE_FAILED",

dataflow/gpu-examples/tensorflow-landsat-prime/e2e_test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from google.cloud import storage
2424
import pytest
2525

26-
NAME = "dataflow/gpu-examples/tensorflow-landsat"
26+
NAME = "dataflow/gpu-examples/tensorflow-landsat-prime"
2727

2828

2929
@pytest.fixture(scope="session")
@@ -61,7 +61,8 @@ def test_tensorflow_landsat(
6161
) -> None:
6262
# Wait until the job finishes.
6363
timeout = 30 * 60 # 30 minutes
64-
utils.dataflow_jobs_wait(job_name=utils.hyphen_name(NAME), timeout_sec=timeout)
64+
job_id = utils.dataflow_job_id(utils.hyphen_name(NAME))
65+
utils.dataflow_jobs_wait(job_id, timeout_sec=timeout)
6566

6667
# Check that output files were created and are not empty.
6768
storage_client = storage.Client()

dataflow/gpu-examples/tensorflow-landsat-prime/main.py

Lines changed: 63 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -148,8 +148,7 @@ def get_band_paths(scene: str, band_names: List[str]) -> Tuple[str, List[str]]:
148148
g = m.groupdict()
149149
scene_dir = f"gs://gcp-public-data-landsat/{g['sensor']}/{g['collection']}/{g['wrs_path']}/{g['wrs_row']}/{scene}"
150150

151-
band_paths = [
152-
f"{scene_dir}/{scene}_{band_name}.TIF" for band_name in band_names]
151+
band_paths = [f"{scene_dir}/{scene}_{band_name}.TIF" for band_name in band_names]
153152

154153
for band_path in band_paths:
155154
if not tf.io.gfile.exists(band_path):
@@ -158,38 +157,31 @@ def get_band_paths(scene: str, band_names: List[str]) -> Tuple[str, List[str]]:
158157
return scene, band_paths
159158

160159

161-
def load_values(scene: str, band_paths: List[str]) -> Tuple[str, np.ndarray]:
162-
"""Loads a scene's bands data as a numpy array.
160+
def save_to_gcs(
161+
scene: str, pixels: np.ndarray, output_path_prefix: str, format: str = "JPEG"
162+
) -> None:
163+
"""Saves a PIL.Image as a JPEG file in the desired path.
163164
164165
Args:
165166
scene: Landsat 8 scene ID.
166-
band_paths: A list of the [Red, Green, Blue] band paths.
167-
168-
Returns:
169-
A (scene, values) pair.
170-
171-
The values are stored in a three-dimensional float32 array with shape:
172-
(band, width, height)
167+
image: A PIL.Image object.
168+
output_path_prefix: Path prefix to save the output files.
169+
format: Image format to save files.
173170
"""
174-
175-
def read_band(band_path: str) -> np.array:
176-
# Use rasterio to read the GeoTIFF values from the band files.
177-
with tf.io.gfile.GFile(band_path, "rb") as f, rasterio.open(f) as data:
178-
return data.read(1)
179-
180-
logging.info(f"{scene}: load_values({band_paths})")
181-
values = [read_band(band_path) for band_path in band_paths]
182-
return scene, np.array(values, np.float32)
171+
filename = os.path.join(output_path_prefix, scene + "." + format.lower())
172+
with tf.io.gfile.GFile(filename, "w") as f:
173+
Image.fromarray(pixels, mode="RGB").save(f, format)
183174

184175

185-
def preprocess_pixels(
176+
def load_as_rgb(
186177
scene: str,
187-
values: np.ndarray,
188-
min_value: float = 0.0,
189-
max_value: float = 1.0,
190-
gamma: float = 1.0,
191-
) -> Tuple[str, tf.Tensor]:
192-
"""Prepares the band data into a pixel-ready format for an RGB image.
178+
band_paths: List[str],
179+
min_value: float = DEFAULT_MIN_BAND_VALUE,
180+
max_value: float = DEFAULT_MAX_BAND_VALUE,
181+
gamma: float = DEFAULT_GAMMA,
182+
) -> Tuple[str, np.ndarray]:
183+
"""Loads a scene's bands data and converts it into a pixel-ready format
184+
for an RGB image.
193185
194186
The input band values come in the shape (band, width, height) with
195187
unbounded positive numbers depending on the sensor's exposure.
@@ -198,20 +190,32 @@ def preprocess_pixels(
198190
199191
Args:
200192
scene: Landsat 8 scene ID.
201-
values: Band values in the shape (band, width, height).
193+
band_paths: A list of the [Red, Green, Blue] band paths.
202194
min_value: Minimum band value.
203195
max_value: Maximum band value.
204196
gamma: Gamma correction value.
205197
206198
Returns:
207-
A (scene, pixels) pair. The pixels are Image-ready values.
199+
A (scene, pixels) pair.
200+
201+
The pixel values are stored in a three-dimensional uint8 array with shape:
202+
(width, height, rgb_channels)
208203
"""
204+
205+
def read_band(band_path: str) -> np.ndarray:
206+
# Use rasterio to read the GeoTIFF values from the band files.
207+
with tf.io.gfile.GFile(band_path, "rb") as f, rasterio.open(f) as data:
208+
return data.read(1).astype(np.float32)
209+
209210
logging.info(
210-
f"{scene}: preprocess_pixels({values.shape}:{values.dtype}, min={min_value}, max={max_value}, gamma={gamma})"
211+
f"{scene}: load_as_image({band_paths}, min={min_value}, max={max_value}, gamma={gamma})"
211212
)
212213

213-
# Reshape (band, width, height) into (width, height, band).
214-
pixels = tf.transpose(values, (1, 2, 0))
214+
# Read the GeoTIFF files.
215+
band_values = [read_band(band_path) for band_path in band_paths]
216+
217+
# We get the band values into the shape (width, height, band).
218+
pixels = np.stack(band_values, axis=-1)
215219

216220
# Rescale to values from 0.0 to 1.0 and clamp them into that range.
217221
pixels -= min_value
@@ -221,25 +225,9 @@ def preprocess_pixels(
221225
# Apply gamma correction.
222226
pixels **= 1.0 / gamma
223227

224-
# Return the pixel values as int8 in the range from 0 to 255,
228+
# Return the pixel values as uint8 in the range from 0 to 255,
225229
# which is what PIL.Image expects.
226-
return scene, tf.cast(pixels * 255.0, dtype=tf.uint8)
227-
228-
229-
def save_to_gcs(
230-
scene: str, image: Image.Image, output_path_prefix: str, format: str = "JPEG"
231-
) -> None:
232-
"""Saves a PIL.Image as a JPEG file in the desired path.
233-
234-
Args:
235-
scene: Landsat 8 scene ID.
236-
image: A PIL.Image object.
237-
output_path_prefix: Path prefix to save the output files.
238-
format: Image format to save files.
239-
"""
240-
filename = os.path.join(output_path_prefix, scene + "." + format.lower())
241-
with tf.io.gfile.GFile(filename, "w") as f:
242-
image.save(f, format)
230+
return scene, tf.cast(pixels * 255.0, dtype=tf.uint8).numpy()
243231

244232

245233
def run(
@@ -269,26 +257,23 @@ def run(
269257
(
270258
pipeline
271259
| "Create scene IDs" >> beam.Create(scenes)
272-
| "Check GPU availability" >> beam.Map(
260+
| "Check GPU availability"
261+
>> beam.Map(
273262
lambda x, unused_side_input: x,
274263
unused_side_input=beam.pvalue.AsSingleton(
275264
pipeline
276265
| beam.Create([None])
277266
| beam.Map(check_gpus).with_resource_hints(accelerator=gpu_hint)
278267
),
279268
)
269+
| "Get RGB band paths" >> beam.Map(get_band_paths, rgb_band_names)
280270
# We reshuffle to prevent fusion and allow all I/O operations to happen in parallel.
281271
# For more information, see the "Preventing fusion" section in the documentation:
282272
# https://cloud.google.com/dataflow/docs/guides/deploying-a-pipeline#preventing-fusion
283273
| "Reshuffle" >> beam.Reshuffle()
284-
| "Get RGB band paths" >> beam.Map(get_band_paths, rgb_band_names)
285-
| "Load RGB band values" >> beam.MapTuple(load_values)
286-
| "Preprocess pixels GPU" >> beam.MapTuple(
287-
preprocess_pixels, min_value, max_value, gamma
288-
).with_resource_hints(accelerator=gpu_hint)
289-
| "Convert to image" >> beam.MapTuple(
290-
lambda scene, rgb_pixels: (
291-
scene, Image.fromarray(rgb_pixels.numpy(), mode="RGB"))
274+
| "Load bands as RGB"
275+
>> beam.MapTuple(load_as_rgb, min_value, max_value, gamma).with_resource_hints(
276+
accelerator=gpu_hint
292277
)
293278
| "Save to Cloud Storage" >> beam.MapTuple(save_to_gcs, output_path_prefix)
294279
)
@@ -303,47 +288,46 @@ def run(
303288
"--output-path-prefix",
304289
required=True,
305290
help="Path prefix for output image files. "
306-
"This can be a Google Cloud Storage path.")
291+
"This can be a Google Cloud Storage path.",
292+
)
307293
parser.add_argument(
308294
"--scene",
309295
dest="scenes",
310296
action="append",
311297
help="One or more Landsat scene IDs to process, for example "
312298
"LC08_L1TP_109078_20200411_20200422_01_T1. "
313299
"They must be in the format: "
314-
"https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes")
315-
parser.add_argument(
316-
"--gpu-type",
317-
default=DEFAULT_GPU_TYPE,
318-
help="GPU type to use.")
300+
"https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes",
301+
)
302+
parser.add_argument("--gpu-type", default=DEFAULT_GPU_TYPE, help="GPU type to use.")
319303
parser.add_argument(
320-
"--gpu-count",
321-
type=int,
322-
default=DEFAULT_GPU_COUNT,
323-
help="GPU count to use.")
304+
"--gpu-count", type=int, default=DEFAULT_GPU_COUNT, help="GPU count to use."
305+
)
324306
parser.add_argument(
325307
"--rgb-band-names",
326308
nargs=3,
327309
default=DEFAULT_RGB_BAND_NAMES,
328-
help="List of three band names to be mapped to the RGB channels.")
310+
help="List of three band names to be mapped to the RGB channels.",
311+
)
329312
parser.add_argument(
330313
"--min",
331314
type=float,
332315
default=DEFAULT_MIN_BAND_VALUE,
333-
help="Minimum value of the band value range.")
316+
help="Minimum value of the band value range.",
317+
)
334318
parser.add_argument(
335319
"--max",
336320
type=float,
337321
default=DEFAULT_MAX_BAND_VALUE,
338-
help="Maximum value of the band value range.")
322+
help="Maximum value of the band value range.",
323+
)
339324
parser.add_argument(
340-
"--gamma",
341-
type=float,
342-
default=DEFAULT_GAMMA,
343-
help="Gamma correction factor.")
325+
"--gamma", type=float, default=DEFAULT_GAMMA, help="Gamma correction factor."
326+
)
344327
args, beam_args = parser.parse_known_args()
345328

346-
run(scenes=args.scenes or DEFAULT_SCENES,
329+
run(
330+
scenes=args.scenes or DEFAULT_SCENES,
347331
output_path_prefix=args.output_path_prefix,
348332
vis_params={
349333
"rgb_band_names": args.rgb_band_names,
@@ -353,4 +337,5 @@ def run(
353337
},
354338
gpu_type=args.gpu_type,
355339
gpu_count=args.gpu_count,
356-
beam_args=beam_args)
340+
beam_args=beam_args,
341+
)

0 commit comments

Comments
 (0)