Skip to content

Commit 2531535

Browse files
author
Changpeng Fang
committed
AMDGPU: Implement FDIV optimizations in AMDGPUCodeGenPrepare
Summary: RCP has the accuracy limit. If FDIV fpmath require high accuracy rcp may not meet the requirement. However, in DAG lowering, fpmath information gets lost, and thus we may generate either inaccurate rcp related computation or slow code for fdiv. In patch implements fdiv optimizations in the AMDGPUCodeGenPrepare, which could exactly know !fpmath. FastUnsafeRcpLegal: We determine whether it is legal to use rcp based on unsafe-fp-math, fast math flags, denormals and fpmath accuracy request. RCP Optimizations: 1/x -> rcp(x) when fast unsafe rcp is legal or fpmath >= 2.5ULP with denormals flushed. a/b -> a*rcp(b) when fast unsafe rcp is legal. Use fdiv.fast: a/b -> fdiv.fast(a, b) when RCP optimization is not performed and fpmath >= 2.5ULP with denormals flushed. 1/x -> fdiv.fast(1,x) when RCP optimization is not performed and fpmath >= 2.5ULP with denormals. Reviewers: arsenm Differential Revision: https://reviews.llvm.org/D71293
1 parent 7ad17e0 commit 2531535

12 files changed

+479
-172
lines changed

llvm/lib/Target/AMDGPU/AMDGPUCodeGenPrepare.cpp

Lines changed: 110 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -606,12 +606,64 @@ bool AMDGPUCodeGenPrepare::foldBinOpIntoSelect(BinaryOperator &BO) const {
606606
return true;
607607
}
608608

609-
static bool shouldKeepFDivF32(Value *Num, bool UnsafeDiv, bool HasDenormals) {
609+
// Perform RCP optimizations:
610+
//
611+
// 1/x -> rcp(x) when fast unsafe rcp is legal or fpmath >= 2.5ULP with
612+
// denormals flushed.
613+
//
614+
// a/b -> a*rcp(b) when fast unsafe rcp is legal.
615+
static Value *performRCPOpt(Value *Num, Value *Den, bool FastUnsafeRcpLegal,
616+
IRBuilder<> Builder, MDNode *FPMath, Module *Mod,
617+
bool HasDenormals, bool NeedHighAccuracy) {
618+
619+
Type *Ty = Den->getType();
620+
if (!FastUnsafeRcpLegal && Ty->isFloatTy() &&
621+
(HasDenormals || NeedHighAccuracy))
622+
return nullptr;
623+
624+
Function *Decl = Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_rcp, Ty);
625+
if (const ConstantFP *CLHS = dyn_cast<ConstantFP>(Num)) {
626+
if (FastUnsafeRcpLegal || Ty->isFloatTy() || Ty->isHalfTy()) {
627+
if (CLHS->isExactlyValue(1.0)) {
628+
// v_rcp_f32 and v_rsq_f32 do not support denormals, and according to
629+
// the CI documentation has a worst case error of 1 ulp.
630+
// OpenCL requires <= 2.5 ulp for 1.0 / x, so it should always be OK to
631+
// use it as long as we aren't trying to use denormals.
632+
//
633+
// v_rcp_f16 and v_rsq_f16 DO support denormals.
634+
635+
// NOTE: v_sqrt and v_rcp will be combined to v_rsq later. So we don't
636+
// insert rsq intrinsic here.
637+
638+
// 1.0 / x -> rcp(x)
639+
return Builder.CreateCall(Decl, { Den });
640+
}
641+
642+
// Same as for 1.0, but expand the sign out of the constant.
643+
if (CLHS->isExactlyValue(-1.0)) {
644+
// -1.0 / x -> rcp (fneg x)
645+
Value *FNeg = Builder.CreateFNeg(Den);
646+
return Builder.CreateCall(Decl, { FNeg });
647+
}
648+
}
649+
}
650+
651+
if (FastUnsafeRcpLegal) {
652+
// Turn into multiply by the reciprocal.
653+
// x / y -> x * (1.0 / y)
654+
Value *Recip = Builder.CreateCall(Decl, { Den });
655+
return Builder.CreateFMul(Num, Recip, "", FPMath);
656+
}
657+
return nullptr;
658+
}
659+
660+
static bool shouldKeepFDivF32(Value *Num, bool FastUnsafeRcpLegal,
661+
bool HasDenormals) {
610662
const ConstantFP *CNum = dyn_cast<ConstantFP>(Num);
611663
if (!CNum)
612664
return HasDenormals;
613665

614-
if (UnsafeDiv)
666+
if (FastUnsafeRcpLegal)
615667
return true;
616668

617669
bool IsOne = CNum->isExactlyValue(+1.0) || CNum->isExactlyValue(-1.0);
@@ -620,64 +672,90 @@ static bool shouldKeepFDivF32(Value *Num, bool UnsafeDiv, bool HasDenormals) {
620672
return HasDenormals ^ IsOne;
621673
}
622674

623-
// Insert an intrinsic for fast fdiv for safe math situations where we can
624-
// reduce precision. Leave fdiv for situations where the generic node is
625-
// expected to be optimized.
675+
676+
// Optimizations is performed based on fpmath, fast math flags as wells as
677+
// denormals to lower fdiv using either rcp or fdiv.fast.
678+
//
679+
// FastUnsafeRcpLegal: We determine whether it is legal to use rcp based on
680+
// unsafe-fp-math, fast math flags, denormals and fpmath
681+
// accuracy request.
682+
//
683+
// RCP Optimizations:
684+
// 1/x -> rcp(x) when fast unsafe rcp is legal or fpmath >= 2.5ULP with
685+
// denormals flushed.
686+
// a/b -> a*rcp(b) when fast unsafe rcp is legal.
687+
//
688+
// Use fdiv.fast:
689+
// a/b -> fdiv.fast(a, b) when RCP optimization is not performed and
690+
// fpmath >= 2.5ULP with denormals flushed.
691+
//
692+
// 1/x -> fdiv.fast(1,x) when RCP optimization is not performed and
693+
// fpmath >= 2.5ULP with denormals.
626694
bool AMDGPUCodeGenPrepare::visitFDiv(BinaryOperator &FDiv) {
627-
Type *Ty = FDiv.getType();
628695

629-
if (!Ty->getScalarType()->isFloatTy())
630-
return false;
696+
Type *Ty = FDiv.getType()->getScalarType();
631697

632-
MDNode *FPMath = FDiv.getMetadata(LLVMContext::MD_fpmath);
633-
if (!FPMath)
698+
// No intrinsic for fdiv16 if target does not support f16.
699+
if (Ty->isHalfTy() && !ST->has16BitInsts())
634700
return false;
635701

636702
const FPMathOperator *FPOp = cast<const FPMathOperator>(&FDiv);
637-
float ULP = FPOp->getFPAccuracy();
638-
if (ULP < 2.5f)
639-
return false;
703+
MDNode *FPMath = FDiv.getMetadata(LLVMContext::MD_fpmath);
704+
const bool NeedHighAccuracy = !FPMath || FPOp->getFPAccuracy() < 2.5f;
640705

641706
FastMathFlags FMF = FPOp->getFastMathFlags();
642-
bool UnsafeDiv = HasUnsafeFPMath || FMF.isFast() ||
643-
FMF.allowReciprocal();
707+
// Determine whether it is ok to use rcp based on unsafe-fp-math,
708+
// fast math flags, denormals and accuracy request.
709+
const bool FastUnsafeRcpLegal = HasUnsafeFPMath || FMF.isFast() ||
710+
(FMF.allowReciprocal() && ((!HasFP32Denormals && !NeedHighAccuracy)
711+
|| FMF.approxFunc()));
644712

645-
// With UnsafeDiv node will be optimized to just rcp and mul.
646-
if (UnsafeDiv)
647-
return false;
713+
// Use fdiv.fast for only f32, fpmath >= 2.5ULP and rcp is not used.
714+
const bool UseFDivFast = Ty->isFloatTy() && !NeedHighAccuracy &&
715+
!FastUnsafeRcpLegal;
648716

649-
IRBuilder<> Builder(FDiv.getParent(), std::next(FDiv.getIterator()), FPMath);
717+
IRBuilder<> Builder(FDiv.getParent(), std::next(FDiv.getIterator()));
650718
Builder.setFastMathFlags(FMF);
651719
Builder.SetCurrentDebugLocation(FDiv.getDebugLoc());
652720

653-
Function *Decl = Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_fdiv_fast);
654-
655721
Value *Num = FDiv.getOperand(0);
656722
Value *Den = FDiv.getOperand(1);
657723

658724
Value *NewFDiv = nullptr;
659-
660-
if (VectorType *VT = dyn_cast<VectorType>(Ty)) {
725+
if (VectorType *VT = dyn_cast<VectorType>(FDiv.getType())) {
661726
NewFDiv = UndefValue::get(VT);
662727

663728
// FIXME: Doesn't do the right thing for cases where the vector is partially
664729
// constant. This works when the scalarizer pass is run first.
665730
for (unsigned I = 0, E = VT->getNumElements(); I != E; ++I) {
666731
Value *NumEltI = Builder.CreateExtractElement(Num, I);
667732
Value *DenEltI = Builder.CreateExtractElement(Den, I);
668-
Value *NewElt;
669-
670-
if (shouldKeepFDivF32(NumEltI, UnsafeDiv, HasFP32Denormals)) {
671-
NewElt = Builder.CreateFDiv(NumEltI, DenEltI);
672-
} else {
673-
NewElt = Builder.CreateCall(Decl, { NumEltI, DenEltI });
733+
Value *NewElt = nullptr;
734+
if (UseFDivFast && !shouldKeepFDivF32(NumEltI, FastUnsafeRcpLegal,
735+
HasFP32Denormals)) {
736+
Function *Decl =
737+
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_fdiv_fast);
738+
NewElt = Builder.CreateCall(Decl, { NumEltI, DenEltI }, "", FPMath);
674739
}
740+
if (!NewElt) // Try rcp.
741+
NewElt = performRCPOpt(NumEltI, DenEltI, FastUnsafeRcpLegal, Builder,
742+
FPMath, Mod, HasFP32Denormals, NeedHighAccuracy);
743+
if (!NewElt)
744+
NewElt = Builder.CreateFDiv(NumEltI, DenEltI, "", FPMath);
675745

676746
NewFDiv = Builder.CreateInsertElement(NewFDiv, NewElt, I);
677747
}
678-
} else {
679-
if (!shouldKeepFDivF32(Num, UnsafeDiv, HasFP32Denormals))
680-
NewFDiv = Builder.CreateCall(Decl, { Num, Den });
748+
} else { // Scalar.
749+
if (UseFDivFast && !shouldKeepFDivF32(Num, FastUnsafeRcpLegal,
750+
HasFP32Denormals)) {
751+
Function *Decl =
752+
Intrinsic::getDeclaration(Mod, Intrinsic::amdgcn_fdiv_fast);
753+
NewFDiv = Builder.CreateCall(Decl, { Num, Den }, "", FPMath);
754+
}
755+
if (!NewFDiv) { // Try rcp.
756+
NewFDiv = performRCPOpt(Num, Den, FastUnsafeRcpLegal, Builder, FPMath,
757+
Mod, HasFP32Denormals, NeedHighAccuracy);
758+
}
681759
}
682760

683761
if (NewFDiv) {

llvm/lib/Target/AMDGPU/SIISelLowering.cpp

Lines changed: 44 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -7474,49 +7474,54 @@ SDValue SITargetLowering::lowerFastUnsafeFDIV(SDValue Op,
74747474
SDValue RHS = Op.getOperand(1);
74757475
EVT VT = Op.getValueType();
74767476
const SDNodeFlags Flags = Op->getFlags();
7477-
bool Unsafe = DAG.getTarget().Options.UnsafeFPMath || Flags.hasAllowReciprocal();
74787477

7479-
if (!Unsafe && VT == MVT::f32 && hasFP32Denormals(DAG.getMachineFunction()))
7478+
bool FastUnsafeRcpLegal = DAG.getTarget().Options.UnsafeFPMath ||
7479+
(Flags.hasAllowReciprocal() &&
7480+
((VT == MVT::f32 && hasFP32Denormals(DAG.getMachineFunction())) ||
7481+
VT == MVT::f16 ||
7482+
Flags.hasApproximateFuncs()));
7483+
7484+
// Do rcp optimization only when fast unsafe rcp is legal here.
7485+
// NOTE: We already performed RCP optimization to insert intrinsics in
7486+
// AMDGPUCodeGenPrepare. Ideally there should have no opportunity here to
7487+
// rcp optimization.
7488+
// However, there are cases like FREM, which is expended into a sequence
7489+
// of instructions including FDIV, which may expose new opportunities.
7490+
if (!FastUnsafeRcpLegal)
74807491
return SDValue();
74817492

74827493
if (const ConstantFPSDNode *CLHS = dyn_cast<ConstantFPSDNode>(LHS)) {
7483-
if (Unsafe || VT == MVT::f32 || VT == MVT::f16) {
7484-
if (CLHS->isExactlyValue(1.0)) {
7485-
// v_rcp_f32 and v_rsq_f32 do not support denormals, and according to
7486-
// the CI documentation has a worst case error of 1 ulp.
7487-
// OpenCL requires <= 2.5 ulp for 1.0 / x, so it should always be OK to
7488-
// use it as long as we aren't trying to use denormals.
7489-
//
7490-
// v_rcp_f16 and v_rsq_f16 DO support denormals.
7491-
7492-
// 1.0 / sqrt(x) -> rsq(x)
7493-
7494-
// XXX - Is UnsafeFPMath sufficient to do this for f64? The maximum ULP
7495-
// error seems really high at 2^29 ULP.
7496-
if (RHS.getOpcode() == ISD::FSQRT)
7497-
return DAG.getNode(AMDGPUISD::RSQ, SL, VT, RHS.getOperand(0));
7498-
7499-
// 1.0 / x -> rcp(x)
7500-
return DAG.getNode(AMDGPUISD::RCP, SL, VT, RHS);
7501-
}
7494+
if (CLHS->isExactlyValue(1.0)) {
7495+
// v_rcp_f32 and v_rsq_f32 do not support denormals, and according to
7496+
// the CI documentation has a worst case error of 1 ulp.
7497+
// OpenCL requires <= 2.5 ulp for 1.0 / x, so it should always be OK to
7498+
// use it as long as we aren't trying to use denormals.
7499+
//
7500+
// v_rcp_f16 and v_rsq_f16 DO support denormals.
75027501

7503-
// Same as for 1.0, but expand the sign out of the constant.
7504-
if (CLHS->isExactlyValue(-1.0)) {
7505-
// -1.0 / x -> rcp (fneg x)
7506-
SDValue FNegRHS = DAG.getNode(ISD::FNEG, SL, VT, RHS);
7507-
return DAG.getNode(AMDGPUISD::RCP, SL, VT, FNegRHS);
7508-
}
7502+
// 1.0 / sqrt(x) -> rsq(x)
7503+
7504+
// XXX - Is UnsafeFPMath sufficient to do this for f64? The maximum ULP
7505+
// error seems really high at 2^29 ULP.
7506+
if (RHS.getOpcode() == ISD::FSQRT)
7507+
return DAG.getNode(AMDGPUISD::RSQ, SL, VT, RHS.getOperand(0));
7508+
7509+
// 1.0 / x -> rcp(x)
7510+
return DAG.getNode(AMDGPUISD::RCP, SL, VT, RHS);
75097511
}
7510-
}
75117512

7512-
if (Unsafe) {
7513-
// Turn into multiply by the reciprocal.
7514-
// x / y -> x * (1.0 / y)
7515-
SDValue Recip = DAG.getNode(AMDGPUISD::RCP, SL, VT, RHS);
7516-
return DAG.getNode(ISD::FMUL, SL, VT, LHS, Recip, Flags);
7513+
// Same as for 1.0, but expand the sign out of the constant.
7514+
if (CLHS->isExactlyValue(-1.0)) {
7515+
// -1.0 / x -> rcp (fneg x)
7516+
SDValue FNegRHS = DAG.getNode(ISD::FNEG, SL, VT, RHS);
7517+
return DAG.getNode(AMDGPUISD::RCP, SL, VT, FNegRHS);
7518+
}
75177519
}
75187520

7519-
return SDValue();
7521+
// Turn into multiply by the reciprocal.
7522+
// x / y -> x * (1.0 / y)
7523+
SDValue Recip = DAG.getNode(AMDGPUISD::RCP, SL, VT, RHS);
7524+
return DAG.getNode(ISD::FMUL, SL, VT, LHS, Recip, Flags);
75207525
}
75217526

75227527
static SDValue getFPBinOp(SelectionDAG &DAG, unsigned Opcode, const SDLoc &SL,
@@ -8663,6 +8668,11 @@ SDValue SITargetLowering::performRcpCombine(SDNode *N,
86638668
N->getFlags());
86648669
}
86658670

8671+
if ((VT == MVT::f32 || VT == MVT::f16) && N0.getOpcode() == ISD::FSQRT) {
8672+
return DCI.DAG.getNode(AMDGPUISD::RSQ, SDLoc(N), VT,
8673+
N0.getOperand(0), N->getFlags());
8674+
}
8675+
86668676
return AMDGPUTargetLowering::performRcpCombine(N, DCI);
86678677
}
86688678

0 commit comments

Comments
 (0)