diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h index 79da40f7c1338..01932226fa500 100644 --- a/flang/include/flang/Evaluate/match.h +++ b/flang/include/flang/Evaluate/match.h @@ -8,6 +8,7 @@ #ifndef FORTRAN_EVALUATE_MATCH_H_ #define FORTRAN_EVALUATE_MATCH_H_ +#include "flang/Common/Fortran-consts.h" #include "flang/Common/visit.h" #include "flang/Evaluate/expression.h" #include "llvm/ADT/STLExtras.h" @@ -34,15 +35,29 @@ struct IsOperation> { template constexpr bool is_operation_v{detail::IsOperation::value}; -template -const evaluate::Expr &deparen(const evaluate::Expr &x) { - if (auto *parens{std::get_if>(&x.u)}) { +template +const evaluate::Expr> &deparen(const evaluate::Expr> &x) { + if (auto *parens{std::get_if>>(&x.u)}) { return deparen(parens->template operand<0>()); } else { return x; } } +template +const evaluate::Expr> &deparen( + const evaluate::Expr> &x) { + return x; +} + +// Some expressions (e.g. TypelessExpression) don't allow parentheses, while +// those that do have Expr as the argument to the parentheses. This means +// that there is no consistent return type that works for all expressions. +// Delete this overload explicitly so an attempt to use it creates a clearer +// error message. +const evaluate::Expr &deparen( + const evaluate::Expr &) = delete; + // Expr matchers (patterns) // // Each pattern should implement diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index 50e63d356be02..f25497ece61c4 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -67,6 +67,22 @@ struct IsIntegral> { template constexpr bool is_integral_v{IsIntegral::value}; +template struct IsFloatingPoint { + static constexpr bool value{false}; +}; + +template +struct IsFloatingPoint> { + static constexpr bool value{// + C == common::TypeCategory::Real || C == common::TypeCategory::Complex}; +}; + +template +constexpr bool is_floating_point_v{IsFloatingPoint::value}; + +template +constexpr bool is_numeric_v{is_integral_v || is_floating_point_v}; + template using ReassocOpBase = evaluate::match::AnyOfPattern< // evaluate::match::Add, // @@ -88,7 +104,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { using Id = evaluate::rewrite::Identity; struct NonIntegralTag {}; - ReassocRewriter(const SomeExpr &atom) : atom_(atom) {} + ReassocRewriter(const SomeExpr &atom, const SemanticsContext &context) + : atom_(atom), context_(context) {} // Try to find cases where the input expression is of the form // (1) (a . b) . c, or @@ -102,8 +119,13 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // For example, assuming x is the atomic variable: // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b. template >> + typename = std::enable_if_t>> evaluate::Expr operator()(evaluate::Expr &&x, const U &u) { + if constexpr (is_floating_point_v) { + if (!context_.langOptions().AssociativeMath) { + return Id::operator()(std::move(x), u); + } + } // As per the above comment, there are 3 subexpressions involved in this // transformation. A match::Expr will match evaluate::Expr when T is // same as U, plus it will store a pointer (ref) to the matched expression. @@ -169,7 +191,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } template >> + typename = std::enable_if_t>> evaluate::Expr operator()( evaluate::Expr &&x, const U &u, NonIntegralTag = {}) { return Id::operator()(std::move(x), u); @@ -181,6 +203,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } const SomeExpr &atom_; + const SemanticsContext &context_; }; struct AnalyzedCondStmt { @@ -809,7 +832,7 @@ OmpStructureChecker::CheckAtomicUpdateAssignment( CheckStorageOverlap(atom, GetNonAtomArguments(atom, update.rhs), source); return std::nullopt; } else if (tryReassoc) { - ReassocRewriter ra(atom); + ReassocRewriter ra(atom, context_); SomeExpr raRhs{evaluate::rewrite::Mutator(ra)(update.rhs)}; std::tie(hasErrors, tryReassoc) = CheckAtomicUpdateAssignmentRhs( diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc-fp.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc-fp.f90 new file mode 100644 index 0000000000000..c86589cacd679 --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update-reassoc-fp.f90 @@ -0,0 +1,100 @@ +!RUN: %flang_fc1 -emit-hlfir -ffast-math -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +subroutine f00(x, y) + implicit none + real :: x, y + + !$omp atomic update + x = ((x + 1) + y) + 2 +end + +!CHECK-LABEL: func.func @_QPf00 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %cst = arith.constant 1.000000e+00 : f32 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref +!CHECK: %[[Y_1:[0-9]+]] = arith.addf %cst, %[[LOAD_Y]] fastmath : f32 +!CHECK: %cst_0 = arith.constant 2.000000e+00 : f32 +!CHECK: %[[Y_1_2:[0-9]+]] = arith.addf %[[Y_1]], %cst_0 fastmath : f32 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32): +!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG]], %[[Y_1_2]] fastmath : f32 +!CHECK: omp.yield(%[[ARG_P]] : f32) +!CHECK: } + + +subroutine f01(x, y, z) + implicit none + complex :: x, y, z + + !$omp atomic update + x = (x + y) + z +end + +!CHECK-LABEL: func.func @_QPf01 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[Y_Z:[0-9]+]] = fir.addc %[[LOAD_Y]], %[[LOAD_Z]] {fastmath = #arith.fastmath} : complex +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: complex): +!CHECK: %[[ARG_P:[0-9]+]] = fir.addc %[[ARG]], %[[Y_Z]] {fastmath = #arith.fastmath} : complex +!CHECK: omp.yield(%[[ARG_P]] : complex) +!CHECK: } + + +subroutine f02(x, y) + implicit none + complex :: x + real :: y + + !$omp atomic update + x = (real(x) + y) + 1 +end + +!CHECK-LABEL: func.func @_QPf02 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref +!CHECK: %cst = arith.constant 1.000000e+00 : f32 +!CHECK: %[[Y_1:[0-9]+]] = arith.addf %[[LOAD_Y]], %cst fastmath : f32 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: complex): +!CHECK: %[[ARG_X:[0-9]+]] = fir.extract_value %[[ARG]], [0 : index] : (complex) -> f32 +!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG_X]], %[[Y_1]] fastmath : f32 +!CHECK: %cst_0 = arith.constant 0.000000e+00 : f32 +!CHECK: %[[CPLX:[0-9]+]] = fir.undefined complex +!CHECK: %[[CPLX_I:[0-9]+]] = fir.insert_value %[[CPLX]], %[[ARG_P]], [0 : index] : (complex, f32) -> complex +!CHECK: %[[CPLX_R:[0-9]+]] = fir.insert_value %[[CPLX_I]], %cst_0, [1 : index] : (complex, f32) -> complex +!CHECK: omp.yield(%[[CPLX_R]] : complex) +!CHECK: } + + +subroutine f03(x, a, b, c) + implicit none + real(kind=4) :: x + real(kind=8) :: a, b, c + + !$omp atomic update + x = ((b + a) + x) + c +end + +!CHECK-LABEL: func.func @_QPf03 +!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref +!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref +!CHECK: %[[A_B:[0-9]+]] = arith.addf %[[LOAD_B]], %[[LOAD_A]] fastmath : f64 +!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref +!CHECK: %[[A_B_C:[0-9]+]] = arith.addf %[[A_B]], %[[LOAD_C]] fastmath : f64 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: f32): +!CHECK: %[[ARG_8:[0-9]+]] = fir.convert %[[ARG]] : (f32) -> f64 +!CHECK: %[[ARG_P:[0-9]+]] = arith.addf %[[ARG_8]], %[[A_B_C]] fastmath : f64 +!CHECK: %[[ARG_4:[0-9]+]] = fir.convert %[[ARG_P]] : (f64) -> f32 +!CHECK: omp.yield(%[[ARG_4]] : f32) +!CHECK: }