Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit b1cde79

Browse files
add sc2 code
1 parent b8d318c commit b1cde79

File tree

2 files changed

+319
-1
lines changed

2 files changed

+319
-1
lines changed

β€ŽREADME.mdβ€Ž

Lines changed: 176 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,176 @@
1-
# starcoder2
1+
# StarCoder 2
2+
3+
<p align="center"><a href="https://huggingface.co/bigcode">[πŸ€— Models]</a> | <a href="">[Paper]</a> | <a href="https://marketplace.visualstudio.com/items?itemName=HuggingFace.huggingface-vscode">[VSCode]</a>
4+
</p>
5+
6+
StarCoder2 is a family of code generation models (3B, 7B, and 15B), trained on 600+ programming languages from [The Stack v2]() and some natural language text such as Wikipedia, Arxiv, and GitHub issues. The models use Grouped Query Attention, a context window of 16,384 tokens, with sliding window attention of 4,096 tokens. The 3B & 7B models were trained on 3+ trillion tokens, while the 15B was trained on 4+ trillion tokens.
7+
8+
9+
# Disclaimer
10+
11+
Before you can use the models, go to `hf.co/bigcode/starcoder2-15b` and accept the agreement, and make sure you are logged into the Hugging Face hub:
12+
```bash
13+
huggingface-cli login
14+
```
15+
16+
# Table of Contents
17+
1. [Quickstart](#quickstart)
18+
- [Installation](#installation)
19+
- [Model usage and memory footprint](#model-usage-and-memory-footprint)
20+
- [Text-generation-inference code](#text-generation-inference)
21+
2. [Fine-tuning](#fine-tuning)
22+
- [Setup](#setup)
23+
- [Training](#training)
24+
3. [Evaluation](#evaluation)
25+
26+
# Quickstart
27+
StarCoder2 models are intended for code completion, they are not instruction models and commands like "Write a function that computes the square root." do not work well.
28+
29+
## Installation
30+
First, we have to install all the libraries listed in `requirements.txt`
31+
```bash
32+
pip install -r requirements.txt
33+
# export your HF token, found here: https://huggingface.co/settings/account
34+
export HF_TOKEN=xxx
35+
```
36+
37+
## Model usage and memory footprint
38+
Here are some examples to load the model and generate code. Ensure you've installed `transformers` from source (it should be the case if you used `requirements.txt`). We also include the memory footprint of the largest model, `StarCoder2-15B`, for each setup.
39+
40+
41+
### Running the model on CPU/ one GPU / multi GPU
42+
```python
43+
# pip install git+https://github.com/huggingface/transformers.git # TODO: merge PR to main
44+
from transformers import AutoModelForCausalLM, AutoTokenizer
45+
46+
checkpoint = "bigcode/starcoder2-15b"
47+
device = "cuda" # for GPU usage or "cpu" for CPU usage
48+
49+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
50+
# to use Multiple GPUs do `model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto")`
51+
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
52+
53+
inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to(device)
54+
outputs = model.generate(inputs)
55+
print(tokenizer.decode(outputs[0]))
56+
```
57+
58+
### Running the model on a GPU using different precisions
59+
60+
* _Using `torch.bfloat16`_
61+
62+
```python
63+
# pip install accelerate
64+
import torch
65+
from transformers import AutoTokenizer, AutoModelForCausalLM
66+
67+
checkpoint = "bigcode/starcoder2-15b"
68+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
69+
70+
# for fp16 use `torch_dtype=torch.float16` instead
71+
model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
72+
73+
inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda")
74+
outputs = model.generate(inputs)
75+
print(tokenizer.decode(outputs[0]))
76+
```
77+
```python
78+
>>> print(f"Memory footprint: {model.get_memory_footprint() / 1e6:.2f} MB")
79+
Memory footprint: 32251.33 MB
80+
```
81+
82+
#### Quantized Versions through `bitsandbytes`
83+
* _Using 8-bit precision (int8)_
84+
85+
```python
86+
# pip install bitsandbytes accelerate
87+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
88+
89+
# to use 4bit use `load_in_4bit=True` instead
90+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
91+
92+
checkpoint = "bigcode/starcoder2-15b_16k"
93+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
94+
model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b_16k", quantization_config=quantization_config)
95+
96+
inputs = tokenizer.encode("def print_hello_world():", return_tensors="pt").to("cuda")
97+
outputs = model.generate(inputs)
98+
print(tokenizer.decode(outputs[0]))
99+
```
100+
```bash
101+
>>> print(f"Memory footprint: {model.get_memory_footprint() / 1e6:.2f} MB")
102+
# load_in_8bit
103+
Memory footprint: 16900.18 MB
104+
# load_in_4bit
105+
>>> print(f"Memory footprint: {model.get_memory_footprint() / 1e6:.2f} MB")
106+
Memory footprint: 9224.60 MB
107+
```
108+
You can also use `pipeline` for the generation:
109+
```python
110+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
111+
checkpoint = "bigcode/starcoder2-15b"
112+
113+
model = AutoModelForCausalLM.from_pretrained(checkpoint)
114+
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
115+
116+
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0)
117+
print( pipe("def hello():") )
118+
```
119+
120+
## Text-generation-inference: TODO
121+
122+
```bash
123+
docker run -p 8080:80 -v $PWD/data:/data -e HUGGING_FACE_HUB_TOKEN=<YOUR BIGCODE ENABLED TOKEN> -d ghcr.io/huggingface/text-generation-inference:latest --model-id bigcode/starcoder2-15b --max-total-tokens 8192
124+
```
125+
For more details, see [here](https://github.com/huggingface/text-generation-inference).
126+
127+
# Fine-tuning
128+
129+
Here, we showcase how you can fine-tune StarCoder2 models.
130+
131+
## Setup
132+
133+
Install `pytorch` [see documentation](https://pytorch.org/), for example the following command works with cuda 12.1:
134+
```bash
135+
conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
136+
```
137+
138+
Install the requirements (this installs `transformers` from source to support the StarCoder2 architecture):
139+
```bash
140+
pip install -r requirements.txt
141+
```
142+
143+
Before you run any of the scripts make sure you are logged in `wandb` and HuggingFace Hub to push the checkpoints:
144+
```bash
145+
wandb login
146+
huggingface-cli login
147+
```
148+
Now that everything is done, you can clone the repository and get into the corresponding directory.
149+
150+
## Training
151+
To fine-tune efficiently with a low cost, we use [PEFT](https://github.com/huggingface/peft) library for Low-Rank Adaptation (LoRA) training and [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) for 4bit quantization. We also use the `SFTTrainer` from [TRL](https://github.com/huggingface/trl).
152+
153+
154+
For this example, we will fine-tune StarCoder2-3b on the `Rust` subset of [the-stack-smol](https://huggingface.co/datasets/bigcode/the-stack-smol). This is just for illustration purposes; for a larger and cleaner dataset of Rust code, you can use [The Stack dedup](https://huggingface.co/datasets/bigcode/the-stack-dedup).
155+
156+
To launch the training:
157+
```bash
158+
accelerate launch finetune.py \
159+
--model_id "bigcode/starcoder2-3b" \
160+
--dataset_name "bigcode/the-stack-smol" \
161+
--subset "data/rust" \
162+
--dataset_text_field "content" \
163+
--split "train" \
164+
--max_seq_length 1024 \
165+
--max_steps 10000 \
166+
--micro_batch_size 1 \
167+
--gradient_accumulation_steps 8 \
168+
--learning_rate 2e-5 \
169+
--warmup_steps 20 \
170+
--num_proc "$(nproc)"
171+
```
172+
173+
If you want to fine-tune on other text datasets, you need to change `dataset_text_field` argument to the name of the column containing the code/text you want to train on.
174+
175+
# Evaluation
176+
To evaluate StarCoder2 and its derivatives, you can use the [BigCode-Evaluation-Harness](https://github.com/bigcode-project/bigcode-evaluation-harness) for evaluating Code LLMs.

β€Žfinetune.pyβ€Ž

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
# Code adapted from https://github.com/huggingface/trl/blob/main/examples/research_projects/stack_llama/scripts/supervised_finetuning.py
2+
# and https://huggingface.co/blog/gemma-peft
3+
import argparse
4+
import multiprocessing
5+
import os
6+
7+
import torch
8+
import transformers
9+
from accelerate import PartialState
10+
from datasets import load_dataset
11+
from peft import LoraConfig
12+
from transformers import (
13+
AutoModelForCausalLM,
14+
BitsAndBytesConfig,
15+
logging,
16+
set_seed,
17+
)
18+
from trl import SFTTrainer
19+
20+
21+
def get_args():
22+
parser = argparse.ArgumentParser()
23+
parser.add_argument("--model_id", type=str, default="bigcode/starcoder2-3b")
24+
parser.add_argument("--dataset_name", type=str, default="the-stack-smol")
25+
parser.add_argument("--subset", type=str, default="data/rust")
26+
parser.add_argument("--split", type=str, default="train")
27+
parser.add_argument("--dataset_text_field", type=str, default="content")
28+
29+
parser.add_argument("--max_seq_length", type=int, default=1024)
30+
parser.add_argument("--max_steps", type=int, default=1000)
31+
parser.add_argument("--micro_batch_size", type=int, default=1)
32+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
33+
parser.add_argument("--weight_decay", type=float, default=0.01)
34+
parser.add_argument("--bf16", type=bool, default=True)
35+
36+
parser.add_argument("--attention_dropout", type=float, default=0.1)
37+
parser.add_argument("--learning_rate", type=float, default=2e-4)
38+
parser.add_argument("--lr_scheduler_type", type=str, default="cosine")
39+
parser.add_argument("--warmup_steps", type=int, default=100)
40+
parser.add_argument("--seed", type=int, default=0)
41+
parser.add_argument("--output_dir", type=str, default="outputs")
42+
parser.add_argument("--num_proc", type=int, default=None)
43+
return parser.parse_args()
44+
45+
46+
def print_trainable_parameters(model):
47+
"""
48+
Prints the number of trainable parameters in the model.
49+
"""
50+
trainable_params = 0
51+
all_param = 0
52+
for _, param in model.named_parameters():
53+
all_param += param.numel()
54+
if param.requires_grad:
55+
trainable_params += param.numel()
56+
print(
57+
f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
58+
)
59+
60+
61+
def main(args):
62+
# config
63+
bnb_config = BitsAndBytesConfig(
64+
load_in_4bit=True,
65+
bnb_4bit_quant_type="nf4",
66+
bnb_4bit_compute_dtype=torch.bfloat16,
67+
)
68+
lora_config = LoraConfig(
69+
r=8,
70+
target_modules=[
71+
"q_proj",
72+
"o_proj",
73+
"k_proj",
74+
"v_proj",
75+
"gate_proj",
76+
"up_proj",
77+
"down_proj",
78+
],
79+
task_type="CAUSAL_LM",
80+
)
81+
82+
# load model and dataset
83+
token = os.environ.get("HF_TOKEN", None)
84+
model = AutoModelForCausalLM.from_pretrained(
85+
args.model_id,
86+
quantization_config=bnb_config,
87+
device_map={"": PartialState().process_index},
88+
token=token,
89+
attention_dropout=args.attention_dropout,
90+
)
91+
print_trainable_parameters(model)
92+
93+
data = load_dataset(
94+
args.dataset_name,
95+
data_dir=args.subset,
96+
split=args.split,
97+
token=token,
98+
num_proc=args.num_proc if args.num_proc else multiprocessing.cpu_count(),
99+
)
100+
101+
# setup the trainer
102+
trainer = SFTTrainer(
103+
model=model,
104+
train_dataset=data,
105+
max_seq_length=args.max_seq_length,
106+
args=transformers.TrainingArguments(
107+
per_device_train_batch_size=args.micro_batch_size,
108+
gradient_accumulation_steps=args.gradient_accumulation_steps,
109+
warmup_steps=args.warmup_steps,
110+
max_steps=args.max_steps,
111+
learning_rate=args.learning_rate,
112+
lr_scheduler_type=args.lr_scheduler_type,
113+
weight_decay=args.weight_decay,
114+
bf16=args.bf16,
115+
logging_strategy="steps",
116+
logging_steps=10,
117+
output_dir=args.output_dir,
118+
optim="paged_adamw_8bit",
119+
seed=args.seed,
120+
run_name=f"train-{args.model_id.split('/')[-1]}",
121+
report_to="wandb",
122+
),
123+
peft_config=lora_config,
124+
dataset_text_field=args.dataset_text_field,
125+
)
126+
127+
# launch
128+
print("Training...")
129+
trainer.train()
130+
131+
print("Saving the last checkpoint of the model")
132+
model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
133+
print("Training Done! πŸ’₯")
134+
135+
136+
if __name__ == "__main__":
137+
args = get_args()
138+
set_seed(args.seed)
139+
os.makedirs(args.output_dir, exist_ok=True)
140+
141+
logging.set_verbosity_error()
142+
143+
main(args)

0 commit comments

Comments
(0)

AltStyle γ«γ‚ˆγ£γ¦ε€‰ζ›γ•γ‚ŒγŸγƒšγƒΌγ‚Έ (->γ‚ͺγƒͺγ‚ΈγƒŠγƒ«) /