19
19
def transforms_noaug_train (
20
20
img_size : Union [int , Tuple [int , int ]] = 224 ,
21
21
interpolation : str = 'bilinear' ,
22
- use_prefetcher : bool = False ,
23
22
mean : Tuple [float , ...] = IMAGENET_DEFAULT_MEAN ,
24
23
std : Tuple [float , ...] = IMAGENET_DEFAULT_STD ,
24
+ use_prefetcher : bool = False ,
25
+ normalize : bool = True ,
25
26
):
26
27
""" No-augmentation image transforms for training.
27
28
@@ -31,6 +32,7 @@ def transforms_noaug_train(
31
32
mean: Image normalization mean.
32
33
std: Image normalization standard deviation.
33
34
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
35
+ normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
34
36
35
37
Returns:
36
38
@@ -45,6 +47,9 @@ def transforms_noaug_train(
45
47
if use_prefetcher :
46
48
# prefetcher and collate will handle tensor conversion and norm
47
49
tfl += [ToNumpy ()]
50
+ elif not normalize :
51
+ # when normalize disabled, converted to tensor without scaling, keep original dtype
52
+ tfl += [transforms .PILToTensor ()]
48
53
else :
49
54
tfl += [
50
55
transforms .ToTensor (),
@@ -77,6 +82,7 @@ def transforms_imagenet_train(
77
82
re_count : int = 1 ,
78
83
re_num_splits : int = 0 ,
79
84
use_prefetcher : bool = False ,
85
+ normalize : bool = True ,
80
86
separate : bool = False ,
81
87
):
82
88
""" ImageNet-oriented image transforms for training.
@@ -103,6 +109,7 @@ def transforms_imagenet_train(
103
109
re_count: Number of random erasing regions.
104
110
re_num_splits: Control split of random erasing across batch size.
105
111
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
112
+ normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
106
113
separate: Output transforms in 3-stage tuple.
107
114
108
115
Returns:
@@ -209,12 +216,15 @@ def transforms_imagenet_train(
209
216
if use_prefetcher :
210
217
# prefetcher and collate will handle tensor conversion and norm
211
218
final_tfl += [ToNumpy ()]
219
+ elif not normalize :
220
+ # when normalize disable, converted to tensor without scaling, keeps original dtype
221
+ final_tfl += [transforms .PILToTensor ()]
212
222
else :
213
223
final_tfl += [
214
224
transforms .ToTensor (),
215
225
transforms .Normalize (
216
226
mean = torch .tensor (mean ),
217
- std = torch .tensor (std )
227
+ std = torch .tensor (std ),
218
228
),
219
229
]
220
230
if re_prob > 0. :
@@ -243,6 +253,7 @@ def transforms_imagenet_eval(
243
253
mean : Tuple [float , ...] = IMAGENET_DEFAULT_MEAN ,
244
254
std : Tuple [float , ...] = IMAGENET_DEFAULT_STD ,
245
255
use_prefetcher : bool = False ,
256
+ normalize : bool = True ,
246
257
):
247
258
""" ImageNet-oriented image transform for evaluation and inference.
248
259
@@ -255,6 +266,7 @@ def transforms_imagenet_eval(
255
266
mean: Image normalization mean.
256
267
std: Image normalization standard deviation.
257
268
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
269
+ normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
258
270
259
271
Returns:
260
272
Composed transform pipeline
@@ -304,13 +316,16 @@ def transforms_imagenet_eval(
304
316
if use_prefetcher :
305
317
# prefetcher and collate will handle tensor conversion and norm
306
318
tfl += [ToNumpy ()]
319
+ elif not normalize :
320
+ # when normalize disabled, converted to tensor without scaling, keeps original dtype
321
+ tfl += [transforms .PILToTensor ()]
307
322
else :
308
323
tfl += [
309
324
transforms .ToTensor (),
310
325
transforms .Normalize (
311
326
mean = torch .tensor (mean ),
312
327
std = torch .tensor (std ),
313
- )
328
+ ),
314
329
]
315
330
316
331
return transforms .Compose (tfl )
@@ -342,6 +357,7 @@ def create_transform(
342
357
crop_border_pixels : Optional [int ] = None ,
343
358
tf_preprocessing : bool = False ,
344
359
use_prefetcher : bool = False ,
360
+ normalize : bool = True ,
345
361
separate : bool = False ,
346
362
):
347
363
"""
@@ -373,6 +389,7 @@ def create_transform(
373
389
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
374
390
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
375
391
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
392
+ normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
376
393
separate: Output transforms in 3-stage tuple.
377
394
378
395
Returns:
@@ -397,9 +414,10 @@ def create_transform(
397
414
transform = transforms_noaug_train (
398
415
img_size ,
399
416
interpolation = interpolation ,
400
- use_prefetcher = use_prefetcher ,
401
417
mean = mean ,
402
418
std = std ,
419
+ use_prefetcher = use_prefetcher ,
420
+ normalize = normalize ,
403
421
)
404
422
elif is_training :
405
423
transform = transforms_imagenet_train (
@@ -415,26 +433,28 @@ def create_transform(
415
433
gaussian_blur_prob = gaussian_blur_prob ,
416
434
auto_augment = auto_augment ,
417
435
interpolation = interpolation ,
418
- use_prefetcher = use_prefetcher ,
419
436
mean = mean ,
420
437
std = std ,
421
438
re_prob = re_prob ,
422
439
re_mode = re_mode ,
423
440
re_count = re_count ,
424
441
re_num_splits = re_num_splits ,
442
+ use_prefetcher = use_prefetcher ,
443
+ normalize = normalize ,
425
444
separate = separate ,
426
445
)
427
446
else :
428
447
assert not separate , "Separate transforms not supported for validation preprocessing"
429
448
transform = transforms_imagenet_eval (
430
449
img_size ,
431
450
interpolation = interpolation ,
432
- use_prefetcher = use_prefetcher ,
433
451
mean = mean ,
434
452
std = std ,
435
453
crop_pct = crop_pct ,
436
454
crop_mode = crop_mode ,
437
455
crop_border_pixels = crop_border_pixels ,
456
+ use_prefetcher = use_prefetcher ,
457
+ normalize = normalize ,
438
458
)
439
459
440
460
return transform
0 commit comments