1111import datetime
1212import os
1313import shutil
14+ from pathlib import Path
1415
1516from stanza .resources .common import list_available_languages
1617from stanza .models .common .constant import lcode2lang , lang2lcode
1718
18- from huggingface_hub import Repository , HfApi , HfFolder
19+ from huggingface_hub import HfApi
1920
2021def get_model_card (lang ):
2122 now = datetime .datetime .utcnow ().strftime ("%Y-%m-%d %H:%M:%S.%f" )[:- 3 ]
@@ -41,14 +42,6 @@ def get_model_card(lang):
4142""" .format (short_lang = short_lang , lang_text = lang_text , now = now )
4243 return model_card
4344
44- def write_model_card (repo_local_path , model ):
45- """
46- Write a README for the current model to the given path
47- """
48- readme_path = os .path .join (repo_local_path , "README.md" )
49- with open (readme_path , "w" ) as f :
50- f .write (get_model_card (model ))
51-
5245def parse_args ():
5346 parser = argparse .ArgumentParser (formatter_class = argparse .ArgumentDefaultsHelpFormatter )
5447 parser .add_argument ('--input_dir' , type = str , default = "/u/nlp/software/stanza/models/" , help = 'Directory for loading the stanza models. Will first try input_dir + version, if that exists' )
@@ -61,22 +54,15 @@ def parse_args():
6154 args .lang = list_available_languages ()
6255 return args
6356
64- def copytree (src , dst ):
65- if os .path .exists (dst ):
66- print (f"Cleaning up existing { dst } " )
67- shutil .rmtree (dst )
68- # copy all of the models for this subdir
69- print (f"Copying models from { src } to { dst } " )
70- shutil .copytree (src , dst )
71-
72-
7357def push_to_hub ():
7458 args = parse_args ()
7559 input_dir = args .input_dir
7660 if os .path .exists (input_dir + args .version ):
7761 input_dir = input_dir + args .version
7862 print ("Found directory in %s - using that instead of %s" % (input_dir , args .input_dir ))
7963
64+ new_tag_name = "v" + args .version
65+
8066 api = HfApi ()
8167
8268 print ("Processing languages: {}" .format (args .lang ))
@@ -87,50 +73,34 @@ def push_to_hub():
8773 repo_id = "stanfordnlp/" + repo_name
8874 repo_url = api .create_repo (
8975 repo_id = repo_id ,
90- token = HfFolder .get_token (),
91- exist_ok = True ,
76+ exist_ok = True
9277 )
9378
94- # Clone the repository
95- repo_local_path = os .path .join ("hub" , repo_name )
96-
97- repo = Repository (repo_local_path , clone_from = repo_url )
98- # checkout "main" so that we know we are tracking files correctly
99- repo .git_checkout ("main" )
100- if not repo .is_repo_clean ():
101- print (f"Repo { repo_local_path } is currently not clean. Unwilling to proceed..." )
102- break
103- repo .git_pull (rebase = True )
104-
105- # Make sure jar files are tracked with LFS
106- repo .lfs_track (["*.zip" ])
107- repo .lfs_track (["*.pt" ])
108- repo .push_to_hub (commit_message = "Update tracked files" , clean_ok = True )
109-
110- dst = os .path .join (repo_local_path , "models" )
111- src = os .path .join (input_dir , model )
112- if not os .path .exists (src ):
79+ # Find src folder
80+ src = Path (input_dir ) / model
81+ if not src .exists ():
11382 if not input_dir :
11483 raise FileNotFoundError (f"Could not find models under { src } . Perhaps you forgot to set --input_dir?" )
11584 else :
11685 raise FileNotFoundError (f"Could not find models under { src } " )
117- copytree (src , dst )
118-
119- # Create the model card
120- write_model_card (repo_local_path , model )
121-
122- # Push the model
123- # note: the error of not having anything to push will hopefully
124- # never happen since the README is updated to the millisecond
125- print ("Pushing files to the Hub" )
126- repo .push_to_hub (commit_message = f"Add model { args .version } " )
127-
128- tag = "v" + args .version
129- if repo .tag_exists (tag , remote = repo_url ):
130- repo .delete_tag (tag , remote = repo_url )
131- repo .add_tag (tag_name = tag , message = f"Adding new version of models { tag } " , remote = repo_url )
132- print (f"Added a tag for the new models: { tag } " )
13386
87+ # Update model card in it
88+ (src / "README.md" ).write_text (get_model_card (model ))
89+
90+ # Upload model + model card
91+ # setting delete_patterns will clean up old model files as we go
92+ api .upload_folder (repo_id = repo_id , folder_path = src , commit_message = f"Add model { args .version } " , delete_patterns = "*.pt" )
93+
94+ # Check and delete tag if already exist
95+ refs = api .list_repo_refs (repo_id = repo_id )
96+ for tag in refs .tags :
97+ if tag .name == new_tag_name :
98+ api .delete_tag (repo_id = repo_id , tag = new_tag_name )
99+ break
100+
101+ # Tag model version
102+ api .create_tag (repo_id = repo_id , tag = new_tag_name , tag_message = f"Adding new version of models { new_tag_name } " )
103+ print (f"Added a tag for the new models: { new_tag_name } " )
134104 print (f"View your model in:\n { repo_url } \n \n " )
135105
136106if __name__ == '__main__' :
0 commit comments