Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion flang/include/flang/Optimizer/OpenMP/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
];
}

def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::func::FuncOp"> {
def DoConcurrentConversionPass : Pass<"omp-do-concurrent-conversion", "mlir::ModuleOp"> {
let summary = "Map `DO CONCURRENT` loops to OpenMP worksharing loops.";

let description = [{ This is an experimental pass to map `DO CONCURRENT` loops
Expand Down
70 changes: 39 additions & 31 deletions flang/lib/Optimizer/OpenMP/DoConcurrentConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,9 +173,11 @@ class DoConcurrentConversion

DoConcurrentConversion(
mlir::MLIRContext *context, bool mapToDevice,
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip)
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip,
mlir::SymbolTable &moduleSymbolTable)
: OpConversionPattern(context), mapToDevice(mapToDevice),
concurrentLoopsToSkip(concurrentLoopsToSkip) {}
concurrentLoopsToSkip(concurrentLoopsToSkip),
moduleSymbolTable(moduleSymbolTable) {}

mlir::LogicalResult
matchAndRewrite(fir::DoConcurrentOp doLoop, OpAdaptor adaptor,
Expand Down Expand Up @@ -332,8 +334,8 @@ class DoConcurrentConversion
loop.getLocalVars(),
loop.getLocalSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionLocalArgs())) {
auto localizer = mlir::SymbolTable::lookupNearestSymbolFrom<
fir::LocalitySpecifierOp>(loop, sym);
auto localizer = moduleSymbolTable.lookup<fir::LocalitySpecifierOp>(
sym.getLeafReference());
if (localizer.getLocalitySpecifierType() ==
fir::LocalitySpecifierType::LocalInit)
TODO(localizer.getLoc(),
Expand All @@ -352,6 +354,8 @@ class DoConcurrentConversion
cloneFIRRegionToOMP(localizer.getDeallocRegion(),
privatizer.getDeallocRegion());

moduleSymbolTable.insert(privatizer);

wsloopClauseOps.privateVars.push_back(op);
wsloopClauseOps.privateSyms.push_back(
mlir::SymbolRefAttr::get(privatizer));
Expand All @@ -362,28 +366,34 @@ class DoConcurrentConversion
loop.getReduceVars(), loop.getReduceByrefAttr().asArrayRef(),
loop.getReduceSymsAttr().getAsRange<mlir::SymbolRefAttr>(),
loop.getRegionReduceArgs())) {
auto firReducer =
mlir::SymbolTable::lookupNearestSymbolFrom<fir::DeclareReductionOp>(
loop, sym);
auto firReducer = moduleSymbolTable.lookup<fir::DeclareReductionOp>(
sym.getLeafReference());

mlir::OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPointAfter(firReducer);

auto ompReducer = mlir::omp::DeclareReductionOp::create(
rewriter, firReducer.getLoc(),
sym.getLeafReference().str() + ".omp",
firReducer.getTypeAttr().getValue());

cloneFIRRegionToOMP(firReducer.getAllocRegion(),
ompReducer.getAllocRegion());
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
ompReducer.getInitializerRegion());
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
ompReducer.getReductionRegion());
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
ompReducer.getAtomicReductionRegion());
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
ompReducer.getCleanupRegion());
std::string ompReducerName = sym.getLeafReference().str() + ".omp";

auto ompReducer =
moduleSymbolTable.lookup<mlir::omp::DeclareReductionOp>(
rewriter.getStringAttr(ompReducerName));

if (!ompReducer) {
ompReducer = mlir::omp::DeclareReductionOp::create(
rewriter, firReducer.getLoc(), ompReducerName,
firReducer.getTypeAttr().getValue());

cloneFIRRegionToOMP(firReducer.getAllocRegion(),
ompReducer.getAllocRegion());
cloneFIRRegionToOMP(firReducer.getInitializerRegion(),
ompReducer.getInitializerRegion());
cloneFIRRegionToOMP(firReducer.getReductionRegion(),
ompReducer.getReductionRegion());
cloneFIRRegionToOMP(firReducer.getAtomicReductionRegion(),
ompReducer.getAtomicReductionRegion());
cloneFIRRegionToOMP(firReducer.getCleanupRegion(),
ompReducer.getCleanupRegion());
moduleSymbolTable.insert(ompReducer);
}

wsloopClauseOps.reductionVars.push_back(op);
wsloopClauseOps.reductionByref.push_back(byRef);
Expand Down Expand Up @@ -431,6 +441,7 @@ class DoConcurrentConversion

bool mapToDevice;
llvm::DenseSet<fir::DoConcurrentOp> &concurrentLoopsToSkip;
mlir::SymbolTable &moduleSymbolTable;
};

class DoConcurrentConversionPass
Expand All @@ -444,12 +455,9 @@ class DoConcurrentConversionPass
: DoConcurrentConversionPassBase(options) {}

void runOnOperation() override {
mlir::func::FuncOp func = getOperation();

if (func.isDeclaration())
return;

mlir::ModuleOp module = getOperation();
mlir::MLIRContext *context = &getContext();
mlir::SymbolTable moduleSymbolTable(module);

if (mapTo != flangomp::DoConcurrentMappingKind::DCMK_Host &&
mapTo != flangomp::DoConcurrentMappingKind::DCMK_Device) {
Expand All @@ -463,7 +471,7 @@ class DoConcurrentConversionPass
mlir::RewritePatternSet patterns(context);
patterns.insert<DoConcurrentConversion>(
context, mapTo == flangomp::DoConcurrentMappingKind::DCMK_Device,
concurrentLoopsToSkip);
concurrentLoopsToSkip, moduleSymbolTable);
mlir::ConversionTarget target(*context);
target.addDynamicallyLegalOp<fir::DoConcurrentOp>(
[&](fir::DoConcurrentOp op) {
Expand All @@ -472,8 +480,8 @@ class DoConcurrentConversionPass
target.markUnknownOpDynamicallyLegal(
[](mlir::Operation *) { return true; });

if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
std::move(patterns)))) {
if (mlir::failed(
mlir::applyFullConversion(module, target, std::move(patterns)))) {
signalPassFailure();
}
}
Expand Down
32 changes: 32 additions & 0 deletions flang/test/Transforms/DoConcurrent/reduction_symbol_resultion.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
! RUN: %flang_fc1 -emit-hlfir -fopenmp -fdo-concurrent-to-openmp=host %s -o - \
! RUN: | FileCheck %s

subroutine test1(x,s,N)
real :: x(N), s
integer :: N
do concurrent(i=1:N) reduce(+:s)
s=s+x(i)
end do
end subroutine test1
subroutine test2(x,s,N)
real :: x(N), s
integer :: N
do concurrent(i=1:N) reduce(+:s)
s=s+x(i)
end do
end subroutine test2

! CHECK: omp.declare_reduction @[[RED_SYM:.*]] : f32 init
! CHECK-NOT: omp.declare_reduction

! CHECK-LABEL: func.func @_QPtest1
! CHECK: omp.parallel {
! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
! CHECK: }
! CHECK: }

! CHECK-LABEL: func.func @_QPtest2
! CHECK: omp.parallel {
! CHECK: omp.wsloop reduction(@[[RED_SYM]] {{.*}} : !fir.ref<f32>) {
! CHECK: }
! CHECK: }