Skip to content

Commit d1769c1

Browse files
authored
Fix KeyError occurring using fine_tunes.prepare_data (openai#125)
* Initial commit * Add fix * Reinstate reset_index() * Add suggestions * Remove print stmt * punctuation * Add test for fine_tunes.prepare_data * Renamed file, added docstrings * Move comment placement
1 parent 34a1209 commit d1769c1

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import json
2+
import subprocess
3+
from tempfile import NamedTemporaryFile
4+
5+
6+
def test_long_examples_validator() -> None:
7+
8+
"""
9+
Ensures that long_examples_validator() handles previously applied recommendations,
10+
namely dropped duplicates, without resulting in a KeyError.
11+
"""
12+
13+
# data
14+
short_prompt = "a prompt "
15+
long_prompt = short_prompt * 500
16+
17+
short_completion = "a completion "
18+
long_completion = short_completion * 500
19+
20+
# the order of these matters
21+
unprepared_training_data = [
22+
{"prompt": long_prompt, "completion": long_completion}, # 1 of 2 duplicates
23+
{"prompt": short_prompt, "completion": short_completion},
24+
{"prompt": long_prompt, "completion": long_completion}, # 2 of 2 duplicates
25+
26+
]
27+
28+
with NamedTemporaryFile(suffix="jsonl", mode="w") as training_data:
29+
for prompt_completion_row in unprepared_training_data:
30+
training_data.write(json.dumps(prompt_completion_row) + "\n")
31+
training_data.flush()
32+
33+
prepared_data_cmd_output = subprocess.run(
34+
[f"openai tools fine_tunes.prepare_data -f {training_data.name}"],
35+
stdout=subprocess.PIPE,
36+
text=True,
37+
input="y\ny\ny\ny\ny", # apply all recommendations, one at a time
38+
stderr=subprocess.PIPE,
39+
encoding="utf-8",
40+
shell=True
41+
)
42+
43+
# validate data was prepared successfully
44+
assert prepared_data_cmd_output.stderr == ""
45+
# validate get_long_indexes() applied during optional_fn() call in long_examples_validator()
46+
assert "indices of the long examples has changed" in prepared_data_cmd_output.stdout
47+
48+
return prepared_data_cmd_output.stdout

openai/validators.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -158,17 +158,24 @@ def long_examples_validator(df):
158158

159159
ft_type = infer_task_type(df)
160160
if ft_type != "open-ended generation":
161-
long_examples = df.apply(
162-
lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1
163-
)
164-
long_indexes = df.reset_index().index[long_examples].tolist()
161+
def get_long_indexes(d):
162+
long_examples = d.apply(
163+
lambda x: len(x.prompt) + len(x.completion) > 10000, axis=1
164+
)
165+
return d.reset_index().index[long_examples].tolist()
166+
167+
long_indexes = get_long_indexes(df)
165168

166169
if len(long_indexes) > 0:
167170
immediate_msg = f"\n- There are {len(long_indexes)} examples that are very long. These are rows: {long_indexes}\nFor conditional generation, and for classification the examples shouldn't be longer than 2048 tokens."
168171
optional_msg = f"Remove {len(long_indexes)} long examples"
169172

170173
def optional_fn(x):
171-
return x.drop(long_indexes)
174+
175+
long_indexes_to_drop = get_long_indexes(x)
176+
if long_indexes != long_indexes_to_drop:
177+
sys.stdout.write(f"The indices of the long examples has changed as a result of a previously applied recommendation.\nThe {len(long_indexes_to_drop)} long examples to be dropped are now at the following indices: {long_indexes_to_drop}\n")
178+
return x.drop(long_indexes_to_drop)
172179

173180
return Remediation(
174181
name="long_examples",

0 commit comments

Comments
 (0)