Skip to content

Commit 31b4919

Browse files
committed
added test code
1 parent 06daf7b commit 31b4919

File tree

4 files changed

+99
-2
lines changed

4 files changed

+99
-2
lines changed

commit0/cli.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from pathlib import Path
33
from typing import Union, List
44
from typing_extensions import Annotated
5+
import commit0.harness.batch_run_pytest_ids
56
import commit0.harness.run_pytest_ids
67
import commit0.harness.get_pytest_ids
78
import commit0.harness.build
@@ -300,6 +301,49 @@ def test(
300301
)
301302

302303

304+
@commit0_app.command()
305+
def batch_test(
306+
test_ids: str = typer.Argument(
307+
None,
308+
help='All ways pytest supports to run and select tests. Please provide a single string. Example: "test_mod.py", "testing/", "test_mod.py::test_func", "-k \'MyClass and not method\'"',
309+
),
310+
backend: str = typer.Option("modal", help="Backend to use for testing"),
311+
timeout: int = typer.Option(1800, help="Timeout for tests in seconds"),
312+
num_cpus: int = typer.Option(1, help="Number of CPUs to use"),
313+
reference: Annotated[
314+
bool, typer.Option("--reference", help="Test the reference commit")
315+
] = False,
316+
coverage: Annotated[
317+
bool, typer.Option("--coverage", help="Whether to get coverage information")
318+
] = False,
319+
rebuild: bool = typer.Option(
320+
False, "--rebuild", help="Whether to rebuild an image"
321+
),
322+
commit0_config_file: str = typer.Option(
323+
".commit0.yaml",
324+
help="Path to the commit0 dot file, where the setup config is stored",
325+
),
326+
verbose: int = typer.Option(
327+
1,
328+
"--verbose",
329+
"-v",
330+
help="Set this to 2 for more logging information",
331+
count=True,
332+
),
333+
) -> None:
334+
"""Run tests on a Commit0 repository."""
335+
commit0.harness.batch_run_pytest_ids.main(
336+
test_ids,
337+
reference,
338+
coverage,
339+
backend,
340+
timeout,
341+
num_cpus,
342+
rebuild,
343+
verbose,
344+
)
345+
346+
303347
@commit0_app.command()
304348
def evaluate(
305349
branch: Union[str, None] = typer.Option(

examples/star/run.sh

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
python examples/star/star.py \
22
--model_name_or_path meta-llama/Llama-3.1-8B-Instruct \
33
--dataset_name commit0/mbpp \
4-
-n 10 \
4+
-n 100 \
55
--output_dir outputs \
66
--low_cpu_mem_usage \
77
--with_tracking \
@@ -10,5 +10,6 @@ python examples/star/star.py \
1010
--learning_rate 1e-6 \
1111
--per_device_train_batch_size 1 \
1212
--gradient_accumulation_steps 8 \
13-
--max_workers 64
13+
--max_workers 64 \
14+
--temperature 1
1415

examples/star/test.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Get test accuracy"""
2+
3+
from datasets import load_dataset
4+
from examples.star.inference import generate_predictions
5+
from examples.star.utils import (
6+
execute_tests,
7+
generate_prompt,
8+
parse_args,
9+
)
10+
11+
12+
def main() -> None:
13+
args = parse_args()
14+
ds = load_dataset(args.dataset_name, args.dataset_config_name)['test']
15+
model_name = args.model_name_or_path
16+
17+
# sample
18+
all_samples = generate_predictions(
19+
model_name, ds, args.temperature, args.n
20+
)
21+
ds.add_column(name="sample", column=all_samples).to_json(
22+
f"{args.output_dir}/data/{model_name.split('/')[-1]}-test-samples.json"
23+
)
24+
assert len(ds) == len(all_samples)
25+
26+
# verify and construct the training set
27+
all_traces, all_execution_results = execute_tests(
28+
ds, all_samples, max_workers=args.max_workers
29+
)
30+
passed = 0
31+
for example, execution_results, samples in zip(
32+
ds, all_execution_results, all_samples
33+
):
34+
for execution_result, sample in zip(execution_results, samples):
35+
# pytest exit code: https://docs.pytest.org/en/stable/reference/exit-codes.html
36+
if execution_result == 0:
37+
passed += 1
38+
print(f"passed: {passed/len(ds)}")
39+
40+
if __name__ == "__main__":
41+
main()
42+
43+
44+
__all__ = []

examples/star/test.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
python examples/star/test.py \
2+
--model_name_or_path $1 \
3+
--dataset_name commit0/mbpp \
4+
-n $2 \
5+
--output_dir outputs \
6+
--max_workers 64 \
7+
--temperature 0
8+

0 commit comments

Comments
 (0)