Description
There is a FIXME comment at line 103 of torchvision/transforms/autoaugment.py:
# FIXME: Eliminate copy-pasted code for fill standardization and _augmentation_space() by moving stuff on a base class
Four classes (AutoAugment, RandAugment, TrivialAugmentWide, AugMix) all independently repeat the same fill-standardization logic in their forward() methods:
fill = self.fill
channels, height, width = F.get_dimensions(img)
if isinstance(img, Tensor):
if isinstance(fill, (int, float)):
fill = [float(fill)] * channels
elif fill is not None:
fill = [float(f) for f in fill]
They also independently inherit from torch.nn.Module and define interpolation and fill in their __init__.
Proposal
Create an _AutoAugmentBase class (like v2 already did in torchvision/transforms/v2/_auto_augment.py) that:
- Inherits from
torch.nn.Module
- Holds common
__init__ params: interpolation and fill
- Provides a
_get_fill() helper to eliminate the duplicated fill-standardization logic
This is a pure internal refactor — no public API changes, no behavior changes. All existing tests should pass as-is.
Note: _augmentation_space() is intentionally not unified since each class uses different signatures and contents.
cc @pmeier
Description
There is a FIXME comment at line 103 of
torchvision/transforms/autoaugment.py:Four classes (
AutoAugment,RandAugment,TrivialAugmentWide,AugMix) all independently repeat the same fill-standardization logic in theirforward()methods:They also independently inherit from
torch.nn.Moduleand defineinterpolationandfillin their__init__.Proposal
Create an
_AutoAugmentBaseclass (like v2 already did intorchvision/transforms/v2/_auto_augment.py) that:torch.nn.Module__init__params:interpolationandfill_get_fill()helper to eliminate the duplicated fill-standardization logicThis is a pure internal refactor — no public API changes, no behavior changes. All existing tests should pass as-is.
Note:
_augmentation_space()is intentionally not unified since each class uses different signatures and contents.cc @pmeier