diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h index 01932226fa500..32a4a7409fba7 100644 --- a/flang/include/flang/Evaluate/match.h +++ b/flang/include/flang/Evaluate/match.h @@ -11,6 +11,7 @@ #include "flang/Common/Fortran-consts.h" #include "flang/Common/visit.h" #include "flang/Evaluate/expression.h" +#include "flang/Support/Fortran.h" #include "llvm/ADT/STLExtras.h" #include @@ -86,9 +87,12 @@ template struct TypePattern { mutable const MatchType *ref{nullptr}; }; -/// Matches one of the patterns provided as template arguments. All of these -/// patterns should have the same number of operands, i.e. they all should -/// try to match input expression with the same number of children, i.e. +/// Matches one of the patterns provided as template arguments. +/// Upon creation of an AnyOfPattern object with some arguments, say args, +/// each of the pattern objects will be created using args as arguments to +/// the constructor. This means that each of the patterns should be +/// constructible from args, in particular all patterns should take the same +/// number of inputs. So, for example, /// AnyOfPattern is ok, whereas /// AnyOfPattern is not. template struct AnyOfPattern { @@ -178,9 +182,51 @@ struct OperationPattern : public TypePattern { }; template -OperationPattern(const Ops &...ops, llvm::type_identity) +OperationPattern(const Ops &..., llvm::type_identity) -> OperationPattern; +// Encode the actual operator in the type, so that the class is constructible +// only from operand patterns. This will make it usable in AnyOfPattern. +template +struct LogicalOperationPattern + : public OperationPattern, Ops...> { + using Base = OperationPattern, Ops...>; + static constexpr common::LogicalOperator opCode{Operator}; + +private: + template bool matchOp(const LogicalOperation &op) const { + if constexpr (ValType::kind == K) { + return op.logicalOperator == opCode; + } + return false; + } + template bool matchOp(const U &) const { return false; } + +public: + LogicalOperationPattern(const Ops &...ops, llvm::type_identity = {}) + : Base(ops...) {} + + template bool match(const evaluate::Expr &input) const { + // All logical operations (for a given type T) have the same operation + // type (LogicalOperation), so the type-based matching will not + // be able to tell specific operations from one another. + // Check the operation code first, if that matches then use the the + // base class's match. + if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) { + return Base::match(input); + } else { + return false; + } + } + + template bool match(const U &input) const { // + return false; + } +}; + +// No deduction guide for LogicalOperationPattern, since the "Operator" +// parameter cannot be deduced from the constructor arguments. + // Namespace-level definitions template using Expr = ExprPattern; @@ -188,6 +234,15 @@ template using Expr = ExprPattern; template using Op = OperationPattern; +template +using LogicalOp = LogicalOperationPattern; + +template +LogicalOp logical(const Op0 &op0, const Op1 &op1) { + return LogicalOp(op0, op1); +} + template bool match(const Pattern &pattern, const Input &input) { return pattern.match(input); diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp index f25497ece61c4..ab8aa5f342e48 100644 --- a/flang/lib/Semantics/check-omp-atomic.cpp +++ b/flang/lib/Semantics/check-omp-atomic.cpp @@ -61,8 +61,7 @@ template struct IsIntegral> { static constexpr bool value{// C == common::TypeCategory::Integer || - C == common::TypeCategory::Unsigned || - C == common::TypeCategory::Logical}; + C == common::TypeCategory::Unsigned}; }; template constexpr bool is_integral_v{IsIntegral::value}; @@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint::value}; template constexpr bool is_numeric_v{is_integral_v || is_floating_point_v}; +template struct IsLogical { + static constexpr bool value{false}; +}; + +template +struct IsLogical> { + static constexpr bool value{C == common::TypeCategory::Logical}; +}; + +template constexpr bool is_logical_v{IsLogical::value}; + template using ReassocOpBase = evaluate::match::AnyOfPattern< // evaluate::match::Add, // - evaluate::match::Mul>; + evaluate::match::Mul, // + evaluate::match::LogicalOp, + evaluate::match::LogicalOp, + evaluate::match::LogicalOp, + evaluate::match::LogicalOp>; template struct ReassocOp : public ReassocOpBase { @@ -110,8 +124,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // Try to find cases where the input expression is of the form // (1) (a . b) . c, or // (2) a . (b . c), - // where . denotes an associative operation (currently + or *), and a, b, c - // are some subexpresions. + // where . denotes an associative operation, and a, b, c are some + // subexpresions. // If one of the operands in the nested operation is the atomic variable // (with some possible type conversions applied to it), bring it to the // top-level operation, and move the top-level operand into the nested @@ -119,7 +133,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // For example, assuming x is the atomic variable: // (a + x) + b -> (a + b) + x, i.e. (conceptually) swap x and b. template >> + typename = std::enable_if_t || is_logical_v>> evaluate::Expr operator()(evaluate::Expr &&x, const U &u) { if constexpr (is_floating_point_v) { if (!context_.langOptions().AssociativeMath) { @@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { // some order) from the example above. evaluate::match::Expr sub[3]; auto inner{reassocOp(sub[0], sub[1])}; - auto outer1{reassocOp(inner, sub[2])}; // inner + something - auto outer2{reassocOp(sub[2], inner)}; // something + inner + auto outer1{reassocOp(inner, sub[2])}; // inner . something + auto outer2{reassocOp(sub[2], inner)}; // something . inner #if !defined(__clang__) && !defined(_MSC_VER) && \ (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5)) // If GCC version < 8.5, use this definition. For the other definition @@ -167,23 +181,9 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } return common::visit( [&](auto &&s) { - using Expr = evaluate::Expr; - using TypeS = llvm::remove_cvref_t; - // This visitor has to be semantically correct for all possible - // types of s even though at runtime s will only be one of the - // matched types. - // Limit the construction to the operation types that we tried - // to match (otherwise TypeS(op1, op2) would fail for non-binary - // operations). - if constexpr (common::HasMember) { - Expr atom{*sub[atomIdx].ref}; - Expr op1{*sub[(atomIdx + 1) % 3].ref}; - Expr op2{*sub[(atomIdx + 2) % 3].ref}; - return Expr( - TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2))))); - } else { - return Expr(TypeS(s)); - } + // Build the new expression from the matched components. + return Reconstruct(s, *sub[atomIdx].ref, + *sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref); }, evaluate::match::deparen(x).u); } @@ -191,13 +191,43 @@ struct ReassocRewriter : public evaluate::rewrite::Identity { } template >> + typename = std::enable_if_t && !is_logical_v>> evaluate::Expr operator()( evaluate::Expr &&x, const U &u, NonIntegralTag = {}) { return Id::operator()(std::move(x), u); } private: + template + evaluate::Expr Reconstruct(const S &op, evaluate::Expr atom, + evaluate::Expr op1, evaluate::Expr op2) { + using TypeS = llvm::remove_cvref_t; + // This function has to be semantically correct for all possible types + // of S even though at runtime s will only be one of the matched types. + // Limit the construction to the operation types that we tried to match + // (otherwise TypeS(op1, op2) would fail for non-binary operations). + if constexpr (!common::HasMember) { + return evaluate::Expr(TypeS(op)); + } else if constexpr (is_logical_v) { + constexpr int K{T::kind}; + if constexpr (std::is_same_v>) { + // Logical operators take an extra argument in their constructor, + // so they need their own reconstruction code. + common::LogicalOperator opCode{op.logicalOperator}; + return evaluate::Expr(TypeS( // + opCode, std::move(atom), + evaluate::Expr(TypeS( // + opCode, std::move(op1), std::move(op2))))); + } + } else { + // Generic reconstruction. + return evaluate::Expr(TypeS( // + std::move(atom), + evaluate::Expr(TypeS( // + std::move(op1), std::move(op2))))); + } + } + template bool IsAtom(const evaluate::Expr &x) const { return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_); } diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 new file mode 100644 index 0000000000000..ccde4fed12f2f --- /dev/null +++ b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 @@ -0,0 +1,137 @@ +!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s + +subroutine f00(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .and. y .and. z +end + +!CHECK-LABEL: func.func @_QPf00 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f01(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .or. y .or. z +end + +!CHECK-LABEL: func.func @_QPf01 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f02(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .eqv. y .eqv. z +end + +!CHECK-LABEL: func.func @_QPf02 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f03(x, y, z) + implicit none + logical :: x, y, z + + !$omp atomic update + x = x .neqv. y .neqv. z +end + +!CHECK-LABEL: func.func @_QPf03 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref> +!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref> +!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1 +!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1 +!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1 +!CHECK: %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1 +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: } + + +subroutine f04(x, a, b, c) + implicit none + logical(kind=4) :: x + logical(kind=8) :: a, b, c + + !$omp atomic update + x = ((b .and. a) .and. x) .and. c +end + +!CHECK-LABEL: func.func @_QPf04 +!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1 +!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2 +!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3 +!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0 +!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref> +!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref> +!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1 +!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1 +!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref> +!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1 +!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref> { +!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>): +!CHECK: %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8> +!CHECK: %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1 +!CHECK: %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1 + +!CHECK: %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4> +!CHECK: omp.yield(%[[RET]] : !fir.logical<4>) +!CHECK: }