Skip to content

Commit 3bfd036

Browse files
committed
Add normalize flag to transforms factory, allow return of non-normalized native dtype torch.Tensors
1 parent a69863a commit 3bfd036

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

timm/data/transforms_factory.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919
def transforms_noaug_train(
2020
img_size: Union[int, Tuple[int, int]] = 224,
2121
interpolation: str = 'bilinear',
22-
use_prefetcher: bool = False,
2322
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
2423
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
24+
use_prefetcher: bool = False,
25+
normalize: bool = True,
2526
):
2627
""" No-augmentation image transforms for training.
2728
@@ -31,6 +32,7 @@ def transforms_noaug_train(
3132
mean: Image normalization mean.
3233
std: Image normalization standard deviation.
3334
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
35+
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
3436
3537
Returns:
3638
@@ -45,6 +47,9 @@ def transforms_noaug_train(
4547
if use_prefetcher:
4648
# prefetcher and collate will handle tensor conversion and norm
4749
tfl += [ToNumpy()]
50+
elif not normalize:
51+
# when normalize disabled, converted to tensor without scaling, keep original dtype
52+
tfl += [transforms.PILToTensor()]
4853
else:
4954
tfl += [
5055
transforms.ToTensor(),
@@ -77,6 +82,7 @@ def transforms_imagenet_train(
7782
re_count: int = 1,
7883
re_num_splits: int = 0,
7984
use_prefetcher: bool = False,
85+
normalize: bool = True,
8086
separate: bool = False,
8187
):
8288
""" ImageNet-oriented image transforms for training.
@@ -103,6 +109,7 @@ def transforms_imagenet_train(
103109
re_count: Number of random erasing regions.
104110
re_num_splits: Control split of random erasing across batch size.
105111
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
112+
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
106113
separate: Output transforms in 3-stage tuple.
107114
108115
Returns:
@@ -209,12 +216,15 @@ def transforms_imagenet_train(
209216
if use_prefetcher:
210217
# prefetcher and collate will handle tensor conversion and norm
211218
final_tfl += [ToNumpy()]
219+
elif not normalize:
220+
# when normalize disable, converted to tensor without scaling, keeps original dtype
221+
final_tfl += [transforms.PILToTensor()]
212222
else:
213223
final_tfl += [
214224
transforms.ToTensor(),
215225
transforms.Normalize(
216226
mean=torch.tensor(mean),
217-
std=torch.tensor(std)
227+
std=torch.tensor(std),
218228
),
219229
]
220230
if re_prob > 0.:
@@ -243,6 +253,7 @@ def transforms_imagenet_eval(
243253
mean: Tuple[float, ...] = IMAGENET_DEFAULT_MEAN,
244254
std: Tuple[float, ...] = IMAGENET_DEFAULT_STD,
245255
use_prefetcher: bool = False,
256+
normalize: bool = True,
246257
):
247258
""" ImageNet-oriented image transform for evaluation and inference.
248259
@@ -255,6 +266,7 @@ def transforms_imagenet_eval(
255266
mean: Image normalization mean.
256267
std: Image normalization standard deviation.
257268
use_prefetcher: Prefetcher enabled. Do not convert image to tensor or normalize.
269+
normalize: Normalize tensor output w/ provided mean/std (if prefetcher not used).
258270
259271
Returns:
260272
Composed transform pipeline
@@ -304,13 +316,16 @@ def transforms_imagenet_eval(
304316
if use_prefetcher:
305317
# prefetcher and collate will handle tensor conversion and norm
306318
tfl += [ToNumpy()]
319+
elif not normalize:
320+
# when normalize disabled, converted to tensor without scaling, keeps original dtype
321+
tfl += [transforms.PILToTensor()]
307322
else:
308323
tfl += [
309324
transforms.ToTensor(),
310325
transforms.Normalize(
311326
mean=torch.tensor(mean),
312327
std=torch.tensor(std),
313-
)
328+
),
314329
]
315330

316331
return transforms.Compose(tfl)
@@ -342,6 +357,7 @@ def create_transform(
342357
crop_border_pixels: Optional[int] = None,
343358
tf_preprocessing: bool = False,
344359
use_prefetcher: bool = False,
360+
normalize: bool = True,
345361
separate: bool = False,
346362
):
347363
"""
@@ -373,6 +389,7 @@ def create_transform(
373389
crop_border_pixels: Inference crop border of specified # pixels around edge of original image.
374390
tf_preprocessing: Use TF 1.0 inference preprocessing for testing model ports
375391
use_prefetcher: Pre-fetcher enabled. Do not convert image to tensor or normalize.
392+
normalize: Normalization tensor output w/ provided mean/std (if prefetcher not used).
376393
separate: Output transforms in 3-stage tuple.
377394
378395
Returns:
@@ -397,9 +414,10 @@ def create_transform(
397414
transform = transforms_noaug_train(
398415
img_size,
399416
interpolation=interpolation,
400-
use_prefetcher=use_prefetcher,
401417
mean=mean,
402418
std=std,
419+
use_prefetcher=use_prefetcher,
420+
normalize=normalize,
403421
)
404422
elif is_training:
405423
transform = transforms_imagenet_train(
@@ -415,26 +433,28 @@ def create_transform(
415433
gaussian_blur_prob=gaussian_blur_prob,
416434
auto_augment=auto_augment,
417435
interpolation=interpolation,
418-
use_prefetcher=use_prefetcher,
419436
mean=mean,
420437
std=std,
421438
re_prob=re_prob,
422439
re_mode=re_mode,
423440
re_count=re_count,
424441
re_num_splits=re_num_splits,
442+
use_prefetcher=use_prefetcher,
443+
normalize=normalize,
425444
separate=separate,
426445
)
427446
else:
428447
assert not separate, "Separate transforms not supported for validation preprocessing"
429448
transform = transforms_imagenet_eval(
430449
img_size,
431450
interpolation=interpolation,
432-
use_prefetcher=use_prefetcher,
433451
mean=mean,
434452
std=std,
435453
crop_pct=crop_pct,
436454
crop_mode=crop_mode,
437455
crop_border_pixels=crop_border_pixels,
456+
use_prefetcher=use_prefetcher,
457+
normalize=normalize,
438458
)
439459

440460
return transform

0 commit comments

Comments
 (0)