-
Notifications
You must be signed in to change notification settings - Fork 24k
/
Copy path_visualize_sharding.py
218 lines (183 loc) · 7.09 KB
/
_visualize_sharding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
# mypy: allow-untyped-defs
import importlib.util
import numpy as np
from torch._prims_common import ShapeType
from torch.distributed.tensor._utils import _compute_local_shape_and_global_offset
__all__ = ["visualize_sharding"]
Color = tuple[float, float, float]
def _create_table(
shards: list[tuple[tuple[int, int], tuple[int, int], int]], device_kind: str = ""
):
"""
Creates a tabulate table given row and column ranges with device name
"""
from tabulate import tabulate
# Extract unique row and column ranges
row_ranges = sorted({block[0] for block in shards})
col_ranges = sorted({block[1] for block in shards})
# Create a matrix initialized with empty strings
matrix = [["" for _ in col_ranges] for _ in row_ranges]
# Fill the matrix with values
for block in shards:
row_index = row_ranges.index(block[0])
col_index = col_ranges.index(block[1])
if matrix[row_index][col_index] == "":
matrix[row_index][col_index] = device_kind + ":" + str(block[2])
else:
matrix[row_index][col_index] += "," + str(block[2])
# Prepare headers
row_headers = [f"Row {r[0]}-{r[1]}" for r in row_ranges]
col_headers = [f"Col {c[0]}-{c[1]}" for c in col_ranges]
return tabulate(matrix, headers=col_headers, showindex=row_headers)
def make_color_iter(color_map, num_rows, num_cols):
num_colors = num_rows * num_cols
for idx in range(num_colors):
yield color_map(idx)
def _canonicalize_color(color: Color) -> str:
if isinstance(color, str):
return color
r, g, b = (int(a * 255) for a in color)
return f"#{r:02X}{g:02X}{b:02X}"
def _get_text_color(color: str) -> str:
r, g, b = map(lambda x: int(x, 16), (color[1:3], color[3:5], color[5:7])) # noqa: C417
if (r * 0.299 + g * 0.587 + b * 0.114) > 186:
return "#000000"
return "#ffffff"
def _create_rich_table(
shape: ShapeType,
shards: list[tuple[tuple[int, int], tuple[int, int], int]],
device_kind: str = "",
scale: float = 1.0,
min_width: int = 9,
max_width: int = 80,
):
import matplotlib
import rich.align
import rich.box
import rich.console
import rich.padding
import rich.style
import rich.table
dtensor_height = shape[0] if len(shape) > 0 else 1
dtensor_width = shape[1] if len(shape) > 0 else shape[0]
row_ranges = sorted({s[0] for s in shards})
col_ranges = sorted({s[1] for s in shards})
num_rows, num_cols = len(row_ranges), len(col_ranges)
console = rich.console.Console(width=max_width)
use_color = console.color_system
color_iter = make_color_iter(matplotlib.colormaps["tab20b"], num_rows, num_cols)
base_height = int(10 * scale)
aspect_ratio = (shape[1] if len(shape) == 2 else 1) / shape[0]
base_width = int(base_height * aspect_ratio)
height_to_width_ratio = 2.5
table = rich.table.Table(
show_header=False,
show_lines=not use_color,
padding=0,
highlight=not use_color,
pad_edge=False,
box=rich.box.SQUARE if not use_color else None,
)
for row in range(num_rows):
table_row = []
for col in range(num_cols):
entry = (
device_kind
+ ":"
+ ",".join(
[
str(device_id)
for row_range, col_range, device_id in shards
if row_range == row_ranges[row] and col_range == col_ranges[col]
]
)
)
width = (col_ranges[col][1] - col_ranges[col][0]) / dtensor_width
width = int(width * base_width * height_to_width_ratio)
height = (row_ranges[row][1] - row_ranges[row][0]) / dtensor_height
height = int(height * base_height)
left_padding, remainder = divmod(width - len(entry) - 2, 2)
right_padding = left_padding + remainder
top_padding, remainder = divmod(height - 2, 2)
bottom_padding = top_padding + remainder
if use_color:
color = _canonicalize_color(next(color_iter)[:3])
text_color = _get_text_color(color)
top_padding += 1
bottom_padding += 1
left_padding += 1
right_padding += 1
else:
color = None
text_color = None
padding = (
max(top_padding, 0),
max(right_padding, 0),
max(bottom_padding, 0),
max(left_padding, 0),
)
table_row.append(
rich.padding.Padding(
rich.align.Align(entry, "center", vertical="middle"),
padding,
style=rich.style.Style(bgcolor=color, color=text_color),
)
)
table.add_row(*table_row)
console.print(table, end="\n\n")
def visualize_sharding(dtensor, header="", use_rich: bool = False):
"""
Visualizes sharding in the terminal for :class:`DTensor` that are 1D or 2D.
.. note:: This requires the ``tabulate`` package, or ``rich`` and ``matplotlib``.
No sharding info will be printed for empty tensors
"""
if dtensor.numel() == 0: # Do not print empty dtensors.
return
if len(dtensor.shape) >= 3:
raise RuntimeError("visualize sharding supports only 1D or 2D DTensor")
if dtensor.device_mesh.get_coordinate() is None: # current rank is not in the mesh
return
# Only display the visualization once for each DTensor, on the rank whose
# coordinate is 0 on all dimensions. For example, if the mesh is a full mesh,
# we will only print on rank 0.
local_rank_zero_on_all_dim = all(
dtensor.device_mesh.get_local_rank(mesh_dim=dim) == 0
for dim in range(dtensor.device_mesh.ndim)
)
if not local_rank_zero_on_all_dim:
return
device_coords = {
int(device_index.item()): list(coord)
for coord, device_index in np.ndenumerate(
np.array(dtensor.device_mesh.mesh.tolist())
)
}
device_shard_shape_and_offsets = {
device_index: _compute_local_shape_and_global_offset(
dtensor.shape,
dtensor.device_mesh.shape,
device_coords[device_index],
dtensor.placements,
)
for device_index in device_coords
}
shards = [
(
(offset[0], offset[0] + shape[0] - 1),
(offset[1], offset[1] + shape[1] - 1),
device_index,
)
for device_index, (shape, offset) in device_shard_shape_and_offsets.items()
]
if (
importlib.util.find_spec("rich")
and importlib.util.find_spec("matplotlib")
and use_rich
):
_create_rich_table(
dtensor.shape, shards, device_kind=dtensor.device_mesh.device_type
)
elif importlib.util.find_spec("tabulate"):
print(_create_table(shards, device_kind=dtensor.device_mesh.device_type))
else:
raise ValueError("`visualize_sharding` requires either `rich` or `tabulate`.")