Skip to content

Conversation

ro-i
Copy link
Contributor

@ro-i ro-i commented May 20, 2025

In case a ptrmask cannot be converted to the new address space due to an unknown mask value, this needs to be detcted and an addrspacecast is needed to not hinder a future use of the unconverted return value of ptrmask. Otherwise, users of this value will become invalid by receiving a nullptr as an operand.

This LLVM defect was identified via the AMD Fuzzing project.

(See https://reviews.llvm.org/D80129 for an explanation of why some ptrmasks are impossible to convert to other addrspaces.)

@llvmbot
Copy link
Member

llvmbot commented May 20, 2025

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-amdgpu

Author: Robert Imschweiler (ro-i)

Changes

In case a ptrmask cannot be converted to the new address space due to an unknown mask value, this needs to be detcted and an addrspacecast is needed to not hinder a future use of the unconverted return value of ptrmask. Otherwise, users of this value will become invalid by receiving a nullptr as an operand.

This LLVM defect was identified via the AMD Fuzzing project.

(See https://reviews.llvm.org/D80129 for an explanation of why some ptrmasks are impossible to convert to other addrspaces.)


Full diff: https://github.com/llvm/llvm-project/pull/140802.diff

2 Files Affected:

  • (modified) llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp (+14-1)
  • (modified) llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll (+18)
diff --git a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
index d3771c0903456..4f2e8bbd1102a 100644
--- a/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
+++ b/llvm/lib/Transforms/Scalar/InferAddressSpaces.cpp
@@ -1338,7 +1338,20 @@ bool InferAddressSpacesImpl::rewriteWithNewAddressSpaces(
 
     unsigned OperandNo = PoisonUse->getOperandNo();
     assert(isa<PoisonValue>(NewV->getOperand(OperandNo)));
-    NewV->setOperand(OperandNo, ValueWithNewAddrSpace.lookup(PoisonUse->get()));
+    WeakTrackingVH NewOp = ValueWithNewAddrSpace.lookup(PoisonUse->get());
+    if (NewOp) {
+      NewV->setOperand(OperandNo, NewOp);
+    } else {
+      // Something went wrong while converting the instruction defining the new
+      // operand value.  -> Replace the poison value with the previous operand
+      // value combined with an addrspace case.
+      Value *PoisonOp = NewV->getOperand(OperandNo);
+      Value *OldOp = V->getOperand(OperandNo);
+      Value *AddrSpaceCast =
+          new AddrSpaceCastInst(OldOp, PoisonOp->getType(), "",
+                                cast<Instruction>(NewV)->getIterator());
+      NewV->setOperand(OperandNo, AddrSpaceCast);
+    }
   }
 
   SmallVector<Instruction *, 16> DeadInstructions;
diff --git a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
index 6ef926f935830..1c1d1df79520d 100644
--- a/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
+++ b/llvm/test/Transforms/InferAddressSpaces/AMDGPU/ptrmask.ll
@@ -343,6 +343,24 @@ define i8 @ptrmask_cast_local_to_flat_load_range_mask(ptr addrspace(3) %src.ptr,
   ret i8 %load
 }
 
+; Non-const masks with no known range should not prevent other ptr-manipulating
+; instructions (such as gep) from being converted.
+define i8 @ptrmask_cast_local_to_flat_unknown_mask(ptr addrspace(3) %src.ptr, i64 %mask, i64 %idx) {
+; CHECK-LABEL: @ptrmask_cast_local_to_flat_unknown_mask(
+; CHECK-NEXT:    [[CAST:%.*]] = addrspacecast ptr addrspace(3) [[SRC_PTR:%.*]] to ptr
+; CHECK-NEXT:    [[MASKED:%.*]] = call ptr @llvm.ptrmask.p0.i64(ptr [[CAST]], i64 [[MASK:%.*]])
+; CHECK-NEXT:    [[TMP1:%.*]] = addrspacecast ptr [[MASKED]] to ptr addrspace(3)
+; CHECK-NEXT:    [[GEP:%.*]] = getelementptr i8, ptr addrspace(3) [[TMP1]], i64 [[IDX:%.*]]
+; CHECK-NEXT:    [[LOAD:%.*]] = load i8, ptr addrspace(3) [[GEP]], align 1
+; CHECK-NEXT:    ret i8 [[LOAD]]
+;
+  %cast = addrspacecast ptr addrspace(3) %src.ptr to ptr
+  %masked = call ptr @llvm.ptrmask.p0.i64(ptr %cast, i64 %mask)
+  %gep = getelementptr i8, ptr %masked, i64 %idx
+  %load = load i8, ptr %gep
+  ret i8 %load
+}
+
 declare ptr @llvm.ptrmask.p0.i64(ptr, i64) #0
 declare ptr addrspace(5) @llvm.ptrmask.p5.i32(ptr addrspace(5), i32) #0
 declare ptr addrspace(3) @llvm.ptrmask.p3.i32(ptr addrspace(3), i32) #0

@arsenm arsenm requested review from krzysz00 and arichardson May 22, 2025 17:03
Copy link
Member

@arichardson arichardson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not too familiar with this pass just some comments based on looking at the code.

@krzysz00
Copy link
Contributor

(note: will review properly after the holiday)

@llvmbot llvmbot added the llvm:analysis Includes value tracking, cost tables and constant folding label May 22, 2025
Copy link
Contributor

@krzysz00 krzysz00 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this seems reasonable to me (sorry, lost the review) but I don't know if I've got the requisite knowledge of the code to approve this.

Maybe if there are no objections

ro-i added 6 commits June 10, 2025 06:36
In case a ptrmask cannot be converted to the new address space due to an
unknown mask value, this needs to be detcted and an addrspacecast is
needed to not hinder a future use of the unconverted return value of
ptrmask. Otherwise, users of this value will become invalid by receiving
a nullptr as an operand.

This LLVM defect was identified via the AMD Fuzzing project.
@ro-i ro-i force-pushed the ptrmask-addrspace-fix branch from 9102778 to af027cb Compare June 10, 2025 12:03
@ro-i
Copy link
Contributor Author

ro-i commented Jun 10, 2025

Rebased and resolved conflict

@ro-i
Copy link
Contributor Author

ro-i commented Jun 30, 2025

Ping.

// (e.g. 32-bit) casts work by chopping off the high bits.
if (FromASBitSize < ToASBitSize)
return 0;
return ToASBitSize;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These implicit conversions from unsigned are broken. I think the KnownBits constructors should have explicit added.

This should always return a KnownBits with the correct bitwidth and change the mask value, not the bitwidth

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the KnownBits constructors should have explicit added.

do you dislike the implicit constructor of KnownBits because it's intuitively a bit ambiguous if the unsigned argument stands for "number of known bits" or "bitwidth"?

// pointer in the destination addrspace.
// The default implementation returns an empty optional in case one of the
// addrspaces is not integral.
virtual std::optional<std::pair<KnownBits, KnownBits>>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to return an optional; the default can simply return a completely unknown value

return FromPtrBits.trunc(ToASBitSize);
// By default, we do not assume that null results in null again, except for
// addrspace 0.
if (!FromAS && FromPtrBits.isZero())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if (!FromAS && FromPtrBits.isZero())
if (FromAS == 0 && FromPtrBits.isZero())

Comment on lines +195 to +197
// The default implementation returns an empty optional in case the source
// addrspace is not an integral addrspace.
virtual std::optional<KnownBits>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, just use KnownBits's conservative default

if (FromASBitSize < ToASBitSize)
return 0;
return ToASBitSize;
if (FromPtrBits.getBitWidth() >= ToASBitSize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this all equivalent to

  return FromAS == 0 ? FromPtrBits.zextOrTrunc(ToASBitSize) : FromPtrBits.anyextOrTrunc(ToASBitSize);

@arsenm arsenm requested a review from shiltian September 5, 2025 01:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:AMDGPU llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants