Welcome to TPT β Think β’ Prune β’ Train! A framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.
TPT is a threeβstep, iterative workflow:
- Think β The model generates multiple, detailed solution traces.
- Prune β We automatically keep only the traces that reach the correct answer.
- Train β The model fineβtunes on this highβquality synthetic data to boost its skills.
Loop the cycle β watch the model level up. β¨
Below is the minimal commandβline recipe for each stage. Adjust paths/flags to taste.
Produce N solution attempts per question.
python gen_synth.py \ --model_name google/gemma-2-2b-it \ --max_model_len 1500 \ --num_samples 5 \ --math data/gsm8ktrain.json \ --output_dir samples/math_train/2b
Outputs go to samples/math_train/ft/e0.json ... e5.json.
- Score correctness with
evmath.py(example):This writespython evmath.py --samples_dir samples/math_train/ft --answer_path data gsm8ktrain --num_samples 5
correct_answers.jsonandpass_at_k_results.json. - Create new train/eval JSON:
python make_json.py \ --input samples/math_train/correct_answers.json \ --train_output data/next/train2k.json \ --eval_output data/next/evnext.json \ --train_size 2000
Use the new data in the next TPT cycle (back to Train).
Fineβtune the base model used to generate the data on the created dataset.
python sft_math.py \ --model_name_or_path google/gemma-2-2b-it \ --train_data_path data/next/train2k.json \ --eval_data_path data/next/evnext.json \ --learning_rate 1e-6 \ --output_dir gemma-tpt
This produces a checkpoint under gemma-tpt/ and logs to W&B (set your project and name inside the script).
TPT/
βββ data/ # Datasets (initial + generated)
βββ gemma-tpt/ # Model checkpoints & artifacts
βββ samples/ # Synthetic traces
βββ wandb/ # Experiment tracking
βββ evmath.py # Scoring / pruning script
βββ gen_eval.py # Generates evaluation questions
βββ gen_synth.py # Synthetic generation script (Think)
βββ make_json.py # Builds new train/eval JSON (Prune)
βββ sft_math.py # Supervised fineβtune (Train)
βββ README.md # You are here
βββ requirements.txt # Python deps
- Python 3.10
pip
git clone <repository-url> cd <repository-folder> # Create & activate venv python3.10 -m venv tpt_env source tpt_env/bin/activate # Windows: tpt_env\Scripts\activate # Install deps python3.10 -m pip install -r requirements.txt # Extra: flashinfer wheel (for vLLMβFlashAttention) python3.10 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.3-cp310-cp310-linux_x86_64.whl
Activate later with:
source tpt_env/bin/activate Ready? Time to Think β Prune β Train and watch your model improve