Skip to content

Conversation

ergawy
Copy link
Member

@ergawy ergawy commented Sep 3, 2025

Extends `do concurrent` to OpenMP device mapping by adding support for
mapping `reduce` specifiers to omp `reduction` clauses. The changes
attach 2 `reduction` clauses to the mapped OpenMP construct: one on the
`teams` part of the construct and one on the `wloop` part.
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_6_local_specs_device branch from 78fc5ed to 78e1013 Compare September 4, 2025 09:39
@ergawy ergawy force-pushed the users/ergawy/upstream_dc_device_7_reduce_device branch from f748bd2 to 6987182 Compare September 4, 2025 09:39
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Sep 4, 2025
@llvmbot
Copy link
Member

llvmbot commented Sep 4, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Kareem Ergawy (ergawy)

Changes

Extends do concurrent to OpenMP device mapping by adding support for mapping reduce specifiers to omp reduction clauses. The changes attach 2 reduction clauses to the mapped OpenMP construct: one on the teams part of the construct and one on the wloop part.


Full diff: https://github.com/llvm/llvm-project/pull/156610.diff

2 Files Affected:

  • (modified) flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp (+68-49)
  • (added) flang/test/Transforms/DoConcurrent/reduce_device.mlir (+53)
diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
index 66b778fecc208..135382abb0227 100644
--- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
+++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
@@ -140,6 +140,9 @@ void collectLoopLiveIns(fir::DoConcurrentLoopOp loop,
 
   for (mlir::Value local : loop.getLocalVars())
     liveIns.push_back(local);
+
+  for (mlir::Value reduce : loop.getReduceVars())
+    liveIns.push_back(reduce);
 }
 
 /// Collects values that are local to a loop: "loop-local values". A loop-local
@@ -272,7 +275,7 @@ class DoConcurrentConversion
       targetOp =
           genTargetOp(doLoop.getLoc(), rewriter, mapper, loopNestLiveIns,
                       targetClauseOps, loopNestClauseOps, liveInShapeInfoMap);
-      genTeamsOp(doLoop.getLoc(), rewriter);
+      genTeamsOp(rewriter, loop, mapper);
     }
 
     mlir::omp::ParallelOp parallelOp =
@@ -488,46 +491,7 @@ class DoConcurrentConversion
     if (!mapToDevice)
       genPrivatizers(rewriter, mapper, loop, wsloopClauseOps);
 
-    if (!loop.getReduceVars().empty()) {
-      for (auto [op, byRef, sym, arg] : llvm::zip_equal(
-               loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
-               loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
-               loop.getRegionReduceArgs())) {
-        auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
-            sym.getLeafReference());
-
-        mlir::OpBuilder::InsertionGuard guard(rewriter);
-        rewriter.setInsertionPointAfter(firReducer);
-        std::string ompReducerName = sym.getLeafReference().str() + ".omp";
-
-        auto ompReducer =
-            moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
-                rewriter.getStringAttr(ompReducerName));
-
-        if (!ompReducer) {
-          ompReducer = mlir::omp::DeclareReductionOp::create(
-              rewriter, firReducer.getLoc(), ompReducerName,
-              firReducer.getTypeAttr().getValue());
-
-          cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
-                              ompReducer.getAllocRegion());
-          cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
-                              ompReducer.getInitializerRegion());
-          cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
-                              ompReducer.getReductionRegion());
-          cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
-                              ompReducer.getAtomicReductionRegion());
-          cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
-                              ompReducer.getCleanupRegion());
-          moduleSymbolTable.insert(ompReducer);
-        }
-
-        wsloopClauseOps.reductionVars.push_back(op);
-        wsloopClauseOps.reductionByref.push_back(byRef);
-        wsloopClauseOps.reductionSyms.push_back(
-            mlir::SymbolRefAttr::get(ompReducer));
-      }
-    }
+    genReductions(rewriter, mapper, loop, wsloopClauseOps);
 
     auto wsloopOp =
         mlir::omp::WsloopOp::create(rewriter, loop.getLoc(), wsloopClauseOps);
@@ -549,8 +513,6 @@ class DoConcurrentConversion
 
     rewriter.setInsertionPointToEnd(&loopNestOp.getRegion().back());
     mlir::omp::YieldOp::create(rewriter, loop->getLoc());
-    loop->getParentOfType<mlir::ModuleOp>().print(
-        llvm::errs(), mlir::OpPrintingFlags().assumeVerified());
 
     return {loopNestOp, wsloopOp};
   }
@@ -771,15 +733,26 @@ class DoConcurrentConversion
                                             liveInName, shape);
   }
 
-  mlir::omp::TeamsOp
-  genTeamsOp(mlir::Location loc,
-             mlir::ConversionPatternRewriter &rewriter) const {
-    auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(
-        loc, /*clauses=*/mlir::omp::TeamsOperands{});
+  mlir::omp::TeamsOp genTeamsOp(mlir::ConversionPatternRewriter &rewriter,
+                                fir::DoConcurrentLoopOp loop,
+                                mlir::IRMapping &mapper) const {
+    mlir::omp::TeamsOperands teamsOps;
+    genReductions(rewriter, mapper, loop, teamsOps);
+
+    mlir::Location loc = loop.getLoc();
+    auto teamsOp = rewriter.create<mlir::omp::TeamsOp>(loc, teamsOps);
+    Fortran::common::openmp::EntryBlockArgs teamsArgs;
+    teamsArgs.reduction.vars = teamsOps.reductionVars;
+    Fortran::common::openmp::genEntryBlock(rewriter, teamsArgs,
+                                           teamsOp.getRegion());
 
-    rewriter.createBlock(&teamsOp.getRegion());
     rewriter.setInsertionPoint(rewriter.create<mlir::omp::TerminatorOp>(loc));
 
+    for (auto [loopVar, teamsArg] : llvm::zip_equal(
+             loop.getReduceVars(), teamsOp.getRegion().getArguments())) {
+      mapper.map(loopVar, teamsArg);
+    }
+
     return teamsOp;
   }
 
@@ -846,6 +819,52 @@ class DoConcurrentConversion
       }
   }
 
+  void genReductions(mlir::ConversionPatternRewriter &rewriter,
+                     mlir::IRMapping &mapper, fir::DoConcurrentLoopOp loop,
+                     mlir::omp::ReductionClauseOps &reductionClauseOps) const {
+    if (!loop.getReduceVars().empty()) {
+      for (auto [var, byRef, sym, arg] : llvm::zip_equal(
+               loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
+               loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
+               loop.getRegionReduceArgs())) {
+        auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
+            sym.getLeafReference());
+
+        mlir::OpBuilder::InsertionGuard guard(rewriter);
+        rewriter.setInsertionPointAfter(firReducer);
+        std::string ompReducerName = sym.getLeafReference().str() + ".omp";
+
+        auto ompReducer =
+            moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
+                rewriter.getStringAttr(ompReducerName));
+
+        if (!ompReducer) {
+          ompReducer = mlir::omp::DeclareReductionOp::create(
+              rewriter, firReducer.getLoc(), ompReducerName,
+              firReducer.getTypeAttr().getValue());
+
+          cloneFIRRegionToOMP(rewriter, firReducer.getAllocRegion(),
+                              ompReducer.getAllocRegion());
+          cloneFIRRegionToOMP(rewriter, firReducer.getInitializerRegion(),
+                              ompReducer.getInitializerRegion());
+          cloneFIRRegionToOMP(rewriter, firReducer.getReductionRegion(),
+                              ompReducer.getReductionRegion());
+          cloneFIRRegionToOMP(rewriter, firReducer.getAtomicReductionRegion(),
+                              ompReducer.getAtomicReductionRegion());
+          cloneFIRRegionToOMP(rewriter, firReducer.getCleanupRegion(),
+                              ompReducer.getCleanupRegion());
+          moduleSymbolTable.insert(ompReducer);
+        }
+
+        reductionClauseOps.reductionVars.push_back(
+            mapToDevice ? mapper.lookup(var) : var);
+        reductionClauseOps.reductionByref.push_back(byRef);
+        reductionClauseOps.reductionSyms.push_back(
+            mlir::SymbolRefAttr::get(ompReducer));
+      }
+    }
+  }
+
   bool mapToDevice;
   llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
   mlir::SymbolTable &moduleSymbolTable;
diff --git a/flang/test/Transforms/DoConcurrent/reduce_device.mlir b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
new file mode 100644
index 0000000000000..3e46692a15dca
--- /dev/null
+++ b/flang/test/Transforms/DoConcurrent/reduce_device.mlir
@@ -0,0 +1,53 @@
+// RUN: fir-opt --omp-do-concurrent-conversion="map-to=device" %s -o - | FileCheck %s
+
+fir.declare_reduction @add_reduction_f32 : f32 init {
+^bb0(%arg0: f32):
+  %cst = arith.constant 0.000000e+00 : f32
+  fir.yield(%cst : f32)
+} combiner {
+^bb0(%arg0: f32, %arg1: f32):
+  %0 = arith.addf %arg0, %arg1 fastmath<contract> : f32
+  fir.yield(%0 : f32)
+}
+
+func.func @_QPfoo() {
+  %0 = fir.dummy_scope : !fir.dscope
+  %3 = fir.alloca f32 {bindc_name = "s", uniq_name = "_QFfooEs"}
+  %4:2 = hlfir.declare %3 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+    %c1 = arith.constant 1 : index
+  %c10 = arith.constant 1 : index
+  fir.do_concurrent {
+    %7 = fir.alloca i32 {bindc_name = "i"}
+    %8:2 = hlfir.declare %7 {uniq_name = "_QFfooEi"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
+    fir.do_concurrent.loop (%arg0) = (%c1) to (%c10) step (%c1) reduce(@add_reduction_f32 #fir.reduce_attr<add> %4#0 -> %arg1 : !fir.ref<f32>) {
+      %9 = fir.convert %arg0 : (index) -> i32
+      fir.store %9 to %8#0 : !fir.ref<i32>
+      %10:2 = hlfir.declare %arg1 {uniq_name = "_QFfooEs"} : (!fir.ref<f32>) -> (!fir.ref<f32>, !fir.ref<f32>)
+      %11 = fir.load %10#0 : !fir.ref<f32>
+      %cst = arith.constant 1.000000e+00 : f32
+      %12 = arith.addf %11, %cst fastmath<contract> : f32
+      hlfir.assign %12 to %10#0 : f32, !fir.ref<f32>
+    }
+  }
+  return
+}
+
+// CHECK: omp.declare_reduction @[[OMP_RED:.*.omp]] : f32
+
+// CHECK: %[[S_DECL:.*]]:2 = hlfir.declare %6 {uniq_name = "_QFfooEs"}
+// CHECK: %[[S_MAP:.*]] = omp.map.info var_ptr(%[[S_DECL]]#1
+
+// CHECK: omp.target host_eval({{.*}}) map_entries({{.*}}, %[[S_MAP]] -> %[[S_TARGET_ARG:.*]] : {{.*}}) {
+// CHECK:   %[[S_DEV_DECL:.*]]:2 = hlfir.declare %[[S_TARGET_ARG]]
+// CHECK:   omp.teams reduction(@[[OMP_RED]] %[[S_DEV_DECL]]#0 -> %[[RED_TEAMS_ARG:.*]] : !fir.ref<f32>) {
+// CHECK:   omp.parallel {
+// CHECK:     omp.distribute {
+// CHECK:       omp.wsloop reduction(@[[OMP_RED]] %[[RED_TEAMS_ARG]] -> %[[RED_WS_ARG:.*]] : {{.*}}) {
+// CHECK:         %[[S_WS_DECL:.*]]:2 = hlfir.declare %[[RED_WS_ARG]] {uniq_name = "_QFfooEs"}
+// CHECK:         %[[S_VAL:.*]] = fir.load %[[S_WS_DECL]]#0
+// CHECK:         %[[RED_RES:.*]] = arith.addf %[[S_VAL]], %{{.*}} fastmath<contract> : f32
+// CHECK:         hlfir.assign %[[RED_RES]] to %[[S_WS_DECL]]#0
+// CHECK:       }
+// CHECK:     }
+// CHECK:   }
+// CHECK: }

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants