Skip to content

Commit 07c9b21

Browse files
committed
avoid loading checkpoint twice, store the transformed image for later use
1 parent bb85f1e commit 07c9b21

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

ImageNet/training_scripts/imagenet_training/classification_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from timm.models import create_model
88
from timm.data import resolve_data_config
99
from timm.data.transforms_factory import create_transform
10-
10+
import torchvision
1111

1212
parser = argparse.ArgumentParser(description='PyTorch ImageNet Inference')
1313
parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
@@ -31,10 +31,10 @@
3131
mode = args.netmode,
3232
)
3333

34-
print('Restoring model state from checkpoint...')
35-
model_weights = torch.load(args.weights, map_location='cpu')
36-
model.load_state_dict(model_weights)
37-
model.eval()
34+
# print('Restoring model state from checkpoint...')
35+
# model_weights = torch.load(args.weights, map_location='cpu')
36+
# model.load_state_dict(model_weights)
37+
# model.eval()
3838

3939
if args.jit:
4040
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
@@ -47,6 +47,8 @@
4747
filename = "./misc_files/dog.jpg"
4848
img = Image.open(filename).convert('RGB')
4949
tensor = transform(img).unsqueeze(0)
50+
# save the transformed image for visualization or testing the ported models
51+
torchvision.utils.save_image(tensor.squeeze(0),'img_test_transformed.jpg')
5052

5153
with torch.no_grad():
5254
out = model(tensor)

0 commit comments

Comments
 (0)