diff --git a/axbench/scripts/inference.py b/axbench/scripts/inference.py index e3478a3..a88cf2b 100644 --- a/axbench/scripts/inference.py +++ b/axbench/scripts/inference.py @@ -15,7 +15,7 @@ import pyreax -import os, argparse, yaml, json, glob, pickle, time +import os, sys, argparse, yaml, json, glob, pickle, time import pandas as pd from tqdm.auto import tqdm import torch @@ -24,19 +24,22 @@ import atexit from pyreax import ( - EXAMPLE_TAG, + EXAMPLE_TAG, ReAXFactory, ) from args.dataset_args import DatasetArgs from transformers import set_seed -# all supported methods import axbench from axbench import SteeringDatasetFactory from openai import AsyncOpenAI import httpx, asyncio +import torch.multiprocessing as mp +from torch.multiprocessing import Queue, Process +from queue import Empty + import logging logging.basicConfig(format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(filename)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S', @@ -52,37 +55,19 @@ STEERING_EXCLUDE_MODELS = {} LATENT_EXCLUDE_MODELS = {"PromptSteering"} - def load_config(config_path): - """ - Load metadata from a JSON lines file. - """ with open(Path(config_path) / CONFIG_FILE) as f: d = json.load(f) return d - def load_state(dump_dir, mode): - """ - Load the state from a file if it exists. - - Args: - dump_dir (str): The directory to load the state file from. - - Returns: - dict: The loaded state dictionary, or None if no state file exists. - """ state_path = os.path.join(f"{dump_dir}/inference", f"{mode}_{STATE_FILE}") if os.path.exists(state_path): with open(state_path, "rb") as f: return pickle.load(f) return None - def load_metadata_flatten(metadata_path): - """ - Load flatten metadata from a JSON lines file. - """ metadata = [] group_id = 0 with open(Path(metadata_path) / METADATA_FILE, 'r') as f: @@ -99,11 +84,10 @@ def load_metadata_flatten(metadata_path): "contrast_concepts_map": {concept: contrast_concepts_map}, "group_id": group_id } - metadata += [flatten_data] # Return the metadata as is + metadata += [flatten_data] group_id += 1 return metadata - def save( dump_dir, state, concept_id, partition, current_df, rotation_freq): @@ -120,36 +104,34 @@ def save( """ dump_dir = Path(dump_dir) / "inference" dump_dir.mkdir(parents=True, exist_ok=True) - + # Save state state_path = os.path.join(dump_dir, f"{partition}_{STATE_FILE}") with open(state_path, "wb") as f: pickle.dump(state, f) - + # Save DataFrame fragment_index = concept_id // rotation_freq df_path = os.path.join(dump_dir, f"{partition}_data_fragment_{fragment_index}.parquet") - + if os.path.exists(df_path): existing_df = pd.read_parquet(df_path) combined_df = pd.concat([existing_df, current_df], ignore_index=True) else: combined_df = current_df - - combined_df.to_parquet(df_path, engine='pyarrow') + combined_df.to_parquet(df_path, engine='pyarrow') def create_data_latent(dataset_factory, metadata, concept_id, num_of_examples, args): - # prepare concept related data. concept = metadata[concept_id]["concept"] sae_link = metadata[concept_id]["ref"] group_id = metadata[concept_id]["group_id"] - sae_id = int(sae_link.split("/")[-1]) + sae_id = int(sae_link.split("/")[-1]) concept_genres_map = metadata[concept_id]["concept_genres_map"] contrast_concepts_map = metadata[concept_id]["contrast_concepts_map"] _, eval_contrast_concepts_map = \ dataset_factory.prepare_concepts( - [concept], + [concept], concept_genres_map=concept_genres_map, contrast_concepts_map=contrast_concepts_map, api_tag="inference") current_df = dataset_factory.create_eval_df( @@ -162,14 +144,12 @@ def create_data_latent(dataset_factory, metadata, concept_id, num_of_examples, a current_df["group_id"] = group_id return current_df - def create_data_steering( - dataset_factory, metadata, concept_id, num_of_examples, + dataset_factory, metadata, concept_id, num_of_examples, n_steering_factors, steering_datasets, args): - # prepare concept related data. concept = metadata[concept_id]["concept"] sae_link = metadata[concept_id]["ref"] - sae_id = int(sae_link.split("/")[-1]) + sae_id = int(sae_link.split("/")[-1]) current_df = dataset_factory.create_eval_df( [concept], num_of_examples, n_steering_factors, steering_datasets, @@ -180,9 +160,117 @@ def create_data_steering( return current_df, (concept_id, sae_link, sae_id) +def create_tokenizer(args): + tokenizer = AutoTokenizer.from_pretrained(args.steering_model_name, model_max_length=512) + tokenizer.padding_side = "right" + return tokenizer -def infer_steering(args): +def create_base_model(args, device): + logger.warning(f"Loading base model to {device}") + base_model = AutoModelForCausalLM.from_pretrained( + args.steering_model_name, + device_map=device + ) + base_model.config.use_cache = False + return base_model.eval() + +def setup_benchmark_model(model_class, base_model, tokenizer, layer, metadata, mode, args): + model = model_class( + base_model, tokenizer, layer=layer, + low_rank_dimension=len(metadata) + ).to(base_model.device) + + load_kwargs = {"mode": mode} if mode == "steering" else {} + model.load( + dump_dir=args.train_dir, + sae_path=metadata[0]["ref"], + **load_kwargs + ) + + if mode == "steering": + model.pre_compute_mean_activations( + os.path.join(args.dump_dir, "inference"), + master_data_dir=args.master_data_dir + ) + + return model +def infer_latent(args): + data_dir = args.data_dir + train_dir = args.train_dir + dump_dir = args.dump_dir + num_of_examples = args.latent_num_of_examples + rotation_freq = args.rotation_freq + config = load_config(train_dir) + metadata = load_metadata_flatten(data_dir) + layer = config["layer"] + + client = AsyncOpenAI( + api_key=os.environ.get("OPENAI_API_KEY"), + timeout=60.0, + http_client=httpx.AsyncClient( + limits=httpx.Limits( + max_keepalive_connections=100, + max_connections=1000 + ), + headers={"Connection": "close"}, + ), + max_retries=3, + ) + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + base_model = create_base_model(args, device) + tokenizer = create_tokenizer(args) + + benchmark_models = {} + for model_name in args.models: + if model_name in LATENT_EXCLUDE_MODELS: + continue + model_class = getattr(axbench, model_name) + model = setup_benchmark_model( + model_class, base_model, tokenizer, layer, + metadata, "latent", args + ) + logger.info(f"Model {model_name} initialized on {device}") + + benchmark_models[model_name] = model + + dataset_factory = ReAXFactory( + base_model, client, tokenizer, dump_dir, + use_cache=True, master_data_dir=args.master_data_dir, + lm_model=args.lm_model + ) + atexit.register(dataset_factory.save_cache) + atexit.register(dataset_factory.reset_stats) + + state = load_state(args.dump_dir, "latent") + start_concept_id = state.get("concept_id", 0) if state else 0 + + progress_bar = tqdm( + range(start_concept_id, len(metadata)), + initial=start_concept_id, + total=len(metadata), + desc="Infer Latent" + ) + + for concept_id in progress_bar: + # Create data for current concept + current_df = create_data_latent( + dataset_factory, metadata, concept_id, num_of_examples, args) + + # Process concept + for model_name, model in benchmark_models.items(): + results = model.predict_latent(current_df) + for k, v in results.items(): + current_df[f"{model_name}_{k}"] = v + + # Save results + save(dump_dir, {"concept_id": concept_id + 1}, concept_id, "latent", + current_df, rotation_freq) + # Clean up GPU memory + torch.cuda.empty_cache() + +def infer_steering(args): data_dir = args.data_dir train_dir = args.train_dir dump_dir = args.dump_dir @@ -193,20 +281,8 @@ def infer_steering(args): layer = config["layer"] n_steering_factors = args.n_steering_factors steering_datasets = args.steering_datasets - - # Load lm. - model = AutoModelForCausalLM.from_pretrained( - args.steering_model_name if args.steering_model_name else args.model_name, - device_map="cpu" - ) - model.config.use_cache = False - model = model.cuda() - model = model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.steering_model_name) - tokenizer.padding_side = "right" - # Create a new OpenAI client. - lm_client = AsyncOpenAI( + client = AsyncOpenAI( api_key=os.environ.get("OPENAI_API_KEY"), timeout=60.0, http_client=httpx.AsyncClient( @@ -219,58 +295,105 @@ def infer_steering(args): max_retries=3, ) - state = load_state(args.dump_dir, "steering") - start_concept_id = state.get("concept_id", 0) if state else 0 - logger.warning(f"Starting concept index: {start_concept_id}") - progress_bar = tqdm(range(start_concept_id, len(metadata)), desc="Inferencing with concepts") + # Create base model and tokenizer + device = "cuda:0" if torch.cuda.is_available() else "cpu" + base_model = create_base_model(args, device) + tokenizer = create_tokenizer(args) + + # Create benchmark models + benchmark_models = {} + for model_name in args.models: + if model_name in STEERING_EXCLUDE_MODELS: + continue + model_class = getattr(axbench, model_name) + model = setup_benchmark_model( + model_class, base_model, tokenizer, layer, + metadata, "steering", args + ) + benchmark_models[model_name] = model - # We dont need to load dataset factory for steering, only existing datasets. + # Create dataset factory dataset_factory = SteeringDatasetFactory( - model, tokenizer, dump_dir, - master_data_dir=args.master_data_dir, lm_client=lm_client, + base_model, tokenizer, dump_dir, + master_data_dir=args.master_data_dir, + lm_client=client, lm_model=args.lm_model ) - # Pre-load inference models. - benchmark_models = [] - for model_name in args.models: - model_class = getattr(axbench, model_name) - logger.warning(f"Loading {model_class} from disk for inference.\n") - benchmark_model = model_class( - model, tokenizer, layer=layer, - low_rank_dimension=len(metadata)) - benchmark_model.load( - dump_dir=train_dir, sae_path=metadata[0]["ref"], mode="steering") - benchmark_models += [benchmark_model] - # Pre-compute mean activations for steering eval based on latent eval. - benchmark_model.pre_compute_mean_activations( - os.path.join(dump_dir, "inference"), master_data_dir=args.master_data_dir) + state = load_state(args.dump_dir, "steering") + start_concept_id = state.get("concept_id", 0) if state else 0 + + progress_bar = tqdm( + range(start_concept_id, len(metadata)), + initial=start_concept_id, + total=len(metadata), + desc="Infer Steering" + ) - torch.cuda.empty_cache() for concept_id in progress_bar: - # Create. + # Create data for current concept current_df, (_, sae_link, sae_id) = create_data_steering( - dataset_factory, metadata, concept_id, num_of_examples, - n_steering_factors, steering_datasets, args) - # Evaluate. - for model_idx, model_name in enumerate(args.models): - if model_name in STEERING_EXCLUDE_MODELS: - continue - results = benchmark_models[model_idx].predict_steer( - current_df, concept_id=concept_id, sae_link=sae_link, sae_id=sae_id, - batch_size=args.steering_batch_size, + dataset_factory, metadata, concept_id, + num_of_examples, + n_steering_factors, + steering_datasets, args) + + # Process concept + for model_name, model in benchmark_models.items(): + results = model.predict_steer( + current_df, + concept_id=concept_id, + sae_link=sae_link, + sae_id=sae_id, + batch_size=args.steering_batch_size, eval_output_length=args.steering_output_length ) for k, v in results.items(): current_df[f"{model_name}_{k}"] = v - # Save. + # Save results save(dump_dir, {"concept_id": concept_id + 1}, concept_id, "steering", - current_df, rotation_freq) + current_df, rotation_freq) + # Clean up GPU memory + torch.cuda.empty_cache() -def infer_latent(args): - +def worker_process(model_name, gpu_id, task_queue, result_queue, args, metadata, layer, shared_base_model=None): + # Set GPU environment + torch.cuda.set_device(gpu_id) + device = f'cuda:{gpu_id}' + + # Load base model and tokenizer + if shared_base_model is not None: + base_model = shared_base_model + else: + base_model = create_base_model(args, device) + base_model.eval() + tokenizer = create_tokenizer(args) + + # Create benchmark model + model_class = getattr(axbench, model_name) + model = setup_benchmark_model( + model_class, base_model, tokenizer, layer, + metadata, "steering", args + ) + + while True: + task = task_queue.get() + if task is None: # Exit signal + break + concept_id, current_df, sae_link, sae_id = task + results = model.predict_steer( + current_df, + concept_id=concept_id, + sae_link=sae_link, + sae_id=sae_id, + batch_size=args.steering_batch_size, + eval_output_length=args.steering_output_length + ) + result_queue.put((model_name, concept_id, results)) + +def infer_steering_multi_gpu(args): data_dir = args.data_dir train_dir = args.train_dir dump_dir = args.dump_dir @@ -279,14 +402,10 @@ def infer_latent(args): config = load_config(train_dir) metadata = load_metadata_flatten(data_dir) layer = config["layer"] - - # Load lm. - model = AutoModelForCausalLM.from_pretrained(args.model_name, device_map="cpu") - model.config.use_cache = False - model = model.cuda() - model = model.eval() - tokenizer = AutoTokenizer.from_pretrained(args.model_name) - tokenizer.padding_side = "right" + n_steering_factors = args.n_steering_factors + steering_datasets = args.steering_datasets + + factory_tokenizer = create_tokenizer(args) # Create a new OpenAI client. client = AsyncOpenAI( @@ -302,49 +421,102 @@ def infer_latent(args): max_retries=3, ) - # Load dataset factory for evals. - dataset_factory = ReAXFactory( - model, client, tokenizer, dump_dir, - use_cache=True, master_data_dir=args.master_data_dir, + # Create dataset factory and base model on GPU 0 + base_model = create_base_model(args, "cuda:0") + dataset_factory = SteeringDatasetFactory( + base_model, factory_tokenizer, dump_dir, + master_data_dir=args.master_data_dir, + lm_client=client, lm_model=args.lm_model ) - atexit.register(dataset_factory.save_cache) - atexit.register(dataset_factory.reset_stats) - # Pre-load inference models. - benchmark_models = [] - for model_name in args.models: - model_class = getattr(axbench, model_name) - logger.warning(f"Loading {model_class} from disk for inference.\n") - benchmark_model = model_class( - model, tokenizer, layer=layer, - low_rank_dimension=len(metadata)) - benchmark_model.load( - dump_dir=train_dir, sae_path=metadata[0]["ref"]) - benchmark_models += [benchmark_model] + # Set up multiprocessing + num_gpus = torch.cuda.device_count() + if num_gpus == 0: + num_gpus = 1 # Use CPU if no GPU is available - state = load_state(args.dump_dir, "latent") + task_queue = Queue() + result_queue = Queue() + + processes = [] + model_names = [m for m in args.models if m not in STEERING_EXCLUDE_MODELS] + num_workers = len(model_names) + + for idx, model_name in enumerate(model_names): + gpu_id = idx % num_gpus if torch.cuda.is_available() else -1 + p = Process(target=worker_process, args=( + model_name, gpu_id, task_queue, result_queue, args, metadata, layer, + base_model if gpu_id == 0 else None)) + p.start() + processes.append(p) + + state = load_state(args.dump_dir, "steering") start_concept_id = state.get("concept_id", 0) if state else 0 - logger.warning(f"Starting concept index: {start_concept_id}") - progress_bar = tqdm(range(start_concept_id, len(metadata)), desc="Inferencing with concepts") - - torch.cuda.empty_cache() - for concept_id in progress_bar: - # Create. - current_df = create_data_latent( - dataset_factory, metadata, concept_id, num_of_examples, args) - # Evaluate. - for model_idx, model_name in enumerate(args.models): - if model_name in LATENT_EXCLUDE_MODELS: + progress_bar = tqdm( + range(start_concept_id, len(metadata)), + initial=start_concept_id, + total=len(metadata), + desc="Processing concepts in multi-GPU steering" + ) + + for concept_id in progress_bar: + # Create data for current concept + current_df, (_, sae_link, sae_id) = create_data_steering( + dataset_factory, metadata, concept_id, + num_of_examples, + n_steering_factors, + steering_datasets, args) + + # Distribute task to all workers + for _ in range(num_workers): + task_queue.put((concept_id, current_df, sae_link, sae_id)) + + # Collect results from all workers for this concept + results_for_concept = {} + remaining_workers = num_workers + while remaining_workers > 0: + try: + model_name, c_id, results = result_queue.get(timeout=30) + if results is not None: + results_for_concept[model_name] = results + remaining_workers -= 1 + except Empty: + # Check if any worker died + if any(not p.is_alive() for p in processes): + logger.error("One or more workers failed, terminating all processes") + # Terminate all processes + for p in processes: + if p.is_alive(): + p.terminate() + # Wait for all processes to finish + for p in processes: + p.join() + # Clean up GPU memory + torch.cuda.empty_cache() + # Exit the program + sys.exit(1) + # All workers still alive, continue waiting continue - results = benchmark_models[model_idx].predict_latent(current_df) + + # Update DataFrame with results + for model_name, results in results_for_concept.items(): for k, v in results.items(): current_df[f"{model_name}_{k}"] = v - - # Save. - save(dump_dir, {"concept_id": concept_id + 1}, concept_id, "latent", - current_df, rotation_freq) + + # Save results + save(dump_dir, {"concept_id": concept_id + 1}, concept_id, "steering", + current_df, rotation_freq) + + # Send exit signal to all workers + for _ in range(num_workers): + task_queue.put(None) + + # Clean up workers + for p in processes: + p.join() + torch.cuda.empty_cache() + def main(): custom_args = [ @@ -358,10 +530,13 @@ def main(): } ] args = DatasetArgs(custom_args=custom_args) + if not args.steering_model_name: + args.steering_model_name = args.model_name + logger.warning("Inferencing with following configuration:") logger.warning(args) set_seed(args.seed) - + def check_latent_eval_done(args): # Check if at least one latent eval fragment exists. if os.path.exists(os.path.join( @@ -375,8 +550,11 @@ def check_latent_eval_done(args): # steering eval must be done after latent eval. if not check_latent_eval_done(args): raise ValueError("Latent eval must be done before steering eval.") - infer_steering(args) - + if args.multi_gpu: + mp.set_start_method('spawn', force=True) + infer_steering_multi_gpu(args) + else: + infer_steering(args) if __name__ == "__main__": main()