Skip to content

Commit d4be4ab

Browse files
authored
[Auto Parallel] Adapt mp_async_allreduce optimize in auto paralle (#10770)
* add mp_async_allreduce * update mock_layer() * add comment * remove comment
1 parent b8f4b0c commit d4be4ab

File tree

5 files changed

+81
-6
lines changed

5 files changed

+81
-6
lines changed

llm/auto_parallel/deepseek-v3/run_pretrain_auto.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,11 @@ def main():
461461
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
462462

463463
if training_args.enable_linear_fused_grad_add:
464-
from fused_layers import mock_layers
464+
from llm.utils.fused_layers import mock_layers
465465

466-
mock_layers()
466+
mock_layers(
467+
mp_async_allreduce=True if "enable_mp_async_allreduce" in training_args.tensor_parallel_config else False
468+
)
467469

468470
if model_args.tokenizer_name_or_path is None:
469471
model_args.tokenizer_name_or_path = model_args.model_name_or_path

llm/auto_parallel/gpt-3/run_pretrain_auto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,9 @@ def main():
447447
if training_args.enable_linear_fused_grad_add:
448448
from llm.utils.fused_layers import mock_layers
449449

450-
mock_layers()
450+
mock_layers(
451+
mp_async_allreduce=True if "enable_mp_async_allreduce" in training_args.tensor_parallel_config else False
452+
)
451453

452454
if model_args.tokenizer_name_or_path is None:
453455
model_args.tokenizer_name_or_path = model_args.model_name_or_path

llm/auto_parallel/llama/run_pretrain_auto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -465,7 +465,9 @@ def main():
465465
if training_args.enable_linear_fused_grad_add and not do_enable_sp_async_reduce_scatter:
466466
from llm.utils.fused_layers import mock_layers
467467

468-
mock_layers()
468+
mock_layers(
469+
mp_async_allreduce=True if "enable_mp_async_allreduce" in training_args.tensor_parallel_config else False
470+
)
469471

470472
if model_args.tokenizer_name_or_path is None:
471473
model_args.tokenizer_name_or_path = model_args.model_name_or_path

llm/auto_parallel/qwen/run_pretrain_auto.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -436,7 +436,9 @@ def main():
436436
if training_args.enable_linear_fused_grad_add:
437437
from llm.utils.fused_layers import mock_layers
438438

439-
mock_layers()
439+
mock_layers(
440+
mp_async_allreduce=True if "enable_mp_async_allreduce" in training_args.tensor_parallel_config else False
441+
)
440442

441443
if model_args.tokenizer_name_or_path is None:
442444
model_args.tokenizer_name_or_path = model_args.model_name_or_path

llm/utils/fused_layers.py

Lines changed: 68 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,18 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import os
15+
1416
import paddle
17+
import paddle.distributed as dist
1518
from paddle import _C_ops
19+
from paddle.distributed import fleet
20+
from paddle.distributed.fleet.utils.log_util import logger
1621
from paddle.framework import core
1722

23+
_mp_async_allreduce = False
24+
_raise_cuda_env_unset_warning = True
25+
1826

1927
def is_fused_matmul_bias_supported():
2028
if paddle.is_compiled_with_cuda() and not paddle.is_compiled_with_rocm() or paddle.is_compiled_with_xpu():
@@ -29,6 +37,27 @@ def is_fused_matmul_bias_supported():
2937
origin_linear = paddle.nn.functional.linear
3038

3139

40+
def sync_allreduce(task, dist_tensor, mp_placement_index):
41+
new_placments = list()
42+
for idx, placment in enumerate(dist_tensor.placements):
43+
if idx == mp_placement_index:
44+
new_placments.append(dist.Replicate())
45+
else:
46+
new_placments.append(placment)
47+
place = paddle.framework._current_expected_place()
48+
place = paddle.framework._get_paddle_place(place)
49+
50+
task.wait()
51+
52+
return paddle.Tensor(
53+
dist_tensor._local_value(),
54+
dims=dist_tensor.shape,
55+
process_mesh=dist_tensor.process_mesh,
56+
placements=new_placments,
57+
place=place,
58+
)
59+
60+
3261
class FusedLinearWithGradAdd(paddle.autograd.PyLayer):
3362
@staticmethod
3463
def forward(ctx, x, weight, bias=None, name=None):
@@ -41,41 +70,79 @@ def backward(ctx, y_grad):
4170
x, weight, bias = ctx.saved_tensor()
4271
x_grad = paddle.matmul(y_grad, weight, transpose_y=True)
4372

73+
task = None
74+
if _mp_async_allreduce and x_grad.process_mesh is not None:
75+
# Using small operation to preempt GPU SMs for all_reduce to achieve overlap.
76+
if int(os.getenv("CUDA_DEVICE_MAX_CONNECTIONS", "0")) != 1:
77+
global _raise_cuda_env_unset_warning
78+
if _raise_cuda_env_unset_warning:
79+
logger.warning(
80+
"You set mp_async_allreduce=True, but you forget to set environment "
81+
"variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance "
82+
"loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance."
83+
)
84+
_raise_cuda_env_unset_warning = False
85+
86+
mp_placement_index = x_grad.process_mesh.dim_names.index("mp")
87+
if mp_placement_index != -1 and x_grad.placements[mp_placement_index].is_partial():
88+
hcg = fleet.get_hybrid_communicate_group()
89+
model_parallel_group = hcg.get_model_parallel_group()
90+
task = dist.stream.all_reduce(
91+
x_grad._local_value(),
92+
group=model_parallel_group,
93+
sync_op=False,
94+
)
95+
4496
# _C_ops.fused_linear_param_grad_add(x, y_grad, dw, db, multi precision, has bias)
4597
if bias is None:
4698
if hasattr(weight, "main_grad"):
4799
weight.main_grad, _ = _C_ops.fused_linear_param_grad_add(
48100
x, y_grad, weight.main_grad, None, True, False
49101
)
102+
if task is not None:
103+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
50104
return x_grad, None
51105
else:
52106
if weight.grad is not None:
53107
weight.grad, _ = _C_ops.fused_linear_param_grad_add(
54108
x, y_grad, weight.grad, None, False if weight.grad.dtype != paddle.float32 else True, False
55109
)
110+
if task is not None:
111+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
56112
return x_grad, None
57113
else:
114+
if task is not None:
115+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
58116
weight_grad, _ = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, False)
59117
return x_grad, weight_grad
60118

61119
if hasattr(weight, "main_grad") and hasattr(bias, "main_grad"):
62120
weight.main_grad, bias.main_grad = _C_ops.fused_linear_param_grad_add(
63121
x, y_grad, weight.main_grad, bias.main_grad, True, True
64122
)
123+
if task is not None:
124+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
65125
return x_grad, None, None
66126
else:
67127
if weight.grad is not None:
68128
assert bias.grad is not None
69129
weight.grad, bias.grad = _C_ops.fused_linear_param_grad_add(
70130
x, y_grad, weight.grad, bias.grad, False if weight.grad.dtype != paddle.float32 else True, True
71131
)
132+
if task is not None:
133+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
72134
return x_grad, None, None
73135
else:
74136
weight_grad, bias_grad = _C_ops.fused_linear_param_grad_add(x, y_grad, None, None, False, True)
137+
if task is not None:
138+
x_grad = sync_allreduce(task, x_grad, mp_placement_index)
75139
return x_grad, weight_grad, bias_grad
76140

77141

78-
def mock_layers():
142+
def mock_layers(mp_async_allreduce=False):
143+
global _mp_async_allreduce
144+
_mp_async_allreduce = mp_async_allreduce
145+
79146
paddle.nn.functional.linear = FusedLinearWithGradAdd.apply
80147
if is_fused_matmul_bias_supported():
81148
paddle.incubate.nn.functional.fused_linear = FusedLinearWithGradAdd.apply

0 commit comments

Comments
 (0)