Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

thu-nics/TaH

Folders and files

NameName
Last commit message
Last commit date

Latest commit

History

15 Commits

Repository files navigation

Project TaH Logo

Think-at-Hard

Selective Latent Iterations to Improve Reasoning Language Models

🌐 Project Page β€’ πŸ“‘ Paper β€’ πŸ€— HuggingFace

Think-at-Hard (TaH) improves LLM reasoning by running extra latent iterations only on hard tokens instead of all tokens. A lightweight decider and duo-causal attention enable targeted refinement while keeping full parallelism. TaH outperforms fixed two-iteration baselines by 8–11% while skipping 94% of second iterations, and also beats strong single-iteration Qwen3 models by 4–5%.

Feel free to star the repo or cite the paper if you find it interesting.

@article{fu2025tah,
 title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models}, 
 author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang},
 journal={arXiv preprint arXiv:2510.08577},
 year={2025},
}

News

Environment Setup

Create a new environment:

conda create -n tah python=3.10
conda activate tah

Install the package:

pip install -e .

For training and evaluation, install additional dependencies:

pip install -e ".[training,evaluation]"

For code generation evaluation, install evalplus

Run an example for TaH

python script/playground/inference_example.py

This script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token.

Run evaluation

Evaluate TaH model

python script/evaluation/eval.py \
 --eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \
 --model_path nics-efc/TaH-plus-1.7B \
 --dataset_name gsm8k \
 --backend tah \
 --job_nums 8 \
 --tp_size_per_job 1

Key parameters:

  • --eval_config: Path to evaluation config file
  • --model_path: Path to the model
  • --dataset_name: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found in tah/evaluate/eval_configs/dataset_configs.json)
  • --backend: Inference backend (tah for TaH)
  • --job_nums: Number of parallel jobs
  • --tp_size_per_job: Tensor parallel size per job

Evaluate standard baseline model

python script/evaluation/eval.py \
 --eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \
 --model_path nics-efc/Standard-1.7B \
 --dataset_name gsm8k \
 --backend hf \
 --job_nums 8 \
 --tp_size_per_job 1

Similar to TaH evaluation, but using:

  • --backend hf or --backend sglang

Train your own TaH model

Training a TaH model consists of three stages:

Step0: Prepare model and data

1. Prepare training data

Use a reference model to generate hard token labels for the training and validation data:

### step 0
# download the default subset of OpenR1-Math-220k
python script/preparation/download.py
# filter and split
python script/preparation/filter_split.py
# label the hard tokens
python script/preparation/label.py \
 --num_gpu 8 \
 --dataset_path ./data/initial_data/openr1-math/train.jsonl \
 --test_model_list Qwen/Qwen3-1.7B \
 --output_path ./data/processed_data/openr1-math/1_7/train \
 --max_input_length 10000
python script/preparation/label.py \
 --num_gpu 8 \
 --dataset_path ./data/initial_data/openr1-math/eval.jsonl \
 --test_model_list Qwen/Qwen3-1.7B \
 --output_path ./data/processed_data/openr1-math/1_7/eval \
 --max_input_length 10000 \

2. (Optional) Prepare pruned model

For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version):

### step 0
python script/preparation/prune.py \
 --model Qwen/Qwen3-1.7B-Base \
 --dataset ./data/processed_data/openr1-math/1_7/eval \
 --output ./model/qwen3_1.7_base_pruned \
 --num_prune 1

Step1: Train with Fixed Iteration Labels

The first stage uses fixed iteration labels for training:

### step 1
python -m accelerate.commands.launch \
 --config_file ./script/recipes/accelerate_configs/zero2.yaml \
 --num_processes 8 \
 ./script/train/SFT_TaH.py \
 --config ./script/recipes/qwen3_1.7/sft_tah_step1.yaml

Key configurations in Step1 (sft_tah_step1.yaml):

  • max_iter: 2: Maximum number of iterations
  • iter_decider: "FixedLabelIterDecider": Use fixed labels to decide iterations
  • iter_label_generator: "FixedIterLabelGenerator": Generate labels from mismatch field in data
  • input_updater: "AdditiveUpdater": Use additive updater for input updates
  • adapter: "lora": Use LoRA adapter for deeper iteration
  • train_loss: "NextTokenPredLoss": Next token prediction loss

Step2: Train Iteration Decider

The second stage trains the iteration decider:

### step 2
python -m accelerate.commands.launch \
 --config_file ./script/recipes/accelerate_configs/zero2.yaml \
 --num_processes 8 \
 ./script/train/SFT_TaH.py \
 --config ./script/recipes/qwen3_1.7/sft_tah_step2.yaml

Key configurations in Step2 (sft_tah_step2.yaml):

  • tah_model_path: Load the model trained in Step1
  • iter_decider: "MLPIterDecider": Use MLP decider to automatically determine iterations
  • train_loss: "IterDeciderLoss": Iteration decider loss function
  • freeze_component: [model.simple_base_model]: Freeze model backbone

After two-stage training, the model can automatically decide when to perform latent reasoning iterations.

Understand the Code

Code Structure

TaH/
β”œβ”€β”€ tah/ # Core package
β”‚ β”œβ”€β”€ model/ # Core model components
β”‚ β”œβ”€β”€ train/ # Training components
β”‚ β”œβ”€β”€ evaluate/ # Evaluation utilities
β”‚ └── utils/ # General utilities
β”œβ”€β”€ bash/ # Bash scripts for training and evaluation
β”œβ”€β”€ script/ # Execution scripts
β”‚ β”œβ”€β”€ analysis/ # Analysis scripts
β”‚ β”œβ”€β”€ evaluation/ # Evaluation scripts
β”‚ β”œβ”€β”€ preparation/ # Preparation for training
β”‚ β”‚ β”œβ”€β”€ label.py # Data labeling (generate mismatch labels)
β”‚ β”‚ └── prune.py # Model pruning
β”‚ β”œβ”€β”€ playground/ # Some examples
β”‚ └── recipes/ # Configuration files
β”‚ β”œβ”€β”€ qwen3_0.6/ # Qwen3-0.6B-Base configs
β”‚ β”œβ”€β”€ qwen3_1.7/ # Qwen3-1.7B-Base configs
β”‚ └── accelerate_configs/ # Distributed training configs
└── pyproject.toml # Project configuration

Future Work

  • Support more inference backends (e.g., SGLang)
  • Optimize iteration decision strategies
  • Integrate TaH with online distillation or RL
  • Support training for larger models

Related Projects

Explore more efficient LLM projects from us:

R2R
Token-level routing for reasoning LLMs C2C
Communicate through KV-Cache between LLMs FrF
Efficient video token reduction for LVLMs MoA
Mixture of sparse attention for LLMs

About

Official implementation of paper "Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models"

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

Contributors

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /