diff --git a/axbench/sweep/aryaman/simple.yaml b/axbench/sweep/aryaman/simple.yaml index 75fcaec..eff9c11 100644 --- a/axbench/sweep/aryaman/simple.yaml +++ b/axbench/sweep/aryaman/simple.yaml @@ -16,12 +16,12 @@ train: seed: 42 use_bf16: true models: - ReFT: + ReAX: batch_size: 6 n_epochs: 12 - topk: 1 + k_latent_null_loss: 1 lr: 0.003 - coeff_l1_loss_null: 0.05 + coeff_l1_loss_null: 0.000 coeff_l1_loss: 0.000 coeff_norm_loss: 0.000 DiffMean: @@ -36,7 +36,7 @@ train: inference: use_bf16: true - models: ["LAT", "ReFT", "DiffMean", "PromptSteering"] + models: ["LAT", "ReAX", "DiffMean", "PromptSteering"] model_name: "google/gemma-2-2b" # latent related params input_length: 32 @@ -57,7 +57,7 @@ inference: temperature: 0.7 evaluate: - models: ["LAT", "ReFT", "DiffMean", "PromptSteering"] + models: ["LAT", "ReAX", "DiffMean", "PromptSteering"] latent_evaluators: [ "AUCROCEvaluator", "HardNegativeEvaluator", @@ -73,5 +73,4 @@ evaluate: run_winrate: true winrate_baseline: "PromptSteering" # master data dir is shared across all jobs. - master_data_dir: "axbench/data" - + master_data_dir: "axbench/data" \ No newline at end of file diff --git a/axbench/utils/dataset.py b/axbench/utils/dataset.py index 3797c04..07eb5fd 100644 --- a/axbench/utils/dataset.py +++ b/axbench/utils/dataset.py @@ -188,95 +188,38 @@ def create_eval_df(self, concepts, subset_n, concept_genres_map, train_contrast_ self.logger.warning(f"Finished creating current dataframe in {round(time.time() - start, 3)} sec.") return df + def create_train_df(self, concepts, n, concept_genres_map, contrast_concepts_map, **kwargs): concept2id = {concept: i for i, concept in enumerate(concepts)} lm_model, model, tokenizer = self.lm_model, self.model, self.tokenizer start = time.time() self.logger.warning("Creating dataframe.") - n_per_concept = n // (len(concepts) + 1) + n_per_concept = n // len(concepts) all_examples = [] input_length = kwargs.get("input_length", 32) output_length = kwargs.get("output_length", 10) - - # for each concept, we create a set of seed random content. - concepts_random_content = get_random_content( - self.seed_sentences, tokenizer=tokenizer, count=3*n, - genres=concept_genres_map, concepts=concepts, length=input_length, split="train" - ) - - # for concepts with polysemantic senses, we create additional examples. - polysemantic_tasks = [] - for concept in concepts: - if len(contrast_concepts_map[concept]) != 0: - count = n_per_concept // (len(concepts)*2) - polysemantic_concepts = [random.choice(contrast_concepts_map[concept]) for _ in range(count)] - polysemantic_tasks.append(modify_content_with_polysemantic_concepts( - client=lm_model, tokenizer=tokenizer, - polysemantic_concepts=polysemantic_concepts, - concept=concept, content=concepts_random_content[concept][:n], - length=input_length - )) - polysemantic_content = asyncio.run(run_tasks(polysemantic_tasks)) - polysemantic_content = {content[0]: content[1] for content in polysemantic_content} - - # aggregate these null examples. - null_prompts = [] - for concept in concepts: - n_random = (n_per_concept // len(concepts)) if len(contrast_concepts_map[concept]) == 0 else n_per_concept // (len(concepts)*2) - for content in concepts_random_content[concept][n:n+n_random]: - null_prompts.append( - Prompt(concept=concept, tag="empty", content=content)) - if len(contrast_concepts_map[concept]) != 0: - for content in polysemantic_content[concept]: - null_prompts.append( - Prompt(concept=concept, tag=f"{content[0][0]}//{content[0][1]}", - content=content[1])) - - # get continuations from STEERED MODEL (not datagen model) - null_outputs = get_model_continues( - model=model, tokenizer=tokenizer, prompts=[p.content for p in null_prompts], - max_new_tokens=int(output_length*1.5) - ) - # Save control examples - for prompt, output in zip(null_prompts, null_outputs): - in_idx = concept2id[prompt.concept] - out_idx = sample_index_exclude(len(concepts), in_idx) - all_examples += [[ - prompt.content, output, EXAMPLE_TAG.CONTROL.value, - in_idx, out_idx, prompt.tag, "empty", - ]] - - # modify exist content to have desired concepts. - modify_prompts = [] + # for each concept, random seed content + steered continuation for concept in concepts: - for prompt in concepts_random_content[concept][2*n:2*n+len(null_prompts)]: - modify_prompts.append( - Prompt(concept=concept, tag="empty", content=prompt)) # include source content ID - modify_task = modify_content_with_concept( - client=lm_model, tokenizer=tokenizer, - content=[(p.concept, p.tag, p.content) for p in modify_prompts], # keep the same interface - length=input_length - ) - concept_prompts = asyncio.run(run_tasks([modify_task]))[0] - - # process experiment examples with content tracking - inverse_concepts = [concepts[sample_index_exclude(len(concepts), concept2id[prompt[0]])] - for prompt in modify_prompts] - continue_task = continue_with_concept( - self.lm_model, self.tokenizer, - concepts=inverse_concepts, content=concept_prompts, length=output_length) - concept_outputs = asyncio.run(run_tasks([continue_task]))[0] - - for i, (prompt, output) in enumerate(zip(concept_prompts, concept_outputs)): - in_idx = concept2id[modify_prompts[i][0]] - out_idx = concept2id[inverse_concepts[i]] - all_examples += [[ - prompt, output, EXAMPLE_TAG.EXPERIMENT.value, - in_idx, out_idx, modify_prompts[i][0], inverse_concepts[i], - ]] + logger.warning(f"Working on: {concept}") + prompts = get_random_content( + self.seed_sentences, tokenizer=tokenizer, count=n_per_concept, + genres=concept_genres_map, concepts=concepts, length=input_length, split="train" + )[concept] + logger.warning(f"Made {len(prompts)} prompts") + outputs_task = continue_with_concept( + self.lm_model, self.tokenizer, + concepts=[concept for _ in range(len(prompts))], content=prompts, length=output_length) + outputs = asyncio.run(run_tasks([outputs_task]))[0] + for i, (prompt, output) in enumerate(zip(prompts, outputs)): + in_idx = 0 + out_idx = concept2id[concept] + all_examples += [[ + prompt, output, EXAMPLE_TAG.EXPERIMENT.value, + in_idx, out_idx, "empty", concept, + ]] # update the column definitions of the DataFrame df = pd.DataFrame(