Skip to content

apple/ml-clara

Repository files navigation

CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning

Paper License deploy deploy deploy

This software project accompanies the research paper, CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning.

Updates

  • Nov 25, 2025. Models are available on Huggingface.
  • Dec 3, 2025. Evaluation data are available in ./evaluation/evaluation_data.

Motivation

Retrieval-Augmented Generation (RAG) enhances large language models with external knowledge but suffers from long contexts and disjoint retrieval-generation optimization. Existing soft compression frameworks face two key limitations: (i) reconstruction-based objectives bias compressors toward surface patterns rather than semantic preservation; (ii) retrievers and compressors are trained separately, requiring double encoding despite compressed vectors being inherently retrievable.

In this work, we investigate:

  • How can we improve semantic preservation in compressed representations through better pretraining objectives?
  • How can we unify retrieval and generation optimization to avoid redundant encoding and disjoint objectives?

We design a Three-stage training approach and introduce document compression techniques to improve RAG efficiency. The key findings are listed below.

Findings

  • Efficient Compression: CLaRa achieves significant compression rates (32x-64x) while preserving essential information for accurate answer generation.

  • Three-Stage Training: A carefully designed Three-stage training approach (compression pretraining + compression instruction tuning + end-to-end fine-tuning) enables effective learning of both retrieval and generation.

For more interesting findings, please refer to our original paper!


Three-Stage Training

CLaRa uses a carefully designed three-stage training approach:

Stage 1: Compression Pretraining

  • Train the compressor using SCP framework with QA pairs and paraphrases
  • Retain key semantics through QA-based and paraphrase-guided supervision
  • Support compression rates of 1x-256x

Stage 2: Compression Instruction Tuning

  • Fine-tune the compressor on instruction-following tasks for downstream QA
  • Use text-based QA output to ensure compressed representations retain sufficient semantics

Stage 3: End-to-End Fine-tuning (CLaRa)

  • Jointly train reranker and generator via a single language modeling loss
  • Unify retrieval and generation in shared continuous space using differentiable top-k estimator

In this repository, we release our implementation of CLaRa, built upon OpenRLHF.

Getting Started

├── scripts/                      # Training and evaluation scripts
│   ├── train_pretraining.sh     # Stage 1: Compression pretraining
│   ├── train_instruction_tuning.sh  # Stage 2: Compression instruction tuning
│   ├── train_stage_end_to_end.sh    # Stage 3: End-to-end training
│   └── evaluation_end_to_end.sh     # Evaluation scripts
├── openrlhf/                     # Core training framework
│   ├── models/                   # Model implementations
│   │   └── modeling_clara.py   # CLaRa model definition
│   ├── datasets/                 # Dataset handling
│   │   └── sft_dataset.py        # Training dataset
│   ├── trainer/                  # Training utilities
│   │   └── sft_trainer.py        # SFT trainer
│   └── cli/                      # Command line interface
│       └── train_sft.py          # Main training script
├── evaluation/                   # Evaluation framework
├── example/                      # Example training data
│   ├── pretrain_data.jsonl
│   ├── instruction_tuning_data.jsonl
│   └── end_to_end_data.jsonl
└── README.md                     # This file

1. Prepare code and environment

Clone the repository and set up the environment:

# Create conda environment
env=clara
conda create -n $env python=3.10 -y
conda activate $env

# Install dependencies
pip install -r requirements.txt

# Set up environment variables
export PYTHONPATH=/path/to/clara:$PYTHONPATH

Key dependencies include:

  • PyTorch >= 2.0
  • Transformers >= 4.20
  • DeepSpeed >= 0.18
  • Flash Attention 2
  • Accelerate

2. Data preparation

Prepare training data in JSONL format. For pretraining stage:

# Example data format for pretraining
{
    "data_type": "qa",
    "question": ["Question 1",],
    "answers": ["Answer 1"],
    "docs": ["Document 1"]
}

For end-to-end training:

{
    "question": "Single question text",
    "docs": ["Document 1", "Document 2", ...],
    "gold_answer": "Reference answer"
}

3. Start training

Stage 1: Salient Compressor Pretraining (SCP)

Pre-train the document compressor :

bash scripts/train_pretraining.sh

Key parameters:

  • --compress_rate: Compression rate (default: 32)
  • --doc_max_length: Maximum document length (default: 256)
  • --stage stage1: Training stage
  • --mse_loss: Use MSE loss to align compressed and original representations
  • --qa_loss: Use QA loss for semantic preservation

Stage 2: Compression Instruction Tuning

Fine-tune the compressor on instruction-following tasks:

bash scripts/train_instruction_tuning.sh

Key parameters:

  • --pretrain_checkpoint: Path to stage 1 checkpoint
  • --stage stage1_2: Training stage
  • --generation_top_k: Top-k sampling for generation (default: 5)
  • --mse_loss: Use MSE loss for compression training
  • --do_eval_gen: Enable generation evaluation

Stage 3: End-to-End Training

Fine-tune the model end-to-end with retrieval:

bash scripts/train_stage_end_to_end.sh

Key parameters:

  • --pretrain_checkpoint: Path to stage 2 checkpoint
  • --stage stage2: Training stage
  • --generation_top_k: Top-k sampling for generation
  • --do_eval_gen: Enable generation evaluation

4. Distributed Training

The training scripts support distributed training across multiple nodes and GPUs:

  • --max_len: Maximum sequence length (default: 2048 for stage1/stage2, 1024 for stage3)
  • --train_batch_size: Training batch size
  • --micro_train_batch_size: Micro batch size for gradient accumulation
  • --learning_rate: Learning rate (default: 1e-4 for stage1/stage2, 5e-6 for stage3)
  • --max_epochs: Maximum training epochs
  • --zero_stage: ZeRO optimization stage (default: 2)
  • --bf16: Use bfloat16 precision
  • --flash_attn: Use Flash Attention 2

Inference

The CLaRa models can be loaded and used for inference. We provide three models corresponding to different training stages:

Stage 1: Compression Pretraining model (click to expand)
from transformers import AutoModel

model_path = "path/to/stage1/model"
model = AutoModel.from_pretrained(
    model_path, 
    trust_remote_code=True
).to('cuda')

# Example documents
documents = [
    [
        "Document 1 content...",
        "Document 2 content...",
        "Document 3 content..."
    ]
]

questions = ["" for _ in range(len(documents))]

# Generate paraphrase from compressed representations
output = model.generate_from_paraphrase(
    questions=questions, 
    documents=documents, 
    max_new_tokens=64
)

print('Generated paraphrase:', output[0])
Stage 2: Compression Instruction Tuning model (click to expand)
from transformers import AutoModel

model_path = "path/to/stage2/model"
model = AutoModel.from_pretrained(
    model_path, 
    trust_remote_code=True
).to('cuda')

# Example documents and question
documents = [
    [
        "Document 1 content...",
        "Document 2 content...",
        "Document 3 content..."
    ]
]

questions = ["Your question here"]

# Generate answer from compressed representations
output = model.generate_from_text(
    questions=questions, 
    documents=documents, 
    max_new_tokens=64
)

print('Generated answer:', output[0])
Stage 3: End-to-End (CLaRa) model (click to expand)
from transformers import AutoModel

model_path = "path/to/stage3/model"
model = AutoModel.from_pretrained(
    model_path, 
    trust_remote_code=True
).to('cuda')

# Example documents and question
# Note: Stage 3 supports retrieval with multiple candidate documents
documents = [
    ["Document 1 content..." for _ in range(20)]  # 20 candidate documents
]

questions = ["Your question here"]

# Generate answer with retrieval and reranking
# The top-k is decided by generation_top_k in config.json
output, topk_indices = model.generate_from_questions(
    questions=questions, 
    documents=documents, 
    max_new_tokens=64
)

print('Generated answer:', output[0])
print('Top-k selected document indices:', topk_indices)

Evaluation

The evaluation framework is based on standard RAG benchmarks. Run evaluation:

End-to-end evaluation:

bash scripts/evaluation_end_to_end.sh

Instruction tuning evaluation:

bash scripts/evaluation_instruction_tuning.sh

Supported datasets:

  • HotpotQA: Multi-hop question answering
  • MuSiQue: Multi-hop question answering with diverse reasoning
  • 2WikiMultiHopQA: Multi-hop question answering over Wikipedia
  • Natural Questions: Open-domain question answering

Results

Compression Performance

We evaluate our document compressor on four QA datasets (NQ, HotpotQA, MuSiQue, 2WikiMultiHopQA) under two settings: Normal (retrieving top-5 documents) and Oracle (gold document included). CLaRa consistently outperforms all baselines across different compression ratios.

Main Results (Mistral-7B, Normal Setting)

Model CR NQ HotpotQA MuSiQue 2Wiki Avg
AutoCompressor - 17.24 14.61 3.81 19.89 13.89
XRAG 128 32.35 25.16 3.64 28.79 22.48
COCOM 16 24.12 21.48 3.52 24.48 18.40
PCC 16 31.38 22.29 3.43 19.47 19.14
LLMLingua-2 4 47.53 37.05 9.02 44.35 34.49
PISCO 16 54.39 41.94 10.09 44.88 37.83
Mistral-7B w/ retrieval - 54.58 42.94 8.94 44.24 37.67
CLaRa (CR=4) 4 57.05 45.09 10.34 46.94 39.86
CLaRa (CR=16) 16 55.56 43.72 10.55 46.00 38.96
CLaRa (CR=32) 32 54.64 43.52 10.55 46.58 38.82

Oracle Setting Results (Mistral-7B)

Model CR NQ HotpotQA MuSiQue 2Wiki Avg
PISCO 16 73.44 66.53 33.80 60.45 58.55
Mistral-7B w/ retrieval - 71.64 70.77 45.72 68.83 64.24
CLaRa (CR=4) 4 76.50 73.81 46.26 70.48 66.76
CLaRa (CR=16) 16 75.48 70.79 43.15 66.16 63.90
CLaRa (CR=32) 32 73.77 69.51 38.31 64.54 61.53

Key Findings:

  • ✅ CLaRa outperforms PISCO by +1.13% (Normal) and +5.35% (Oracle) on average
  • ✅ CLaRa outperforms LLMLingua-2 by +5.37% (Normal) on average
  • ✅ CLaRa matches/exceeds text-based baseline with +2.36% average gain on Mistral-7B

Retrieval Performance

For detailed experimental results and analysis, please refer to our paper.

Acknowledgments

We sincerely appreciate the following works for CLaRa:

Citation

@misc{he2025clarabridgingretrievalgeneration,
      title={CLaRa: Bridging Retrieval and Generation with Continuous Latent Reasoning}, 
      author={Jie He and Richard He Bai and Sinead Williamson and Jeff Z. Pan and Navdeep Jaitly and Yizhe Zhang},
      year={2025},
      eprint={2511.18659},
      archivePrefix={arXiv},
      primaryClass={cs.CL},
      url={https://arxiv.org/abs/2511.18659}, 
}

About

No description, website, or topics provided.

Resources

License

Code of conduct

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published