-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[NVPTX] Constant fold NVVM add/mul/div/fma #152544
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.
@llvm/pr-subscribers-llvm-transforms @llvm/pr-subscribers-backend-nvptx Author: Lewis Crawford (LewisCrawford) ChangesConstant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes. Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff 6 Files Affected:
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
return DenormalMode::getIEEE();
}
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return !Call->isStrictFP();
+ // NVVM add intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+
+ // NVVM div intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+
+ // NVVM mul intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+
+ // NVVM fma intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+
// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
return ConstantFP::get(Ty->getContext(), Res);
}
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFAddRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.add(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMulRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f: {
+ bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFDivRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.divide(B, RoundMode);
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
}
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
return ConstantFP::get(Ty->getContext(), V);
}
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f: {
+ bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+ APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+ APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMARoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
case Intrinsic::amdgcn_cubeid:
case Intrinsic::amdgcn_cubema:
case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;# Add(1.25, -2.0) #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+;###############################################################
+;# Add(0.0, NaN) #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+defin...
[truncated]
|
@llvm/pr-subscribers-llvm-ir Author: Lewis Crawford (LewisCrawford) ChangesConstant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes. Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff 6 Files Affected:
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
return DenormalMode::getIEEE();
}
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return !Call->isStrictFP();
+ // NVVM add intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+
+ // NVVM div intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+
+ // NVVM mul intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+
+ // NVVM fma intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+
// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
return ConstantFP::get(Ty->getContext(), Res);
}
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFAddRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.add(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMulRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f: {
+ bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFDivRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.divide(B, RoundMode);
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
}
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
return ConstantFP::get(Ty->getContext(), V);
}
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f: {
+ bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+ APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+ APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMARoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
case Intrinsic::amdgcn_cubeid:
case Intrinsic::amdgcn_cubema:
case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;# Add(1.25, -2.0) #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+;###############################################################
+;# Add(0.0, NaN) #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+defin...
[truncated]
|
@llvm/pr-subscribers-llvm-analysis Author: Lewis Crawford (LewisCrawford) ChangesConstant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes. Patch is 146.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/152544.diff 6 Files Affected:
diff --git a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
index 11bfd733a8854..cc4929a1ff8da 100644
--- a/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
+++ b/llvm/include/llvm/IR/NVVMIntrinsicUtils.h
@@ -414,6 +414,178 @@ inline DenormalMode GetNVVMDenormMode(bool ShouldFTZ) {
return DenormalMode::getIEEE();
}
+inline bool FAddShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM add intrinsic");
+}
+
+inline APFloat::roundingMode GetFAddRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM add");
+}
+
+inline bool FMulShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM mul intrinsic");
+}
+
+inline APFloat::roundingMode GetFMulRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM mul");
+}
+
+inline bool FDivShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM div intrinsic");
+}
+
+inline APFloat::roundingMode GetFDivRoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM div");
+}
+
+inline bool FMAShouldFTZ(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return true;
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ return false;
+ }
+ llvm_unreachable("Checking FTZ flag for invalid NVVM fma intrinsic");
+}
+
+inline APFloat::roundingMode GetFMARoundingMode(Intrinsic::ID IntrinsicID) {
+ switch (IntrinsicID) {
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ return APFloat::rmTowardNegative;
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ return APFloat::rmNearestTiesToEven;
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ return APFloat::rmTowardPositive;
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+ return APFloat::rmTowardZero;
+ }
+ llvm_unreachable("Invalid FP instrinsic rounding mode for NVVM fma");
+}
+
} // namespace nvvm
} // namespace llvm
#endif // LLVM_IR_NVVMINTRINSICUTILS_H
diff --git a/llvm/lib/Analysis/ConstantFolding.cpp b/llvm/lib/Analysis/ConstantFolding.cpp
index dd98b62baca33..69fdd4f2b3e71 100644
--- a/llvm/lib/Analysis/ConstantFolding.cpp
+++ b/llvm/lib/Analysis/ConstantFolding.cpp
@@ -1843,6 +1843,62 @@ bool llvm::canConstantFoldCallTo(const CallBase *Call, const Function *F) {
case Intrinsic::nvvm_sqrt_rn_ftz_f:
return !Call->isStrictFP();
+ // NVVM add intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f:
+
+ // NVVM div intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f:
+
+ // NVVM mul intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f:
+
+ // NVVM fma intrinsics with explicit rounding modes
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f:
+
// Sign operations are actually bitwise operations, they do not raise
// exceptions even for SNANs.
case Intrinsic::fabs:
@@ -3318,6 +3374,96 @@ static Constant *ConstantFoldIntrinsicCall2(Intrinsic::ID IntrinsicID, Type *Ty,
return ConstantFP::get(Ty->getContext(), Res);
}
+
+ case Intrinsic::nvvm_add_rm_f:
+ case Intrinsic::nvvm_add_rn_f:
+ case Intrinsic::nvvm_add_rp_f:
+ case Intrinsic::nvvm_add_rz_f:
+ case Intrinsic::nvvm_add_rm_d:
+ case Intrinsic::nvvm_add_rn_d:
+ case Intrinsic::nvvm_add_rp_d:
+ case Intrinsic::nvvm_add_rz_d:
+ case Intrinsic::nvvm_add_rm_ftz_f:
+ case Intrinsic::nvvm_add_rn_ftz_f:
+ case Intrinsic::nvvm_add_rp_ftz_f:
+ case Intrinsic::nvvm_add_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FAddShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFAddRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.add(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_mul_rm_f:
+ case Intrinsic::nvvm_mul_rn_f:
+ case Intrinsic::nvvm_mul_rp_f:
+ case Intrinsic::nvvm_mul_rz_f:
+ case Intrinsic::nvvm_mul_rm_d:
+ case Intrinsic::nvvm_mul_rn_d:
+ case Intrinsic::nvvm_mul_rp_d:
+ case Intrinsic::nvvm_mul_rz_d:
+ case Intrinsic::nvvm_mul_rm_ftz_f:
+ case Intrinsic::nvvm_mul_rn_ftz_f:
+ case Intrinsic::nvvm_mul_rp_ftz_f:
+ case Intrinsic::nvvm_mul_rz_ftz_f: {
+
+ bool IsFTZ = nvvm::FMulShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMulRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.multiply(B, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
+ case Intrinsic::nvvm_div_rm_f:
+ case Intrinsic::nvvm_div_rn_f:
+ case Intrinsic::nvvm_div_rp_f:
+ case Intrinsic::nvvm_div_rz_f:
+ case Intrinsic::nvvm_div_rm_d:
+ case Intrinsic::nvvm_div_rn_d:
+ case Intrinsic::nvvm_div_rp_d:
+ case Intrinsic::nvvm_div_rz_d:
+ case Intrinsic::nvvm_div_rm_ftz_f:
+ case Intrinsic::nvvm_div_rn_ftz_f:
+ case Intrinsic::nvvm_div_rp_ftz_f:
+ case Intrinsic::nvvm_div_rz_ftz_f: {
+ bool IsFTZ = nvvm::FDivShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(Op1V) : Op1V;
+ APFloat B = IsFTZ ? FTZPreserveSign(Op2V) : Op2V;
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFDivRoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.divide(B, RoundMode);
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
}
if (!Ty->isHalfTy() && !Ty->isFloatTy() && !Ty->isDoubleTy())
@@ -3729,6 +3875,38 @@ static Constant *ConstantFoldScalarCall3(StringRef Name,
V.fusedMultiplyAdd(C2, C3, APFloat::rmNearestTiesToEven);
return ConstantFP::get(Ty->getContext(), V);
}
+
+ case Intrinsic::nvvm_fma_rm_f:
+ case Intrinsic::nvvm_fma_rn_f:
+ case Intrinsic::nvvm_fma_rp_f:
+ case Intrinsic::nvvm_fma_rz_f:
+ case Intrinsic::nvvm_fma_rm_d:
+ case Intrinsic::nvvm_fma_rn_d:
+ case Intrinsic::nvvm_fma_rp_d:
+ case Intrinsic::nvvm_fma_rz_d:
+ case Intrinsic::nvvm_fma_rm_ftz_f:
+ case Intrinsic::nvvm_fma_rn_ftz_f:
+ case Intrinsic::nvvm_fma_rp_ftz_f:
+ case Intrinsic::nvvm_fma_rz_ftz_f: {
+ bool IsFTZ = nvvm::FMAShouldFTZ(IntrinsicID);
+ APFloat A = IsFTZ ? FTZPreserveSign(C1) : C1;
+ APFloat B = IsFTZ ? FTZPreserveSign(C2) : C2;
+ APFloat C = IsFTZ ? FTZPreserveSign(C3) : C3;
+
+ APFloat::roundingMode RoundMode =
+ nvvm::GetFMARoundingMode(IntrinsicID);
+
+ APFloat Res = A;
+ APFloat::opStatus Status = Res.fusedMultiplyAdd(B, C, RoundMode);
+
+ if (!Res.isNaN() &&
+ (Status == APFloat::opOK || Status == APFloat::opInexact)) {
+ Res = IsFTZ ? FTZPreserveSign(Res) : Res;
+ return ConstantFP::get(Ty->getContext(), Res);
+ }
+ return nullptr;
+ }
+
case Intrinsic::amdgcn_cubeid:
case Intrinsic::amdgcn_cubema:
case Intrinsic::amdgcn_cubesc:
diff --git a/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
new file mode 100644
index 0000000000000..a3d87439d74cf
--- /dev/null
+++ b/llvm/test/Transforms/InstSimplify/const-fold-nvvm-add.ll
@@ -0,0 +1,876 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt < %s -passes=instsimplify -march=nvptx64 --mcpu=sm_86 --mattr=+ptx72 -S | FileCheck %s
+
+; Check constant-folding for NVVM add intrinsics with different rounding modes
+
+;###############################################################
+;# Add(1.25, -2.0) #
+;###############################################################
+; Tests addition of two normal numbers (1.25 and -2.0) where the result
+; is exactly representable. All rounding modes produce the same result.
+
+define double @test_1_25_minus_2_rm_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rm_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rn_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rn_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rp_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rp_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define double @test_1_25_minus_2_rz_d() {
+; CHECK-LABEL: define double @test_1_25_minus_2_rz_d() {
+; CHECK-NEXT: ret double -7.500000e-01
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 1.25, double -2.0)
+ ret double %res
+}
+
+define float @test_1_25_minus_2_rm_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rm_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rm.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rn_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rn.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rp_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rp.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-LABEL: define float @test_1_25_minus_2_rz_ftz_f() {
+; CHECK-NEXT: ret float -7.500000e-01
+;
+ %res = call float @llvm.nvvm.add.rz.ftz.f(float 1.25, float -2.0)
+ ret float %res
+}
+
+;###############################################################
+;# Add(0.0, NaN) #
+;###############################################################
+; Tests addition of a zero with NaN.
+; The result is always NaN and the operation is not constant-folded.
+
+define double @test_zero_plus_nan_rm_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rm_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rm.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rm.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rn_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rn_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rn.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rn.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rp_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rp_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rp.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rp.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define double @test_zero_plus_nan_rz_d() {
+; CHECK-LABEL: define double @test_zero_plus_nan_rz_d() {
+; CHECK-NEXT: [[RES:%.*]] = call double @llvm.nvvm.add.rz.d(double 0.000000e+00, double 0x7FF4444400000000)
+; CHECK-NEXT: ret double [[RES]]
+;
+ %res = call double @llvm.nvvm.add.rz.d(double 0.0, double 0x7ff4444400000000)
+ ret double %res
+}
+
+define float @test_zero_plus_nan_rm_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rm_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rm.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rm.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rn_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rn_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rn.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rn.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rp_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rp_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rp.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rp.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+define float @test_zero_plus_nan_rz_f() {
+; CHECK-LABEL: define float @test_zero_plus_nan_rz_f() {
+; CHECK-NEXT: [[RES:%.*]] = call float @llvm.nvvm.add.rz.f(float 0.000000e+00, float 0x7FFF444400000000)
+; CHECK-NEXT: ret float [[RES]]
+;
+ %res = call float @llvm.nvvm.add.rz.f(float 0.0, float 0x7FFF444400000000)
+ ret float %res
+}
+
+defin...
[truncated]
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thanks for reviewing this! :) I'm going to wait a few days before merging this until some downstream issues have been resolved, but will merge it here soon. |
Constant fold the NVVM intrinsics for add, mul, div, fma with specific rounding modes.