Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit ad3bd59

Browse files
Finalized trainer.
1 parent d105ad8 commit ad3bd59

File tree

3 files changed

+15
-8
lines changed

3 files changed

+15
-8
lines changed

‎.gitignore‎

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,8 @@ token.txt
33

44
# data
55
repos/
6-
data.txt
6+
data.txt
7+
8+
#model
9+
GPyT/
10+
wandb/

‎model/train.py‎

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,13 @@
66
from transformers import Trainer, TrainingArguments
77
import os
88

9+
os.environ["WANDB_DISABLED"] = "true"
10+
911
def encode(lines):
1012
return tokenizer(lines['text'], add_special_tokens=True, truncation=True, max_length=512)
1113

1214
TRAIN_BASE = False
13-
TOKENIZER_DIR = "tokenizer"
15+
TOKENIZER_DIR = "../tokenizer"
1416

1517
paths = ["../data.txt"]
1618

@@ -44,11 +46,11 @@ def encode(lines):
4446

4547
config = GPT2Config(
4648
vocab_size = tokenizer.vocab_size,
47-
bos_token_id = tokenizer.bos_token_id,
48-
eos_token_id = tokenizer.eos_token_id
49+
bos_token = tokenizer.bos_token_id,
50+
eos_token = tokenizer.eos_token_id
4951
)
5052

51-
model = GPT2LMHeadModel()
53+
model = GPT2LMHeadModel(config)
5254

5355
dataset = load_dataset("text", data_files=paths)
5456

@@ -62,12 +64,13 @@ def encode(lines):
6264

6365
training_args = TrainingArguments(
6466
output_dir="../GPyT",
65-
per_device_train_batch_size=10,
6667
overwrite_output_dir=True,
6768
num_train_epochs=1,
69+
per_device_train_batch_size=10,
6870
save_steps=100,
6971
save_total_limit=2,
70-
prediction_loss_only=True
72+
prediction_loss_only=True,
73+
remove_unused_columns=False
7174
)
7275

7376
trainer = Trainer(

‎preprocess/tokenizer.py‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
44

55
TRAIN_BASE = False
6-
TOKENIZER_DIR = "tokenizer"
6+
TOKENIZER_DIR = "../tokenizer"
77

88
paths = ["../data.txt"]
99

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /