Skip to content

Shengyu-Feng/TSMC4MATH

Repository files navigation

Step-by-Step Reasoning for Math Problems via Twisted Sequential Monte Carlo

image

The reproduced code for Step-by-Step Reasoning for Math Problems via Twisted Sequential Monte Carlo, ICLR 2025. The implementation is based on gpt-accelera. If you have any question, please contact me at: [email protected].

Setup

Install the conda virtual environment through

conda env create -f environment.yml
conda activate TSMC

Download data (see ./data/README.md).

Training

Let us take the training of llemma-7b on 4 A6000 GPUs as an example. We first need to prepare the model checkpoint in the gpt-fast format:

export DATA_DIR=/path/to/your/data/directory
export MODEL_REPO=EleutherAI/llemma_7b
export OMP_NUM_THREADS=4
export TOKENIZERS_PARALLELISM=True

python scripts/download.py \
    --repo_id $MODEL_REPO \
    --local_dir $DATA_DIR/checkpoints

python scripts/convert_hf_checkpoint.py \
    --checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
    --target_precision bf16

Fine-tune the generator

Then we fine-tune the generator with PRM800K:

torchrun --standalone --nproc_per_node=4 \
    train_sft.py \
    --do_train \
    --checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
    --source_max_len 768 \
    --target_max_len 768 \
    --total_max_len 768 \
    --per_device_train_batch_size 32 \
    --micro_train_batch_size 4 \
    --learning_rate 2e-5 \
    --lr_eta_min 2e-7 \
    --num_train_epochs 3 \
    --dataset $DATA_DIR/all_data/prm800k_sft_train.json \
    --dataset_format "prm-v2" \
    --add_eos_to_marked_target \
    --save_strategy epoch \
    --save_total_limit 1 \
    --save_dir $DATA_DIR/checkpoints/prm800k_sft \
    --resume_from_checkpoint

Estimate the value function

To estimate the value function, we need to generate the training and validation datasets with the generator

torchrun --standalone --nproc_per_node=4 \
    inference_generate.py \
    --prompt_file $DATA_DIR/all_data/math500_train.json \
    --checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
    --finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
    --num_samples 40 \
    --output_file math500_tsmc_train.json \
    --batch_size  40 \
    --top_k 20 \
    --temperature 0.7 \
    --default_compile \
    --max_new_tokens 768 \
    --generate_training_data \
    --resume_generation 
    
torchrun --standalone --nproc_per_node=4 \
    inference_generate.py \
    --prompt_file $DATA_DIR/all_data/math500_valid.json \
    --checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
    --finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
    --num_samples 40 \
    --output_file math500_tsmc_valid.json \
    --batch_size  40 \
    --top_k 20 \
    --temperature 0.7 \
    --default_compile \
    --max_new_tokens 768 \
    --generate_training_data \
    --resume_generation     

Then launch the training with

torchrun --standalone --nproc_per_node=4 \
    train_tsmc.py \
    --train_prompt_file $DATA_DIR/all_data/math500_tsmc_train.json \
    --valid_prompt_file $DATA_DIR/all_data/math500_tsmc_valid.json \
    --output_file_prefix "train_tsmc_eval" \
    --checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
    --per_device_train_batch_size 10 \
    --per_device_eval_batch_size 10 \
    --warmup_ratio 0.05 \
    --tensor_parallel_size 4 \
    --source_max_len 384 \
    --target_max_len 768 \
    --learning_rate 1e-5 \
    --lr_eta_min 1e-7 \
    --num_train_epochs 2 \
    --save_strategy steps \
    --save_steps 1000 \
    --save_total_limit 1 \
    --save_only_model True \
    --loss_function CTL \
    --train_step \
    --compile \
    --save_dir $DATA_DIR/checkpoints/math500_tsmc_value

Inference

Here we set the total majority voting sample size N as 240, and the TSMC batch size M as 40 (check the meaning of these parameters in our paper). Launch the inference through

torchrun --standalone --nproc_per_node=4 \
    inference_tsmc.py \
    --prompt_file $DATA_DIR/all_data/math500_test.json \
    --checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
    --finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
    --finetune_reward_checkpoint_path $DATA_DIR/checkpoints/math500_tsmc_value \
    --num_samples 240 \
    --output_file math500_tsmc_solution.json \
    --batch_size 40 \
    --tsmc_batch_size 40 \
    --stop_step 5 \
    --top_k 20 \
    --compile \
    --temperature 0.7 \
    --tsmc_temperature 0.5 \
    --warmup 50 \
    --max_new_tokens 768 \
    --resume_generation

Finally, we can evaluate the generated solutions through

python evaluate.py --propmt_file <test prompt file> --solution_file math500_tsmc_solution.json --tsmc_batch_size 40 --num_sampels 240 --n_trials 100 --method orm_score

Citation

If you find this work helpful, please cite this paper:

@inproceedings{
    feng2025stepbystep,
    title={Step-by-Step Reasoning for Math Problems  via Twisted Sequential Monte Carlo},
    author={Shengyu Feng and Xiang Kong and Shuang Ma and Aonan Zhang and Dong Yin and Chong Wang and Ruoming Pang and Yiming Yang},
    booktitle={The Thirteenth International Conference on Learning Representations},
    year={2025},
    url={https://openreview.net/forum?id=Ze4aPP0tIn}
}

About

[ICLR2025] Step-by-Step Reasoning for Math Problems via Twisted Sequential Monte Carlo (https://arxiv.org/abs/2410.01920)

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages