Skip to content

Conversation

kparzysz
Copy link
Contributor

@kparzysz kparzysz commented Sep 4, 2025

This is a follow-up to PR153488 and PR155840, this time for expressions of logical type. The handling of logical operations in Expr differs slightly from regular arithmetic operations. The difference is that the specific operation (e.g. and, or, etc.) is not a part of the type, but stored as a data member.
Both the matching code and the reconstruction code needed to be extended to correctly handle the data member.

This fixes #144944

This is a follow-up to PR153488 and PR155840, this time for expressions
of logical type. The handling of logical operations in Expr<T> differs
slightly from regular arithmetic operations. The difference is that the
specific operation (e.g. and, or, etc.) is not a part of the type, but
stored as a data member.
Both the matching code and the reconstruction code needed to be extended
to correctly handle the data member.

This fixes llvm#144944
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir flang:openmp flang:semantics labels Sep 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-flang-semantics

@llvm/pr-subscribers-flang-fir-hlfir

Author: Krzysztof Parzyszek (kparzysz)

Changes

This is a follow-up to PR153488 and PR155840, this time for expressions of logical type. The handling of logical operations in Expr<T> differs slightly from regular arithmetic operations. The difference is that the specific operation (e.g. and, or, etc.) is not a part of the type, but stored as a data member.
Both the matching code and the reconstruction code needed to be extended to correctly handle the data member.

This fixes #144944


Full diff: https://github.com/llvm/llvm-project/pull/156961.diff

3 Files Affected:

  • (modified) flang/include/flang/Evaluate/match.h (+59-4)
  • (modified) flang/lib/Semantics/check-omp-atomic.cpp (+56-26)
  • (added) flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 (+137)
diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h
index 01932226fa500..32a4a7409fba7 100644
--- a/flang/include/flang/Evaluate/match.h
+++ b/flang/include/flang/Evaluate/match.h
@@ -11,6 +11,7 @@
 #include "flang/Common/Fortran-consts.h"
 #include "flang/Common/visit.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Support/Fortran.h"
 #include "llvm/ADT/STLExtras.h"
 
 #include <tuple>
@@ -86,9 +87,12 @@ template <typename T> struct TypePattern {
   mutable const MatchType *ref{nullptr};
 };
 
-/// Matches one of the patterns provided as template arguments. All of these
-/// patterns should have the same number of operands, i.e. they all should
-/// try to match input expression with the same number of children, i.e.
+/// Matches one of the patterns provided as template arguments.
+/// Upon creation of an AnyOfPattern object with some arguments, say args,
+/// each of the pattern objects will be created using args as arguments to
+/// the constructor. This means that each of the patterns should be
+/// constructible from args, in particular all patterns should take the same
+/// number of inputs. So, for example,
 /// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
 /// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
 template <typename... Patterns> struct AnyOfPattern {
@@ -178,9 +182,51 @@ struct OperationPattern : public TypePattern<OpType> {
 };
 
 template <typename OpType, typename... Ops>
-OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
+OperationPattern(const Ops &..., llvm::type_identity<OpType>)
     -> OperationPattern<OpType, Ops...>;
 
+// Encode the actual operator in the type, so that the class is constructible
+// only from operand patterns. This will make it usable in AnyOfPattern.
+template <common::LogicalOperator Operator, typename ValType, typename... Ops>
+struct LogicalOperationPattern
+    : public OperationPattern<LogicalOperation<ValType::kind>, Ops...> {
+  using Base = OperationPattern<LogicalOperation<ValType::kind>, Ops...>;
+  static constexpr common::LogicalOperator opCode{Operator};
+
+private:
+  template <int K> bool matchOp(const LogicalOperation<K> &op) const {
+    if constexpr (ValType::kind == K) {
+      return op.logicalOperator == opCode;
+    }
+    return false;
+  }
+  template <typename U> bool matchOp(const U &) const { return false; }
+
+public:
+  LogicalOperationPattern(const Ops &...ops, llvm::type_identity<ValType> = {})
+      : Base(ops...) {}
+
+  template <typename T> bool match(const evaluate::Expr<T> &input) const {
+    // All logical operations (for a given type T) have the same operation
+    // type (LogicalOperation<T::kind>), so the type-based matching will not
+    // be able to tell specific operations from one another.
+    // Check the operation code first, if that matches then use the the
+    // base class's match.
+    if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) {
+      return Base::match(input);
+    } else {
+      return false;
+    }
+  }
+
+  template <typename U> bool match(const U &input) const { //
+    return false;
+  }
+};
+
+// No deduction guide for LogicalOperationPattern, since the "Operator"
+// parameter cannot be deduced from the constructor arguments.
+
 // Namespace-level definitions
 
 template <typename T> using Expr = ExprPattern<T>;
@@ -188,6 +234,15 @@ template <typename T> using Expr = ExprPattern<T>;
 template <typename OpType, typename... Ops>
 using Op = OperationPattern<OpType, Ops...>;
 
+template <common::LogicalOperator Operator, typename ValType, typename... Ops>
+using LogicalOp = LogicalOperationPattern<Operator, ValType, Ops...>;
+
+template <common::LogicalOperator Operator, typename Type, typename Op0,
+    typename Op1>
+LogicalOp<Operator, Type, Op0, Op1> logical(const Op0 &op0, const Op1 &op1) {
+  return LogicalOp<Operator, Type, Op0, Op1>(op0, op1);
+}
+
 template <typename Pattern, typename Input>
 bool match(const Pattern &pattern, const Input &input) {
   return pattern.match(input);
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index f25497ece61c4..ab8aa5f342e48 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -61,8 +61,7 @@ template <common::TypeCategory C, int K>
 struct IsIntegral<evaluate::Type<C, K>> {
   static constexpr bool value{//
       C == common::TypeCategory::Integer ||
-      C == common::TypeCategory::Unsigned ||
-      C == common::TypeCategory::Logical};
+      C == common::TypeCategory::Unsigned};
 };
 
 template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
@@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
 template <typename T>
 constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};
 
+template <typename...> struct IsLogical {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsLogical<evaluate::Type<C, K>> {
+  static constexpr bool value{C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_logical_v{IsLogical<T>::value};
+
 template <typename T, typename Op0, typename Op1>
 using ReassocOpBase = evaluate::match::AnyOfPattern< //
     evaluate::match::Add<T, Op0, Op1>, //
-    evaluate::match::Mul<T, Op0, Op1>>;
+    evaluate::match::Mul<T, Op0, Op1>, //
+    evaluate::match::LogicalOp<common::LogicalOperator::And, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Or, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Eqv, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Neqv, T, Op0, Op1>>;
 
 template <typename T, typename Op0, typename Op1>
 struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
@@ -110,8 +124,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   // Try to find cases where the input expression is of the form
   // (1) (a . b) . c, or
   // (2) a . (b . c),
-  // where . denotes an associative operation (currently + or *), and a, b, c
-  // are some subexpresions.
+  // where . denotes an associative operation, and a, b, c are some
+  // subexpresions.
   // If one of the operands in the nested operation is the atomic variable
   // (with some possible type conversions applied to it), bring it to the
   // top-level operation, and move the top-level operand into the nested
@@ -119,7 +133,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   // For example, assuming x is the atomic variable:
   //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
   template <typename T, typename U,
-      typename = std::enable_if_t<is_numeric_v<T>>>
+      typename = std::enable_if_t<is_numeric_v<T> || is_logical_v<T>>>
   evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
     if constexpr (is_floating_point_v<T>) {
       if (!context_.langOptions().AssociativeMath) {
@@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
     // some order) from the example above.
     evaluate::match::Expr<T> sub[3];
     auto inner{reassocOp<T>(sub[0], sub[1])};
-    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
-    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner . something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something . inner
 #if !defined(__clang__) && !defined(_MSC_VER) && \
     (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
     // If GCC version < 8.5, use this definition. For the other definition
@@ -167,23 +181,9 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
       }
       return common::visit(
           [&](auto &&s) {
-            using Expr = evaluate::Expr<T>;
-            using TypeS = llvm::remove_cvref_t<decltype(s)>;
-            // This visitor has to be semantically correct for all possible
-            // types of s even though at runtime s will only be one of the
-            // matched types.
-            // Limit the construction to the operation types that we tried
-            // to match (otherwise TypeS(op1, op2) would fail for non-binary
-            // operations).
-            if constexpr (common::HasMember<TypeS, MatchTypes>) {
-              Expr atom{*sub[atomIdx].ref};
-              Expr op1{*sub[(atomIdx + 1) % 3].ref};
-              Expr op2{*sub[(atomIdx + 2) % 3].ref};
-              return Expr(
-                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
-            } else {
-              return Expr(TypeS(s));
-            }
+            // Build the new expression from the matched components.
+            return Reconstruct<T, MatchTypes>(s, *sub[atomIdx].ref,
+                *sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref);
           },
           evaluate::match::deparen(x).u);
     }
@@ -191,13 +191,43 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   }
 
   template <typename T, typename U,
-      typename = std::enable_if_t<!is_numeric_v<T>>>
+      typename = std::enable_if_t<!is_numeric_v<T> && !is_logical_v<T>>>
   evaluate::Expr<T> operator()(
       evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
     return Id::operator()(std::move(x), u);
   }
 
 private:
+  template <typename T, typename MatchTypes, typename S>
+  evaluate::Expr<T> Reconstruct(const S &op, evaluate::Expr<T> atom,
+      evaluate::Expr<T> op1, evaluate::Expr<T> op2) {
+    using TypeS = llvm::remove_cvref_t<decltype(op)>;
+    // This function has to be semantically correct for all possible types
+    // of S even though at runtime s will only be one of the matched types.
+    // Limit the construction to the operation types that we tried to match
+    // (otherwise TypeS(op1, op2) would fail for non-binary operations).
+    if constexpr (!common::HasMember<TypeS, MatchTypes>) {
+      return evaluate::Expr<T>(TypeS(op));
+    } else if constexpr (is_logical_v<T>) {
+      constexpr int K{T::kind};
+      if constexpr (std::is_same_v<TypeS, evaluate::LogicalOperation<K>>) {
+        // Logical operators take an extra argument in their constructor,
+        // so they need their own reconstruction code.
+        common::LogicalOperator opCode{op.logicalOperator};
+        return evaluate::Expr<T>(TypeS( //
+            opCode, std::move(atom),
+            evaluate::Expr<T>(TypeS( //
+                opCode, std::move(op1), std::move(op2)))));
+      }
+    } else {
+      // Generic reconstruction.
+      return evaluate::Expr<T>(TypeS( //
+          std::move(atom),
+          evaluate::Expr<T>(TypeS( //
+              std::move(op1), std::move(op2)))));
+    }
+  }
+
   template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
     return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
   }
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90
new file mode 100644
index 0000000000000..ccde4fed12f2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90
@@ -0,0 +1,137 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .and. y .and. z
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f01(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .or. y .or. z
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f02(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .eqv. y .eqv. z
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f03(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .neqv. y .neqv. z
+end
+
+!CHECK-LABEL: func.func @_QPf03
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f04(x, a, b, c)
+  implicit none
+  logical(kind=4) :: x
+  logical(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b .and. a) .and. x) .and. c
+end
+
+!CHECK-LABEL: func.func @_QPf04
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1
+!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1
+!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1
+!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1
+!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8>
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1
+!CHECK:   %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1
+
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }

@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-flang-openmp

Author: Krzysztof Parzyszek (kparzysz)

Changes

This is a follow-up to PR153488 and PR155840, this time for expressions of logical type. The handling of logical operations in Expr<T> differs slightly from regular arithmetic operations. The difference is that the specific operation (e.g. and, or, etc.) is not a part of the type, but stored as a data member.
Both the matching code and the reconstruction code needed to be extended to correctly handle the data member.

This fixes #144944


Full diff: https://github.com/llvm/llvm-project/pull/156961.diff

3 Files Affected:

  • (modified) flang/include/flang/Evaluate/match.h (+59-4)
  • (modified) flang/lib/Semantics/check-omp-atomic.cpp (+56-26)
  • (added) flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 (+137)
diff --git a/flang/include/flang/Evaluate/match.h b/flang/include/flang/Evaluate/match.h
index 01932226fa500..32a4a7409fba7 100644
--- a/flang/include/flang/Evaluate/match.h
+++ b/flang/include/flang/Evaluate/match.h
@@ -11,6 +11,7 @@
 #include "flang/Common/Fortran-consts.h"
 #include "flang/Common/visit.h"
 #include "flang/Evaluate/expression.h"
+#include "flang/Support/Fortran.h"
 #include "llvm/ADT/STLExtras.h"
 
 #include <tuple>
@@ -86,9 +87,12 @@ template <typename T> struct TypePattern {
   mutable const MatchType *ref{nullptr};
 };
 
-/// Matches one of the patterns provided as template arguments. All of these
-/// patterns should have the same number of operands, i.e. they all should
-/// try to match input expression with the same number of children, i.e.
+/// Matches one of the patterns provided as template arguments.
+/// Upon creation of an AnyOfPattern object with some arguments, say args,
+/// each of the pattern objects will be created using args as arguments to
+/// the constructor. This means that each of the patterns should be
+/// constructible from args, in particular all patterns should take the same
+/// number of inputs. So, for example,
 /// AnyOfPattern<SomeBinaryOp, OtherBinaryOp> is ok, whereas
 /// AnyOfPattern<SomeBinaryOp, SomeTernaryOp> is not.
 template <typename... Patterns> struct AnyOfPattern {
@@ -178,9 +182,51 @@ struct OperationPattern : public TypePattern<OpType> {
 };
 
 template <typename OpType, typename... Ops>
-OperationPattern(const Ops &...ops, llvm::type_identity<OpType>)
+OperationPattern(const Ops &..., llvm::type_identity<OpType>)
     -> OperationPattern<OpType, Ops...>;
 
+// Encode the actual operator in the type, so that the class is constructible
+// only from operand patterns. This will make it usable in AnyOfPattern.
+template <common::LogicalOperator Operator, typename ValType, typename... Ops>
+struct LogicalOperationPattern
+    : public OperationPattern<LogicalOperation<ValType::kind>, Ops...> {
+  using Base = OperationPattern<LogicalOperation<ValType::kind>, Ops...>;
+  static constexpr common::LogicalOperator opCode{Operator};
+
+private:
+  template <int K> bool matchOp(const LogicalOperation<K> &op) const {
+    if constexpr (ValType::kind == K) {
+      return op.logicalOperator == opCode;
+    }
+    return false;
+  }
+  template <typename U> bool matchOp(const U &) const { return false; }
+
+public:
+  LogicalOperationPattern(const Ops &...ops, llvm::type_identity<ValType> = {})
+      : Base(ops...) {}
+
+  template <typename T> bool match(const evaluate::Expr<T> &input) const {
+    // All logical operations (for a given type T) have the same operation
+    // type (LogicalOperation<T::kind>), so the type-based matching will not
+    // be able to tell specific operations from one another.
+    // Check the operation code first, if that matches then use the the
+    // base class's match.
+    if (common::visit([&](auto &&s) { return matchOp(s); }, deparen(input).u)) {
+      return Base::match(input);
+    } else {
+      return false;
+    }
+  }
+
+  template <typename U> bool match(const U &input) const { //
+    return false;
+  }
+};
+
+// No deduction guide for LogicalOperationPattern, since the "Operator"
+// parameter cannot be deduced from the constructor arguments.
+
 // Namespace-level definitions
 
 template <typename T> using Expr = ExprPattern<T>;
@@ -188,6 +234,15 @@ template <typename T> using Expr = ExprPattern<T>;
 template <typename OpType, typename... Ops>
 using Op = OperationPattern<OpType, Ops...>;
 
+template <common::LogicalOperator Operator, typename ValType, typename... Ops>
+using LogicalOp = LogicalOperationPattern<Operator, ValType, Ops...>;
+
+template <common::LogicalOperator Operator, typename Type, typename Op0,
+    typename Op1>
+LogicalOp<Operator, Type, Op0, Op1> logical(const Op0 &op0, const Op1 &op1) {
+  return LogicalOp<Operator, Type, Op0, Op1>(op0, op1);
+}
+
 template <typename Pattern, typename Input>
 bool match(const Pattern &pattern, const Input &input) {
   return pattern.match(input);
diff --git a/flang/lib/Semantics/check-omp-atomic.cpp b/flang/lib/Semantics/check-omp-atomic.cpp
index f25497ece61c4..ab8aa5f342e48 100644
--- a/flang/lib/Semantics/check-omp-atomic.cpp
+++ b/flang/lib/Semantics/check-omp-atomic.cpp
@@ -61,8 +61,7 @@ template <common::TypeCategory C, int K>
 struct IsIntegral<evaluate::Type<C, K>> {
   static constexpr bool value{//
       C == common::TypeCategory::Integer ||
-      C == common::TypeCategory::Unsigned ||
-      C == common::TypeCategory::Logical};
+      C == common::TypeCategory::Unsigned};
 };
 
 template <typename T> constexpr bool is_integral_v{IsIntegral<T>::value};
@@ -83,10 +82,25 @@ constexpr bool is_floating_point_v{IsFloatingPoint<T>::value};
 template <typename T>
 constexpr bool is_numeric_v{is_integral_v<T> || is_floating_point_v<T>};
 
+template <typename...> struct IsLogical {
+  static constexpr bool value{false};
+};
+
+template <common::TypeCategory C, int K>
+struct IsLogical<evaluate::Type<C, K>> {
+  static constexpr bool value{C == common::TypeCategory::Logical};
+};
+
+template <typename T> constexpr bool is_logical_v{IsLogical<T>::value};
+
 template <typename T, typename Op0, typename Op1>
 using ReassocOpBase = evaluate::match::AnyOfPattern< //
     evaluate::match::Add<T, Op0, Op1>, //
-    evaluate::match::Mul<T, Op0, Op1>>;
+    evaluate::match::Mul<T, Op0, Op1>, //
+    evaluate::match::LogicalOp<common::LogicalOperator::And, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Or, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Eqv, T, Op0, Op1>,
+    evaluate::match::LogicalOp<common::LogicalOperator::Neqv, T, Op0, Op1>>;
 
 template <typename T, typename Op0, typename Op1>
 struct ReassocOp : public ReassocOpBase<T, Op0, Op1> {
@@ -110,8 +124,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   // Try to find cases where the input expression is of the form
   // (1) (a . b) . c, or
   // (2) a . (b . c),
-  // where . denotes an associative operation (currently + or *), and a, b, c
-  // are some subexpresions.
+  // where . denotes an associative operation, and a, b, c are some
+  // subexpresions.
   // If one of the operands in the nested operation is the atomic variable
   // (with some possible type conversions applied to it), bring it to the
   // top-level operation, and move the top-level operand into the nested
@@ -119,7 +133,7 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   // For example, assuming x is the atomic variable:
   //   (a + x) + b  ->  (a + b) + x,  i.e. (conceptually) swap x and b.
   template <typename T, typename U,
-      typename = std::enable_if_t<is_numeric_v<T>>>
+      typename = std::enable_if_t<is_numeric_v<T> || is_logical_v<T>>>
   evaluate::Expr<T> operator()(evaluate::Expr<T> &&x, const U &u) {
     if constexpr (is_floating_point_v<T>) {
       if (!context_.langOptions().AssociativeMath) {
@@ -133,8 +147,8 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
     // some order) from the example above.
     evaluate::match::Expr<T> sub[3];
     auto inner{reassocOp<T>(sub[0], sub[1])};
-    auto outer1{reassocOp<T>(inner, sub[2])}; // inner + something
-    auto outer2{reassocOp<T>(sub[2], inner)}; // something + inner
+    auto outer1{reassocOp<T>(inner, sub[2])}; // inner . something
+    auto outer2{reassocOp<T>(sub[2], inner)}; // something . inner
 #if !defined(__clang__) && !defined(_MSC_VER) && \
     (__GNUC__ < 8 || (__GNUC__ == 8 && __GNUC_MINOR__ < 5))
     // If GCC version < 8.5, use this definition. For the other definition
@@ -167,23 +181,9 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
       }
       return common::visit(
           [&](auto &&s) {
-            using Expr = evaluate::Expr<T>;
-            using TypeS = llvm::remove_cvref_t<decltype(s)>;
-            // This visitor has to be semantically correct for all possible
-            // types of s even though at runtime s will only be one of the
-            // matched types.
-            // Limit the construction to the operation types that we tried
-            // to match (otherwise TypeS(op1, op2) would fail for non-binary
-            // operations).
-            if constexpr (common::HasMember<TypeS, MatchTypes>) {
-              Expr atom{*sub[atomIdx].ref};
-              Expr op1{*sub[(atomIdx + 1) % 3].ref};
-              Expr op2{*sub[(atomIdx + 2) % 3].ref};
-              return Expr(
-                  TypeS(atom, Expr(TypeS(std::move(op1), std::move(op2)))));
-            } else {
-              return Expr(TypeS(s));
-            }
+            // Build the new expression from the matched components.
+            return Reconstruct<T, MatchTypes>(s, *sub[atomIdx].ref,
+                *sub[(atomIdx + 1) % 3].ref, *sub[(atomIdx + 2) % 3].ref);
           },
           evaluate::match::deparen(x).u);
     }
@@ -191,13 +191,43 @@ struct ReassocRewriter : public evaluate::rewrite::Identity {
   }
 
   template <typename T, typename U,
-      typename = std::enable_if_t<!is_numeric_v<T>>>
+      typename = std::enable_if_t<!is_numeric_v<T> && !is_logical_v<T>>>
   evaluate::Expr<T> operator()(
       evaluate::Expr<T> &&x, const U &u, NonIntegralTag = {}) {
     return Id::operator()(std::move(x), u);
   }
 
 private:
+  template <typename T, typename MatchTypes, typename S>
+  evaluate::Expr<T> Reconstruct(const S &op, evaluate::Expr<T> atom,
+      evaluate::Expr<T> op1, evaluate::Expr<T> op2) {
+    using TypeS = llvm::remove_cvref_t<decltype(op)>;
+    // This function has to be semantically correct for all possible types
+    // of S even though at runtime s will only be one of the matched types.
+    // Limit the construction to the operation types that we tried to match
+    // (otherwise TypeS(op1, op2) would fail for non-binary operations).
+    if constexpr (!common::HasMember<TypeS, MatchTypes>) {
+      return evaluate::Expr<T>(TypeS(op));
+    } else if constexpr (is_logical_v<T>) {
+      constexpr int K{T::kind};
+      if constexpr (std::is_same_v<TypeS, evaluate::LogicalOperation<K>>) {
+        // Logical operators take an extra argument in their constructor,
+        // so they need their own reconstruction code.
+        common::LogicalOperator opCode{op.logicalOperator};
+        return evaluate::Expr<T>(TypeS( //
+            opCode, std::move(atom),
+            evaluate::Expr<T>(TypeS( //
+                opCode, std::move(op1), std::move(op2)))));
+      }
+    } else {
+      // Generic reconstruction.
+      return evaluate::Expr<T>(TypeS( //
+          std::move(atom),
+          evaluate::Expr<T>(TypeS( //
+              std::move(op1), std::move(op2)))));
+    }
+  }
+
   template <typename T> bool IsAtom(const evaluate::Expr<T> &x) const {
     return IsSameOrConvertOf(evaluate::AsGenericExpr(AsRvalue(x)), atom_);
   }
diff --git a/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90 b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90
new file mode 100644
index 0000000000000..ccde4fed12f2f
--- /dev/null
+++ b/flang/test/Lower/OpenMP/atomic-update-reassoc-logical.f90
@@ -0,0 +1,137 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp -fopenmp-version=60 %s -o - | FileCheck %s
+
+subroutine f00(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .and. y .and. z
+end
+
+!CHECK-LABEL: func.func @_QPf00
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[AND_YZ:[0-9]+]] = arith.andi %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[AND_XYZ:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[AND_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f01(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .or. y .or. z
+end
+
+!CHECK-LABEL: func.func @_QPf01
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[OR_YZ:[0-9]+]] = arith.ori %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[OR_XYZ:[0-9]+]] = arith.ori %[[CVT_X]], %[[OR_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[OR_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f02(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .eqv. y .eqv. z
+end
+
+!CHECK-LABEL: func.func @_QPf02
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[EQV_YZ:[0-9]+]] = arith.cmpi eq, %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[EQV_XYZ:[0-9]+]] = arith.cmpi eq, %[[CVT_X]], %[[EQV_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[EQV_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f03(x, y, z)
+  implicit none
+  logical :: x, y, z
+
+  !$omp atomic update
+  x = x .neqv. y .neqv. z
+end
+
+!CHECK-LABEL: func.func @_QPf03
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[Y:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[Z:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[LOAD_Y:[0-9]+]] = fir.load %[[Y]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[LOAD_Z:[0-9]+]] = fir.load %[[Z]]#0 : !fir.ref<!fir.logical<4>>
+!CHECK: %[[CVT_Y:[0-9]+]] = fir.convert %[[LOAD_Y]] : (!fir.logical<4>) -> i1
+!CHECK: %[[CVT_Z:[0-9]+]] = fir.convert %[[LOAD_Z]] : (!fir.logical<4>) -> i1
+!CHECK: %[[NEQV_YZ:[0-9]+]] = arith.cmpi ne, %[[CVT_Y]], %[[CVT_Z]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> i1
+!CHECK:   %[[NEQV_XYZ:[0-9]+]] = arith.cmpi ne, %[[CVT_X]], %[[NEQV_YZ]] : i1
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[NEQV_XYZ]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }
+
+
+subroutine f04(x, a, b, c)
+  implicit none
+  logical(kind=4) :: x
+  logical(kind=8) :: a, b, c
+
+  !$omp atomic update
+  x = ((b .and. a) .and. x) .and. c
+end
+
+!CHECK-LABEL: func.func @_QPf04
+!CHECK: %[[A:[0-9]+]]:2 = hlfir.declare %arg1
+!CHECK: %[[B:[0-9]+]]:2 = hlfir.declare %arg2
+!CHECK: %[[C:[0-9]+]]:2 = hlfir.declare %arg3
+!CHECK: %[[X:[0-9]+]]:2 = hlfir.declare %arg0
+!CHECK: %[[LOAD_B:[0-9]+]] = fir.load %[[B]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[LOAD_A:[0-9]+]] = fir.load %[[A]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[CVT_B:[0-9]+]] = fir.convert %[[LOAD_B]] : (!fir.logical<8>) -> i1
+!CHECK: %[[CVT_A:[0-9]+]] = fir.convert %[[LOAD_A]] : (!fir.logical<8>) -> i1
+!CHECK: %[[AND_BA:[0-9]+]] = arith.andi %[[CVT_B]], %[[CVT_A]] : i1
+!CHECK: %[[LOAD_C:[0-9]+]] = fir.load %[[C]]#0 : !fir.ref<!fir.logical<8>>
+!CHECK: %[[CVT_C:[0-9]+]] = fir.convert %[[LOAD_C]] : (!fir.logical<8>) -> i1
+!CHECK: %[[AND_BAC:[0-9]+]] = arith.andi %[[AND_BA]], %[[CVT_C]] : i1
+!CHECK: omp.atomic.update memory_order(relaxed) %[[X]]#0 : !fir.ref<!fir.logical<4>> {
+!CHECK: ^bb0(%[[ARG:arg[0-9]+]]: !fir.logical<4>):
+!CHECK:   %[[CVT8_X:[0-9]+]] = fir.convert %[[ARG]] : (!fir.logical<4>) -> !fir.logical<8>
+!CHECK:   %[[CVT_X:[0-9]+]] = fir.convert %[[CVT8_X]] : (!fir.logical<8>) -> i1
+!CHECK:   %[[AND_XBAC:[0-9]+]] = arith.andi %[[CVT_X]], %[[AND_BAC]] : i1
+
+!CHECK:   %[[RET:[0-9]+]] = fir.convert %[[AND_XBAC]] : (i1) -> !fir.logical<4>
+!CHECK:   omp.yield(%[[RET]] : !fir.logical<4>)
+!CHECK: }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang:openmp flang:semantics flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[flang][OpenMP] OMP ATOMIC restriction too strong?
3 participants