Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Welcome to TPT, a framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.

Notifications You must be signed in to change notification settings

ScalingIntelligence/TPT

Repository files navigation

TPT Project

Welcome to TPT – Think β€’ Prune β€’ Train! A framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.


πŸš€ What is TPT?

TPT is a three‐step, iterative workflow:

  1. Think – The model generates multiple, detailed solution traces.
  2. Prune – We automatically keep only the traces that reach the correct answer.
  3. Train – The model fine‐tunes on this high‐quality synthetic data to boost its skills.

Loop the cycle β†’ watch the model level up. ✨


πŸ› οΈ Workflow & Commands

Below is the minimal command‐line recipe for each stage. Adjust paths/flags to taste.

1. Think – Generate Synthetic Traces (πŸ’‘ gen_synth.py)

Produce N solution attempts per question.

python gen_synth.py \
 --model_name google/gemma-2-2b-it \
 --max_model_len 1500 \
 --num_samples 5 \
 --math data/gsm8ktrain.json \
 --output_dir samples/math_train/2b

Outputs go to samples/math_train/ft/e0.json ... e5.json.

2. Prune & Split (βœ‚οΈ evmath.py β†’ πŸ“„ make_json.py)

  1. Score correctness with evmath.py (example):
    python evmath.py --samples_dir samples/math_train/ft --answer_path data gsm8ktrain --num_samples 5
    This writes correct_answers.json and pass_at_k_results.json.
  2. Create new train/eval JSON:
    python make_json.py \
     --input samples/math_train/correct_answers.json \
     --train_output data/next/train2k.json \
     --eval_output data/next/evnext.json \
     --train_size 2000

Use the new data in the next TPT cycle (back to Train).

3. Train (πŸš‚ sft_math.py)

Fine‐tune the base model used to generate the data on the created dataset.

python sft_math.py \
 --model_name_or_path google/gemma-2-2b-it \
 --train_data_path data/next/train2k.json \
 --eval_data_path data/next/evnext.json \
 --learning_rate 1e-6 \
 --output_dir gemma-tpt

This produces a checkpoint under gemma-tpt/ and logs to W&B (set your project and name inside the script).


πŸ“‚ Repository Structure

TPT/
β”œβ”€β”€ data/ # Datasets (initial + generated)
β”œβ”€β”€ gemma-tpt/ # Model checkpoints & artifacts
β”œβ”€β”€ samples/ # Synthetic traces
β”œβ”€β”€ wandb/ # Experiment tracking
β”œβ”€β”€ evmath.py # Scoring / pruning script
β”œβ”€β”€ gen_eval.py # Generates evaluation questions
β”œβ”€β”€ gen_synth.py # Synthetic generation script (Think)
β”œβ”€β”€ make_json.py # Builds new train/eval JSON (Prune)
β”œβ”€β”€ sft_math.py # Supervised fine‐tune (Train)
β”œβ”€β”€ README.md # You are here
β”œβ”€β”€ requirements.txt # Python deps

βš™οΈ Setup Guide

Prerequisites

  • Python 3.10
  • pip

Installation

git clone <repository-url>
cd <repository-folder>
# Create & activate venv
python3.10 -m venv tpt_env
source tpt_env/bin/activate # Windows: tpt_env\Scripts\activate
# Install deps
python3.10 -m pip install -r requirements.txt
# Extra: flashinfer wheel (for vLLM‐FlashAttention)
python3.10 -m pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.1.2/flashinfer-0.1.2+cu121torch2.3-cp310-cp310-linux_x86_64.whl

Activate later with:

source tpt_env/bin/activate 

Ready? Time to Think β†’ Prune β†’ Train and watch your model improve

About

Welcome to TPT, a framework for teaching large language models to solve math problems by learning from (and improving on) their own reasoning traces.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /