Skip to content

Commit 06effd8

Browse files
committed
update avg_checkpoints.py to include existing averaged wights as well
1 parent 07c9b21 commit 06effd8

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

ImageNet/training_scripts/imagenet_training/avg_checkpoints.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import glob
1818
import hashlib
1919
from timm.models.helpers import load_state_dict
20+
from validate import validate
2021

2122
parser = argparse.ArgumentParser(description='PyTorch Checkpoint Averager')
2223
parser.add_argument('--input', default='', type=str, metavar='PATH', help='path to base input folder containing checkpoints')
@@ -25,7 +26,7 @@
2526
parser.add_argument('--no-use-ema', dest='no_use_ema', action='store_true', help='Force not using ema version of weights (if present)')
2627
parser.add_argument('--no-sort', dest='no_sort', action='store_true', help='Do not sort and select by checkpoint metric, also makes "n" argument irrelevant')
2728
parser.add_argument('-n', type=int, default=10, metavar='N', help='Number of checkpoints to average')
28-
29+
parser.add_argument('--avg_weights', default='', type=str, metavar='PATH',help='avg fmodel filepath')
2930

3031
def checkpoint_metric(checkpoint_path):
3132
if not checkpoint_path or not os.path.isfile(checkpoint_path):
@@ -50,24 +51,38 @@ def main():
5051
args.sort = not args.no_sort
5152

5253
if os.path.exists(args.output):
53-
print("Error: Output filename ({}) already exists.".format(args.output))
54-
exit(1)
54+
with open(args.output, 'rb') as f:
55+
sha_hash = hashlib.sha256(f.read()).hexdigest()
56+
print(f'{args.output}')
57+
name,ext = os.path.splitext(args.output)
58+
new_name = f'{name}_{str(sha_hash)[-10:]}{ext}'
59+
os.rename(args.output, new_name)
60+
print(f'renamed "{args.output}" to "{new_name}"')
5561

5662
pattern = args.input
5763
if not args.input.endswith(os.path.sep) and not args.filter.startswith(os.path.sep):
5864
pattern += os.path.sep
5965
pattern += args.filter
6066
checkpoints = glob.glob(pattern, recursive=True)
61-
67+
print(f'checkpoints: {checkpoints}')
68+
6269
if args.sort:
6370
checkpoint_metrics = []
6471
for c in checkpoints:
6572
metric = checkpoint_metric(c)
6673
if metric is not None:
6774
checkpoint_metrics.append((metric, c))
75+
76+
if args.avg_weights:
77+
if os.path.exists(args.avg_weights):
78+
checkpoint = torch.load(args.avg_weights, map_location='cpu')
79+
acc = float(args.avg_weights.split('_')[-1].split('.pth')[0])
80+
checkpoint_metrics.append((acc, args.avg_weights))
81+
else:
82+
print(f'FILE DOESNT EXIST!')
6883
checkpoint_metrics = list(sorted(checkpoint_metrics))
6984
checkpoint_metrics = checkpoint_metrics[-args.n:]
70-
print("Selected checkpoints:")
85+
print(f"Selected checkpoints:'({len(checkpoint_metrics)})'")
7186
[print(m, c) for m, c in checkpoint_metrics]
7287
avg_checkpoints = [c for m, c in checkpoint_metrics]
7388
else:

0 commit comments

Comments
 (0)