Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit 2478982

Browse files
Add MQA (#4)
incorporate benchmarks
1 parent 1e59b28 commit 2478982

File tree

12 files changed

+272
-52
lines changed

12 files changed

+272
-52
lines changed

‎Makefile‎

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
batch_size := 1
2+
3+
install-mqa-transformers:
4+
git clone https://github.com/bigcode-project/transformers.git; \
5+
cd transformers; \
6+
git checkout mayank/multi_query; \
7+
pip install .; \
8+
cd ..; \
9+
rm -rf transformers;
10+
11+
# BLOOM AliBi
12+
hf-1b-bloom-fp32:
13+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype float32 --batch_size ${batch_size}
14+
15+
hf-1b-bloom-bf16:
16+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype bfloat16 --batch_size ${batch_size}
17+
18+
hf-1b-bloom-int8:
19+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class BLOOM --dtype int8 --batch_size ${batch_size}
20+
21+
ds-inference-1b-bloom-fp16:
22+
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class BLOOM --batch_size ${batch_size}
23+
24+
# GPT2 MHA
25+
hf-1b-GPT2-mha-fp32:
26+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype float32 --batch_size ${batch_size}
27+
28+
hf-1b-GPT2-mha-bf16:
29+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype bfloat16 --batch_size ${batch_size}
30+
31+
hf-1b-GPT2-mha-int8:
32+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype int8 --batch_size ${batch_size}
33+
34+
ds-inference-1b-GPT2-mha-fp16:
35+
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --batch_size ${batch_size}
36+
37+
# GPT2 MQA
38+
hf-1b-GPT2-mqa-fp32:
39+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype float32 --batch_size ${batch_size}
40+
41+
hf-1b-GPT2-mqa-bf16:
42+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype bfloat16 --batch_size ${batch_size}
43+
44+
hf-1b-GPT2-mqa-int8:
45+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --dtype int8 --batch_size ${batch_size}
46+
47+
ds-inference-1b-GPT2-mqa-fp16:
48+
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 2 --batch_size ${batch_size}
49+
50+
# GPT2 MQA1
51+
hf-1b-GPT2-mqa1-fp32:
52+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype float32 --batch_size ${batch_size}
53+
54+
hf-1b-GPT2-mqa1-bf16:
55+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype bfloat16 --batch_size ${batch_size}
56+
57+
hf-1b-GPT2-mqa1-int8:
58+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype int8 --batch_size ${batch_size}
59+
60+
ds-inference-1b-GPT2-mqa1-fp16:
61+
deepspeed --num_gpus 1 src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class DS_Inference_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --batch_size ${batch_size}
62+
63+
# Input length experiments
64+
hf-1b-GPT2-mqa1-int8-input-length:
65+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 3 --dtype int8 --batch_size ${batch_size} --max_input_length ${max_input_length}
66+
67+
hf-1b-GPT2-mha-int8-input-length:
68+
python src/main.py --hidden_size 2048 --n_head 16 --n_layer 24 --pipeline_class HF_GPU_Pipeline --model_class GPT2 --n_positions 2048 --attention_type 1 --dtype int8 --batch_size ${batch_size} --max_input_length ${max_input_length}

‎README.md‎

Lines changed: 122 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,122 @@
1-
# bigcode-inference-benchmark
1+
# bigcode-inference-benchmark
2+
A100 80GB
3+
4+
## BLOOM
5+
```python
6+
hidden_size = 2048
7+
n_head = 16
8+
n_layer = 24
9+
total_params = 1311535104
10+
```
11+
12+
Throughput (tokens/sec | msec/token)
13+
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
14+
|:----------:|:---------------:|:---------------:|:---------------:|:-------------------:|
15+
| 1 | 51.59 \| 19.38 | 47.46 \| 21.07 | 16.53 \| 60.49 | 61.61 \| 16.23 |
16+
| 2 | 103.92 \| 9.62 | 96.88 \| 10.32 | 33.79 \| 29.60 | 121.55 \| 8.23 |
17+
| 4 | 211.96 \| 4.72 | 193.72 \| 5.16 | 67.38 \| 14.84 | 240.06 \| 4.17 |
18+
| 8 | 411.79 \| 2.43 | 370.67 \| 2.70 | 134.34 \| 7.44 | 492.42 \| 2.03 |
19+
| 16 | 804.55 \| 1.24 | 781.29 \| 1.28 | 275.69 \| 3.63 | 970.59 \| 1.03 |
20+
| 32 | 1574.68 \| 0.64 | 1539.19 \| 0.65 | 537.14 \| 1.86 | 1999.04 \| 0.50 |
21+
| 64 | 2712.46 \| 0.37 | 3038.01 \| 0.33 | 1070.50 \| 0.93 | 3971.09 \| 0.25 |
22+
| 128 | 2974.36 \| 0.34 | 5795.97 \| 0.17 | 2055.34 \| 0.49 | 7514.59 \| 0.13 |
23+
| 256 | 3695.44 \| 0.27 | 8216.27 \| 0.12 | 3523.77 \| 0.28 | 10226.50 \| 0.10 |
24+
| 384 | 3591.13 \| 0.28 | 9328.18 \| 0.11 | 4585.33 \| 0.22 | 11094.27 \| 0.09 |
25+
| 512 | 3708.54 \| 0.27 | 9446.34 \| 0.11 | 5416.48 \| 0.18 | 11390.85 \| 0.09 |
26+
| 640 | 3859.43 \| 0.26 | 9572.53 \| 0.10 | 6113.65 \| 0.16 | 11625.71 \| 0.09 |
27+
| 768 | 3804.82 \| 0.26 | 9464.75 \| 0.11 | 6582.52 \| 0.15 | 11814.31 \| 0.08 |
28+
| 896 | 3652.42 \| 0.27 | 9482.11 \| 0.11 | 7111.08 \| 0.14 | 11744.38 \| 0.09 |
29+
| 1024 | oom | 9710.46 \| 0.10 | 7486.36 \| 0.13 | 11534.95 \| 0.09 |
30+
| 1152 | oom | 9712.39 \| 0.10 | 7544.99 \| 0.13 | oom |
31+
| 1280 | oom | 9667.19 \| 0.10 | 7858.91 \| 0.13 | oom |
32+
| 1408 | oom | 9771.91 \| 0.10 | 8116.30 \| 0.12 | oom |
33+
| 1536 | oom | 9744.56 \| 0.10 | 8201.28 \| 0.12 | oom |
34+
| 1664 | oom | 9719.82 \| 0.10 | 8227.56 \| 0.12 | oom |
35+
| 1792 | oom | 9690.61 \| 0.10 | 8344.36 \| 0.12 | oom |
36+
| 1920 | oom | oom | oom | oom |
37+
38+
Latency (sec)
39+
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
40+
|:----------:|:---------:|:---------:|:---------:|:-------------------:|
41+
| 1 | 1.94 | 2.11 | 6.05 | 1.62 |
42+
| 2 | 1.92 | 2.06 | 5.92 | 1.65 |
43+
| 4 | 1.89 | 2.06 | 5.94 | 1.67 |
44+
| 8 | 1.94 | 2.16 | 5.96 | 1.62 |
45+
| 16 | 1.99 | 2.05 | 5.80 | 1.65 |
46+
| 32 | 2.03 | 2.08 | 5.96 | 1.60 |
47+
| 64 | 2.36 | 2.11 | 5.98 | 1.61 |
48+
| 128 | 4.30 | 2.21 | 6.23 | 1.70 |
49+
| 256 | 6.93 | 3.12 | 7.26 | 2.50 |
50+
| 384 | 10.69 | 4.12 | 8.37 | 3.46 |
51+
| 512 | 14.82 | 5.42 | 9.45 | 4.49 |
52+
| 640 | 19.85 | 6.69 | 10.47 | 5.51 |
53+
| 768 | 20.18 | 8.11 | 11.67 | 6.50 |
54+
| 896 | 24.53 | 9.45 | 12.60 | 7.63 |
55+
| 1024 | oom | 10.55 | 13.68 | 8.88 |
56+
| 1152 | oom | 11.86 | 15.27 | oom |
57+
| 1280 | oom | 13.24 | 16.29 | oom |
58+
| 1408 | oom | 14.41 | 17.35 | oom |
59+
| 1536 | oom | 15.76 | 18.73 | oom |
60+
| 1664 | oom | 17.12 | 20.22 | oom |
61+
| 1792 | oom | 18.49 | 21.48 | oom |
62+
| 1920 | oom | oom | oom | oom |
63+
64+
## GPT2 Multi-Head Attention
65+
```python
66+
hidden_size = 2048
67+
n_head = 16
68+
n_layer = 24
69+
total_params = 1315725312
70+
```
71+
72+
Throughput (tokens/sec | msec/token)
73+
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
74+
|:----------:|:---------------:|:----------------:|:----------------:|:-------------------:|
75+
| 1 | 43.11 \| 23.20 | 40.69 \| 24.57 | 32.29 \| 30.97 | 122.76 \| 8.15 |
76+
| 2 | 80.76 \| 12.38 | 80.87 \| 12.37 | 63.54 \| 15.74 | 247.85 \| 4.03 |
77+
| 4 | 160.38 \| 6.24 | 154.98 \| 6.45 | 131.00 \| 7.63 | 503.52 \| 1.99 |
78+
| 8 | 328.62 \| 3.04 | 332.90 \| 3.00 | 260.16 \| 3.84 | 1022.20 \| 0.98 |
79+
| 16 | 662.08 \| 1.51 | 669.27 \| 1.49 | 523.29 \| 1.91 | 2027.35 \| 0.49 |
80+
| 32 | 1314.92 \| 0.76 | 1287.95 \| 0.78 | 1055.57 \| 0.95 | 4231.82 \| 0.24 |
81+
| 64 | 2118.17 \| 0.47 | 2487.35 \| 0.40 | 1969.26 \| 0.51 | 8311.39 \| 0.12 |
82+
| 128 | 2860.26 \| 0.35 | 4268.99 \| 0.23 | 3581.49 \| 0.28 | 15879.15 \| 0.06 |
83+
| 256 | 3487.86 \| 0.29 | 6917.01 \| 0.14 | 6132.47 \| 0.16 | 21635.49 \| 0.05 |
84+
| 384 | 3794.16 \| 0.26 | 8821.31 \| 0.11 | 7774.37 \| 0.13 | 23872.25 \| 0.04 |
85+
| 512 | 3804.37 \| 0.26 | 10068.51 \| 0.10 | 8872.88 \| 0.11 | 25009.06 \| 0.04 |
86+
| 640 | 4124.01 \| 0.24 | 10547.88 \| 0.09 | 9956.58 \| 0.10 | oom |
87+
| 768 | 3950.39 \| 0.25 | 10675.09 \| 0.09 | 10584.21 \| 0.09 | oom |
88+
| 896 | 3937.28 \| 0.25 | 10780.82 \| 0.09 | 10994.00 \| 0.09 | oom |
89+
| 1024 | oom | 11192.55 \| 0.09 | 11306.37 \| 0.09 | oom |
90+
| 1152 | oom | 11178.30 \| 0.09 | 11290.51 \| 0.09 | oom |
91+
| 1280 | oom | 11383.98 \| 0.09 | 11459.89 \| 0.09 | oom |
92+
| 1408 | oom | 11477.66 \| 0.09 | 11565.90 \| 0.09 | oom |
93+
| 1536 | oom | 11382.66 \| 0.09 | 11491.99 \| 0.09 | oom |
94+
| 1664 | oom | 11571.52 \| 0.09 | 11603.73 \| 0.09 | oom |
95+
| 1792 | oom | 11394.20 \| 0.09 | 11412.46 \| 0.09 | oom |
96+
| 1920 | oom | oom | oom | oom |
97+
98+
Latency (sec)
99+
| batch_size | HF (fp32) | HF (bf16) | HF (int8) | DS-inference (fp16) |
100+
|:----------:|:---------:|:---------:|:---------:|:-------------------:|
101+
| 1 | 2.32 | 2.46 | 3.10 | 0.81 |
102+
| 2 | 2.48 | 2.47 | 3.15 | 0.81 |
103+
| 4 | 2.49 | 2.58 | 3.05 | 0.79 |
104+
| 8 | 2.43 | 2.40 | 3.07 | 0.78 |
105+
| 16 | 2.42 | 2.39 | 3.06 | 0.79 |
106+
| 32 | 2.43 | 2.48 | 3.03 | 0.76 |
107+
| 64 | 3.02 | 2.57 | 3.25 | 0.77 |
108+
| 128 | 4.48 | 3.00 | 3.57 | 0.81 |
109+
| 256 | 7.34 | 3.70 | 4.17 | 1.18 |
110+
| 384 | 10.12 | 4.35 | 4.94 | 1.61 |
111+
| 512 | 13.46 | 5.09 | 5.77 | 2.05 |
112+
| 640 | 15.52 | 6.07 | 6.43 | oom |
113+
| 768 | 19.44 | 7.19 | 7.26 | oom |
114+
| 896 | 22.76 | 8.31 | 8.15 | oom |
115+
| 1024 | oom | 9.15 | 9.06 | oom |
116+
| 1152 | oom | 10.31 | 10.20 | oom |
117+
| 1280 | oom | 11.24 | 11.17 | oom |
118+
| 1408 | oom | 12.27 | 12.17 | oom |
119+
| 1536 | oom | 13.49 | 13.37 | oom |
120+
| 1664 | oom | 14.38 | 14.34 | oom |
121+
| 1792 | oom | 15.73 | 15.70 | oom |
122+
| 1920 | oom | oom | oom | oom |

‎benchmark.sh‎

Lines changed: 0 additions & 5 deletions
This file was deleted.

‎run_batch_size.sh‎

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
export CUDA_VISIBLE_DEVICES=0
2+
3+
rm -rf ./tmp
4+
5+
for bs in {1,2,4,8,16,32,64}
6+
do
7+
make 1ドル batch_size=$bs
8+
done
9+
10+
for i in {1..20}
11+
do
12+
bs=$(($i*128))
13+
make 1ドル batch_size=$bs
14+
done

‎run_input_length.sh‎

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
export CUDA_VISIBLE_DEVICES=0
2+
3+
rm -rf ./tmp
4+
5+
for max_input_length in {4,8,16,32,64,128,256,512,1024,1536,1900}
6+
do
7+
make 1ドル batch_size=32 max_input_length=$max_input_length
8+
done

‎src/main.py‎

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ def main() -> None:
77

88
args = get_args(get_arg_parser())
99

10-
inputs = get_dummy_batch(args.batch_size)
10+
inputs = get_dummy_batch(args.batch_size, args.max_input_length)
11+
1112
generate_kwargs = dict(max_new_tokens=args.max_new_tokens, do_sample=False)
1213

1314
pipeline_class = getattr(pipelines, args.pipeline_class)

‎src/pipelines/ds_inference.py‎

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
import deepspeed
55
import torch
6-
from transformers import BloomForCausalLM
76

87
from .pipeline import Pipeline
98

@@ -16,7 +15,7 @@ def __init__(self, args: Namespace) -> None:
1615

1716
# with deepspeed.OnDevice(dtype=torch.bfloat16, device="meta"):
1817
# model = BloomForCausalLM._from_config(config, torch_dtype=torch.bfloat16)
19-
self.model = BloomForCausalLM._from_config(self.config, torch_dtype=torch.bfloat16)
18+
self.model = self.model_class.from_pretrained("tmp", torch_dtype=torch.bfloat16)
2019
self.model.eval()
2120

2221
# checkpoints_json = os.path.join(args.model_name, "checkpoints.json")

‎src/pipelines/hf.py‎

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from argparse import Namespace
22

33
import torch
4-
from transformers import BloomForCausalLM
54

65
from .pipeline import Pipeline
76

@@ -11,13 +10,15 @@ def __init__(self, args: Namespace, device: str = "cpu") -> None:
1110
super().__init__(args)
1211

1312
model_kwargs = {}
13+
1414
if args.dtype == torch.int8:
1515
model_kwargs["load_in_8bit"] = True
16+
model_kwargs["device_map"] = "auto"
1617
else:
1718
model_kwargs["torch_dtype"] = args.dtype
1819

1920
self.input_device = device
20-
self.model = BloomForCausalLM._from_config(self.config, **model_kwargs).to(self.input_device)
21+
self.model = self.model_class.from_pretrained("tmp", **model_kwargs).to(self.input_device)
2122
self.model.eval()
2223

2324

‎src/pipelines/pipeline.py‎

Lines changed: 41 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,14 @@
1+
import os
12
from argparse import Namespace
2-
from typing import List, Tuple
3+
from typing import List, Tuple, Union
34

45
import torch
5-
from transformers import AutoTokenizer, BloomConfig
6+
from transformers import AutoTokenizer, BloomConfig, BloomForCausalLM, GPT2Config, GPT2LMHeadModel
67

78

89
class Pipeline:
910
def __init__(self, args: Namespace) -> None:
10-
self.config = BloomConfig.from_dict(
11-
{
12-
"apply_residual_connection_post_layernorm": False,
13-
"architectures": ["BloomModel"],
14-
"attention_dropout": 0.0,
15-
"attention_softmax_in_fp32": True,
16-
"bias_dropout_fusion": True,
17-
"bos_token_id": 1,
18-
"eos_token_id": 2,
19-
"hidden_dropout": 0.0,
20-
"hidden_size": args.hidden_size,
21-
"initializer_range": 0.02,
22-
"layer_norm_epsilon": 1e-05,
23-
"masked_softmax_fusion": True,
24-
"model_type": "bloom",
25-
"n_head": args.n_head,
26-
"n_inner": None,
27-
"n_layer": args.n_layer,
28-
"offset_alibi": 100,
29-
"pad_token_id": 3,
30-
"pretraining_tp": 1,
31-
"skip_bias_add": True,
32-
"skip_bias_add_qkv": False,
33-
"slow_but_exact": False,
34-
"transformers_version": "4.22.2",
35-
"unk_token_id": 0,
36-
"use_cache": True,
37-
"vocab_size": 250880,
38-
}
39-
)
40-
41-
# hardcoded for now to bigscience/bloom
42-
self.tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom")
43-
11+
self.config, self.tokenizer, self.model_class = get_config_tokenizer_model_class(args)
4412
self.model = None
4513
self.input_device = None
4614

@@ -69,3 +37,40 @@ def get_num_parameters(self) -> int:
6937
for i in self.model.parameters():
7038
param_count += i.numel()
7139
return param_count
40+
41+
42+
def get_config_tokenizer_model_class(args: Namespace) -> Union[BloomConfig, GPT2Config]:
43+
tokenizer = AutoTokenizer.from_pretrained("gpt2")
44+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
45+
46+
if args.model_class.lower() == "bloom":
47+
config = BloomConfig(
48+
attention_softmax_in_fp32=True,
49+
hidden_size=args.hidden_size,
50+
n_head=args.n_head,
51+
n_layer=args.n_layer,
52+
vocab_size=len(tokenizer),
53+
bos_token_id=tokenizer.bos_token_id,
54+
eos_token_id=tokenizer.eos_token_id,
55+
use_cache=True,
56+
)
57+
model_class = BloomForCausalLM
58+
elif args.model_class.lower() == "gpt2":
59+
config = GPT2Config(
60+
n_embd=args.hidden_size,
61+
n_head=args.n_head,
62+
n_layer=args.n_layer,
63+
n_positions=args.n_positions,
64+
bos_token_id=tokenizer.bos_token_id,
65+
eos_token_id=tokenizer.eos_token_id,
66+
attention_type=args.attention_type,
67+
print_details=False,
68+
vocab_size=len(tokenizer),
69+
use_cache=True,
70+
)
71+
model_class = GPT2LMHeadModel
72+
73+
if not os.path.exists("tmp"):
74+
model_class._from_config(config).save_pretrained("tmp")
75+
76+
return config, tokenizer, model_class

‎src/utils/arguments.py‎

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,19 @@
66
def get_arg_parser() -> ArgumentParser:
77
parser = ArgumentParser()
88
parser.add_argument("--pipeline_class", default="HF_GPU_Pipeline", type=str)
9+
parser.add_argument("--model_class", default="GPT2", type=str)
910
parser.add_argument("--batch_size", default=1, type=int)
1011
parser.add_argument("--dtype", default="bfloat16", type=str)
12+
parser.add_argument("--max_input_length", default=-1, type=int)
1113
parser.add_argument("--max_new_tokens", default=100, type=int)
1214
parser.add_argument("--local_rank", type=int)
1315
parser.add_argument("--hidden_size", type=int)
16+
parser.add_argument("--attention_type", type=int)
17+
parser.add_argument("--n_positions", type=int)
1418
parser.add_argument("--n_head", type=int)
1519
parser.add_argument("--n_layer", type=int)
1620
parser.add_argument("--benchmark_cycles", type=int, default=5)
21+
parser.add_argument("--clear_every_run", action="store_true")
1722
return parser
1823

1924

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /