11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import os
15
+
14
16
import paddle
17
+ import paddle .distributed as dist
15
18
from paddle import _C_ops
19
+ from paddle .distributed import fleet
20
+ from paddle .distributed .fleet .utils .log_util import logger
16
21
from paddle .framework import core
17
22
23
+ _mp_async_allreduce = False
24
+ _raise_cuda_env_unset_warning = True
25
+
18
26
19
27
def is_fused_matmul_bias_supported ():
20
28
if paddle .is_compiled_with_cuda () and not paddle .is_compiled_with_rocm () or paddle .is_compiled_with_xpu ():
@@ -29,6 +37,27 @@ def is_fused_matmul_bias_supported():
29
37
origin_linear = paddle .nn .functional .linear
30
38
31
39
40
+ def sync_allreduce (task , dist_tensor , mp_placement_index ):
41
+ new_placments = list ()
42
+ for idx , placment in enumerate (dist_tensor .placements ):
43
+ if idx == mp_placement_index :
44
+ new_placments .append (dist .Replicate ())
45
+ else :
46
+ new_placments .append (placment )
47
+ place = paddle .framework ._current_expected_place ()
48
+ place = paddle .framework ._get_paddle_place (place )
49
+
50
+ task .wait ()
51
+
52
+ return paddle .Tensor (
53
+ dist_tensor ._local_value (),
54
+ dims = dist_tensor .shape ,
55
+ process_mesh = dist_tensor .process_mesh ,
56
+ placements = new_placments ,
57
+ place = place ,
58
+ )
59
+
60
+
32
61
class FusedLinearWithGradAdd (paddle .autograd .PyLayer ):
33
62
@staticmethod
34
63
def forward (ctx , x , weight , bias = None , name = None ):
@@ -41,41 +70,79 @@ def backward(ctx, y_grad):
41
70
x , weight , bias = ctx .saved_tensor ()
42
71
x_grad = paddle .matmul (y_grad , weight , transpose_y = True )
43
72
73
+ task = None
74
+ if _mp_async_allreduce and x_grad .process_mesh is not None :
75
+ # Using small operation to preempt GPU SMs for all_reduce to achieve overlap.
76
+ if int (os .getenv ("CUDA_DEVICE_MAX_CONNECTIONS" , "0" )) != 1 :
77
+ global _raise_cuda_env_unset_warning
78
+ if _raise_cuda_env_unset_warning :
79
+ logger .warning (
80
+ "You set mp_async_allreduce=True, but you forget to set environment "
81
+ "variable CUDA_DEVICE_MAX_CONNECTIONS=1, which may leads to performance "
82
+ "loss. Try to export CUDA_DEVICE_MAX_CONNECTIONS=1 for better performance."
83
+ )
84
+ _raise_cuda_env_unset_warning = False
85
+
86
+ mp_placement_index = x_grad .process_mesh .dim_names .index ("mp" )
87
+ if mp_placement_index != - 1 and x_grad .placements [mp_placement_index ].is_partial ():
88
+ hcg = fleet .get_hybrid_communicate_group ()
89
+ model_parallel_group = hcg .get_model_parallel_group ()
90
+ task = dist .stream .all_reduce (
91
+ x_grad ._local_value (),
92
+ group = model_parallel_group ,
93
+ sync_op = False ,
94
+ )
95
+
44
96
# _C_ops.fused_linear_param_grad_add(x, y_grad, dw, db, multi precision, has bias)
45
97
if bias is None :
46
98
if hasattr (weight , "main_grad" ):
47
99
weight .main_grad , _ = _C_ops .fused_linear_param_grad_add (
48
100
x , y_grad , weight .main_grad , None , True , False
49
101
)
102
+ if task is not None :
103
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
50
104
return x_grad , None
51
105
else :
52
106
if weight .grad is not None :
53
107
weight .grad , _ = _C_ops .fused_linear_param_grad_add (
54
108
x , y_grad , weight .grad , None , False if weight .grad .dtype != paddle .float32 else True , False
55
109
)
110
+ if task is not None :
111
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
56
112
return x_grad , None
57
113
else :
114
+ if task is not None :
115
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
58
116
weight_grad , _ = _C_ops .fused_linear_param_grad_add (x , y_grad , None , None , False , False )
59
117
return x_grad , weight_grad
60
118
61
119
if hasattr (weight , "main_grad" ) and hasattr (bias , "main_grad" ):
62
120
weight .main_grad , bias .main_grad = _C_ops .fused_linear_param_grad_add (
63
121
x , y_grad , weight .main_grad , bias .main_grad , True , True
64
122
)
123
+ if task is not None :
124
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
65
125
return x_grad , None , None
66
126
else :
67
127
if weight .grad is not None :
68
128
assert bias .grad is not None
69
129
weight .grad , bias .grad = _C_ops .fused_linear_param_grad_add (
70
130
x , y_grad , weight .grad , bias .grad , False if weight .grad .dtype != paddle .float32 else True , True
71
131
)
132
+ if task is not None :
133
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
72
134
return x_grad , None , None
73
135
else :
74
136
weight_grad , bias_grad = _C_ops .fused_linear_param_grad_add (x , y_grad , None , None , False , True )
137
+ if task is not None :
138
+ x_grad = sync_allreduce (task , x_grad , mp_placement_index )
75
139
return x_grad , weight_grad , bias_grad
76
140
77
141
78
- def mock_layers ():
142
+ def mock_layers (mp_async_allreduce = False ):
143
+ global _mp_async_allreduce
144
+ _mp_async_allreduce = mp_async_allreduce
145
+
79
146
paddle .nn .functional .linear = FusedLinearWithGradAdd .apply
80
147
if is_fused_matmul_bias_supported ():
81
148
paddle .incubate .nn .functional .fused_linear = FusedLinearWithGradAdd .apply
0 commit comments