Skip to content

Commit df2458c

Browse files
committed
Pytorch first experiments
1 parent 111bf7c commit df2458c

File tree

7 files changed

+93
-41
lines changed

7 files changed

+93
-41
lines changed

examples/TPU/pytorch/run_tpu_glue.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
3+
os.environ["TPU_IP_ADDRESS"] = "192.168.0.2"
4+
os.environ["TPU_NAME"] = "node-1"
5+
os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;192.168.0.2:8470"
6+
7+
8+
import torch
9+
import torch_xla
10+
import torch_xla.core.xla_model as xm
11+
12+
device = xm.xla_device()
13+
14+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
15+
16+
model = GPT2LMHeadModel.from_pretrained("gpt2")
17+
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
18+
19+
sequence = "This runs on TPU"
20+
input_ids = torch.tensor([tokenizer.encode(sequence)], device=device)
21+
22+
model.train().to(device)
23+
24+
print(input_ids)
File renamed without changes.
File renamed without changes.

examples/run_glue.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import logging
2323
import os
2424
import random
25+
from time import time
2526

2627
import numpy as np
2728
import torch
@@ -130,59 +131,72 @@ def train(args, train_dataset, model, tokenizer):
130131
for _ in train_iterator:
131132
epoch_iterator = tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])
132133
for step, batch in enumerate(epoch_iterator):
134+
start = time()
133135
model.train()
136+
134137
batch = tuple(t.to(args.device) for t in batch)
135138
inputs = {'input_ids': batch[0],
136139
'attention_mask': batch[1],
137140
'labels': batch[3]}
141+
142+
print(batch[0].device)
143+
138144
if args.model_type != 'distilbert':
139145
inputs['token_type_ids'] = batch[2] if args.model_type in ['bert', 'xlnet'] else None # XLM, DistilBERT and RoBERTa don't use segment_ids
146+
147+
intermediate = time()
148+
print("Model took " + str(intermediate - start) + " to put on device.")
149+
140150
outputs = model(**inputs)
141-
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
142151

143-
if args.n_gpu > 1:
144-
loss = loss.mean() # mean() to average on multi-gpu parallel training
145-
if args.gradient_accumulation_steps > 1:
146-
loss = loss / args.gradient_accumulation_steps
152+
loss = outputs[0] # model outputs are always tuple in transformers (see doc)
147153

148-
if args.fp16:
149-
with amp.scale_loss(loss, optimizer) as scaled_loss:
150-
scaled_loss.backward()
151-
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
152-
else:
153-
loss.backward()
154-
torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
154+
# if args.n_gpu > 1:
155+
# loss = loss.mean() # mean() to average on multi-gpu parallel training
156+
# if args.gradient_accumulation_steps > 1:
157+
# loss = loss / args.gradient_accumulation_steps
158+
#
159+
# if args.fp16:
160+
# with amp.scale_loss(loss, optimizer) as scaled_loss:
161+
# scaled_loss.backward()
162+
# torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
163+
# else:
164+
loss.backward()
165+
# torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
155166

156167
tr_loss += loss.item()
157-
if (step + 1) % args.gradient_accumulation_steps == 0:
158-
optimizer.step()
159-
scheduler.step() # Update learning rate schedule
160-
model.zero_grad()
161-
global_step += 1
162-
163-
if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
164-
# Log metrics
165-
if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
166-
results = evaluate(args, model, tokenizer)
167-
for key, value in results.items():
168-
tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
169-
tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
170-
tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
171-
logging_loss = tr_loss
172-
173-
if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
174-
# Save model checkpoint
175-
output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
176-
if not os.path.exists(output_dir):
177-
os.makedirs(output_dir)
178-
model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
179-
model_to_save.save_pretrained(output_dir)
180-
torch.save(args, os.path.join(output_dir, 'training_args.bin'))
181-
logger.info("Saving model checkpoint to %s", output_dir)
182-
183-
if args.max_steps > 0 and global_step > args.max_steps:
184-
epoch_iterator.close()
185-
break
168+
# if (step + 1) % args.gradient_accumulation_steps == 0:
169+
# optimizer.step()
170+
# scheduler.step() # Update learning rate schedule
171+
# model.zero_grad()
172+
# global_step += 1
173+
#
174+
# if args.local_rank in [-1, 0] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
175+
# # Log metrics
176+
# if args.local_rank == -1 and args.evaluate_during_training: # Only evaluate when single GPU otherwise metrics may not average well
177+
# results = evaluate(args, model, tokenizer)
178+
# for key, value in results.items():
179+
# tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
180+
# tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
181+
# tb_writer.add_scalar('loss', (tr_loss - logging_loss)/args.logging_steps, global_step)
182+
# logging_loss = tr_loss
183+
#
184+
# if args.local_rank in [-1, 0] and args.save_steps > 0 and global_step % args.save_steps == 0:
185+
# # Save model checkpoint
186+
# output_dir = os.path.join(args.output_dir, 'checkpoint-{}'.format(global_step))
187+
# if not os.path.exists(output_dir):
188+
# os.makedirs(output_dir)
189+
# model_to_save = model.module if hasattr(model, 'module') else model # Take care of distributed/parallel training
190+
# model_to_save.save_pretrained(output_dir)
191+
# torch.save(args, os.path.join(output_dir, 'training_args.bin'))
192+
# logger.info("Saving model checkpoint to %s", output_dir)
193+
#
194+
# if args.max_steps > 0 and global_step > args.max_steps:
195+
# epoch_iterator.close()
196+
# break
197+
198+
end = time()
199+
print("Model took " + str(end - start) + " for its forward pass.")
186200
if args.max_steps > 0 and global_step > args.max_steps:
187201
train_iterator.close()
188202
break
@@ -379,6 +393,8 @@ def main():
379393
parser.add_argument('--seed', type=int, default=42,
380394
help="random seed for initialization")
381395

396+
parser.add_argument('--tpu', action='store_true',
397+
help="Whether to use try and connect to a tpu")
382398
parser.add_argument('--fp16', action='store_true',
383399
help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
384400
parser.add_argument('--fp16_opt_level', type=str, default='O1',
@@ -393,6 +409,7 @@ def main():
393409
if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train and not args.overwrite_output_dir:
394410
raise ValueError("Output directory ({}) already exists and is not empty. Use --overwrite_output_dir to overcome.".format(args.output_dir))
395411

412+
396413
# Setup distant debugging if needed
397414
if args.server_ip and args.server_port:
398415
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
@@ -410,6 +427,16 @@ def main():
410427
device = torch.device("cuda", args.local_rank)
411428
torch.distributed.init_process_group(backend='nccl')
412429
args.n_gpu = 1
430+
431+
432+
if args.tpu:
433+
os.environ["TPU_IP_ADDRESS"] = "192.168.0.2"
434+
os.environ["TPU_NAME"] = "node-1"
435+
os.environ["XRT_TPU_CONFIG"] = "tpu_worker;0;192.168.0.2:8470"
436+
import torch_xla
437+
import torch_xla.core.xla_model as xm
438+
device = xm.xla_device()
439+
413440
args.device = device
414441

415442
# Setup logging

transformers/file_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727

2828
try:
2929
import tensorflow as tf
30+
assert hasattr(tf, "__version__")
3031
assert int(tf.__version__[0]) >= 2
3132
_tf_available = True # pylint: disable=invalid-name
3233
logger.info("TensorFlow version {} available.".format(tf.__version__))

0 commit comments

Comments
 (0)