Skip to content

Conversation

durga4github
Copy link
Contributor

This patch includes im2col and gather mode
support for the TMA Load Op. The lowering is
also updated to intrinsics except when a Predicate
is given. This completes the Blackwell additions
on this Op.

  • NVVM Dialect has support for Shared::Cluster
    address-space now. So, this patch also updates the
    Op to use AS(7) instead of AS(3). The corresponding
    inline-ptx based unit tests are also updated.
  • lit tests are added for all combinations.

This patch includes im2col and gather mode
support to the TMA Load Op. The lowering is
also updated to intrinsics except when Predicate
is given. This completes the Blackwell additions
on this Op.

lit tests are added for all combinations.

Signed-off-by: Durgadoss R <durgadossr@nvidia.com>
@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Durgadoss R (durga4github)

Changes

This patch includes im2col and gather mode
support for the TMA Load Op. The lowering is
also updated to intrinsics except when a Predicate
is given. This completes the Blackwell additions
on this Op.

  • NVVM Dialect has support for Shared::Cluster
    address-space now. So, this patch also updates the
    Op to use AS(7) instead of AS(3). The corresponding
    inline-ptx based unit tests are also updated.
  • lit tests are added for all combinations.

Patch is 116.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156347.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37-29)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+88-8)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+39-65)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_im2col.mlir (+298)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_invalid.mlir (+37)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_tile.mlir (+204)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..aace19ab834c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2367,6 +2367,23 @@ def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+// Num CTAs in a group participating in the TCGEN05 operation.
+// This corresponds to the "cta_group::1", "cta_group::2"
+// modifiers in the PTX instructions.
+def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
+def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
+
+def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
+                            "NVVM Tcgen05 group kind",
+  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def Tcgen05GroupKindAttr :
+  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
@@ -2413,26 +2430,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
   AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
-  Arguments<(ins  LLVM_PointerShared:$dstMem,
-                  LLVM_AnyPointer:$tmaDescriptor,
+  Arguments<(ins  LLVM_PointerSharedCluster:$dstMem,
+                  LLVM_PointerGeneric:$tmaDescriptor,
                   Variadic<I32>:$coordinates,
                   LLVM_PointerShared:$mbar,                  
                   Variadic<I16>:$im2colOffsets,
                   Optional<I16>:$multicastMask,
                   Optional<I64>:$l2CacheHint,
+                  DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
+                  OptionalAttr<Tcgen05GroupKindAttr>:$group,
                   PtxPredicate:$predicate)> {
   let description = [{
     Initiates an asynchronous copy operation on the tensor data from global 
-    memory to shared memory. 
-
-    The Op operates has two load modes:
-    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor 
-    layout is preserved at the destination. 
-
-    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
-    the elements in the Bounding Box of the source tensor are rearranged into
-    columns at the destination. In this mode, the tensor has to be at least 
-    3-dimensional. 
+    memory to shared::cluster memory. This Op supports all the load modes
+    specified in `TMALoadMode`.
 
     The `multicastMask` operand is optional. When it is present, the Op copies
     data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2490,6 +2501,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
     }
   }];
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { return !getPredicate(); }
+
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, id, args);
+  }];
 }
 
 def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
@@ -3314,23 +3339,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
 //===----------------------------------------------------------------------===//
 // NVVM TCGEN05 Ops
 //===----------------------------------------------------------------------===//
-// Num CTAs in a group participating in the TCGEN05 operation.
-// This corresponds to the "cta_group::1", "cta_group::2"
-// modifiers in the PTX instructions.
-def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
-def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
-
-def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
-                            "NVVM Tcgen05 group kind",
-  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def Tcgen05GroupKindAttr :
-  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
 def Tcgen05FenceAfter  : I32EnumAttrCase<"AFTER_THREAD_SYNC",  1, "after">;
 def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab1666a0e8e75..be913eaaa27b8 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1003,9 +1003,13 @@ struct NVGPUTmaAsyncLoadOpLowering
     for (auto [index, value] : llvm::enumerate(coords)) {
       coords[index] = truncToI32(b, value);
     }
+
+    // TODO: Enhance the NVGPU Op for other modes too
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
         op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
         ValueRange{}, adaptor.getMulticastMask(), Value{},
+        NVVM::TMALoadMode::TILE, // default is TILE mode
+        nullptr,                 // default is no cta-group
         adaptor.getPredicate());
     return success();
   }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ff6ccbaac2b35..2bd550dc3b89b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -49,7 +49,7 @@ using namespace NVVM;
 //===----------------------------------------------------------------------===//
 
 // This verifier is shared among the following Ops:
-// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
+// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
                                                      bool isIm2Col,
@@ -73,13 +73,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
   return success();
 }
 
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
-  size_t numIm2ColOffsets = getIm2colOffsets().size();
-  bool isIm2Col = numIm2ColOffsets > 0;
-  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
-                                         numIm2ColOffsets, getLoc());
-}
-
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   TMAStoreMode mode = getMode();
   // We lower through inline-ptx when getPredicate() is true.
@@ -157,6 +150,17 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
                              getMode(), getLoc());
 }
 
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+  TMALoadMode mode = getMode();
+  if (getPredicate()) {
+    if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
+      return emitError(
+          "Inline-ptx lowering supported only for Tile/Im2col mode.");
+  }
+  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+                             getMode(), getLoc());
+}
+
 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
   TMAStoreMode mode = getMode();
   size_t dims = getCoordinates().size();
@@ -1495,6 +1499,82 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getDstMem()));
+  args.push_back(mt.lookupValue(thisOp.getMbar()));
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  // Coordinates and im2col-offsets
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  // MulticastMask, if available
+  mlir::Value mcMask = thisOp.getMulticastMask();
+  const bool hasMC = static_cast<bool>(mcMask);
+  llvm::Value *i16Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Unused);
+
+  // CacheHint, if available
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+  // Flag arguments for multicast, cache-hint and CTAGroup
+  args.push_back(builder.getInt1(hasMC));
+  args.push_back(builder.getInt1(hasCacheHint));
+
+  // Flag argument CTAGroup
+  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+  // Hence, the +1 to getGroup().
+  const int32_t val =
+      thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+  llvm::Value *cg =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+  args.push_back(cg);
+
+  const unsigned NI = llvm::Intrinsic::not_intrinsic;
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+      {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
+      {NI, NI, NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}};
+
+  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+                "TMALoadModes must match number of rows in IDTable");
+  size_t mode = static_cast<size_t>(thisOp.getMode());
+  size_t dim = thisOp.getCoordinates().size();
+  llvm::Intrinsic::ID id = IDTable[mode][dim];
+  if (id == llvm::Intrinsic::not_intrinsic)
+    llvm_unreachable(
+        "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
+
+  return {id, std::move(args)};
+}
+
 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89075120d16ea..f0bcf9f3498b0 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -96,119 +96,93 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 }
 
 // CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "l,l,r,r,r,r,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "l,l,r,r,r,r,r,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "l,l,r,r,r,r,r,r,h,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "l,l,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "l,l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %bar...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Sep 1, 2025

@llvm/pr-subscribers-mlir-gpu

Author: Durgadoss R (durga4github)

Changes

This patch includes im2col and gather mode
support for the TMA Load Op. The lowering is
also updated to intrinsics except when a Predicate
is given. This completes the Blackwell additions
on this Op.

  • NVVM Dialect has support for Shared::Cluster
    address-space now. So, this patch also updates the
    Op to use AS(7) instead of AS(3). The corresponding
    inline-ptx based unit tests are also updated.
  • lit tests are added for all combinations.

Patch is 116.74 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/156347.diff

7 Files Affected:

  • (modified) mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td (+37-29)
  • (modified) mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp (+4)
  • (modified) mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp (+88-8)
  • (modified) mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir (+39-65)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_im2col.mlir (+298)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_invalid.mlir (+37)
  • (added) mlir/test/Target/LLVMIR/nvvm/tma_load_tile.mlir (+204)
diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
index 9528da05c9fd6..aace19ab834c3 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td
@@ -2367,6 +2367,23 @@ def TMAStoreModeAttr : EnumAttr<NVVM_Dialect, TMAStoreMode, "tma_store_mode"> {
   let assemblyFormat = "`<` $value `>`";
 }
 
+// Num CTAs in a group participating in the TCGEN05 operation.
+// This corresponds to the "cta_group::1", "cta_group::2"
+// modifiers in the PTX instructions.
+def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
+def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
+
+def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
+                            "NVVM Tcgen05 group kind",
+  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
+  let genSpecializedAttr = 0;
+  let cppNamespace = "::mlir::NVVM";
+}
+def Tcgen05GroupKindAttr :
+  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
+  let assemblyFormat = "`<` $value `>`";
+}
+
 def NVVM_CpAsyncBulkCommitGroupOp : NVVM_Op<"cp.async.bulk.commit.group">,
   Arguments<(ins )> {
   let assemblyFormat = "attr-dict";
@@ -2413,26 +2430,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
   NVVM_Op<"cp.async.bulk.tensor.shared.cluster.global", 
   [DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>, 
   AttrSizedOperandSegments, NVVMRequiresSM<90>]>,
-  Arguments<(ins  LLVM_PointerShared:$dstMem,
-                  LLVM_AnyPointer:$tmaDescriptor,
+  Arguments<(ins  LLVM_PointerSharedCluster:$dstMem,
+                  LLVM_PointerGeneric:$tmaDescriptor,
                   Variadic<I32>:$coordinates,
                   LLVM_PointerShared:$mbar,                  
                   Variadic<I16>:$im2colOffsets,
                   Optional<I16>:$multicastMask,
                   Optional<I64>:$l2CacheHint,
+                  DefaultValuedAttr<TMALoadModeAttr, "TMALoadMode::TILE">:$mode,
+                  OptionalAttr<Tcgen05GroupKindAttr>:$group,
                   PtxPredicate:$predicate)> {
   let description = [{
     Initiates an asynchronous copy operation on the tensor data from global 
-    memory to shared memory. 
-
-    The Op operates has two load modes:
-    1) Tiled Mode: It's the default mode. The source multi-dimensional tensor 
-    layout is preserved at the destination. 
-
-    2) Im2col Mode: This mode is used when `im2colOffsets` operands are present.
-    the elements in the Bounding Box of the source tensor are rearranged into
-    columns at the destination. In this mode, the tensor has to be at least 
-    3-dimensional. 
+    memory to shared::cluster memory. This Op supports all the load modes
+    specified in `TMALoadMode`.
 
     The `multicastMask` operand is optional. When it is present, the Op copies
     data from global memory to shared memory of multiple CTAs in the cluster.
@@ -2490,6 +2501,20 @@ def NVVM_CpAsyncBulkTensorGlobalToSharedClusterOp :
     }
   }];
   let hasVerifier = 1;
+
+  let extraClassDeclaration = [{
+    bool hasIntrinsic() { return !getPredicate(); }
+
+    static mlir::NVVM::IDArgPair
+    getIntrinsicIDAndArgs(Operation &op, LLVM::ModuleTranslation &mt,
+                          llvm::IRBuilderBase& builder);
+  }];
+
+  string llvmBuilder = [{
+    auto [id, args] = NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+                      *op, moduleTranslation, builder);
+    createIntrinsicCall(builder, id, args);
+  }];
 }
 
 def NVVM_CpAsyncBulkTensorSharedCTAToGlobalOp : 
@@ -3314,23 +3339,6 @@ def NVVM_Breakpoint : NVVM_Op<"breakpoint"> {
 //===----------------------------------------------------------------------===//
 // NVVM TCGEN05 Ops
 //===----------------------------------------------------------------------===//
-// Num CTAs in a group participating in the TCGEN05 operation.
-// This corresponds to the "cta_group::1", "cta_group::2"
-// modifiers in the PTX instructions.
-def Tcgen05GroupCTA_1 : I32EnumAttrCase<"CTA_1", 0, "cta_1">;
-def Tcgen05GroupCTA_2 : I32EnumAttrCase<"CTA_2", 1, "cta_2">;
-
-def Tcgen05GroupKind : I32EnumAttr<"Tcgen05GroupKind",
-                            "NVVM Tcgen05 group kind",
-  [Tcgen05GroupCTA_1, Tcgen05GroupCTA_2]> {
-  let genSpecializedAttr = 0;
-  let cppNamespace = "::mlir::NVVM";
-}
-def Tcgen05GroupKindAttr :
-  EnumAttr<NVVM_Dialect, Tcgen05GroupKind, "tcgen05_group"> {
-  let assemblyFormat = "`<` $value `>`";
-}
-
 def Tcgen05FenceBefore : I32EnumAttrCase<"BEFORE_THREAD_SYNC", 0, "before">;
 def Tcgen05FenceAfter  : I32EnumAttrCase<"AFTER_THREAD_SYNC",  1, "after">;
 def Tcgen05FenceKind : I32EnumAttr<"Tcgen05FenceKind", "NVVM Tcgen05 fence kind",
diff --git a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
index ab1666a0e8e75..be913eaaa27b8 100644
--- a/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
+++ b/mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp
@@ -1003,9 +1003,13 @@ struct NVGPUTmaAsyncLoadOpLowering
     for (auto [index, value] : llvm::enumerate(coords)) {
       coords[index] = truncToI32(b, value);
     }
+
+    // TODO: Enhance the NVGPU Op for other modes too
     rewriter.replaceOpWithNewOp<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(
         op, dest, adaptor.getTensorMapDescriptor(), coords, barrier,
         ValueRange{}, adaptor.getMulticastMask(), Value{},
+        NVVM::TMALoadMode::TILE, // default is TILE mode
+        nullptr,                 // default is no cta-group
         adaptor.getPredicate());
     return success();
   }
diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
index ff6ccbaac2b35..2bd550dc3b89b 100644
--- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
+++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp
@@ -49,7 +49,7 @@ using namespace NVVM;
 //===----------------------------------------------------------------------===//
 
 // This verifier is shared among the following Ops:
-// CpAsyncBulkTensorGlobalToSharedClusterOp (TMA Load)
+// CpAsyncBulkTensorSharedCTAToGlobalOp (TMA Store)
 // CpAsyncBulkTensorReduceOp (TMA Store-Reduce)
 static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
                                                      bool isIm2Col,
@@ -73,13 +73,6 @@ static LogicalResult cpAsyncBulkTensorCommonVerifier(size_t tensorDims,
   return success();
 }
 
-LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
-  size_t numIm2ColOffsets = getIm2colOffsets().size();
-  bool isIm2Col = numIm2ColOffsets > 0;
-  return cpAsyncBulkTensorCommonVerifier(getCoordinates().size(), isIm2Col,
-                                         numIm2ColOffsets, getLoc());
-}
-
 LogicalResult CpAsyncBulkTensorSharedCTAToGlobalOp::verify() {
   TMAStoreMode mode = getMode();
   // We lower through inline-ptx when getPredicate() is true.
@@ -157,6 +150,17 @@ LogicalResult CpAsyncBulkTensorPrefetchOp::verify() {
                              getMode(), getLoc());
 }
 
+LogicalResult CpAsyncBulkTensorGlobalToSharedClusterOp::verify() {
+  TMALoadMode mode = getMode();
+  if (getPredicate()) {
+    if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
+      return emitError(
+          "Inline-ptx lowering supported only for Tile/Im2col mode.");
+  }
+  return verifyTMALoadParams(getCoordinates().size(), getIm2colOffsets().size(),
+                             getMode(), getLoc());
+}
+
 LogicalResult CpAsyncBulkTensorReduceOp::verify() {
   TMAStoreMode mode = getMode();
   size_t dims = getCoordinates().size();
@@ -1495,6 +1499,82 @@ mlir::NVVM::IDArgPair CpAsyncBulkSharedCTAToGlobalOp::getIntrinsicIDAndArgs(
   return {id, std::move(args)};
 }
 
+mlir::NVVM::IDArgPair
+CpAsyncBulkTensorGlobalToSharedClusterOp::getIntrinsicIDAndArgs(
+    Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
+  auto thisOp = cast<NVVM::CpAsyncBulkTensorGlobalToSharedClusterOp>(op);
+  llvm::SmallVector<llvm::Value *> args;
+
+  // Fill the Intrinsic Args
+  args.push_back(mt.lookupValue(thisOp.getDstMem()));
+  args.push_back(mt.lookupValue(thisOp.getMbar()));
+  args.push_back(mt.lookupValue(thisOp.getTmaDescriptor()));
+
+  // Coordinates and im2col-offsets
+  for (auto v : thisOp.getCoordinates())
+    args.push_back(mt.lookupValue(v));
+  for (auto v : thisOp.getIm2colOffsets())
+    args.push_back(mt.lookupValue(v));
+
+  // MulticastMask, if available
+  mlir::Value mcMask = thisOp.getMulticastMask();
+  const bool hasMC = static_cast<bool>(mcMask);
+  llvm::Value *i16Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt16Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasMC ? mt.lookupValue(mcMask) : i16Unused);
+
+  // CacheHint, if available
+  mlir::Value cacheHint = thisOp.getL2CacheHint();
+  const bool hasCacheHint = static_cast<bool>(cacheHint);
+  llvm::Value *i64Unused =
+      llvm::ConstantInt::get(llvm::Type::getInt64Ty(mt.getLLVMContext()), 0);
+  args.push_back(hasCacheHint ? mt.lookupValue(cacheHint) : i64Unused);
+
+  // Flag arguments for multicast, cache-hint and CTAGroup
+  args.push_back(builder.getInt1(hasMC));
+  args.push_back(builder.getInt1(hasCacheHint));
+
+  // Flag argument CTAGroup
+  // CTA_1/2 is mapped to values 1 and 2 for the intrinsics.
+  // Hence, the +1 to getGroup().
+  const int32_t val =
+      thisOp.getGroup() ? (static_cast<int32_t>(*thisOp.getGroup()) + 1) : 0;
+  llvm::Value *cg =
+      llvm::ConstantInt::get(llvm::Type::getInt32Ty(mt.getLLVMContext()), val);
+  args.push_back(cg);
+
+  const unsigned NI = llvm::Intrinsic::not_intrinsic;
+  static constexpr llvm::Intrinsic::ID IDTable[][6] = {
+      {NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_1d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_2d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_5d},
+      {NI, NI, NI, llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_5d},
+      {NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_3d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_4d,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_im2col_w_128_5d},
+      {NI, NI, NI, NI, NI,
+       llvm::Intrinsic::nvvm_cp_async_bulk_tensor_g2s_tile_gather4_2d}};
+
+  static_assert(getMaxEnumValForTMALoadMode() == std::size(IDTable) - 1,
+                "TMALoadModes must match number of rows in IDTable");
+  size_t mode = static_cast<size_t>(thisOp.getMode());
+  size_t dim = thisOp.getCoordinates().size();
+  llvm::Intrinsic::ID id = IDTable[mode][dim];
+  if (id == llvm::Intrinsic::not_intrinsic)
+    llvm_unreachable(
+        "Invalid intrinsic for CpAsyncBulkTensorGlobalToSharedClusterOp.");
+
+  return {id, std::move(args)};
+}
+
 mlir::NVVM::IDArgPair CpAsyncBulkTensorPrefetchOp::getIntrinsicIDAndArgs(
     Operation &op, LLVM::ModuleTranslation &mt, llvm::IRBuilderBase &builder) {
   auto thisOp = cast<NVVM::CpAsyncBulkTensorPrefetchOp>(op);
diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
index 89075120d16ea..f0bcf9f3498b0 100644
--- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
+++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir
@@ -96,119 +96,93 @@ func.func @cp_async_mbarrier_arrive(%bar_shared: !llvm.ptr<3>, %bar_gen: !llvm.p
 }
 
 // CHECK-LABEL: @tma_load_3d_all
-func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "r,l,r,r,r,r,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_3d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$9 cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4} ], [$5],{$6}, $7, $8;", "l,l,r,r,r,r,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2] im2col[%off0] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_4d_all
-func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "r,l,r,r,r,r,r,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_4d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %off0: i16, %off1: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$11 cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5} ], [$6],{$7,$8}, $9, $10;", "l,l,r,r,r,r,r,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3] im2col[%off0,%off1] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_5d_all
-func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint : !llvm.ptr<3>, !llvm.ptr  
-  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "r,l,r,r,r,r,r,r,h,h,h,h,l,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_5d_all(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %crd2: i32, %crd3: i32, %crd4: i32, %off0: i16, %off1: i16, %off2: i16, %ctamask : i16, %cacheHint : i64, %p : i1) {
+  // CHECK: lvm.inline_asm has_side_effects asm_dialect = att "@$13 cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.im2col.multicast::cluster.L2::cache_hint [$0], [$1, {$2,$3,$4,$5,$6} ], [$7],{$8,$9,$10}, $11, $12;", "l,l,r,r,r,r,r,r,h,h,h,h,l,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0,%crd1,%crd2,%crd3,%crd4] im2col[%off0,%off1,%off2] multicast_mask = %ctamask l2_cache_hint = %cacheHint predicate = %p {mode = #nvvm.tma_load_mode<im2col>} : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_1d
-func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "r,l,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_1d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$4 cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2} ], [$3];", "l,l,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor,  %barrier, box[%crd0] predicate=%p : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_2d
-func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] : !llvm.ptr<3>, !llvm.ptr
-  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "r,l,r,r,r,b"
-  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<3>, !llvm.ptr
+func.func @tma_load_2d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<7>, %barrier: !llvm.ptr<3>, %crd0: i32, %crd1: i32, %p : i1) {
+  // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$5 cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes [$0], [$1, {$2,$3} ], [$4];", "l,l,r,r,r,b"
+  nvvm.cp.async.bulk.tensor.shared.cluster.global %dest, %tmaDescriptor, %barrier, box[%crd0,%crd1] predicate=%p  : !llvm.ptr<7>, !llvm.ptr
   return
 }
 
 // CHECK-LABEL: @tma_load_3d
-func.func @tma_load_3d(%tmaDescriptor: !llvm.ptr, %dest : !llvm.ptr<3>, %bar...
[truncated]

@durga4github
Copy link
Contributor Author

@grypp , Please help with a review

if (getPredicate()) {
if (mode != TMALoadMode::TILE && mode != TMALoadMode::IM2COL)
return emitError(
"Inline-ptx lowering supported only for Tile/Im2col mode.");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's change the error message Predicate isn't supported ....

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will update.

let genSpecializedAttr = 0;
let cppNamespace = "::mlir::NVVM";
}
def Tcgen05GroupKindAttr :
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we rename this to CTAGroupKindAttr. I think TMA lives beyond tcgen05.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it makes sense. But this attr is used in many tcgen05 Ops now.
So, let me raise a separate PR to do this name-change (and then will rebase this one on top).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants