diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td index 99202f6ee81e7..e2f092024c250 100644 --- a/flang/include/flang/Optimizer/OpenMP/Passes.td +++ b/flang/include/flang/Optimizer/OpenMP/Passes.td @@ -50,7 +50,7 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> { ]; } -def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::func::FuncOp"> { +def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::ModuleOp"> { let summary = "Map `DO CONCURRENT` loops to OpenMP worksharing loops."; let description = [{ This is an experimental pass to map `DO CONCURRENT` loops diff --git a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp index 2b3ac169e8b5b..c928b76065ade 100644 --- a/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp +++ b/flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp @@ -173,9 +173,11 @@ class DoConcurrentConversion DoConcurrentConversion( mlir::MLIRContext *context, bool mapToDevice, - llvm::DenseSet &concurrentLoopsToSkip) + llvm::DenseSet &concurrentLoopsToSkip, + mlir::SymbolTable &moduleSymbolTable) : OpConversionPattern(context), mapToDevice(mapToDevice), - concurrentLoopsToSkip(concurrentLoopsToSkip) {} + concurrentLoopsToSkip(concurrentLoopsToSkip), + moduleSymbolTable(moduleSymbolTable) {} mlir::LogicalResult matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor, @@ -332,8 +334,8 @@ class DoConcurrentConversion loop.getLocalVars(), loop.getLocalSymsAttr().getAsRange(), loop.getRegionLocalArgs())) { - auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom< - fir::LocalitySpecifierOp>(loop, sym); + auto localizer = moduleSymbolTable.lookup( + sym.getLeafReference()); if (localizer.getLocalitySpecifierType() == fir::LocalitySpecifierType::LocalInit) TODO(localizer.getLoc(), @@ -352,6 +354,8 @@ class DoConcurrentConversion cloneFIRRegionToOMP(localizer.getDeallocRegion(), privatizer.getDeallocRegion()); + moduleSymbolTable.insert(privatizer); + wsloopClauseOps.privateVars.push_back(op); wsloopClauseOps.privateSyms.push_back( mlir::SymbolRefAttr::get(privatizer)); @@ -362,28 +366,34 @@ class DoConcurrentConversion loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(), loop.getReduceSymsAttr().getAsRange(), loop.getRegionReduceArgs())) { - auto firReducer = - mlir::SymbolTable::lookupNearestSymbolFrom( - loop, sym); + auto firReducer = moduleSymbolTable.lookup( + sym.getLeafReference()); mlir::OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointAfter(firReducer); - - auto ompReducer = mlir::omp::DeclareReductionOp::create( - rewriter, firReducer.getLoc(), - sym.getLeafReference().str() + ".omp", - firReducer.getTypeAttr().getValue()); - - cloneFIRRegionToOMP(firReducer.getAllocRegion(), - ompReducer.getAllocRegion()); - cloneFIRRegionToOMP(firReducer.getInitializerRegion(), - ompReducer.getInitializerRegion()); - cloneFIRRegionToOMP(firReducer.getReductionRegion(), - ompReducer.getReductionRegion()); - cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), - ompReducer.getAtomicReductionRegion()); - cloneFIRRegionToOMP(firReducer.getCleanupRegion(), - ompReducer.getCleanupRegion()); + std::string ompReducerName = sym.getLeafReference().str() + ".omp"; + + auto ompReducer = + moduleSymbolTable.lookup( + rewriter.getStringAttr(ompReducerName)); + + if (!ompReducer) { + ompReducer = mlir::omp::DeclareReductionOp::create( + rewriter, firReducer.getLoc(), ompReducerName, + firReducer.getTypeAttr().getValue()); + + cloneFIRRegionToOMP(firReducer.getAllocRegion(), + ompReducer.getAllocRegion()); + cloneFIRRegionToOMP(firReducer.getInitializerRegion(), + ompReducer.getInitializerRegion()); + cloneFIRRegionToOMP(firReducer.getReductionRegion(), + ompReducer.getReductionRegion()); + cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(), + ompReducer.getAtomicReductionRegion()); + cloneFIRRegionToOMP(firReducer.getCleanupRegion(), + ompReducer.getCleanupRegion()); + moduleSymbolTable.insert(ompReducer); + } wsloopClauseOps.reductionVars.push_back(op); wsloopClauseOps.reductionByref.push_back(byRef); @@ -431,6 +441,7 @@ class DoConcurrentConversion bool mapToDevice; llvm::DenseSet &concurrentLoopsToSkip; + mlir::SymbolTable &moduleSymbolTable; }; class DoConcurrentConversionPass @@ -444,12 +455,9 @@ class DoConcurrentConversionPass : DoConcurrentConversionPassBase(options) {} void runOnOperation() override { - mlir::func::FuncOp func = getOperation(); - - if (func.isDeclaration()) - return; - + mlir::ModuleOp module = getOperation(); mlir::MLIRContext *context = &getContext(); + mlir::SymbolTable moduleSymbolTable(module); if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host && mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) { @@ -463,7 +471,7 @@ class DoConcurrentConversionPass mlir::RewritePatternSet patterns(context); patterns.insert( context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device, - concurrentLoopsToSkip); + concurrentLoopsToSkip, moduleSymbolTable); mlir::ConversionTarget target(*context); target.addDynamicallyLegalOp( [&](fir::DoConcurrentOp op) { @@ -472,8 +480,8 @@ class DoConcurrentConversionPass target.markUnknownOpDynamicallyLegal( [](mlir::Operation *) { return true; }); - if (mlir::failed(mlir::applyFullConversion(getOperation(), target, - std::move(patterns)))) { + if (mlir::failed( + mlir::applyFullConversion(module, target, std::move(patterns)))) { signalPassFailure(); } } diff --git a/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90 b/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90 new file mode 100644 index 0000000000000..ab56a4f6c7e70 --- /dev/null +++ b/flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90 @@ -0,0 +1,32 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=host %s -o - \ +! RUN: | FileCheck %s + +subroutine test1(x,s,N) + real :: x(N), s + integer :: N + do concurrent(i=1:N) reduce(+:s) + s=s+x(i) + end do +end subroutine test1 +subroutine test2(x,s,N) + real :: x(N), s + integer :: N + do concurrent(i=1:N) reduce(+:s) + s=s+x(i) + end do +end subroutine test2 + +! CHECK: omp.declare_reduction @[[RED_SYM:.*]] : f32 init +! CHECK-NOT: omp.declare_reduction + +! CHECK-LABEL: func.func @_QPtest1 +! CHECK: omp.parallel { +! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref) { +! CHECK: } +! CHECK: } + +! CHECK-LABEL: func.func @_QPtest2 +! CHECK: omp.parallel { +! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref) { +! CHECK: } +! CHECK: }