diff --git a/build_variables.bzl b/build_variables.bzl index 77fad7cdc5cb..678d47b11c65 100644 --- a/build_variables.bzl +++ b/build_variables.bzl @@ -512,6 +512,7 @@ libtorch_distributed_base_sources = [ "torch/csrc/distributed/c10d/TCPStore.cpp", "torch/csrc/distributed/c10d/TCPStoreBackend.cpp", "torch/csrc/distributed/c10d/TCPStoreLibUvBackend.cpp", + "torch/csrc/distributed/c10d/Types.cpp", "torch/csrc/distributed/c10d/Utils.cpp", "torch/csrc/distributed/c10d/Work.cpp", "torch/csrc/distributed/c10d/comm.cpp", diff --git a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp index 087c2831b4ed..c6ffee464e8f 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGloo.cpp @@ -756,6 +756,9 @@ class AsyncBroadcastWork : public ProcessGroupGloo::AsyncWork { const uint32_t tag; void broadcast(at::Tensor& tensor) { + if (tensor.is_complex()) { + tensor = at::view_as_real(tensor); + } const auto& scalarType = tensor.scalar_type(); gloo::BroadcastOptions opts(context_); opts.setRoot(rootRank); @@ -1061,12 +1064,21 @@ class AsyncReduceWork : public ProcessGroupGloo::AsyncWork { const uint32_t tag; void reduce(std::vector& tensors) { - const auto& scalarType = tensors[0].scalar_type(); + auto tensor = tensors[0]; + if (tensor.is_complex()) { + TORCH_CHECK( + c10d::complexViewAsRealAllowed(reduceOp), + "reduce does not support", + reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } gloo::ReduceOptions opts(context_); + const auto& scalarType = tensor.scalar_type(); opts.setRoot(rootRank); opts.setTag(tag); opts.setReduceFunction(getFunction(scalarType, reduceOp)); - GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensors[0]); + GENERATE_ALL_TYPES(scalarType, setOutput, opts, tensor); gloo::reduce(opts); // Gloo doesn't support AVG so we use SUM + division. diff --git a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp index 1cf6cf25fff6..0018e55c9e5b 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp +++ b/torch/csrc/distributed/c10d/ProcessGroupGlooDetail.hpp @@ -286,8 +286,17 @@ class AsyncAllreduceWork : public ProcessGroupGloo::AsyncWork { const uint32_t tag; void allreduce(std::vector& tensors) { - const auto& scalarType = tensors[0].scalar_type(); + auto tensor = tensors[0]; + if (tensor.is_complex()) { + TORCH_CHECK( + c10d::complexViewAsRealAllowed(reduceOp), + "all_reduce does not support", + reduceOp, + "on complex tensors"); + tensor = at::view_as_real(tensor); + } gloo::AllreduceOptions opts(context_); + const auto& scalarType = tensor.scalar_type(); opts.setReduceFunction(getFunction(scalarType, reduceOp)); opts.setTag(tag); GENERATE_ALL_TYPES(scalarType, setOutputs, opts, tensors); diff --git a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp index b4c5c339cc33..4592780247b0 100644 --- a/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp +++ b/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp @@ -68,23 +68,6 @@ inline bool isUnsupportedFloat8(at::ScalarType t) { ); } -bool complexViewAsRealAllowed(const ReduceOp& reduceOp) { - switch (reduceOp) { - // NOLINTNEXTLINE(bugprone-branch-clone) - case ReduceOp::SUM: - return true; - case ReduceOp::AVG: - return true; - case ReduceOp::PREMUL_SUM: - return true; - case ReduceOp::UNUSED: - return true; - default: - return false; - } - return false; -} - #ifdef ENABLE_NCCL_PREMUL_SUM_SUPPORT template ncclRedOpRAII unpackPreMulSum( @@ -4392,7 +4375,7 @@ c10::intrusive_ptr ProcessGroupNCCL::allreduce( auto tensor = tensors.back(); if (tensor.is_complex()) { TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), + c10d::complexViewAsRealAllowed(opts.reduceOp), "all_reduce does not support", opts.reduceOp, "on complex tensors"); @@ -4586,7 +4569,7 @@ c10::intrusive_ptr ProcessGroupNCCL::reduce( auto tensor = tensors.back(); if (tensor.is_complex()) { TORCH_CHECK( - complexViewAsRealAllowed(opts.reduceOp), + c10d::complexViewAsRealAllowed(opts.reduceOp), "reduce does not support", opts.reduceOp, "on complex tensors"); diff --git a/torch/csrc/distributed/c10d/Types.cpp b/torch/csrc/distributed/c10d/Types.cpp new file mode 100644 index 000000000000..ddc07d91c5fe --- /dev/null +++ b/torch/csrc/distributed/c10d/Types.cpp @@ -0,0 +1,22 @@ +#include + +namespace c10d { + +bool complexViewAsRealAllowed(const ReduceOp& reduceOp) { + switch (reduceOp) { + // NOLINTNEXTLINE(bugprone-branch-clone) + case ReduceOp::SUM: + return true; + case ReduceOp::AVG: + return true; + case ReduceOp::PREMUL_SUM: + return true; + case ReduceOp::UNUSED: + return true; + default: + return false; + } + return false; +} + +} // namespace c10d diff --git a/torch/csrc/distributed/c10d/Types.hpp b/torch/csrc/distributed/c10d/Types.hpp index 8fec5dd0e9e2..28b572a4b468 100644 --- a/torch/csrc/distributed/c10d/Types.hpp +++ b/torch/csrc/distributed/c10d/Types.hpp @@ -110,6 +110,8 @@ ReduceOp makeNCCLPreMulSum(const T& factor) { return rop; } +TORCH_API bool complexViewAsRealAllowed(const ReduceOp& reduceOp); + constexpr auto kUnsetTimeout = std::chrono::milliseconds(-1); struct BroadcastOptions {