Skip to content

Commit aaab123

Browse files
WauplinAngledLuffa
authored andcommitted
Update hugging_stanza.py to use HfApi instead of Repository
1 parent 6103793 commit aaab123

File tree

1 file changed

+25
-55
lines changed

1 file changed

+25
-55
lines changed

hugging_stanza.py

Lines changed: 25 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
import datetime
1212
import os
1313
import shutil
14+
from pathlib import Path
1415

1516
from stanza.resources.common import list_available_languages
1617
from stanza.models.common.constant import lcode2lang, lang2lcode
1718

18-
from huggingface_hub import Repository, HfApi, HfFolder
19+
from huggingface_hub import HfApi
1920

2021
def get_model_card(lang):
2122
now = datetime.datetime.utcnow().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
@@ -41,14 +42,6 @@ def get_model_card(lang):
4142
""".format(short_lang=short_lang, lang_text=lang_text, now=now)
4243
return model_card
4344

44-
def write_model_card(repo_local_path, model):
45-
"""
46-
Write a README for the current model to the given path
47-
"""
48-
readme_path = os.path.join(repo_local_path, "README.md")
49-
with open(readme_path, "w") as f:
50-
f.write(get_model_card(model))
51-
5245
def parse_args():
5346
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
5447
parser.add_argument('--input_dir', type=str, default="/u/nlp/software/stanza/models/", help='Directory for loading the stanza models. Will first try input_dir + version, if that exists')
@@ -61,22 +54,15 @@ def parse_args():
6154
args.lang = list_available_languages()
6255
return args
6356

64-
def copytree(src, dst):
65-
if os.path.exists(dst):
66-
print(f"Cleaning up existing {dst}")
67-
shutil.rmtree(dst)
68-
# copy all of the models for this subdir
69-
print(f"Copying models from {src} to {dst}")
70-
shutil.copytree(src, dst)
71-
72-
7357
def push_to_hub():
7458
args = parse_args()
7559
input_dir = args.input_dir
7660
if os.path.exists(input_dir + args.version):
7761
input_dir = input_dir + args.version
7862
print("Found directory in %s - using that instead of %s" % (input_dir, args.input_dir))
7963

64+
new_tag_name = "v" + args.version
65+
8066
api = HfApi()
8167

8268
print("Processing languages: {}".format(args.lang))
@@ -87,50 +73,34 @@ def push_to_hub():
8773
repo_id = "stanfordnlp/" + repo_name
8874
repo_url = api.create_repo(
8975
repo_id=repo_id,
90-
token=HfFolder.get_token(),
91-
exist_ok=True,
76+
exist_ok=True
9277
)
9378

94-
# Clone the repository
95-
repo_local_path = os.path.join("hub", repo_name)
96-
97-
repo = Repository(repo_local_path, clone_from=repo_url)
98-
# checkout "main" so that we know we are tracking files correctly
99-
repo.git_checkout("main")
100-
if not repo.is_repo_clean():
101-
print(f"Repo {repo_local_path} is currently not clean. Unwilling to proceed...")
102-
break
103-
repo.git_pull(rebase=True)
104-
105-
# Make sure jar files are tracked with LFS
106-
repo.lfs_track(["*.zip"])
107-
repo.lfs_track(["*.pt"])
108-
repo.push_to_hub(commit_message="Update tracked files", clean_ok=True)
109-
110-
dst = os.path.join(repo_local_path, "models")
111-
src = os.path.join(input_dir, model)
112-
if not os.path.exists(src):
79+
# Find src folder
80+
src = Path(input_dir) / model
81+
if not src.exists():
11382
if not input_dir:
11483
raise FileNotFoundError(f"Could not find models under {src}. Perhaps you forgot to set --input_dir?")
11584
else:
11685
raise FileNotFoundError(f"Could not find models under {src}")
117-
copytree(src, dst)
118-
119-
# Create the model card
120-
write_model_card(repo_local_path, model)
121-
122-
# Push the model
123-
# note: the error of not having anything to push will hopefully
124-
# never happen since the README is updated to the millisecond
125-
print("Pushing files to the Hub")
126-
repo.push_to_hub(commit_message=f"Add model {args.version}")
127-
128-
tag = "v" + args.version
129-
if repo.tag_exists(tag, remote=repo_url):
130-
repo.delete_tag(tag, remote=repo_url)
131-
repo.add_tag(tag_name=tag, message=f"Adding new version of models {tag}", remote=repo_url)
132-
print(f"Added a tag for the new models: {tag}")
13386

87+
# Update model card in it
88+
(src / "README.md").write_text(get_model_card(model))
89+
90+
# Upload model + model card
91+
# setting delete_patterns will clean up old model files as we go
92+
api.upload_folder(repo_id=repo_id, folder_path=src, commit_message=f"Add model {args.version}", delete_patterns="*.pt")
93+
94+
# Check and delete tag if already exist
95+
refs = api.list_repo_refs(repo_id=repo_id)
96+
for tag in refs.tags:
97+
if tag.name == new_tag_name:
98+
api.delete_tag(repo_id=repo_id, tag=new_tag_name)
99+
break
100+
101+
# Tag model version
102+
api.create_tag(repo_id=repo_id, tag=new_tag_name, tag_message=f"Adding new version of models {new_tag_name}")
103+
print(f"Added a tag for the new models: {new_tag_name}")
134104
print(f"View your model in:\n {repo_url}\n\n")
135105

136106
if __name__ == '__main__':

0 commit comments

Comments
 (0)