Skip to content

Commit 0af3239

Browse files
committed
[Inductor] Migrate from oneDNN Inner Product to oneDNN MatMul for mkldnn._linear_pointwise and mkldnn._linear_pointwise.binary
ghstack-source-id: c376010 Pull Request resolved: #147360
1 parent 19ed227 commit 0af3239

File tree

6 files changed

+43
-36
lines changed

6 files changed

+43
-36
lines changed

aten/src/ATen/native/mkldnn/Linear.cpp

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -206,14 +206,14 @@ Tensor mkldnn_linear_pointwise(
206206
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
207207

208208
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
209-
output_size.push_back(weight_t.size(0));
209+
output_size.push_back(weight_t.size(1));
210210
auto output = at::empty(output_size, input.options());
211211
if (output.sym_numel() == 0) {
212212
return output;
213213
}
214214
if (dim != 2) {
215215
std::vector<int64_t> output_size_reshaped = {input_reshaped.size(0),
216-
weight_t.size(0)};
216+
weight_t.size(1)};
217217
output = output.reshape(output_size_reshaped);
218218
}
219219

@@ -228,7 +228,7 @@ Tensor mkldnn_linear_pointwise(
228228

229229
std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
230230
if (bias.defined()) {
231-
mkldnn_bias = itensor_from_tensor(bias);
231+
mkldnn_bias = itensor_from_tensor(bias.reshape({1, weight_t.size(1)}));
232232
}
233233
const ideep::tensor w = itensor_from_tensor(weight_t);
234234

@@ -241,20 +241,22 @@ Tensor mkldnn_linear_pointwise(
241241
}
242242

243243
if (mkldnn_bias.has_value()) {
244-
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
244+
ideep::matmul_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
245245
mkldnn_input,
246246
w,
247247
mkldnn_bias.value(),
248248
mkldnn_output,
249-
op_attr,
250-
aprop_kind);
249+
1.0f,
250+
1.0f,
251+
op_attr);
251252
} else {
252-
ideep::inner_product_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
253+
ideep::matmul_forward::compute</*reorder_src=*/false, /*reorder_weight=*/false>(
253254
mkldnn_input,
254255
w,
255256
mkldnn_output,
256-
op_attr,
257-
aprop_kind);
257+
1.0f,
258+
1.0f,
259+
op_attr);
258260
}
259261

260262
if (dim != 2) {
@@ -300,7 +302,7 @@ Tensor mkldnn_linear_pointwise_binary(
300302
dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)});
301303

302304
std::vector<int64_t> output_size(input_size.begin(), input_size.end() - 1);
303-
output_size.push_back(weight_t.size(0));
305+
output_size.push_back(weight_t.size(1));
304306
auto output = at::empty(output_size, input.options());
305307
if (output.sym_numel() == 0) {
306308
return output;
@@ -310,7 +312,7 @@ Tensor mkldnn_linear_pointwise_binary(
310312

311313
if (dim != 2) {
312314
std::vector<int64_t> output_size_reshaped = {
313-
input_reshaped.size(0), weight_t.size(0)};
315+
input_reshaped.size(0), weight_t.size(1)};
314316
output = output.reshape(output_size_reshaped);
315317
other_reshaped = other_reshaped.reshape(output_size_reshaped);
316318
TORCH_CHECK(
@@ -329,25 +331,25 @@ Tensor mkldnn_linear_pointwise_binary(
329331

330332
std::optional<ideep::tensor> mkldnn_bias{std::nullopt};
331333
if (bias.defined()) {
332-
mkldnn_bias = itensor_from_tensor(bias);
334+
mkldnn_bias = itensor_from_tensor(bias.reshape({1, weight_t.size(1)}));
333335
}
334336
const ideep::tensor w = itensor_from_tensor(weight_t);
335337

336338
auto other_desc = mkldnn_other.get_desc();
337339
auto op_attr = ideep::attr_t::fuse_binary(it_binary->second, other_desc);
338340

339341
if (mkldnn_bias.has_value()) {
340-
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
342+
ideep::matmul_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
341343
mkldnn_input,
342344
mkldnn_other,
343345
w,
344346
mkldnn_bias.value(),
345347
mkldnn_output,
346-
op_attr,
347-
aprop_kind);
348+
1.0f,
349+
op_attr);
348350
} else {
349-
ideep::inner_product_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
350-
mkldnn_input, mkldnn_other, w, mkldnn_output, op_attr, aprop_kind);
351+
ideep::matmul_forward::compute_binary</*reorder_src=*/false, /*reorder_weight=*/false>(
352+
mkldnn_input, mkldnn_other, w, mkldnn_output, 1.0f, op_attr);
351353
}
352354

353355
if (dim != 2) {

aten/src/ATen/native/mkldnn/MKLDNNConversions.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,21 +260,20 @@ static Tensor mkldnn_reorder_linear_weight(
260260
const Tensor& self,
261261
std::optional<int64_t> batch_size_opt) {
262262
mkldnn_check_low_precision(self.scalar_type(), "mkldnn_reorder_linear_weight");
263-
auto out_features = self.size(0);
264-
auto in_features = self.size(1);
263+
auto in_features = self.size(0);
264+
auto out_features = self.size(1);
265265
auto self_ = self.contiguous();
266266
auto w = itensor_from_tensor(self_);
267267
ideep::dims input_size;
268268
auto dtype = w.get_data_type();
269269
if (batch_size_opt.has_value()) {
270270
input_size = {batch_size_opt.value(), in_features};
271271
}
272-
auto packed_desc = ideep::inner_product_forward::expected_weights_desc(
273-
{out_features, in_features},
272+
auto packed_desc = ideep::matmul_forward::expected_weights_desc(
273+
{in_features, out_features},
274274
input_size,
275275
/* weight dtype */ dtype,
276-
/* src dtype */ dtype,
277-
ideep::prop_kind::forward_inference);
276+
/* src dtype */ dtype);
278277
ideep::tensor result;
279278
result.init(packed_desc);
280279
result.feed_from(w);

torch/_inductor/fx_passes/mkldnn_fusion.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -925,9 +925,7 @@ def is_linear_add_bias(match):
925925
linear_node = add_node.args[0]
926926
packed_weight_node = linear_node.args[1]
927927
assert packed_weight_node.target == mkldnn._reorder_linear_weight
928-
transpose_weight_node = packed_weight_node.args[0]
929-
assert transpose_weight_node.target == aten.permute.default
930-
weight_meta = transpose_weight_node.args[0].meta.get("val")
928+
weight_meta = packed_weight_node.args[0].meta.get("val")
931929
bias_node = add_node.args[1]
932930
if isinstance(bias_node, int):
933931
# we only folding bias if it is a constant
@@ -1300,9 +1298,6 @@ def linear(match, *args, **kwargs):
13001298
)
13011299
weight = args[1] if linear_node.target == aten.mm.default else args[2]
13021300
with graph.inserting_before(linear_node):
1303-
transpose_weight_node = graph.create_node(
1304-
"call_function", aten.permute.default, (weight, (1, 0))
1305-
)
13061301
weight_dtype = weight.meta.get("val").dtype
13071302
is_lp_weight = weight_dtype in (
13081303
torch.bfloat16,
@@ -1313,9 +1308,20 @@ def linear(match, *args, **kwargs):
13131308
assert (
13141309
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
13151310
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
1311+
weight_node = (
1312+
weight
1313+
if (
1314+
is_lp_weight
1315+
or mkldnn._is_mkldnn_acl_supported()
1316+
or V.aot_compilation
1317+
)
1318+
else graph.create_node(
1319+
"call_function", aten.permute.default, (weight, (1, 0))
1320+
)
1321+
)
13161322
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
13171323
packed_weight_inputs = (
1318-
transpose_weight_node,
1324+
weight_node,
13191325
batch_size.node.shape_env.size_hint(batch_size.node.expr)
13201326
if has_free_symbols(batch_size)
13211327
else batch_size,
@@ -1347,7 +1353,7 @@ def linear(match, *args, **kwargs):
13471353
packed_linear_inputs += (bias, "none", [], "")
13481354
packed_linear_op = mkldnn._linear_pointwise.default
13491355
else:
1350-
packed_linear_inputs += (transpose_weight_node, bias, batch_size)
1356+
packed_linear_inputs += (weight_node, bias, batch_size)
13511357
packed_linear_op = torch.ops.mkl._mkl_linear
13521358
packed_linear_node = graph.create_node(
13531359
"call_function", packed_linear_op, packed_linear_inputs

torch/_inductor/mkldnn_ir.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ def create(cls, x, w, B, attr, scalars, algorithm):
833833
w = cls.require_contiguous(cls.realize_input(w))
834834

835835
*m, _ic = x.get_size()
836-
oc, _ic = w.get_size()
836+
_ic, oc = w.get_size()
837837
output_size = list(m) + [oc]
838838
inputs = [x, w]
839839
constant_args = [attr, scalars if scalars else [-1], algorithm]
@@ -887,7 +887,7 @@ def create(cls, x, y, w, B, attr):
887887
w = cls.require_contiguous(cls.realize_input(w))
888888

889889
*m, _ic = x.get_size()
890-
oc, _ic = w.get_size()
890+
_ic, oc = w.get_size()
891891
output_size = list(m) + [oc]
892892
inputs = [x, y, w]
893893
constant_args = [attr]

torch/_inductor/mkldnn_lowerings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ def epilogue_creator(buf):
258258

259259
kwargs = dict(
260260
has_bias=b is not None,
261-
trans_w=True,
261+
trans_w=False,
262262
epilogue_creator=None if attr == "none" else epilogue_creator,
263263
)
264264
if b is not None:
@@ -321,7 +321,7 @@ def epilogue_creator(buf):
321321

322322
kwargs = dict(
323323
has_bias=b is not None,
324-
trans_w=True,
324+
trans_w=False,
325325
epilogue_creator=epilogue_creator,
326326
)
327327
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]

torch/_meta_registrations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2477,7 +2477,7 @@ def meta_mkldnn_convolution_default(
24772477
def meta_linear_pointwise_default(
24782478
input_tensor, weight, bias, attr, scalars, algorithm
24792479
):
2480-
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
2480+
return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[1]))
24812481

24822482
if torch._C.has_mkl:
24832483
_meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(

0 commit comments

Comments
 (0)