-
Notifications
You must be signed in to change notification settings - Fork 24.9k
Description
🐛 Describe the bug
Describe the bug
I want to train a 2 node 4GPU Elastic training JOB
the training script as below
import argparse
import os
import sys
import time
import tempfile
from urllib.parse import urlparse
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
class ToyModel(nn.Module):
def __init__(self):
super(ToyModel, self).__init__()
self.net1 = nn.Linear(10, 10)
self.relu = nn.ReLU()
self.net2 = nn.Linear(10, 5)
def forward(self, x):
return self.net2(self.relu(self.net1(x)))
def train():
local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
print(f"[{os.getpid()}] (rank = {rank}, local_rank = {local_rank}) training...")
model = ToyModel().cuda(local_rank)
ddp_model = DDP(model, [local_rank])
loss_fn = nn.MSELoss()
optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)
epoch = 100
for i in range(epoch):
time.sleep(1)
optimizer.zero_grad()
outputs = ddp_model(torch.randn(20, 10).to(local_rank))
labels = torch.randn(20, 5).to(local_rank)
loss = loss_fn(outputs, labels)
loss.backward()
print(f"[{os.getpid()}] epoch {i} (rank = {rank}, local_rank = {local_rank}) loss = {loss.item()}\n")
optimizer.step()
def run():
env_dict = {
key: os.environ[key]
for key in ("MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "LOCAL_WORLD_SIZE")
}
print(f"[{os.getpid()}] Initializing process group with: {env_dict}")
dist.init_process_group(backend="nccl")
train()
dist.destroy_process_group()
if __name__ == "__main__":
run()
the start job script as below:
export MASTER_ADDR="192.168.9.104"
export MASTER_PORT="1234"
export LOGLEVEL="DEBUG"
torchrun \
--nnodes=1:3\
--nproc_per_node=2\
--max_restarts=3\
--rdzv_id=1\
--rdzv_backend=c10d\
--rdzv_endpoint="192.168.9.104:1234"\
--master_addr="192.168.9.104" \
--master_port=1234 \
train_elastic.py
ERROR INFO:
root@iZ2zefk3g6rwyhrmh35s3jZ:/workspace/DDP# sh run_elastic.sh
WARNING:torch.distributed.run:
*****************************************
Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further t
une the variable for optimal performance in your application as needed.
*****************************************
[435] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '43007', 'WORLD_SIZE': '4', 'LOCAL_WORLD
_SIZE': '2'}
[434] Initializing process group with: {'MASTER_ADDR': 'iZ2ze9q3ftqtxtqlkrk6tuZ', 'MASTER_PORT': '43007', 'WORLD_SIZE': '4', 'LOCAL_WORLD
_SIZE': '2'}
[W socket.cpp:558] [c10d] The IPv6 network addresses of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or ser
vice not known).
[W socket.cpp:558] [c10d] The IPv4 network addresses of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or ser
vice not known).
[E socket.cpp:610] [c10d] The client socket has failed to connect to any network address of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007).
Traceback (most recent call last):
File "train_elastic.py", line 58, in <module>
[W socket.cpp:558] [c10d] The IPv6 network addresses of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or ser
vice not known).
[W socket.cpp:558] [c10d] The IPv4 network addresses of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or ser
vice not known).
[E socket.cpp:610] [c10d] The client socket has failed to connect to any network address of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007).
run()
Traceback (most recent call last):
File "train_elastic.py", line 52, in run
File "train_elastic.py", line 58, in <module>
dist.init_process_group(backend="nccl")
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/distributed_c10d.py", line 595, in init_process_group
run()
File "train_elastic.py", line 52, in run
store, rank, world_size = next(rendezvous_iterator)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 232, in _env_rendezvous_handler
store = _create_c10d_store(master_addr, master_port, rank, world_size, timeout)
File "/opt/conda/lib/python3.8/site-packages/torch/distributed/rendezvous.py", line 160, in _create_c10d_store
return TCPStore(
RuntimeError: The client socket has failed to connect to any network address of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007). The IPv6 network addres
ses of (iZ2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or service not known). The IPv4 network addresses of (i
Z2ze9q3ftqtxtqlkrk6tuZ, 43007) cannot be retrieved (gai error: -2 - Name or service not known).
dist.init_process_group(backend="nccl")
I have read the pytorch source code, I found when backend is not 'static',agent will use hostname for the MASTER_ADDR. But in my situation , the hostname can not resolve, I want use the IP.
def _get_addr_and_port(
rdzv_parameters: RendezvousParameters,
) -> Tuple[Optional[str], Optional[int]]:
if rdzv_parameters.backend != "static":
return (None, None)
endpoint = rdzv_parameters.endpoint
endpoint = endpoint.strip()
if not endpoint:
raise ValueError(
"Endpoint is missing in endpoint. Try to add --master_addr and --master_port"
)
master_addr, master_port = parse_rendezvous_endpoint(endpoint, default_port=-1)
if master_port == -1:
raise ValueError(
f"port is missing in endpoint: {endpoint}. Try to specify --master_port"
)
return (master_addr, master_port)
...
@staticmethod
def _set_master_addr_port(
store: Store, master_addr: Optional[str], master_port: Optional[int]
):
if master_port is None:
sock = _get_socket_with_port()
with closing(sock):
master_port = sock.getsockname()[1]
if master_addr is None:
master_addr = _get_fq_hostname()
store.set("MASTER_ADDR", master_addr.encode(encoding="UTF-8"))
store.set("MASTER_PORT", str(master_port).encode(encoding="UTF-8"))
How can I set MASTER_ADDR ? is it a bug?
Versions
Collecting environment information...
PyTorch version: 1.11.0
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: Could not collect
Clang version: Could not collect
CMake version: Could not collect
Libc version: glibc-2.27
Python version: 3.8.12 (default, Oct 12 2021, 13:49:34) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.10.84-10.2.al8.x86_64-x86_64-with-glibc2.17
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration:
GPU 0: Tesla V100-SXM2-16GB
GPU 1: Tesla V100-SXM2-16GB
GPU 2: Tesla V100-SXM2-16GB
GPU 3: Tesla V100-SXM2-16GB
Nvidia driver version: 460.91.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.21.2
[pip3] torch==1.11.0
[pip3] torchelastic==0.2.2
[pip3] torchtext==0.12.0
[pip3] torchvision==0.12.0
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 ha36c431_9 nvidia
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py38h7f8727e_0
[conda] mkl_fft 1.3.1 py38hd3c417c_0
[conda] mkl_random 1.2.2 py38h51133e4_0
[conda] numpy 1.21.2 py38h20f2e39_0
[conda] numpy-base 1.21.2 py38h79a1101_0
[conda] pytorch 1.11.0 py3.8_cuda11.3_cudnn8.2.0_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torchelastic 0.2.2 pypi_0 pypi
[conda] torchtext 0.12.0 py38 pytorch
[conda] torchvision 0.12.0 py38_cu113 pytorch
cc @pietern @mrshenli @pritamdamania87 @zhaojuanmao @satgera @rohan-varma @gqchen @aazzolini @osalpekar @jiayisuse @SciPioneer @H-Huang