Skip to content

Commit 7890702

Browse files
committed
[MLIR][NVVM][NVGPU] Combine prefetch and prefetch.tensormap
This change combines the `prefetch` and `prefetch.tensormap` NVVM Ops to one `prefetch` Op. The `tensormap` variant is lowered through the newly added intrinsics. The lowering of the NVGPU `tma.prefetch.descriptor` Op is changed from lowering to the `prefetch.tensormap` Op to `prefetch`.
1 parent d1e6923 commit 7890702

File tree

8 files changed

+243
-75
lines changed

8 files changed

+243
-75
lines changed

mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ include "mlir/Dialect/LLVMIR/LLVMTypes.td"
2525
def LLVM_PointerGeneric : LLVM_PointerInAddressSpace<0>;
2626
def LLVM_PointerGlobal : LLVM_PointerInAddressSpace<1>;
2727
def LLVM_PointerShared : LLVM_PointerInAddressSpace<3>;
28+
def LLVM_PointerConst : LLVM_PointerInAddressSpace<4>;
2829
def LLVM_PointerLocal : LLVM_PointerInAddressSpace<5>;
2930
def LLVM_PointerTensor : LLVM_PointerInAddressSpace<6>;
3031
def LLVM_PointerSharedCluster : LLVM_PointerInAddressSpace<7>;
@@ -2427,15 +2428,25 @@ def PrefetchCacheLevelAttr : EnumAttr<NVVM_Dialect, PrefetchCacheLevel, "prefetc
24272428
let assemblyFormat = "$value";
24282429
}
24292430

2430-
def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
2431+
def NVVM_PrefetchOp : NVVM_Op<"prefetch",
2432+
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]> {
24312433
let summary = "Brings the cache line containing an address into the specified cache level";
24322434
let description = [{
2433-
Operand `addr` can be a global, local or generic address pointer. No
2434-
operation is performed if `addr` maps to a `shared` memory location.
2435+
Prefetches the cache line containing the address given by `addr`. The
2436+
operand may be a global, local, or generic pointer. When `tensormap` is
2437+
specified, the operand may instead be a constant or generic pointer. If the
2438+
address maps to shared memory, the operation has no effect.
2439+
2440+
At most one of `cacheLevel` or `tensormap` may be present. The `cacheLevel`
2441+
attribute selects the target cache level. When combined with `uniform`, the
2442+
prefetch is performed to the uniform cache, in which case `addr` must be a
2443+
generic pointer.
2444+
2445+
When `tensormap` is used, the line containing `addr` is brought from the
2446+
constant or parameter state space for later use by `cp.async.bulk.tensor`.
2447+
If `in_param_space` is specified, the generic pointer is interpreted as
2448+
referring to the parameter state space.
24352449

2436-
The `cacheLevel` attribute specifies the cache level to which the cache line
2437-
containing the specified address is brought.
2438-
24392450
`uniform` can be specified after the `cacheLevel` to indicate that the
24402451
prefetch is performed to the specified uniform cache level. If `uniform` is
24412452
specified, `addr` must be a generic address pointer and no operation is
@@ -2446,33 +2457,41 @@ def NVVM_PrefetchOp : NVVM_Op<"prefetch"> {
24462457

24472458
[For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/#data-movement-and-conversion-instructions-prefetch-prefetchu)
24482459
}];
2449-
let arguments = (ins PrefetchCacheLevelAttr:$cacheLevel,
2450-
UnitAttr:$uniform,
2460+
let arguments = (ins OptionalAttr<PrefetchCacheLevelAttr>:$cacheLevel,
2461+
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority,
24512462
AnyTypeOf<[LLVM_PointerGlobal,
24522463
LLVM_PointerLocal,
2453-
LLVM_PointerGeneric]>:$addr,
2454-
OptionalAttr<CacheEvictionPriorityAttr>:$evictPriority);
2455-
let assemblyFormat = "`level` `=` $cacheLevel (`uniform` $uniform^)? `,` $addr (`,` `evict_priority` `=` $evictPriority^)? attr-dict `:` type($addr)";
2464+
LLVM_PointerGeneric,
2465+
LLVM_PointerConst]>:$addr,
2466+
PtxPredicate:$predicate,
2467+
UnitAttr:$tensormap,
2468+
UnitAttr:$uniform,
2469+
UnitAttr:$in_param_space);
2470+
let assemblyFormat = "(`level` `=` $cacheLevel^ (`uniform` $uniform^)? `,`)? (`tensormap` $tensormap^ (`in_param_space` $in_param_space^)? `,`)? (`evict_priority` `=` $evictPriority^ `,`)? $addr (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
24562471
let hasVerifier = 1;
24572472

24582473
let extraClassDeclaration = [{
2459-
static llvm::Intrinsic::ID getIntrinsicID(NVVM::PrefetchOp &op);
2460-
}];
2461-
let llvmBuilder = [{
2462-
auto intId = NVVM::PrefetchOp::getIntrinsicID(op);
2463-
createIntrinsicCall(builder, intId, $addr);
2474+
static NVVM::IDArgPair
2475+
getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,LLVM::ModuleTranslation &mt,
2476+
llvm::IRBuilderBase &builder);
2477+
bool hasIntrinsic() { return !getPredicate() || !getTensormap(); }
24642478
}];
2465-
}
2466-
2467-
def NVVM_PrefetchTensorMapOp : NVVM_Op<"prefetch.tensormap",
2468-
[DeclareOpInterfaceMethods<BasicPtxBuilderOpInterface>]>,
2469-
Arguments<(ins LLVM_AnyPointer:$tmaDescriptor, PtxPredicate:$predicate)> {
2470-
let assemblyFormat = "$tmaDescriptor (`,` `predicate` `=` $predicate^)? attr-dict `:` type(operands)";
24712479
let extraClassDefinition = [{
2472-
std::string $cppClass::getPtx() {
2480+
std::string $cppClass::getPtx() {
2481+
// Inline PTX is only supported for prefetch tensormap
24732482
return std::string("prefetch.tensormap [%0];");
24742483
}
24752484
}];
2485+
let llvmBuilder = [{
2486+
auto [id, args] = NVVM::PrefetchOp::getIntrinsicIDAndArgs(op,
2487+
moduleTranslation, builder);
2488+
2489+
if(op.getTensormap())
2490+
// Overloaded intrinsic
2491+
createIntrinsicCall(builder, id, args, {args[0]->getType()});
2492+
else
2493+
createIntrinsicCall(builder, id, args);
2494+
}];
24762495
}
24772496

24782497
def NVVM_CpAsyncBulkPrefetchOp : NVVM_Op<"cp.async.bulk.prefetch"> {

mlir/lib/Conversion/NVGPUToNVVM/NVGPUToNVVM.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1700,8 +1700,10 @@ struct NVGPUTmaPrefetchOpLowering
17001700
LogicalResult
17011701
matchAndRewrite(nvgpu::TmaPrefetchOp op, OpAdaptor adaptor,
17021702
ConversionPatternRewriter &rewriter) const override {
1703-
rewriter.replaceOpWithNewOp<NVVM::PrefetchTensorMapOp>(
1704-
op, adaptor.getTensorMapDescriptor(), adaptor.getPredicate());
1703+
rewriter.replaceOpWithNewOp<NVVM::PrefetchOp>(
1704+
op, /* CacheLevel */ nullptr, /* Cache Eviction Priority */ nullptr,
1705+
adaptor.getTensorMapDescriptor(), adaptor.getPredicate(),
1706+
/* Tensormap UnitAttr */ mlir::UnitAttr::get(op.getContext()));
17051707
return success();
17061708
}
17071709
};

mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp

Lines changed: 97 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "llvm/IR/IRBuilder.h"
3434
#include "llvm/Support/Casting.h"
3535
#include "llvm/Support/FormatVariadic.h"
36+
#include "llvm/Support/NVPTXAddrSpace.h"
3637
#include "llvm/Support/raw_ostream.h"
3738
#include <cassert>
3839
#include <optional>
@@ -1236,30 +1237,70 @@ LogicalResult NVVM::PrefetchOp::verify() {
12361237
unsigned addressSpace =
12371238
llvm::cast<LLVM::LLVMPointerType>(getAddr().getType()).getAddressSpace();
12381239
std::optional<NVVM::CacheEvictionPriority> evictPriority = getEvictPriority();
1240+
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = getCacheLevel();
12391241

1240-
if (getUniform()) {
1241-
if (getCacheLevel() != CacheLevel::L1)
1242-
return emitOpError("unsupported cache level, the only supported uniform "
1243-
"cache level is L1");
1242+
if (getTensormap() && cacheLevel)
1243+
return emitOpError("cannot specify both tensormap and cache level");
12441244

1245-
if (addressSpace != MemSpace::kGenericMemorySpace)
1245+
if (getTensormap()) {
1246+
if (addressSpace != MemSpace::kGenericMemorySpace &&
1247+
addressSpace != MemSpace::kConstantMemorySpace) {
12461248
return emitOpError(
1247-
"prefetch to uniform cache requires a generic pointer");
1248-
}
1249+
"prefetch tensormap requires a generic or constant pointer");
1250+
}
12491251

1250-
if (evictPriority) {
1251-
if (getCacheLevel() != CacheLevel::L2)
1252+
if (evictPriority) {
12521253
return emitOpError(
1253-
"cache eviction priority supported only for cache level L2");
1254-
1255-
if (addressSpace != MemSpace::kGlobalMemorySpace)
1256-
return emitOpError("cache eviction priority requires a global pointer");
1254+
"prefetch tensormap does not support eviction priority");
1255+
}
12571256

1258-
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1259-
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1257+
if (getInParamSpace() && addressSpace != MemSpace::kGenericMemorySpace) {
12601258
return emitOpError(
1261-
"unsupported cache eviction priority, only evict_last and "
1262-
"evict_normal are supported");
1259+
"in_param_space can only be specified for a generic pointer");
1260+
}
1261+
1262+
} else if (cacheLevel) {
1263+
if (addressSpace != MemSpace::kGenericMemorySpace &&
1264+
addressSpace != MemSpace::kGlobalMemorySpace &&
1265+
addressSpace != MemSpace::kLocalMemorySpace) {
1266+
return emitOpError("prefetch to cache level requires a generic, global, "
1267+
"or local pointer");
1268+
}
1269+
1270+
if (getUniform()) {
1271+
if (*cacheLevel != CacheLevel::L1) {
1272+
return emitOpError(
1273+
"unsupported cache level, the only supported uniform "
1274+
"cache level is L1");
1275+
}
1276+
1277+
if (addressSpace != MemSpace::kGenericMemorySpace) {
1278+
return emitOpError(
1279+
"prefetch to uniform cache requires a generic pointer");
1280+
}
1281+
}
1282+
1283+
if (evictPriority) {
1284+
if (*cacheLevel != CacheLevel::L2)
1285+
return emitOpError(
1286+
"cache eviction priority supported only for cache level L2");
1287+
1288+
if (addressSpace != MemSpace::kGlobalMemorySpace)
1289+
return emitOpError("cache eviction priority requires a global pointer");
1290+
1291+
if (*evictPriority != NVVM::CacheEvictionPriority::EvictNormal &&
1292+
*evictPriority != NVVM::CacheEvictionPriority::EvictLast)
1293+
return emitOpError(
1294+
"unsupported cache eviction priority, only evict_last and "
1295+
"evict_normal are supported");
1296+
}
1297+
1298+
if (getPredicate())
1299+
return emitOpError("predicate supported only on prefetch tensormap");
1300+
1301+
} else {
1302+
return emitOpError(
1303+
"requires specification of either cache level or tensormap");
12631304
}
12641305

12651306
return success();
@@ -1794,43 +1835,67 @@ NVVM::IDArgPair DotAccumulate2WayOp::getIntrinsicIDAndArgs(
17941835
return {ids[type], args};
17951836
}
17961837

1797-
llvm::Intrinsic::ID PrefetchOp::getIntrinsicID(NVVM::PrefetchOp &op) {
1838+
static llvm::Value *getParamCastedAddr(llvm::Value *addr,
1839+
llvm::IRBuilderBase &builder) {
1840+
return builder.CreateAddrSpaceCast(
1841+
addr,
1842+
llvm::PointerType::get(builder.getContext(),
1843+
llvm::NVPTXAS::AddressSpace::ADDRESS_SPACE_PARAM));
1844+
}
1845+
1846+
NVVM::IDArgPair
1847+
PrefetchOp::getIntrinsicIDAndArgs(NVVM::PrefetchOp &op,
1848+
LLVM::ModuleTranslation &mt,
1849+
llvm::IRBuilderBase &builder) {
17981850
using MemSpace = NVVM::NVVMMemorySpace;
17991851
using CacheLevel = NVVM::PrefetchCacheLevel;
18001852

1801-
NVVM::PrefetchCacheLevel cacheLevel = op.getCacheLevel();
1853+
std::optional<NVVM::PrefetchCacheLevel> cacheLevel = op.getCacheLevel();
18021854
std::optional<NVVM::CacheEvictionPriority> evictPriority =
18031855
op.getEvictPriority();
18041856
unsigned addressSpace =
18051857
llvm::cast<LLVM::LLVMPointerType>(op.getAddr().getType())
18061858
.getAddressSpace();
18071859

1808-
if (op.getUniform() && cacheLevel == CacheLevel::L1)
1809-
return llvm::Intrinsic::nvvm_prefetchu_L1;
1860+
llvm::SmallVector<llvm::Value *> args;
1861+
llvm::Value *addr = mt.lookupValue(op.getAddr());
1862+
args.push_back(op.getInParamSpace() ? getParamCastedAddr(addr, builder)
1863+
: addr);
1864+
1865+
if (op.getTensormap())
1866+
return {llvm::Intrinsic::nvvm_prefetch_tensormap, args};
1867+
1868+
if (op.getUniform() && *cacheLevel == CacheLevel::L1)
1869+
return {llvm::Intrinsic::nvvm_prefetchu_L1, args};
18101870

1811-
if (evictPriority && cacheLevel == CacheLevel::L2) {
1871+
if (evictPriority && *cacheLevel == CacheLevel::L2) {
18121872
switch (*evictPriority) {
18131873
case NVVM::CacheEvictionPriority::EvictLast:
1814-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last;
1874+
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_last, args};
18151875
case NVVM::CacheEvictionPriority::EvictNormal:
1816-
return llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal;
1876+
return {llvm::Intrinsic::nvvm_prefetch_global_L2_evict_normal, args};
18171877
default:
18181878
llvm_unreachable("Invalid cache eviction priority");
18191879
}
18201880
}
18211881

18221882
switch (addressSpace) {
18231883
case MemSpace::kGenericMemorySpace:
1824-
return cacheLevel == CacheLevel::L1 ? llvm::Intrinsic::nvvm_prefetch_L1
1825-
: llvm::Intrinsic::nvvm_prefetch_L2;
1884+
return *cacheLevel == CacheLevel::L1
1885+
? NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L1, args})
1886+
: NVVM::IDArgPair({llvm::Intrinsic::nvvm_prefetch_L2, args});
18261887
case MemSpace::kGlobalMemorySpace:
1827-
return cacheLevel == CacheLevel::L1
1828-
? llvm::Intrinsic::nvvm_prefetch_global_L1
1829-
: llvm::Intrinsic::nvvm_prefetch_global_L2;
1888+
return *cacheLevel == CacheLevel::L1
1889+
? NVVM::IDArgPair(
1890+
{llvm::Intrinsic::nvvm_prefetch_global_L1, args})
1891+
: NVVM::IDArgPair(
1892+
{llvm::Intrinsic::nvvm_prefetch_global_L2, args});
18301893
case MemSpace::kLocalMemorySpace:
1831-
return cacheLevel == CacheLevel::L1
1832-
? llvm::Intrinsic::nvvm_prefetch_local_L1
1833-
: llvm::Intrinsic::nvvm_prefetch_local_L2;
1894+
return *cacheLevel == CacheLevel::L1
1895+
? NVVM::IDArgPair(
1896+
{llvm::Intrinsic::nvvm_prefetch_local_L1, args})
1897+
: NVVM::IDArgPair(
1898+
{llvm::Intrinsic::nvvm_prefetch_local_L2, args});
18341899
default:
18351900
llvm_unreachable("Invalid pointer address space");
18361901
}

mlir/test/Conversion/NVGPUToNVVM/nvgpu-to-nvvm.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -817,9 +817,9 @@ func.func @create_tensor_map(%devicePtr2d : memref<64x128xf32>, %devicePtr1d : m
817817
// CHECK-SAME: %[[arg0:[a-zA-Z0-9_]+]]: !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none>, %[[arg1:[a-zA-Z0-9_]+]]: i1
818818
func.func @tma_prefetch(%tensorMap1d: !tensorMap1d, %p : i1) {
819819
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[arg0]] : !nvgpu.tensormap.descriptor<tensor = memref<128xf32, 3>, swizzle = none, l2promo = none, oob = nan, interleave = none> to !llvm.ptr
820-
// CHECK: nvvm.prefetch.tensormap %[[S0]] : !llvm.ptr
820+
// CHECK: nvvm.prefetch tensormap, %[[S0]] : !llvm.ptr
821821
nvgpu.tma.prefetch.descriptor %tensorMap1d: !tensorMap1d
822-
// CHECK: nvvm.prefetch.tensormap %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
822+
// CHECK: nvvm.prefetch tensormap, %[[S0]], predicate = %[[arg1]] : !llvm.ptr, i1
823823
nvgpu.tma.prefetch.descriptor %tensorMap1d, predicate = %p: !tensorMap1d
824824
func.return
825825
}

mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -582,10 +582,10 @@ func.func @elect_one_leader_sync() {
582582

583583
// CHECK-LABEL: @init_mbarrier_arrive_expect_tx
584584
llvm.func @init_mbarrier_arrive_expect_tx(%desc : !llvm.ptr, %pred : i1) {
585-
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "prefetch.tensormap [$0];", "l"
586-
nvvm.prefetch.tensormap %desc : !llvm.ptr
585+
//CHECK: nvvm.prefetch tensormap, %{{.*}}
586+
nvvm.prefetch tensormap, %desc : !llvm.ptr
587587
//CHECK: llvm.inline_asm has_side_effects asm_dialect = att "@$1 prefetch.tensormap [$0];", "l,b"
588-
nvvm.prefetch.tensormap %desc, predicate = %pred : !llvm.ptr, i1
588+
nvvm.prefetch tensormap, %desc, predicate = %pred : !llvm.ptr, i1
589589
llvm.return
590590
}
591591

mlir/test/Dialect/LLVMIR/nvvm.mlir

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -597,7 +597,7 @@ func.func @dot_accumulate_2way(%a_vec: vector<2xi16>, %b_vec: vector<4xi8>, %c:
597597
}
598598

599599
// CHECK-LABEL: @prefetch
600-
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>) {
600+
func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr: !llvm.ptr<1>, %const_ptr: !llvm.ptr<4>) {
601601
// CHECK: nvvm.prefetch level = L1, %{{.*}}
602602
nvvm.prefetch level = L1, %gen_ptr : !llvm.ptr<0>
603603
// CHECK: nvvm.prefetch level = L1, %{{.*}}
@@ -610,12 +610,24 @@ func.func @prefetch(%gen_ptr: !llvm.ptr, %local_ptr: !llvm.ptr<5>, %global_ptr:
610610
nvvm.prefetch level = L2, %local_ptr : !llvm.ptr<5>
611611
// CHECK: nvvm.prefetch level = L2, %{{.*}}
612612
nvvm.prefetch level = L2, %global_ptr : !llvm.ptr<1>
613-
// CHECK: nvvm.prefetch level = L2, %{{.*}}
614-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
615-
// CHECK: nvvm.prefetch level = L2, %{{.*}}
616-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
613+
// CHECK: nvvm.prefetch level = L2, evict_priority = evict_last, %{{.*}}
614+
nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr :
615+
!llvm.ptr<1>
616+
// CHECK: nvvm.prefetch level = L2, evict_priority = evict_normal, %{{.*}}
617+
nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
617618
// CHECK: nvvm.prefetch level = L1 uniform, %{{.*}}
618619
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
620+
// CHECK: nvvm.prefetch tensormap, %{{.*}}
621+
nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
622+
// CHECK: nvvm.prefetch tensormap, %{{.*}}
623+
nvvm.prefetch tensormap, %const_ptr : !llvm.ptr<4>
624+
// CHECK: nvvm.prefetch tensormap in_param_space, %{{.*}}
625+
nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
626+
return
627+
}
628+
629+
// CHECK-LABEL: @prefetch_tensormap
630+
func.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
619631
return
620632
}
621633

mlir/test/Target/LLVMIR/nvvm/prefetch.mlir

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ llvm.func @prefetch_L2_eviction_priority(%global_ptr: !llvm.ptr<1>) {
3232
// CHECK-NEXT: call void @llvm.nvvm.prefetch.global.L2.evict.normal(ptr addrspace(1) %0)
3333
// CHECK-NEXT: ret void
3434
// CHECK-NEXT: }
35-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_last : !llvm.ptr<1>
36-
nvvm.prefetch level = L2, %global_ptr, evict_priority = evict_normal : !llvm.ptr<1>
35+
nvvm.prefetch level = L2, evict_priority = evict_last, %global_ptr : !llvm.ptr<1>
36+
nvvm.prefetch level = L2, evict_priority = evict_normal, %global_ptr : !llvm.ptr<1>
3737
llvm.return
3838
}
3939

@@ -45,3 +45,17 @@ llvm.func @prefetch_L1_uniform(%gen_ptr: !llvm.ptr) {
4545
nvvm.prefetch level = L1 uniform, %gen_ptr : !llvm.ptr
4646
llvm.return
4747
}
48+
49+
llvm.func @prefetch_tensormap(%gen_ptr: !llvm.ptr, %const_ptr: !llvm.ptr<4>) {
50+
// CHECK-LABEL: define void @prefetch_tensormap(ptr %0, ptr addrspace(4) %1) {
51+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p0(ptr %0)
52+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p4(ptr addrspace(4) %1)
53+
// CHECK-NEXT: %3 = addrspacecast ptr %0 to ptr addrspace(101)
54+
// CHECK-NEXT: call void @llvm.nvvm.prefetch.tensormap.p101(ptr addrspace(101) %3)
55+
// CHECK-NEXT: ret void
56+
// CHECK-NEXT: }
57+
nvvm.prefetch tensormap, %gen_ptr : !llvm.ptr
58+
nvvm.prefetch tensormap, %const_ptr: !llvm.ptr<4>
59+
nvvm.prefetch tensormap in_param_space, %gen_ptr : !llvm.ptr
60+
llvm.return
61+
}

0 commit comments

Comments
 (0)