-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[MLIR] Fix issues with XeGPU to XeVM pass. #155946
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesFixes two issue with XeGPU to XeVM pass
Full diff: https://github.com/llvm/llvm-project/pull/155946.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
- auto tdesc = adaptor.getTensorDesc();
+ auto payload = adaptor.getTensorDesc();
// Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offset);
Value oldOffset =
- vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+ vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+ return vector::InsertOp::create(rewriter, loc, newOffset, payload,
payloadPos);
};
// Update offsets in the payload.
- auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
- val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
- rewriter.replaceOp(op, val);
+ payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ rewriter.replaceOp(op, payload);
return success();
}
};
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- Value src = adaptor.getValue();
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
- gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+ gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
- %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+ %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
- // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
- %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
- // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
- // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+ // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+ // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ %c8 = arith.constant 8 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
|
@llvm/pr-subscribers-mlir-gpu Author: Sang Ik Lee (silee2) ChangesFixes two issue with XeGPU to XeVM pass
Full diff: https://github.com/llvm/llvm-project/pull/155946.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
- auto tdesc = adaptor.getTensorDesc();
+ auto payload = adaptor.getTensorDesc();
// Utility for updating payload offset values from op fold result.
auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
offset = getValueOrCreateCastToIndexLike(rewriter, loc,
rewriter.getI32Type(), offset);
Value oldOffset =
- vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+ vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
- return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+ return vector::InsertOp::create(rewriter, loc, newOffset, payload,
payloadPos);
};
// Update offsets in the payload.
- auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
- val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
- rewriter.replaceOp(op, val);
+ payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+ payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+ rewriter.replaceOp(op, payload);
return success();
}
};
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
auto tileH = tdescTy.getDimSize(0);
int32_t vblocks = tdescTy.getArrayLength();
if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
- VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+ Value src = adaptor.getValue();
+ // If store value is a scalar, get value from op instead of adaptor.
+ // Adaptor might have optimized away single element vector
+ if (src.getType().isIntOrFloat()) {
+ src = op.getValue();
+ }
+ VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
return rewriter.notifyMatchFailure(
op, "Expected store value to be a vector type.");
- auto storeCacheControl =
- translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
- Value src = adaptor.getValue();
// Get flat vector type of integer type with matching element bit size.
VectorType newSrcVecTy =
encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
if (srcVecTy != newSrcVecTy)
src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+ auto storeCacheControl =
+ translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
xevm::BlockStore2dOp::create(
rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
gpu.module @create_nd_tdesc {
// CHECK-LABEL: gpu.func @create_nd_tdesc
- // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+ // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
// CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
- gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+ gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
%stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
// CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
// CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
%ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
: ui64 -> !xegpu.tensor_desc<8x16xf32>
- // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
- %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+ // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+ %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
// CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
// CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+ // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
// CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
- // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+ // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
// CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
// CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
// CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
- // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
- %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+ %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
// CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
- // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+ // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
// CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
// CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
- // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
- // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
- // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
- // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+ // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+ // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+ // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+ // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
// CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
// CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
// CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
// CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
// CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
// CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
- %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+ %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+ // CHECK: %[[C8:.*]] = arith.constant 8 : index
+ %c8 = arith.constant 8 : index
+ // CHECK: %[[C16:.*]] = arith.constant 16 : index
+ %c16 = arith.constant 16 : index
+ // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+ // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+ // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+ // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+ // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+ // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+ %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
gpu.return
}
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM % comments.
@@ -259,23 +259,23 @@ class UpdateNdOffsetToXeVMPattern | |||
// Only 2D offsets are supported for now. | |||
if (mixedOffsets.size() != 2) | |||
return rewriter.notifyMatchFailure(op, "Expected 2D offsets."); | |||
auto tdesc = adaptor.getTensorDesc(); | |||
auto payload = adaptor.getTensorDesc(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: spell the type name.
if (src.getType().isIntOrFloat()) { | ||
src = op.getValue(); | ||
} | ||
VectorType srcVecTy = dyn_cast<VectorType>(src.getType()); | ||
if (!srcVecTy) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in the prev code it fails here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe separate test for update_nd? or rename the test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
Since this is a hot fix, will address the comments in a follow up PR. |
Fixes two issue with XeGPU to XeVM pass