1313# limitations under the License.
1414import torch
1515import os
16- import torchvision . transforms as transforms
17- import torchvision . datasets as datasets
16+ from torchvision import transforms
17+ from torchvision import datasets
1818
1919def get_dataset (dataset_name , data_dir , split , rand_fraction = None ,clean = False , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
2020
2121 if dataset_name in [ 'cifar10' , 'cifar100' ]:
22- dataset = globals ()[f'get_{ dataset_name } ' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
22+ dataset = globals ()[f'get_{ dataset_name } ' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
2323 elif dataset_name in [ 'cifar10vit224' , 'cifar100vit224' ,'cifar10vit384' , 'cifar100vit384' ,]:
2424 imsize = int (dataset_name .split ('vit' )[- 1 ])
2525 dataset_name = dataset_name .split ('vit' )[0 ]
2626 #print ('here')
27- dataset = globals ()['get_cifar_vit' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
27+ dataset = globals ()['get_cifar_vit' ](dataset_name , data_dir , split , imsize = imsize , bucket = bucket , ** kwargs )
2828 else :
2929 assert 'cifar' in dataset_name
3030 print (dataset_name )
@@ -59,10 +59,10 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
5959 if transform is None :
6060 if normalize is None :
6161 if aug == 'large' :
62-
62+
6363 normalize = transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ], std = [0.229 , 0.224 , 0.225 ])
6464 else :
65- normalize = transforms .Normalize (mean = [0.4914 , 0.4822 , 0.4465 ], std = [0.2023 , 0.1994 , 0.2010 ])
65+ normalize = transforms .Normalize (mean = [0.4914 , 0.4822 , 0.4465 ], std = [0.2023 , 0.1994 , 0.2010 ])
6666 transform = transforms .Compose (get_aug (split , imsize = imsize , aug = aug )
6767 + [transforms .ToTensor (), normalize ])
6868 return transform
@@ -71,7 +71,7 @@ def get_transform(split, normalize=None, transform=None, imsize=None, aug='large
7171def get_cifar10 (dataset_name , data_dir , split , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
7272 if imsize == 224 :
7373 transform = get_transform (split , transform = transform , imsize = imsize , aug = 'large' )
74- else :
74+ else :
7575 transform = get_transform (split , transform = transform , imsize = imsize , aug = 'small' )
7676 return datasets .CIFAR10 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
7777
@@ -88,7 +88,7 @@ def get_cifar100N(dataset_name, data_dir, split, rand_fraction=None,transform=No
8888 if split == 'train' :
8989 return CIFAR100N (root = data_dir , train = (split == 'train' ), transform = transform , download = True , rand_fraction = rand_fraction )
9090 else :
91- return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
91+ return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform , download = True , ** kwargs )
9292
9393def get_cifar_vit (dataset_name , data_dir , split , transform = None , imsize = None , bucket = 'pytorch-data' , ** kwargs ):
9494 if imsize == 224 :
@@ -111,12 +111,12 @@ def get_cifar_vit(dataset_name, data_dir, split, transform=None, imsize=None, bu
111111 if dataset_name == 'cifar10' :
112112 return datasets .CIFAR10 (data_dir , train = (split == 'train' ), transform = transform_data , download = True , ** kwargs )
113113 elif dataset_name == 'cifar100' :
114-
114+
115115 return datasets .CIFAR100 (data_dir , train = (split == 'train' ), transform = transform_data , download = True , ** kwargs )
116116 else :
117117 assert dataset_name in ['cifar10' , 'cifar100' ]
118118 else :
119-
119+
120120 if split == 'train' :
121121 transform_data = transforms .Compose ([# transforms.ColorJitter(brightness= 0.4, contrast= 0.4, saturation= 0.4, hue= 0.1),
122122 transforms .Resize (imsize ),
@@ -164,4 +164,4 @@ def get_imagenet_vit(dataset_name, data_dir, split, transform=None, imsize=None,
164164 #return torch.utils.data.distributed.DistributedSampler(train_dataset)
165165 else :
166166 return datasets .ImageFolder (valdir , transform_data )
167- #Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
167+ #Ereturn torch.utils.data.distributed.DistributedSampler(val_dataset, shuffle=False, drop_last=True)
0 commit comments