The reproduced code for Step-by-Step Reasoning for Math Problems via Twisted Sequential Monte Carlo, ICLR 2025. The implementation is based on gpt-accelera. If you have any question, please contact me at: [email protected].
Install the conda virtual environment through
conda env create -f environment.yml
conda activate TSMC
Download data (see ./data/README.md
).
Let us take the training of llemma-7b
on 4 A6000 GPUs as an example. We first need to prepare the model checkpoint in the gpt-fast
format:
export DATA_DIR=/path/to/your/data/directory
export MODEL_REPO=EleutherAI/llemma_7b
export OMP_NUM_THREADS=4
export TOKENIZERS_PARALLELISM=True
python scripts/download.py \
--repo_id $MODEL_REPO \
--local_dir $DATA_DIR/checkpoints
python scripts/convert_hf_checkpoint.py \
--checkpoint_dir $DATA_DIR/checkpoints/$MODEL_REPO \
--target_precision bf16
Then we fine-tune the generator with PRM800K:
torchrun --standalone --nproc_per_node=4 \
train_sft.py \
--do_train \
--checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
--source_max_len 768 \
--target_max_len 768 \
--total_max_len 768 \
--per_device_train_batch_size 32 \
--micro_train_batch_size 4 \
--learning_rate 2e-5 \
--lr_eta_min 2e-7 \
--num_train_epochs 3 \
--dataset $DATA_DIR/all_data/prm800k_sft_train.json \
--dataset_format "prm-v2" \
--add_eos_to_marked_target \
--save_strategy epoch \
--save_total_limit 1 \
--save_dir $DATA_DIR/checkpoints/prm800k_sft \
--resume_from_checkpoint
To estimate the value function, we need to generate the training and validation datasets with the generator
torchrun --standalone --nproc_per_node=4 \
inference_generate.py \
--prompt_file $DATA_DIR/all_data/math500_train.json \
--checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
--finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
--num_samples 40 \
--output_file math500_tsmc_train.json \
--batch_size 40 \
--top_k 20 \
--temperature 0.7 \
--default_compile \
--max_new_tokens 768 \
--generate_training_data \
--resume_generation
torchrun --standalone --nproc_per_node=4 \
inference_generate.py \
--prompt_file $DATA_DIR/all_data/math500_valid.json \
--checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
--finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
--num_samples 40 \
--output_file math500_tsmc_valid.json \
--batch_size 40 \
--top_k 20 \
--temperature 0.7 \
--default_compile \
--max_new_tokens 768 \
--generate_training_data \
--resume_generation
Then launch the training with
torchrun --standalone --nproc_per_node=4 \
train_tsmc.py \
--train_prompt_file $DATA_DIR/all_data/math500_tsmc_train.json \
--valid_prompt_file $DATA_DIR/all_data/math500_tsmc_valid.json \
--output_file_prefix "train_tsmc_eval" \
--checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
--per_device_train_batch_size 10 \
--per_device_eval_batch_size 10 \
--warmup_ratio 0.05 \
--tensor_parallel_size 4 \
--source_max_len 384 \
--target_max_len 768 \
--learning_rate 1e-5 \
--lr_eta_min 1e-7 \
--num_train_epochs 2 \
--save_strategy steps \
--save_steps 1000 \
--save_total_limit 1 \
--save_only_model True \
--loss_function CTL \
--train_step \
--compile \
--save_dir $DATA_DIR/checkpoints/math500_tsmc_value
Here we set the total majority voting sample size N as 240, and the TSMC batch size M as 40 (check the meaning of these parameters in our paper). Launch the inference through
torchrun --standalone --nproc_per_node=4 \
inference_tsmc.py \
--prompt_file $DATA_DIR/all_data/math500_test.json \
--checkpoint_path $DATA_DIR/checkpoints/$MODEL_REPO/model.pth \
--finetune_checkpoint_path $DATA_DIR/checkpoints/prm800k_sft \
--finetune_reward_checkpoint_path $DATA_DIR/checkpoints/math500_tsmc_value \
--num_samples 240 \
--output_file math500_tsmc_solution.json \
--batch_size 40 \
--tsmc_batch_size 40 \
--stop_step 5 \
--top_k 20 \
--compile \
--temperature 0.7 \
--tsmc_temperature 0.5 \
--warmup 50 \
--max_new_tokens 768 \
--resume_generation
Finally, we can evaluate the generated solutions through
python evaluate.py --propmt_file <test prompt file> --solution_file math500_tsmc_solution.json --tsmc_batch_size 40 --num_sampels 240 --n_trials 100 --method orm_score
If you find this work helpful, please cite this paper:
@inproceedings{
feng2025stepbystep,
title={Step-by-Step Reasoning for Math Problems via Twisted Sequential Monte Carlo},
author={Shengyu Feng and Xiang Kong and Shuang Ma and Aonan Zhang and Dong Yin and Chong Wang and Ruoming Pang and Yiming Yang},
booktitle={The Thirteenth International Conference on Learning Representations},
year={2025},
url={https://openreview.net/forum?id=Ze4aPP0tIn}
}