Skip to content

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Aug 28, 2025

Fixes two issue with XeGPU to XeVM pass

  1. xegpu.update_nd_offset op lower generated incorrect code sequence
  2. xegpu.store_nd did not lower single element vector

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2025

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

Fixes two issue with XeGPU to XeVM pass

  1. xegpu.update_nd_offset op lower generated incorrect code sequence
  2. xegpu.store_nd did not lower single element vector

Full diff: https://github.com/llvm/llvm-project/pull/155946.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+15-10)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+30-16)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
     // Only 2D offsets are supported for now.
     if (mixedOffsets.size() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
-    auto tdesc = adaptor.getTensorDesc();
+    auto payload = adaptor.getTensorDesc();
     // Utility for updating payload offset values from op fold result.
     auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
       Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
       offset = getValueOrCreateCastToIndexLike(rewriter, loc,
                                                rewriter.getI32Type(), offset);
       Value oldOffset =
-          vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+          vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
       Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+      return vector::InsertOp::create(rewriter, loc, newOffset, payload,
                                       payloadPos);
     };
     // Update offsets in the payload.
-    auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
-    rewriter.replaceOp(op, val);
+    payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+    rewriter.replaceOp(op, payload);
     return success();
   }
 };
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     auto tileH = tdescTy.getDimSize(0);
     int32_t vblocks = tdescTy.getArrayLength();
     if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+      Value src = adaptor.getValue();
+      // If store value is a scalar, get value from op instead of adaptor.
+      // Adaptor might have optimized away single element vector
+      if (src.getType().isIntOrFloat()) {
+        src = op.getValue();
+      }
+      VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
       if (!srcVecTy)
         return rewriter.notifyMatchFailure(
             op, "Expected store value to be a vector type.");
-      auto storeCacheControl =
-          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      Value src = adaptor.getValue();
       // Get flat vector type of integer type with matching element bit size.
       VectorType newSrcVecTy =
           encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
       if (srcVecTy != newSrcVecTy)
         src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
       xevm::BlockStore2dOp::create(
           rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
           offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
 
 gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
-  // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+  // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
-  gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+  gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+        %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
 
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
         // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
-        // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
-        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
         // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
-        // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
-        // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+        // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+        // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
         // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
         // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
-        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[C8:.*]] = arith.constant 8 : index
+        %c8 = arith.constant 8 : index
+        // CHECK: %[[C16:.*]] = arith.constant 16 : index
+        %c16 = arith.constant 16 : index
+        // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+        // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+        // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+        %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
         gpu.return
     }
 }

@llvmbot
Copy link
Member

llvmbot commented Aug 28, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Sang Ik Lee (silee2)

Changes

Fixes two issue with XeGPU to XeVM pass

  1. xegpu.update_nd_offset op lower generated incorrect code sequence
  2. xegpu.store_nd did not lower single element vector

Full diff: https://github.com/llvm/llvm-project/pull/155946.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp (+15-10)
  • (modified) mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir (+30-16)
diff --git a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
index d8dd09a6280c0..a7f2dc2d6a43e 100644
--- a/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
+++ b/mlir/lib/Conversion/XeGPUToXeVM/XeGPUToXeVM.cpp
@@ -259,7 +259,7 @@ class UpdateNdOffsetToXeVMPattern
     // Only 2D offsets are supported for now.
     if (mixedOffsets.size() != 2)
       return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
-    auto tdesc = adaptor.getTensorDesc();
+    auto payload = adaptor.getTensorDesc();
     // Utility for updating payload offset values from op fold result.
     auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
       Value offset =
@@ -267,15 +267,15 @@ class UpdateNdOffsetToXeVMPattern
       offset = getValueOrCreateCastToIndexLike(rewriter, loc,
                                                rewriter.getI32Type(), offset);
       Value oldOffset =
-          vector::ExtractOp::create(rewriter, loc, tdesc, payloadPos);
+          vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
       Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
-      return vector::InsertOp::create(rewriter, loc, newOffset, tdesc,
+      return vector::InsertOp::create(rewriter, loc, newOffset, payload,
                                       payloadPos);
     };
     // Update offsets in the payload.
-    auto val = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
-    val = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
-    rewriter.replaceOp(op, val);
+    payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
+    payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
+    rewriter.replaceOp(op, payload);
     return success();
   }
 };
@@ -354,18 +354,23 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
     auto tileH = tdescTy.getDimSize(0);
     int32_t vblocks = tdescTy.getArrayLength();
     if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
-      VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue().getType());
+      Value src = adaptor.getValue();
+      // If store value is a scalar, get value from op instead of adaptor.
+      // Adaptor might have optimized away single element vector
+      if (src.getType().isIntOrFloat()) {
+        src = op.getValue();
+      }
+      VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
       if (!srcVecTy)
         return rewriter.notifyMatchFailure(
             op, "Expected store value to be a vector type.");
-      auto storeCacheControl =
-          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
-      Value src = adaptor.getValue();
       // Get flat vector type of integer type with matching element bit size.
       VectorType newSrcVecTy =
           encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
       if (srcVecTy != newSrcVecTy)
         src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
+      auto storeCacheControl =
+          translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
       xevm::BlockStore2dOp::create(
           rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
           offsetH, elemBitSize, tileW, tileH, src,
diff --git a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
index 4ff95b40fe68c..ed664a739d134 100644
--- a/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
+++ b/mlir/test/Conversion/XeGPUToXeVM/create_nd_tdesc.mlir
@@ -2,9 +2,9 @@
 
 gpu.module @create_nd_tdesc {
   // CHECK-LABEL: gpu.func @create_nd_tdesc
-  // CHECK-SAME: %[[ARG0:.*]]: memref<8x16xf32, 1>, %[[ARG1:.*]]: ui64,
+  // CHECK-SAME: %[[ARG0:.*]]: memref<16x32xf32, 1>, %[[ARG1:.*]]: ui64,
   // CHECK-SAME: %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index, %[[ARG5:.*]]: index, %[[ARG6:.*]]: index, %[[ARG7:.*]]: index
-  gpu.func @create_nd_tdesc(%src: memref<8x16xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
+  gpu.func @create_nd_tdesc(%src: memref<16x32xf32, 1>, %ptr: ui64, %shape1: index, %shape2: index,
   %stride1: index, %stride2: index, %offset1: index, %offset2: index) kernel {
         // CHECK: %[[VAR0:.*]] = index.castu %[[ARG1]] : ui64 to index
         // CHECK: %[[BASE_ADDR:.*]] = arith.index_castui %[[VAR0]] : index to i64
@@ -23,17 +23,17 @@ gpu.module @create_nd_tdesc {
         %ptr_tdesc = xegpu.create_nd_tdesc %ptr, shape:[%shape1, %shape2], strides:[%stride1, %stride2]
             : ui64 -> !xegpu.tensor_desc<8x16xf32>
 
-        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<8x16xf32, 1> to memref<8x16xf32>
-        %srcce = memref.memory_space_cast %src : memref<8x16xf32, 1> to memref<8x16xf32>
+        // CHECK: %[[MEMSPACECAST:.*]] = memref.memory_space_cast %[[ARG0]] : memref<16x32xf32, 1> to memref<16x32xf32>
+        %srcce = memref.memory_space_cast %src : memref<16x32xf32, 1> to memref<16x32xf32>
 
         // CHECK: %[[CST_1:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W2:.*]] = arith.constant 0 : i32
         // CHECK: %[[OFFSET_H2:.*]] = arith.constant 0 : i32
+        // CHECK: %[[C32_I64:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %[[C32_I64]] : i64 to i32
         // CHECK: %[[C16_I64:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W2:.*]] = arith.trunci %c16_i64 : i64 to i32
-        // CHECK: %[[C8_I64:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %c8_i64 : i64 to i32
+        // CHECK: %[[SHAPE_H2:.*]] = arith.trunci %[[C16_I64]] : i64 to i32
         // CHECK: %[[BASE_ADDR2:.*]] = arith.index_castui %[[INTPTR]] : index to i64
         // CHECK: %[[VAR14:.*]] = vector.bitcast %[[CST_1]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR15:.*]] = vector.insert %[[BASE_ADDR2]], %[[VAR14]] [0] : i64 into vector<4xi64>
@@ -41,17 +41,17 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR17:.*]] = vector.insert %[[SHAPE_W2]], %[[VAR16]] [2] : i32 into vector<8xi32>
         // CHECK: %[[VAR18:.*]] = vector.insert %[[SHAPE_H2]], %[[VAR17]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR19:.*]] = vector.insert %[[OFFSET_W2]], %[[VAR18]] [4] : i32 into vector<8xi32>
-        // CHECK: %[[VAR20:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
-        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        // CHECK: %[[PAYLOAD:.*]] = vector.insert %[[OFFSET_H2]], %[[VAR19]] [5] : i32 into vector<8xi32>
+        %src_tdesc = xegpu.create_nd_tdesc %srcce : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
 
         // CHECK: %[[CST_4:.*]] = arith.constant dense<0> : vector<8xi32>
-        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<8x16xf32> -> index
+        // CHECK: %[[INTPTR_2:.*]] = memref.extract_aligned_pointer_as_index %[[MEMSPACECAST]] : memref<16x32xf32> -> index
         // CHECK: %[[OFFSET_W3:.*]] = arith.index_cast %[[ARG7]] : index to i32
         // CHECK: %[[OFFSET_H3:.*]] = arith.index_cast %[[ARG6]] : index to i32
-        // CHECK: %[[C16_I64_6:.*]] = arith.constant 16 : i64
-        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C16_I64_6]] : i64 to i32
-        // CHECK: %[[C8_I64_7:.*]] = arith.constant 8 : i64
-        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C8_I64_7]] : i64 to i32
+        // CHECK: %[[C32_I64_6:.*]] = arith.constant 32 : i64
+        // CHECK: %[[SHAPE_W3:.*]] = arith.trunci %[[C32_I64_6]] : i64 to i32
+        // CHECK: %[[C16_I64_7:.*]] = arith.constant 16 : i64
+        // CHECK: %[[SHAPE_H3:.*]] = arith.trunci %[[C16_I64_7]] : i64 to i32
         // CHECK: %[[BASE_ADDR3:.*]] = arith.index_castui %[[INTPTR_2]] : index to i64
         // CHECK: %[[VAR26:.*]] = vector.bitcast %[[CST_4]] : vector<8xi32> to vector<4xi64>
         // CHECK: %[[VAR27:.*]] = vector.insert %[[BASE_ADDR3]], %[[VAR26]] [0] : i64 into vector<4xi64>
@@ -60,7 +60,21 @@ gpu.module @create_nd_tdesc {
         // CHECK: %[[VAR30:.*]] = vector.insert %[[SHAPE_H3]], %[[VAR29]] [3] : i32 into vector<8xi32>
         // CHECK: %[[VAR31:.*]] = vector.insert %[[OFFSET_W3]], %[[VAR30]] [4] : i32 into vector<8xi32>
         // CHECK: %[[VAR32:.*]] = vector.insert %[[OFFSET_H3]], %[[VAR31]] [5] : i32 into vector<8xi32>
-        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<8x16xf32> -> !xegpu.tensor_desc<8x16xf32>
+        %src_tdesc2 = xegpu.create_nd_tdesc %srcce[%offset1, %offset2] : memref<16x32xf32> -> !xegpu.tensor_desc<8x16xf32>
+
+        // CHECK: %[[C8:.*]] = arith.constant 8 : index
+        %c8 = arith.constant 8 : index
+        // CHECK: %[[C16:.*]] = arith.constant 16 : index
+        %c16 = arith.constant 16 : index
+        // CHECK: %[[VAR33:.*]] = arith.index_cast %[[C8]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[PAYLOAD]][5] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR33]] : i32
+        // CHECK: %[[NEW_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[PAYLOAD]] [5] : i32 into vector<8xi32>
+        // CHECK: %[[VAR37:.*]] = arith.index_cast %[[C16]] : index to i32
+        // CHECK: %[[OLD_OFFSET_H:.*]] = vector.extract %[[NEW_PAYLOAD]][4] : i32 from vector<8xi32>
+        // CHECK: %[[NEW_OFFSET_H:.*]] = arith.addi %[[OLD_OFFSET_H]], %[[VAR37]] : i32
+        // CHECK: %[[FINAL_PAYLOAD:.*]] = vector.insert %[[NEW_OFFSET_H]], %[[NEW_PAYLOAD]] [4] : i32 into vector<8xi32>
+        %updated_tdesc = xegpu.update_nd_offset %src_tdesc, [%c8, %c16] : !xegpu.tensor_desc<8x16xf32>
         gpu.return
     }
 }

@silee2 silee2 requested a review from mshahneo August 28, 2025 23:36
Copy link
Contributor

@charithaintc charithaintc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM % comments.

@@ -259,23 +259,23 @@ class UpdateNdOffsetToXeVMPattern
// Only 2D offsets are supported for now.
if (mixedOffsets.size() != 2)
return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
auto tdesc = adaptor.getTensorDesc();
auto payload = adaptor.getTensorDesc();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: spell the type name.

if (src.getType().isIntOrFloat()) {
src = op.getValue();
}
VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
if (!srcVecTy)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the prev code it fails here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe separate test for update_nd? or rename the test.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Contributor

@mshahneo mshahneo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@silee2
Copy link
Contributor Author

silee2 commented Aug 28, 2025

Since this is a hot fix, will address the comments in a follow up PR.

@silee2 silee2 merged commit d943efb into llvm:main Aug 28, 2025
12 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants