Skip to content

Conversation

LewisCrawford
Copy link
Contributor

Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.

Constant fold the NVVM intrinsics for add, mul, div, fma with specific
rounding modes.
@LewisCrawford LewisCrawford requested a review from Artem-B August 7, 2025 16:22
@llvmbot llvmbot added llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes backend:NVPTX llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Aug 7, 2025
@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-backend-nvptx

Author: Lewis Crawford (LewisCrawford)

Changes

Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.


Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+172)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+178)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll (+876)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-div.ll (+880)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-fma.ll (+874)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-mul.ll (+994)
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
   return DenormalMode::getIEEE();
 }
 
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::nvvm_sqrt_rn_ftz_f:
     return !Call->isStrictFP();
 
+  // NVVM add intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+
+  // NVVM div intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+
+  // NVVM mul intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+
+  // NVVM fma intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+
   // Sign operations are actually bitwise operations, they do not raise
   // exceptions even for SNANs.
   case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
 
         return ConstantFP::get(Ty->getContext(), Res);
       }
+
+      case Intrinsic::nvvm_add_rm_f:
+      case Intrinsic::nvvm_add_rn_f:
+      case Intrinsic::nvvm_add_rp_f:
+      case Intrinsic::nvvm_add_rz_f:
+      case Intrinsic::nvvm_add_rm_d:
+      case Intrinsic::nvvm_add_rn_d:
+      case Intrinsic::nvvm_add_rp_d:
+      case Intrinsic::nvvm_add_rz_d:
+      case Intrinsic::nvvm_add_rm_ftz_f:
+      case Intrinsic::nvvm_add_rn_ftz_f:
+      case Intrinsic::nvvm_add_rp_ftz_f:
+      case Intrinsic::nvvm_add_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFAddRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.add(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_mul_rm_f:
+      case Intrinsic::nvvm_mul_rn_f:
+      case Intrinsic::nvvm_mul_rp_f:
+      case Intrinsic::nvvm_mul_rz_f:
+      case Intrinsic::nvvm_mul_rm_d:
+      case Intrinsic::nvvm_mul_rn_d:
+      case Intrinsic::nvvm_mul_rp_d:
+      case Intrinsic::nvvm_mul_rz_d:
+      case Intrinsic::nvvm_mul_rm_ftz_f:
+      case Intrinsic::nvvm_mul_rn_ftz_f:
+      case Intrinsic::nvvm_mul_rp_ftz_f:
+      case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFMulRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_div_rm_f:
+      case Intrinsic::nvvm_div_rn_f:
+      case Intrinsic::nvvm_div_rp_f:
+      case Intrinsic::nvvm_div_rz_f:
+      case Intrinsic::nvvm_div_rm_d:
+      case Intrinsic::nvvm_div_rn_d:
+      case Intrinsic::nvvm_div_rp_d:
+      case Intrinsic::nvvm_div_rz_d:
+      case Intrinsic::nvvm_div_rm_ftz_f:
+      case Intrinsic::nvvm_div_rn_ftz_f:
+      case Intrinsic::nvvm_div_rp_ftz_f:
+      case Intrinsic::nvvm_div_rz_ftz_f: {
+        bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFDivRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.divide(B, RoundMode);
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
       }
 
       if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
           V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
           return ConstantFP::get(Ty->getContext(), V);
         }
+
+        case Intrinsic::nvvm_fma_rm_f:
+        case Intrinsic::nvvm_fma_rn_f:
+        case Intrinsic::nvvm_fma_rp_f:
+        case Intrinsic::nvvm_fma_rz_f:
+        case Intrinsic::nvvm_fma_rm_d:
+        case Intrinsic::nvvm_fma_rn_d:
+        case Intrinsic::nvvm_fma_rp_d:
+        case Intrinsic::nvvm_fma_rz_d:
+        case Intrinsic::nvvm_fma_rm_ftz_f:
+        case Intrinsic::nvvm_fma_rn_ftz_f:
+        case Intrinsic::nvvm_fma_rp_ftz_f:
+        case Intrinsic::nvvm_fma_rz_ftz_f: {
+          bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+          APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+          APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+          APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+          APFloat::roundingMode RoundMode =
+              nvvm::GetFMARoundingMode(IntrinsicID);
+
+          APFloat Res = A;
+          APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+          if (!Res.isNaN() &&
+              (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+            Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+            return ConstantFP::get(Ty->getContext(), Res);
+          }
+          return nullptr;
+        }
+
         case Intrinsic::amdgcn_cubeid:
         case Intrinsic::amdgcn_cubema:
         case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;#                    Add(1.25, -2.0)                          #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+;###############################################################
+;#                          Add(0.0, NaN)                      #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+defin...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-llvm-ir

Author: Lewis Crawford (LewisCrawford)

Changes

Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.


Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+172)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+178)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll (+876)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-div.ll (+880)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-fma.ll (+874)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-mul.ll (+994)
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
   return DenormalMode::getIEEE();
 }
 
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::nvvm_sqrt_rn_ftz_f:
     return !Call->isStrictFP();
 
+  // NVVM add intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+
+  // NVVM div intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+
+  // NVVM mul intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+
+  // NVVM fma intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+
   // Sign operations are actually bitwise operations, they do not raise
   // exceptions even for SNANs.
   case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
 
         return ConstantFP::get(Ty->getContext(), Res);
       }
+
+      case Intrinsic::nvvm_add_rm_f:
+      case Intrinsic::nvvm_add_rn_f:
+      case Intrinsic::nvvm_add_rp_f:
+      case Intrinsic::nvvm_add_rz_f:
+      case Intrinsic::nvvm_add_rm_d:
+      case Intrinsic::nvvm_add_rn_d:
+      case Intrinsic::nvvm_add_rp_d:
+      case Intrinsic::nvvm_add_rz_d:
+      case Intrinsic::nvvm_add_rm_ftz_f:
+      case Intrinsic::nvvm_add_rn_ftz_f:
+      case Intrinsic::nvvm_add_rp_ftz_f:
+      case Intrinsic::nvvm_add_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFAddRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.add(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_mul_rm_f:
+      case Intrinsic::nvvm_mul_rn_f:
+      case Intrinsic::nvvm_mul_rp_f:
+      case Intrinsic::nvvm_mul_rz_f:
+      case Intrinsic::nvvm_mul_rm_d:
+      case Intrinsic::nvvm_mul_rn_d:
+      case Intrinsic::nvvm_mul_rp_d:
+      case Intrinsic::nvvm_mul_rz_d:
+      case Intrinsic::nvvm_mul_rm_ftz_f:
+      case Intrinsic::nvvm_mul_rn_ftz_f:
+      case Intrinsic::nvvm_mul_rp_ftz_f:
+      case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFMulRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_div_rm_f:
+      case Intrinsic::nvvm_div_rn_f:
+      case Intrinsic::nvvm_div_rp_f:
+      case Intrinsic::nvvm_div_rz_f:
+      case Intrinsic::nvvm_div_rm_d:
+      case Intrinsic::nvvm_div_rn_d:
+      case Intrinsic::nvvm_div_rp_d:
+      case Intrinsic::nvvm_div_rz_d:
+      case Intrinsic::nvvm_div_rm_ftz_f:
+      case Intrinsic::nvvm_div_rn_ftz_f:
+      case Intrinsic::nvvm_div_rp_ftz_f:
+      case Intrinsic::nvvm_div_rz_ftz_f: {
+        bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFDivRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.divide(B, RoundMode);
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
       }
 
       if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
           V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
           return ConstantFP::get(Ty->getContext(), V);
         }
+
+        case Intrinsic::nvvm_fma_rm_f:
+        case Intrinsic::nvvm_fma_rn_f:
+        case Intrinsic::nvvm_fma_rp_f:
+        case Intrinsic::nvvm_fma_rz_f:
+        case Intrinsic::nvvm_fma_rm_d:
+        case Intrinsic::nvvm_fma_rn_d:
+        case Intrinsic::nvvm_fma_rp_d:
+        case Intrinsic::nvvm_fma_rz_d:
+        case Intrinsic::nvvm_fma_rm_ftz_f:
+        case Intrinsic::nvvm_fma_rn_ftz_f:
+        case Intrinsic::nvvm_fma_rp_ftz_f:
+        case Intrinsic::nvvm_fma_rz_ftz_f: {
+          bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+          APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+          APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+          APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+          APFloat::roundingMode RoundMode =
+              nvvm::GetFMARoundingMode(IntrinsicID);
+
+          APFloat Res = A;
+          APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+          if (!Res.isNaN() &&
+              (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+            Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+            return ConstantFP::get(Ty->getContext(), Res);
+          }
+          return nullptr;
+        }
+
         case Intrinsic::amdgcn_cubeid:
         case Intrinsic::amdgcn_cubema:
         case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;#                    Add(1.25, -2.0)                          #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+;###############################################################
+;#                          Add(0.0, NaN)                      #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+defin...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Aug 7, 2025

@llvm/pr-subscribers-llvm-analysis

Author: Lewis Crawford (LewisCrawford)

Changes

Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.


Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff

6 Files Affected:

  • (modified) llvm/include/llvm/IR/NVVMIntrinsicUtils.h (+172)
  • (modified) llvm/lib/Analysis/ConstantFolding.cpp (+178)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll (+876)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-div.ll (+880)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-fma.ll (+874)
  • (added) llvm/test/Transforms/InstSimplify/const-fold-nvvm-mul.ll (+994)
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
   return DenormalMode::getIEEE();
 }
 
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return true;
+
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+    return false;
+  }
+  llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+  switch (IntrinsicID) {
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+    return APFloat::rmTowardNegative;
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+    return APFloat::rmNearestTiesToEven;
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+    return APFloat::rmTowardPositive;
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+    return APFloat::rmTowardZero;
+  }
+  llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
 } // namespace nvvm
 } // namespace llvm
 #endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
   case Intrinsic::nvvm_sqrt_rn_ftz_f:
     return !Call->isStrictFP();
 
+  // NVVM add intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_add_rm_d:
+  case Intrinsic::nvvm_add_rn_d:
+  case Intrinsic::nvvm_add_rp_d:
+  case Intrinsic::nvvm_add_rz_d:
+  case Intrinsic::nvvm_add_rm_f:
+  case Intrinsic::nvvm_add_rn_f:
+  case Intrinsic::nvvm_add_rp_f:
+  case Intrinsic::nvvm_add_rz_f:
+  case Intrinsic::nvvm_add_rm_ftz_f:
+  case Intrinsic::nvvm_add_rn_ftz_f:
+  case Intrinsic::nvvm_add_rp_ftz_f:
+  case Intrinsic::nvvm_add_rz_ftz_f:
+
+  // NVVM div intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_div_rm_d:
+  case Intrinsic::nvvm_div_rn_d:
+  case Intrinsic::nvvm_div_rp_d:
+  case Intrinsic::nvvm_div_rz_d:
+  case Intrinsic::nvvm_div_rm_f:
+  case Intrinsic::nvvm_div_rn_f:
+  case Intrinsic::nvvm_div_rp_f:
+  case Intrinsic::nvvm_div_rz_f:
+  case Intrinsic::nvvm_div_rm_ftz_f:
+  case Intrinsic::nvvm_div_rn_ftz_f:
+  case Intrinsic::nvvm_div_rp_ftz_f:
+  case Intrinsic::nvvm_div_rz_ftz_f:
+
+  // NVVM mul intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_mul_rm_d:
+  case Intrinsic::nvvm_mul_rn_d:
+  case Intrinsic::nvvm_mul_rp_d:
+  case Intrinsic::nvvm_mul_rz_d:
+  case Intrinsic::nvvm_mul_rm_f:
+  case Intrinsic::nvvm_mul_rn_f:
+  case Intrinsic::nvvm_mul_rp_f:
+  case Intrinsic::nvvm_mul_rz_f:
+  case Intrinsic::nvvm_mul_rm_ftz_f:
+  case Intrinsic::nvvm_mul_rn_ftz_f:
+  case Intrinsic::nvvm_mul_rp_ftz_f:
+  case Intrinsic::nvvm_mul_rz_ftz_f:
+
+  // NVVM fma intrinsics with explicit rounding modes
+  case Intrinsic::nvvm_fma_rm_d:
+  case Intrinsic::nvvm_fma_rn_d:
+  case Intrinsic::nvvm_fma_rp_d:
+  case Intrinsic::nvvm_fma_rz_d:
+  case Intrinsic::nvvm_fma_rm_f:
+  case Intrinsic::nvvm_fma_rn_f:
+  case Intrinsic::nvvm_fma_rp_f:
+  case Intrinsic::nvvm_fma_rz_f:
+  case Intrinsic::nvvm_fma_rm_ftz_f:
+  case Intrinsic::nvvm_fma_rn_ftz_f:
+  case Intrinsic::nvvm_fma_rp_ftz_f:
+  case Intrinsic::nvvm_fma_rz_ftz_f:
+
   // Sign operations are actually bitwise operations, they do not raise
   // exceptions even for SNANs.
   case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
 
         return ConstantFP::get(Ty->getContext(), Res);
       }
+
+      case Intrinsic::nvvm_add_rm_f:
+      case Intrinsic::nvvm_add_rn_f:
+      case Intrinsic::nvvm_add_rp_f:
+      case Intrinsic::nvvm_add_rz_f:
+      case Intrinsic::nvvm_add_rm_d:
+      case Intrinsic::nvvm_add_rn_d:
+      case Intrinsic::nvvm_add_rp_d:
+      case Intrinsic::nvvm_add_rz_d:
+      case Intrinsic::nvvm_add_rm_ftz_f:
+      case Intrinsic::nvvm_add_rn_ftz_f:
+      case Intrinsic::nvvm_add_rp_ftz_f:
+      case Intrinsic::nvvm_add_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFAddRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.add(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_mul_rm_f:
+      case Intrinsic::nvvm_mul_rn_f:
+      case Intrinsic::nvvm_mul_rp_f:
+      case Intrinsic::nvvm_mul_rz_f:
+      case Intrinsic::nvvm_mul_rm_d:
+      case Intrinsic::nvvm_mul_rn_d:
+      case Intrinsic::nvvm_mul_rp_d:
+      case Intrinsic::nvvm_mul_rz_d:
+      case Intrinsic::nvvm_mul_rm_ftz_f:
+      case Intrinsic::nvvm_mul_rn_ftz_f:
+      case Intrinsic::nvvm_mul_rp_ftz_f:
+      case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+        bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFMulRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
+
+      case Intrinsic::nvvm_div_rm_f:
+      case Intrinsic::nvvm_div_rn_f:
+      case Intrinsic::nvvm_div_rp_f:
+      case Intrinsic::nvvm_div_rz_f:
+      case Intrinsic::nvvm_div_rm_d:
+      case Intrinsic::nvvm_div_rn_d:
+      case Intrinsic::nvvm_div_rp_d:
+      case Intrinsic::nvvm_div_rz_d:
+      case Intrinsic::nvvm_div_rm_ftz_f:
+      case Intrinsic::nvvm_div_rn_ftz_f:
+      case Intrinsic::nvvm_div_rp_ftz_f:
+      case Intrinsic::nvvm_div_rz_ftz_f: {
+        bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+        APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+        APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+        APFloat::roundingMode RoundMode =
+            nvvm::GetFDivRoundingMode(IntrinsicID);
+
+        APFloat Res = A;
+        APFloat::opStatus Status = Res.divide(B, RoundMode);
+        if (!Res.isNaN() &&
+            (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+          Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+          return ConstantFP::get(Ty->getContext(), Res);
+        }
+        return nullptr;
+      }
       }
 
       if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
           V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
           return ConstantFP::get(Ty->getContext(), V);
         }
+
+        case Intrinsic::nvvm_fma_rm_f:
+        case Intrinsic::nvvm_fma_rn_f:
+        case Intrinsic::nvvm_fma_rp_f:
+        case Intrinsic::nvvm_fma_rz_f:
+        case Intrinsic::nvvm_fma_rm_d:
+        case Intrinsic::nvvm_fma_rn_d:
+        case Intrinsic::nvvm_fma_rp_d:
+        case Intrinsic::nvvm_fma_rz_d:
+        case Intrinsic::nvvm_fma_rm_ftz_f:
+        case Intrinsic::nvvm_fma_rn_ftz_f:
+        case Intrinsic::nvvm_fma_rp_ftz_f:
+        case Intrinsic::nvvm_fma_rz_ftz_f: {
+          bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+          APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+          APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+          APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+          APFloat::roundingMode RoundMode =
+              nvvm::GetFMARoundingMode(IntrinsicID);
+
+          APFloat Res = A;
+          APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+          if (!Res.isNaN() &&
+              (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+            Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+            return ConstantFP::get(Ty->getContext(), Res);
+          }
+          return nullptr;
+        }
+
         case Intrinsic::amdgcn_cubeid:
         case Intrinsic::amdgcn_cubema:
         case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;#                    Add(1.25, -2.0)                          #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT:    ret double -7.500000e-01
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+  ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT:    ret float -7.500000e-01
+;
+  %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+  ret float %res
+}
+
+;###############################################################
+;#                          Add(0.0, NaN)                      #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT:    [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT:    ret double [[RES]]
+;
+  %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+  ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT:    [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT:    ret float [[RES]]
+;
+  %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+  ret float %res
+}
+
+defin...
[truncated]

Copy link
Member

@Artem-B Artem-B left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@LewisCrawford
Copy link
Contributor Author

Thanks for reviewing this! :) I'm going to wait a few days before merging this until some downstream issues have been resolved, but will merge it here soon.

@LewisCrawford LewisCrawford merged commit 729701b into llvm:main Aug 22, 2025
15 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend:NVPTX llvm:analysis Includes value tracking, cost tables and constant folding llvm:instcombine Covers the InstCombine, InstSimplify and AggressiveInstCombine passes llvm:ir llvm:transforms
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants