Feel free to star the repo or cite the paper if you find it interesting.
@article{fu2025tah, title={Think-at-Hard: Selective Latent Iterations to Improve Reasoning Language Models}, author={Tianyu Fu and Yichen You and Zekai Chen and Guohao Dai and Huazhong Yang and Yu Wang}, journal={arXiv preprint arXiv:2510.08577}, year={2025}, }
-
[2025/11] We released the TaH-plus-1.7B checkpoint. The model is finetuned from Qwen3-1.7B-Base using 100K samples from the OpenR1 dataset, capable of QA, math, and coding.
-
[2025/11] Our paper was featured as the #2 Paper of the Day on Huggingface Daily Papers
Create a new environment:
conda create -n tah python=3.10 conda activate tah
Install the package:
pip install -e .For training and evaluation, install additional dependencies:
pip install -e ".[training,evaluation]"For code generation evaluation, install evalplus
python script/playground/inference_example.py
This script demonstrates TaH's selective latent iteration mechanism, with color-coded output showing the iteration count for each token.
python script/evaluation/eval.py \ --eval_config ./script/recipes/qwen3_1.7/eval_tah.yaml \ --model_path nics-efc/TaH-plus-1.7B \ --dataset_name gsm8k \ --backend tah \ --job_nums 8 \ --tp_size_per_job 1
Key parameters:
--eval_config: Path to evaluation config file--model_path: Path to the model--dataset_name: Dataset name (supports gsm8k, math500, aime24, etc. Detailed configs can be found intah/evaluate/eval_configs/dataset_configs.json)--backend: Inference backend (tahfor TaH)--job_nums: Number of parallel jobs--tp_size_per_job: Tensor parallel size per job
python script/evaluation/eval.py \ --eval_config ./script/recipes/qwen3_1.7/eval_base.yaml \ --model_path nics-efc/Standard-1.7B \ --dataset_name gsm8k \ --backend hf \ --job_nums 8 \ --tp_size_per_job 1
Similar to TaH evaluation, but using:
--backend hfor--backend sglang
Training a TaH model consists of three stages:
1. Prepare training data
Use a reference model to generate hard token labels for the training and validation data:
### step 0 # download the default subset of OpenR1-Math-220k python script/preparation/download.py # filter and split python script/preparation/filter_split.py # label the hard tokens python script/preparation/label.py \ --num_gpu 8 \ --dataset_path ./data/initial_data/openr1-math/train.jsonl \ --test_model_list Qwen/Qwen3-1.7B \ --output_path ./data/processed_data/openr1-math/1_7/train \ --max_input_length 10000 python script/preparation/label.py \ --num_gpu 8 \ --dataset_path ./data/initial_data/openr1-math/eval.jsonl \ --test_model_list Qwen/Qwen3-1.7B \ --output_path ./data/processed_data/openr1-math/1_7/eval \ --max_input_length 10000 \
2. (Optional) Prepare pruned model
For the TaH version, prune one layer from the base model to match the parameter count of the standard baseline (skip this step for TaH+ version):
### step 0
python script/preparation/prune.py \
--model Qwen/Qwen3-1.7B-Base \
--dataset ./data/processed_data/openr1-math/1_7/eval \
--output ./model/qwen3_1.7_base_pruned \
--num_prune 1The first stage uses fixed iteration labels for training:
### step 1
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step1.yamlKey configurations in Step1 (sft_tah_step1.yaml):
max_iter: 2: Maximum number of iterationsiter_decider: "FixedLabelIterDecider": Use fixed labels to decide iterationsiter_label_generator: "FixedIterLabelGenerator": Generate labels from mismatch field in datainput_updater: "AdditiveUpdater": Use additive updater for input updatesadapter: "lora": Use LoRA adapter for deeper iterationtrain_loss: "NextTokenPredLoss": Next token prediction loss
The second stage trains the iteration decider:
### step 2
python -m accelerate.commands.launch \
--config_file ./script/recipes/accelerate_configs/zero2.yaml \
--num_processes 8 \
./script/train/SFT_TaH.py \
--config ./script/recipes/qwen3_1.7/sft_tah_step2.yamlKey configurations in Step2 (sft_tah_step2.yaml):
tah_model_path: Load the model trained in Step1iter_decider: "MLPIterDecider": Use MLP decider to automatically determine iterationstrain_loss: "IterDeciderLoss": Iteration decider loss functionfreeze_component: [model.simple_base_model]: Freeze model backbone
After two-stage training, the model can automatically decide when to perform latent reasoning iterations.
TaH/
βββ tah/ # Core package
β βββ model/ # Core model components
β βββ train/ # Training components
β βββ evaluate/ # Evaluation utilities
β βββ utils/ # General utilities
βββ bash/ # Bash scripts for training and evaluation
βββ script/ # Execution scripts
β βββ analysis/ # Analysis scripts
β βββ evaluation/ # Evaluation scripts
β βββ preparation/ # Preparation for training
β β βββ label.py # Data labeling (generate mismatch labels)
β β βββ prune.py # Model pruning
β βββ playground/ # Some examples
β βββ recipes/ # Configuration files
β βββ qwen3_0.6/ # Qwen3-0.6B-Base configs
β βββ qwen3_1.7/ # Qwen3-1.7B-Base configs
β βββ accelerate_configs/ # Distributed training configs
βββ pyproject.toml # Project configuration
- Support more inference backends (e.g., SGLang)
- Optimize iteration decision strategies
- Integrate TaH with online distillation or RL
- Support training for larger models
Explore more efficient LLM projects from us: