@@ -148,8 +148,7 @@ def get_band_paths(scene: str, band_names: List[str]) -> Tuple[str, List[str]]:
148148 g = m .groupdict ()
149149 scene_dir = f"gs://gcp-public-data-landsat/{ g ['sensor' ]} /{ g ['collection' ]} /{ g ['wrs_path' ]} /{ g ['wrs_row' ]} /{ scene } "
150150
151- band_paths = [
152- f"{ scene_dir } /{ scene } _{ band_name } .TIF" for band_name in band_names ]
151+ band_paths = [f"{ scene_dir } /{ scene } _{ band_name } .TIF" for band_name in band_names ]
153152
154153 for band_path in band_paths :
155154 if not tf .io .gfile .exists (band_path ):
@@ -158,38 +157,31 @@ def get_band_paths(scene: str, band_names: List[str]) -> Tuple[str, List[str]]:
158157 return scene , band_paths
159158
160159
161- def load_values (scene : str , band_paths : List [str ]) -> Tuple [str , np .ndarray ]:
162- """Loads a scene's bands data as a numpy array.
160+ def save_to_gcs (
161+ scene : str , pixels : np .ndarray , output_path_prefix : str , format : str = "JPEG"
162+ ) -> None :
163+ """Saves a PIL.Image as a JPEG file in the desired path.
163164
164165 Args:
165166 scene: Landsat 8 scene ID.
166- band_paths: A list of the [Red, Green, Blue] band paths.
167-
168- Returns:
169- A (scene, values) pair.
170-
171- The values are stored in a three-dimensional float32 array with shape:
172- (band, width, height)
167+ image: A PIL.Image object.
168+ output_path_prefix: Path prefix to save the output files.
169+ format: Image format to save files.
173170 """
174-
175- def read_band (band_path : str ) -> np .array :
176- # Use rasterio to read the GeoTIFF values from the band files.
177- with tf .io .gfile .GFile (band_path , "rb" ) as f , rasterio .open (f ) as data :
178- return data .read (1 )
179-
180- logging .info (f"{ scene } : load_values({ band_paths } )" )
181- values = [read_band (band_path ) for band_path in band_paths ]
182- return scene , np .array (values , np .float32 )
171+ filename = os .path .join (output_path_prefix , scene + "." + format .lower ())
172+ with tf .io .gfile .GFile (filename , "w" ) as f :
173+ Image .fromarray (pixels , mode = "RGB" ).save (f , format )
183174
184175
185- def preprocess_pixels (
176+ def load_as_rgb (
186177 scene : str ,
187- values : np .ndarray ,
188- min_value : float = 0.0 ,
189- max_value : float = 1.0 ,
190- gamma : float = 1.0 ,
191- ) -> Tuple [str , tf .Tensor ]:
192- """Prepares the band data into a pixel-ready format for an RGB image.
178+ band_paths : List [str ],
179+ min_value : float = DEFAULT_MIN_BAND_VALUE ,
180+ max_value : float = DEFAULT_MAX_BAND_VALUE ,
181+ gamma : float = DEFAULT_GAMMA ,
182+ ) -> Tuple [str , np .ndarray ]:
183+ """Loads a scene's bands data and converts it into a pixel-ready format
184+ for an RGB image.
193185
194186 The input band values come in the shape (band, width, height) with
195187 unbounded positive numbers depending on the sensor's exposure.
@@ -198,20 +190,32 @@ def preprocess_pixels(
198190
199191 Args:
200192 scene: Landsat 8 scene ID.
201- values: Band values in the shape (band, width, height) .
193+ band_paths: A list of the [Red, Green, Blue] band paths .
202194 min_value: Minimum band value.
203195 max_value: Maximum band value.
204196 gamma: Gamma correction value.
205197
206198 Returns:
207- A (scene, pixels) pair. The pixels are Image-ready values.
199+ A (scene, pixels) pair.
200+
201+ The pixel values are stored in a three-dimensional uint8 array with shape:
202+ (width, height, rgb_channels)
208203 """
204+
205+ def read_band (band_path : str ) -> np .ndarray :
206+ # Use rasterio to read the GeoTIFF values from the band files.
207+ with tf .io .gfile .GFile (band_path , "rb" ) as f , rasterio .open (f ) as data :
208+ return data .read (1 ).astype (np .float32 )
209+
209210 logging .info (
210- f"{ scene } : preprocess_pixels( { values . shape } : { values . dtype } , min={ min_value } , max={ max_value } , gamma={ gamma } )"
211+ f"{ scene } : load_as_image( { band_paths } , min={ min_value } , max={ max_value } , gamma={ gamma } )"
211212 )
212213
213- # Reshape (band, width, height) into (width, height, band).
214- pixels = tf .transpose (values , (1 , 2 , 0 ))
214+ # Read the GeoTIFF files.
215+ band_values = [read_band (band_path ) for band_path in band_paths ]
216+
217+ # We get the band values into the shape (width, height, band).
218+ pixels = np .stack (band_values , axis = - 1 )
215219
216220 # Rescale to values from 0.0 to 1.0 and clamp them into that range.
217221 pixels -= min_value
@@ -221,25 +225,9 @@ def preprocess_pixels(
221225 # Apply gamma correction.
222226 pixels **= 1.0 / gamma
223227
224- # Return the pixel values as int8 in the range from 0 to 255,
228+ # Return the pixel values as uint8 in the range from 0 to 255,
225229 # which is what PIL.Image expects.
226- return scene , tf .cast (pixels * 255.0 , dtype = tf .uint8 )
227-
228-
229- def save_to_gcs (
230- scene : str , image : Image .Image , output_path_prefix : str , format : str = "JPEG"
231- ) -> None :
232- """Saves a PIL.Image as a JPEG file in the desired path.
233-
234- Args:
235- scene: Landsat 8 scene ID.
236- image: A PIL.Image object.
237- output_path_prefix: Path prefix to save the output files.
238- format: Image format to save files.
239- """
240- filename = os .path .join (output_path_prefix , scene + "." + format .lower ())
241- with tf .io .gfile .GFile (filename , "w" ) as f :
242- image .save (f , format )
230+ return scene , tf .cast (pixels * 255.0 , dtype = tf .uint8 ).numpy ()
243231
244232
245233def run (
@@ -269,26 +257,23 @@ def run(
269257 (
270258 pipeline
271259 | "Create scene IDs" >> beam .Create (scenes )
272- | "Check GPU availability" >> beam .Map (
260+ | "Check GPU availability"
261+ >> beam .Map (
273262 lambda x , unused_side_input : x ,
274263 unused_side_input = beam .pvalue .AsSingleton (
275264 pipeline
276265 | beam .Create ([None ])
277266 | beam .Map (check_gpus ).with_resource_hints (accelerator = gpu_hint )
278267 ),
279268 )
269+ | "Get RGB band paths" >> beam .Map (get_band_paths , rgb_band_names )
280270 # We reshuffle to prevent fusion and allow all I/O operations to happen in parallel.
281271 # For more information, see the "Preventing fusion" section in the documentation:
282272 # https://cloud.google.com/dataflow/docs/guides/deploying-a-pipeline#preventing-fusion
283273 | "Reshuffle" >> beam .Reshuffle ()
284- | "Get RGB band paths" >> beam .Map (get_band_paths , rgb_band_names )
285- | "Load RGB band values" >> beam .MapTuple (load_values )
286- | "Preprocess pixels GPU" >> beam .MapTuple (
287- preprocess_pixels , min_value , max_value , gamma
288- ).with_resource_hints (accelerator = gpu_hint )
289- | "Convert to image" >> beam .MapTuple (
290- lambda scene , rgb_pixels : (
291- scene , Image .fromarray (rgb_pixels .numpy (), mode = "RGB" ))
274+ | "Load bands as RGB"
275+ >> beam .MapTuple (load_as_rgb , min_value , max_value , gamma ).with_resource_hints (
276+ accelerator = gpu_hint
292277 )
293278 | "Save to Cloud Storage" >> beam .MapTuple (save_to_gcs , output_path_prefix )
294279 )
@@ -303,47 +288,46 @@ def run(
303288 "--output-path-prefix" ,
304289 required = True ,
305290 help = "Path prefix for output image files. "
306- "This can be a Google Cloud Storage path." )
291+ "This can be a Google Cloud Storage path." ,
292+ )
307293 parser .add_argument (
308294 "--scene" ,
309295 dest = "scenes" ,
310296 action = "append" ,
311297 help = "One or more Landsat scene IDs to process, for example "
312298 "LC08_L1TP_109078_20200411_20200422_01_T1. "
313299 "They must be in the format: "
314- "https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes" )
315- parser .add_argument (
316- "--gpu-type" ,
317- default = DEFAULT_GPU_TYPE ,
318- help = "GPU type to use." )
300+ "https://www.usgs.gov/faqs/what-naming-convention-landsat-collections-level-1-scenes" ,
301+ )
302+ parser .add_argument ("--gpu-type" , default = DEFAULT_GPU_TYPE , help = "GPU type to use." )
319303 parser .add_argument (
320- "--gpu-count" ,
321- type = int ,
322- default = DEFAULT_GPU_COUNT ,
323- help = "GPU count to use." )
304+ "--gpu-count" , type = int , default = DEFAULT_GPU_COUNT , help = "GPU count to use."
305+ )
324306 parser .add_argument (
325307 "--rgb-band-names" ,
326308 nargs = 3 ,
327309 default = DEFAULT_RGB_BAND_NAMES ,
328- help = "List of three band names to be mapped to the RGB channels." )
310+ help = "List of three band names to be mapped to the RGB channels." ,
311+ )
329312 parser .add_argument (
330313 "--min" ,
331314 type = float ,
332315 default = DEFAULT_MIN_BAND_VALUE ,
333- help = "Minimum value of the band value range." )
316+ help = "Minimum value of the band value range." ,
317+ )
334318 parser .add_argument (
335319 "--max" ,
336320 type = float ,
337321 default = DEFAULT_MAX_BAND_VALUE ,
338- help = "Maximum value of the band value range." )
322+ help = "Maximum value of the band value range." ,
323+ )
339324 parser .add_argument (
340- "--gamma" ,
341- type = float ,
342- default = DEFAULT_GAMMA ,
343- help = "Gamma correction factor." )
325+ "--gamma" , type = float , default = DEFAULT_GAMMA , help = "Gamma correction factor."
326+ )
344327 args , beam_args = parser .parse_known_args ()
345328
346- run (scenes = args .scenes or DEFAULT_SCENES ,
329+ run (
330+ scenes = args .scenes or DEFAULT_SCENES ,
347331 output_path_prefix = args .output_path_prefix ,
348332 vis_params = {
349333 "rgb_band_names" : args .rgb_band_names ,
@@ -353,4 +337,5 @@ def run(
353337 },
354338 gpu_type = args .gpu_type ,
355339 gpu_count = args .gpu_count ,
356- beam_args = beam_args )
340+ beam_args = beam_args ,
341+ )
0 commit comments