Skip to content

Commit bbdbc20

Browse files
committed
cleaned classification_test.py a bit!
1 parent d0b5217 commit bbdbc20

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

ImageNet/training_scripts/imagenet_training/classification_test.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
1414
parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset')
1515
parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
16+
parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model')
1617
parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before doing classification!')
1718
parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
1819
parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
@@ -25,23 +26,22 @@
2526
model = create_model(
2627
args.model,
2728
num_classes=args.num_classes,
29+
pretrained=args.pretrained,
2830
checkpoint_path=args.weights,
2931
scale=args.netscale,
3032
network_idx = args.netidx,
3133
mode = args.netmode,
3234
)
35+
model.eval()
3336

34-
# print('Restoring model state from checkpoint...')
35-
# model_weights = torch.load(args.weights, map_location='cpu')
36-
# model.load_state_dict(model_weights)
37-
# model.eval()
37+
if not args.pretrained and not args.weights:
38+
print(f'WARNING: No pretrained weights specified! (pretrained is False and there is no checkpoint specified!)')
3839

3940
if args.jit:
4041
dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
4142
model = torch.jit.trace(model, dummy_input)
4243

4344
config = resolve_data_config({}, model=model)
44-
print(f'config: {config}')
4545
transform = create_transform(**config)
4646

4747
filename = "./misc_files/dog.jpg"
@@ -53,13 +53,14 @@
5353
with torch.no_grad():
5454
out = model(tensor)
5555
probabilities = torch.nn.functional.softmax(out[0], dim=0)
56-
print(probabilities.shape) # prints: torch.Size([1000])
56+
print(f'{probabilities.shape}') # prints: torch.Size([1000])
5757

5858
filename="./misc_files/imagenet_classes.txt"
5959
with open(filename, "r") as f:
6060
categories = [s.strip() for s in f.readlines()]
6161

6262
# Print top categories per image
63+
print(f'Top categories:')
6364
top5_prob, top5_catid = torch.topk(probabilities, 5)
6465
for i in range(top5_prob.size(0)):
6566
print(categories[top5_catid[i]], top5_prob[i].item())

0 commit comments

Comments
 (0)