|
13 | 13 | parser.add_argument('--model', '-m', metavar='MODEL', default='simpnet', help='model architecture (default: simpnet)')
|
14 | 14 | parser.add_argument('--num-classes', type=int, default=1000, help='Number classes in dataset')
|
15 | 15 | parser.add_argument('--weights', default='', type=str, metavar='PATH', help='path to model weights (default: none)')
|
| 16 | +parser.add_argument('--pretrained', dest='pretrained', action='store_true', help='use pre-trained model') |
16 | 17 | parser.add_argument('--jit', action='store_true', default=False, help='convert the model to jit before doing classification!')
|
17 | 18 | parser.add_argument('--netscale', type=float, default=1.0, help='scale of the net (default 1.0)')
|
18 | 19 | parser.add_argument('--netidx', type=int, default=0, help='which network to use (5mil or 8mil)')
|
|
25 | 26 | model = create_model(
|
26 | 27 | args.model,
|
27 | 28 | num_classes=args.num_classes,
|
| 29 | + pretrained=args.pretrained, |
28 | 30 | checkpoint_path=args.weights,
|
29 | 31 | scale=args.netscale,
|
30 | 32 | network_idx = args.netidx,
|
31 | 33 | mode = args.netmode,
|
32 | 34 | )
|
| 35 | +model.eval() |
33 | 36 |
|
34 |
| -# print('Restoring model state from checkpoint...') |
35 |
| -# model_weights = torch.load(args.weights, map_location='cpu') |
36 |
| -# model.load_state_dict(model_weights) |
37 |
| -# model.eval() |
| 37 | +if not args.pretrained and not args.weights: |
| 38 | + print(f'WARNING: No pretrained weights specified! (pretrained is False and there is no checkpoint specified!)') |
38 | 39 |
|
39 | 40 | if args.jit:
|
40 | 41 | dummy_input = torch.randn(1, 3, 224, 224, device="cpu")
|
41 | 42 | model = torch.jit.trace(model, dummy_input)
|
42 | 43 |
|
43 | 44 | config = resolve_data_config({}, model=model)
|
44 |
| -print(f'config: {config}') |
45 | 45 | transform = create_transform(**config)
|
46 | 46 |
|
47 | 47 | filename = "./misc_files/dog.jpg"
|
|
53 | 53 | with torch.no_grad():
|
54 | 54 | out = model(tensor)
|
55 | 55 | probabilities = torch.nn.functional.softmax(out[0], dim=0)
|
56 |
| -print(probabilities.shape) # prints: torch.Size([1000]) |
| 56 | +print(f'{probabilities.shape}') # prints: torch.Size([1000]) |
57 | 57 |
|
58 | 58 | filename="./misc_files/imagenet_classes.txt"
|
59 | 59 | with open(filename, "r") as f:
|
60 | 60 | categories = [s.strip() for s in f.readlines()]
|
61 | 61 |
|
62 | 62 | # Print top categories per image
|
| 63 | +print(f'Top categories:') |
63 | 64 | top5_prob, top5_catid = torch.topk(probabilities, 5)
|
64 | 65 | for i in range(top5_prob.size(0)):
|
65 | 66 | print(categories[top5_catid[i]], top5_prob[i].item())
|
|
0 commit comments