|
17 | 17 | # model_name = "microsoft/DialoGPT-small" |
18 | 18 | tokenizer = AutoTokenizer.from_pretrained(model_name) |
19 | 19 | model = AutoModelForCausalLM.from_pretrained(model_name) |
20 | | - |
| 20 | +print("====Greedy search chat====") |
21 | 21 | # chatting 5 times with greedy search |
22 | 22 | for step in range(5): |
23 | 23 | # take user input |
|
35 | 35 | #print the output |
36 | 36 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
37 | 37 | print(f"DialoGPT: {output}") |
38 | | - |
| 38 | +print("====Beam search chat====") |
39 | 39 | # chatting 5 times with beam search |
40 | 40 | for step in range(5): |
41 | 41 | # take user input |
|
55 | 55 | #print the output |
56 | 56 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
57 | 57 | print(f"DialoGPT: {output}") |
58 | | - |
| 58 | +print("====Sampling chat====") |
59 | 59 | # chatting 5 times with sampling |
60 | 60 | for step in range(5): |
61 | 61 | # take user input |
|
75 | 75 | #print the output |
76 | 76 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
77 | 77 | print(f"DialoGPT: {output}") |
78 | | - |
| 78 | +print("====Sampling chat with tweaking temperature====") |
79 | 79 | # chatting 5 times with sampling & tweaking temperature |
80 | 80 | for step in range(5): |
81 | 81 | # take user input |
|
96 | 96 | #print the output |
97 | 97 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
98 | 98 | print(f"DialoGPT: {output}") |
99 | | - |
| 99 | +print("====Top-K sampling chat with tweaking temperature====") |
100 | 100 | # chatting 5 times with Top K sampling & tweaking temperature |
101 | 101 | for step in range(5): |
102 | 102 | # take user input |
|
117 | 117 | #print the output |
118 | 118 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
119 | 119 | print(f"DialoGPT: {output}") |
120 | | - |
| 120 | +print("====Nucleus sampling (top-p) chat with tweaking temperature====") |
121 | 121 | # chatting 5 times with nucleus sampling & tweaking temperature |
122 | 122 | for step in range(5): |
123 | 123 | # take user input |
|
139 | 139 | #print the output |
140 | 140 | output = tokenizer.decode(chat_history_ids[:, bot_input_ids.shape[-1]:][0], skip_special_tokens=True) |
141 | 141 | print(f"DialoGPT: {output}") |
142 | | - |
| 142 | +print("====chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple sentences====") |
143 | 143 | # chatting 5 times with nucleus & top-k sampling & tweaking temperature & multiple |
144 | 144 | # sentences |
145 | 145 | for step in range(5): |
|
155 | 155 | max_length=1000, |
156 | 156 | do_sample=True, |
157 | 157 | top_p=0.95, |
158 | | - top_k=50,Y |
| 158 | + top_k=50, |
159 | 159 | temperature=0.75, |
160 | 160 | num_return_sequences=5, |
161 | 161 | pad_token_id=tokenizer.eos_token_id |
|
0 commit comments