Skip to content

Commit de0e0a1

Browse files
committed
Update train.py
1 parent 9e3530e commit de0e0a1

File tree

1 file changed

+58
-65
lines changed

1 file changed

+58
-65
lines changed

train.py

Lines changed: 58 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@
1919

2020
raw_data_path = 'data/train.txt'
2121
tokenized_data_path = 'data/tokenized/'
22-
raw = True # 选择是否从零开始构建数据集
22+
raw = False # 选择是否从零开始构建数据集
2323
epochs = 5
2424
batch_size = 12
2525
lr = 1.5e-4
2626
warmup_steps = 2000
27-
log_step = 250
27+
log_step = 1
2828
stride = 768
2929
gradient_accumulation = 1
3030
fp16 = False # 不支持半精度的显卡请勿打开
@@ -67,12 +67,19 @@ def main():
6767
model.to(device)
6868
multi_gpu = False
6969
total_tokens = 0
70+
full_line = ''
7071
print('calculating total steps')
7172
for i in tqdm(range(num_pieces)):
7273
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(i), 'r') as f:
73-
total_tokens += len(f.read().split())
74-
num_chunks = total_tokens // stride
75-
total_steps = int(num_chunks * epochs / batch_size / gradient_accumulation)
74+
full_line += f.read()
75+
full_line = [int(item) for item in full_line.split()]
76+
len_full_line = len(full_line)
77+
samples = []
78+
start_point = 0
79+
while start_point + n_ctx < len_full_line:
80+
samples.append(full_line[start_point:start_point+n_ctx])
81+
start_point += stride
82+
total_steps = int(len(samples) * epochs / batch_size / gradient_accumulation)
7683
print('total steps = {}'.format(total_steps))
7784
optimizer = pytorch_transformers.AdamW(model.parameters(), lr=lr, correct_bias=True)
7885
scheduler = pytorch_transformers.WarmupLinearSchedule(optimizer, warmup_steps=warmup_steps,
@@ -93,67 +100,53 @@ def main():
93100
print('epoch {}'.format(epoch + 1))
94101
now = datetime.now()
95102
print('time: {}'.format(now))
96-
x = np.linspace(0, num_pieces - 1, num_pieces, dtype=np.int32)
97-
random.shuffle(x)
98-
piece_num = 0
99-
for i, j in enumerate(x):
100-
with open(tokenized_data_path + 'tokenized_train_{}.txt'.format(j), 'r') as f:
103+
running_loss = 0
104+
random.shuffle(samples)
105+
for step in range(len(samples) // batch_size):
106+
107+
# prepare data
108+
batch = samples[step * batch_size: (step + 1) * batch_size]
109+
batch_labels = []
110+
batch_inputs = []
111+
for ids in batch:
112+
int_ids_for_labels = [int(x) for x in ids]
113+
int_ids_for_inputs = [int(x) for x in ids]
114+
batch_labels.append(int_ids_for_labels)
115+
batch_inputs.append(int_ids_for_inputs)
116+
batch_labels = torch.tensor(batch_labels).long().to(device)
117+
batch_inputs = torch.tensor(batch_inputs).long().to(device)
118+
119+
# forward pass
120+
outputs = model.forward(input_ids=batch_inputs, labels=batch_labels)
121+
loss, logits = outputs[:2]
122+
123+
# get loss
124+
if multi_gpu:
125+
loss = loss.mean()
126+
if gradient_accumulation > 1:
127+
loss = loss / gradient_accumulation
128+
129+
# loss backward
130+
if fp16:
131+
with amp.scale_loss(loss, optimizer) as scaled_loss:
132+
scaled_loss.backward()
133+
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
134+
else:
135+
loss.backward()
136+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
137+
138+
# optimizer step
139+
if (step + 1) % gradient_accumulation == 0:
140+
running_loss += loss.item()
141+
scheduler.step()
142+
optimizer.step()
143+
optimizer.zero_grad()
144+
if (step + 1) % log_step == 0:
145+
print('step {} of epoch {}, loss {}'.format(
146+
(step + 1) // gradient_accumulation,
147+
epoch + 1,
148+
running_loss * gradient_accumulation**2 / log_step))
101149
running_loss = 0
102-
line = f.read()
103-
tokens = line.split()
104-
tokens = [int(token) for token in tokens]
105-
start_point = 0
106-
chunks = []
107-
while start_point < len(tokens) - n_ctx:
108-
chunks.append(tokens[start_point: start_point + n_ctx])
109-
start_point += stride
110-
random.shuffle(chunks)
111-
for step in range(len(chunks) // batch_size):
112-
113-
# prepare data
114-
batch = chunks[step * batch_size: (step + 1) * batch_size]
115-
batch_labels = []
116-
batch_inputs = []
117-
for ids in batch:
118-
int_ids_for_labels = [int(x) for x in ids]
119-
int_ids_for_inputs = [int(x) for x in ids]
120-
batch_labels.append(int_ids_for_labels)
121-
batch_inputs.append(int_ids_for_inputs)
122-
batch_labels = torch.tensor(batch_labels).long().to(device)
123-
batch_inputs = torch.tensor(batch_inputs).long().to(device)
124-
125-
# forward pass
126-
outputs = model.forward(input_ids=batch_inputs, labels=batch_labels)
127-
loss, logits = outputs[:2]
128-
129-
# get loss
130-
if multi_gpu:
131-
loss = loss.mean()
132-
if gradient_accumulation > 1:
133-
loss = loss / gradient_accumulation
134-
135-
# loss backward
136-
if fp16:
137-
with amp.scale_loss(loss, optimizer) as scaled_loss:
138-
scaled_loss.backward()
139-
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
140-
else:
141-
loss.backward()
142-
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
143-
144-
# optimizer step
145-
if (step + 1) % gradient_accumulation == 0:
146-
running_loss += loss.item()
147-
scheduler.step()
148-
optimizer.step()
149-
optimizer.zero_grad()
150-
if (step + 1) % log_step == 0:
151-
print('step {} of piece {} of epoch {}, loss {}'.format(
152-
(step + 1) // gradient_accumulation,
153-
piece_num, epoch + 1,
154-
running_loss * gradient_accumulation**2 / log_step))
155-
running_loss = 0
156-
piece_num += 1
157150

158151
print('saving model for epoch {}'.format(epoch + 1))
159152
if not os.path.exists('./model/model_epoch{}'.format(epoch + 1)):

0 commit comments

Comments
 (0)