diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h index dd610c9702c28..7ced6ed9b44d6 100644 --- a/flang/include/flang/Semantics/openmp-directive-sets.h +++ b/flang/include/flang/Semantics/openmp-directive-sets.h @@ -143,6 +143,7 @@ static const OmpDirectiveSet topTargetSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, }; static const OmpDirectiveSet allTargetSet{topTargetSet}; @@ -172,6 +173,7 @@ static const OmpDirectiveSet topTeamsSet{ Directive::OMPD_teams_distribute_parallel_do_simd, Directive::OMPD_teams_distribute_simd, Directive::OMPD_teams_loop, + Directive::OMPD_teams_workdistribute, }; static const OmpDirectiveSet bottomTeamsSet{ @@ -187,9 +189,16 @@ static const OmpDirectiveSet allTeamsSet{ Directive::OMPD_target_teams_distribute_parallel_do_simd, Directive::OMPD_target_teams_distribute_simd, Directive::OMPD_target_teams_loop, + Directive::OMPD_target_teams_workdistribute, } | topTeamsSet, }; +static const OmpDirectiveSet allWorkdistributeSet{ + Directive::OMPD_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_target_teams_workdistribute, +}; + //===----------------------------------------------------------------------===// // Directive sets for groups of multiple directives //===----------------------------------------------------------------------===// @@ -230,6 +239,9 @@ static const OmpDirectiveSet blockConstructSet{ Directive::OMPD_taskgroup, Directive::OMPD_teams, Directive::OMPD_workshare, + Directive::OMPD_target_teams_workdistribute, + Directive::OMPD_teams_workdistribute, + Directive::OMPD_workdistribute, }; static const OmpDirectiveSet loopConstructSet{ @@ -294,6 +306,7 @@ static const OmpDirectiveSet workShareSet{ Directive::OMPD_scope, Directive::OMPD_sections, Directive::OMPD_single, + Directive::OMPD_workdistribute, } | allDoSet, }; @@ -376,6 +389,7 @@ static const OmpDirectiveSet nestedReduceWorkshareAllowedSet{ }; static const OmpDirectiveSet nestedTeamsAllowedSet{ + Directive::OMPD_workdistribute, Directive::OMPD_distribute, Directive::OMPD_distribute_parallel_do, Directive::OMPD_distribute_parallel_do_simd, diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp index ae60432afccd0..529f51467b45f 100644 --- a/flang/lib/Lower/OpenMP/OpenMP.cpp +++ b/flang/lib/Lower/OpenMP/OpenMP.cpp @@ -534,6 +534,13 @@ static void processHostEvalClauses(lower::AbstractConverter &converter, cp.processCollapse(loc, eval, hostInfo->ops, hostInfo->iv); break; + case OMPD_teams_workdistribute: + cp.processThreadLimit(stmtCtx, hostInfo->ops); + [[fallthrough]]; + case OMPD_target_teams_workdistribute: + cp.processNumTeams(stmtCtx, hostInfo->ops); + break; + // Standalone 'target' case. case OMPD_target: { processSingleNestedIf( @@ -2818,6 +2825,17 @@ genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable, queue, item, clauseOps); } +static mlir::omp::WorkdistributeOp genWorkdistributeOp( + lower::AbstractConverter &converter, lower::SymMap &symTable, + semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval, + mlir::Location loc, const ConstructQueue &queue, + ConstructQueue::const_iterator item) { + return genOpWithBody( + OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval, + llvm::omp::Directive::OMPD_workdistribute), + queue, item); +} + //===----------------------------------------------------------------------===// // Code generation functions for the standalone version of constructs that can // also be a leaf of a composite construct @@ -3454,7 +3472,10 @@ static void genOMPDispatch(lower::AbstractConverter &converter, case llvm::omp::Directive::OMPD_unroll: genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); break; - // case llvm::omp::Directive::OMPD_workdistribute: + case llvm::omp::Directive::OMPD_workdistribute: + newOp = genWorkdistributeOp(converter, symTable, semaCtx, eval, loc, queue, + item); + break; case llvm::omp::Directive::OMPD_workshare: newOp = genWorkshareOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item); diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp index 46b14861096f1..3fa7ca76c1642 100644 --- a/flang/lib/Parser/openmp-parsers.cpp +++ b/flang/lib/Parser/openmp-parsers.cpp @@ -1845,11 +1845,15 @@ TYPE_PARSER( // MakeBlockConstruct(llvm::omp::Directive::OMPD_target_data) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target_parallel) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target_teams) || + MakeBlockConstruct( + llvm::omp::Directive::OMPD_target_teams_workdistribute) || MakeBlockConstruct(llvm::omp::Directive::OMPD_target) || MakeBlockConstruct(llvm::omp::Directive::OMPD_task) || MakeBlockConstruct(llvm::omp::Directive::OMPD_taskgroup) || MakeBlockConstruct(llvm::omp::Directive::OMPD_teams) || - MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare)) + MakeBlockConstruct(llvm::omp::Directive::OMPD_teams_workdistribute) || + MakeBlockConstruct(llvm::omp::Directive::OMPD_workshare) || + MakeBlockConstruct(llvm::omp::Directive::OMPD_workdistribute)) #undef MakeBlockConstruct // OMP SECTIONS Directive diff --git a/flang/lib/Semantics/check-omp-structure.cpp b/flang/lib/Semantics/check-omp-structure.cpp index cbe6b2c68bf05..93e8ff1f81450 100644 --- a/flang/lib/Semantics/check-omp-structure.cpp +++ b/flang/lib/Semantics/check-omp-structure.cpp @@ -141,6 +141,67 @@ class OmpWorkshareBlockChecker { parser::CharBlock source_; }; +// 'OmpWorkdistributeBlockChecker' is used to check the validity of the +// assignment statements and the expressions enclosed in an OpenMP +// workdistribute construct +class OmpWorkdistributeBlockChecker { +public: + OmpWorkdistributeBlockChecker( + SemanticsContext &context, parser::CharBlock source) + : context_{context}, source_{source} {} + + template bool Pre(const T &) { return true; } + template void Post(const T &) {} + + bool Pre(const parser::AssignmentStmt &assignment) { + const auto &var{std::get(assignment.t)}; + const auto &expr{std::get(assignment.t)}; + const auto *lhs{GetExpr(context_, var)}; + const auto *rhs{GetExpr(context_, expr)}; + if (lhs && rhs) { + Tristate isDefined{semantics::IsDefinedAssignment( + lhs->GetType(), lhs->Rank(), rhs->GetType(), rhs->Rank())}; + if (isDefined == Tristate::Yes) { + context_.Say(expr.source, + "Defined assignment statement is not " + "allowed in a WORKDISTRIBUTE construct"_err_en_US); + } + } + return true; + } + + bool Pre(const parser::Expr &expr) { + if (const auto *e{GetExpr(context_, expr)}) { + for (const Symbol &symbol : evaluate::CollectSymbols(*e)) { + const Symbol &root{GetAssociationRoot(symbol)}; + if (IsFunction(root)) { + std::string attrs{""}; + if (!IsElementalProcedure(root)) { + attrs = " non-ELEMENTAL"; + } + if (root.attrs().test(Attr::IMPURE)) { + if (attrs != "") { + attrs = "," + attrs; + } + attrs = " IMPURE" + attrs; + } + if (attrs != "") { + context_.Say(expr.source, + "User defined%s function '%s' is not allowed in a " + "WORKDISTRIBUTE construct"_err_en_US, + attrs, root.name()); + } + } + } + } + return false; + } + +private: + SemanticsContext &context_; + parser::CharBlock source_; +}; + // `OmpUnitedTaskDesignatorChecker` is used to check if the designator // can appear within the TASK construct class OmpUnitedTaskDesignatorChecker { @@ -809,6 +870,13 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { "TARGET construct with nested TEAMS region contains statements or " "directives outside of the TEAMS construct"_err_en_US); } + if (GetContext().directive == llvm::omp::Directive::OMPD_workdistribute && + GetContextParent().directive != llvm::omp::Directive::OMPD_teams) { + context_.Say(x.BeginDir().DirName().source, + "%s region can only be strictly nested within the " + "teams region"_err_en_US, + ContextDirectiveAsFortran()); + } } CheckNoBranching(block, beginSpec.DirId(), beginSpec.source); @@ -892,6 +960,17 @@ void OmpStructureChecker::Enter(const parser::OpenMPBlockConstruct &x) { HasInvalidWorksharingNesting( beginSpec.source, llvm::omp::nestedWorkshareErrSet); break; + case llvm::omp::OMPD_workdistribute: + if (!CurrentDirectiveIsNested()) { + context_.Say(beginSpec.source, + "A workdistribute region must be nested inside teams region only."_err_en_US); + } + CheckWorkdistributeBlockStmts(block, beginSpec.source); + break; + case llvm::omp::OMPD_teams_workdistribute: + case llvm::omp::OMPD_target_teams_workdistribute: + CheckWorkdistributeBlockStmts(block, beginSpec.source); + break; case llvm::omp::Directive::OMPD_scope: case llvm::omp::Directive::OMPD_single: // TODO: This check needs to be extended while implementing nesting of @@ -4470,6 +4549,22 @@ void OmpStructureChecker::CheckWorkshareBlockStmts( } } +void OmpStructureChecker::CheckWorkdistributeBlockStmts( + const parser::Block &block, parser::CharBlock source) { + OmpWorkdistributeBlockChecker ompWorkdistributeBlockChecker{context_, source}; + + for (auto it{block.begin()}; it != block.end(); ++it) { + if (parser::Unwrap(*it)) { + parser::Walk(*it, ompWorkdistributeBlockChecker); + } else { + context_.Say(source, + "The structured block in a WORKDISTRIBUTE construct may consist of " + "only " + "SCALAR or ARRAY assignments"_err_en_US); + } + } +} + void OmpStructureChecker::CheckIfContiguous(const parser::OmpObject &object) { if (auto contig{IsContiguous(context_, object)}; contig && !*contig) { const parser::Name *name{GetObjectName(object)}; diff --git a/flang/lib/Semantics/check-omp-structure.h b/flang/lib/Semantics/check-omp-structure.h index 6b33ca6ab583f..350c22ddf51aa 100644 --- a/flang/lib/Semantics/check-omp-structure.h +++ b/flang/lib/Semantics/check-omp-structure.h @@ -242,6 +242,7 @@ class OmpStructureChecker llvmOmpClause clause, const parser::OmpObjectList &ompObjectList); bool CheckTargetBlockOnlyTeams(const parser::Block &); void CheckWorkshareBlockStmts(const parser::Block &, parser::CharBlock); + void CheckWorkdistributeBlockStmts(const parser::Block &, parser::CharBlock); void CheckIteratorRange(const parser::OmpIteratorSpecifier &x); void CheckIteratorModifier(const parser::OmpIterator &x); diff --git a/flang/lib/Semantics/resolve-directives.cpp b/flang/lib/Semantics/resolve-directives.cpp index fe0d2a73805de..7ded1b2dfda22 100644 --- a/flang/lib/Semantics/resolve-directives.cpp +++ b/flang/lib/Semantics/resolve-directives.cpp @@ -1735,10 +1735,13 @@ bool OmpAttributeVisitor::Pre(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_taskgroup: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_workshare: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_teams_workdistribute: PushContext(dirSpec.source, dirId); break; default: @@ -1768,9 +1771,12 @@ void OmpAttributeVisitor::Post(const parser::OpenMPBlockConstruct &x) { case llvm::omp::Directive::OMPD_target: case llvm::omp::Directive::OMPD_task: case llvm::omp::Directive::OMPD_teams: + case llvm::omp::Directive::OMPD_workdistribute: case llvm::omp::Directive::OMPD_parallel_workshare: case llvm::omp::Directive::OMPD_target_teams: - case llvm::omp::Directive::OMPD_target_parallel: { + case llvm::omp::Directive::OMPD_target_parallel: + case llvm::omp::Directive::OMPD_target_teams_workdistribute: + case llvm::omp::Directive::OMPD_teams_workdistribute: { bool hasPrivate; for (const auto *allocName : allocateNames_) { hasPrivate = false; diff --git a/flang/test/Lower/OpenMP/workdistribute.f90 b/flang/test/Lower/OpenMP/workdistribute.f90 new file mode 100644 index 0000000000000..dc66cd73e692b --- /dev/null +++ b/flang/test/Lower/OpenMP/workdistribute.f90 @@ -0,0 +1,30 @@ +! RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s + +! CHECK-LABEL: func @_QPtarget_teams_workdistribute +subroutine target_teams_workdistribute() + integer :: aa(10), bb(10) + ! CHECK: omp.target + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp target teams workdistribute + aa = bb + ! CHECK: omp.terminator + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end target teams workdistribute +end subroutine target_teams_workdistribute + +! CHECK-LABEL: func @_QPteams_workdistribute +subroutine teams_workdistribute() + use iso_fortran_env + real(kind=real32) :: a + real(kind=real32), dimension(10) :: x + real(kind=real32), dimension(10) :: y + ! CHECK: omp.teams + ! CHECK: omp.workdistribute + !$omp teams workdistribute + y = a * x + y + ! CHECK: omp.terminator + ! CHECK: omp.terminator + !$omp end teams workdistribute +end subroutine teams_workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute01.f90 b/flang/test/Semantics/OpenMP/workdistribute01.f90 new file mode 100644 index 0000000000000..76ddbc74ea4ec --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute01.f90 @@ -0,0 +1,16 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 6.0 +! workdistribute Construct +! Invalid do construct inside !$omp workdistribute + +subroutine workdistribute() + integer n, i + !ERROR: A workdistribute region must be nested inside teams region only. + !ERROR: The structured block in a WORKDISTRIBUTE construct may consist of only SCALAR or ARRAY assignments + !$omp workdistribute + do i = 1, n + print *, "omp workdistribute" + end do + !$omp end workdistribute + +end subroutine workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute02.f90 b/flang/test/Semantics/OpenMP/workdistribute02.f90 new file mode 100644 index 0000000000000..ad2cde2c3daf0 --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute02.f90 @@ -0,0 +1,34 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 6.0 +! workdistribute Construct +! The !omp workdistribute construct must not contain any user defined +! function calls unless the function is ELEMENTAL. + +module my_mod + contains + integer function my_func() + my_func = 10 + end function my_func + + impure integer function impure_my_func() + impure_my_func = 20 + end function impure_my_func + + impure elemental integer function impure_ele_my_func() + impure_ele_my_func = 20 + end function impure_ele_my_func +end module my_mod + +subroutine workdistribute(aa, bb, cc, n) + use my_mod + integer n + real aa(n), bb(n), cc(n) + !$omp teams + !$omp workdistribute + !ERROR: User defined non-ELEMENTAL function 'my_func' is not allowed in a WORKDISTRIBUTE construct + aa = my_func() + aa = bb * cc + !$omp end workdistribute + !$omp end teams + +end subroutine workdistribute diff --git a/flang/test/Semantics/OpenMP/workdistribute03.f90 b/flang/test/Semantics/OpenMP/workdistribute03.f90 new file mode 100644 index 0000000000000..eac28cb39c47f --- /dev/null +++ b/flang/test/Semantics/OpenMP/workdistribute03.f90 @@ -0,0 +1,34 @@ +! RUN: %python %S/../test_errors.py %s %flang -fopenmp +! OpenMP Version 6.0 +! workdistribute Construct +! All array assignments, scalar assignments, and masked array assignments +! must be intrinsic assignments. + +module defined_assign + interface assignment(=) + module procedure work_assign + end interface + + contains + subroutine work_assign(a,b) + integer, intent(out) :: a + logical, intent(in) :: b(:) + end subroutine work_assign +end module defined_assign + +program omp_workdistribute + use defined_assign + + integer :: a, aa(10), bb(10) + logical :: l(10) + l = .TRUE. + + !$omp teams + !$omp workdistribute + !ERROR: Defined assignment statement is not allowed in a WORKDISTRIBUTE construct + a = l + aa = bb + !$omp end workdistribute + !$omp end teams + +end program omp_workdistribute diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 79f25bb05f20e..e9be1a6595da1 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1309,6 +1309,17 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } +def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { + let association = AS_Block; + let category = CA_Executable; + let languages = [L_Fortran]; +} +def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { + let leafConstructs = OMP_Workdistribute.leafConstructs; + let association = OMP_Workdistribute.association; + let category = OMP_Workdistribute.category; + let languages = [L_Fortran]; +} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2452,6 +2463,35 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, @@ -2682,6 +2722,25 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> { let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let allowedClauses = [ VersionedClause, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index be114ea4fb631..d87e65f04eca5 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2115,4 +2115,27 @@ def AllocateDirOp : OpenMP_Op<"allocate_dir", clauses = [ let hasVerifier = 1; } +//===----------------------------------------------------------------------===// +// workdistribute Construct +//===----------------------------------------------------------------------===// + +def WorkdistributeOp : OpenMP_Op<"workdistribute"> { + let summary = "workdistribute directive"; + let description = [{ + workdistribute divides execution of the enclosed structured block into + separate units of work, each executed only once by each + initial thread in the league. + ``` + !$omp target teams + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end target teams + ``` + }]; + let regions = (region AnyRegion:$region); + let hasVerifier = 1; + let assemblyFormat = "$region attr-dict"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index c1c1767ef90b0..bc2e3feacab02 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3874,6 +3874,58 @@ LogicalResult AllocateDirOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + // Check that region exists and is not empty + Region ®ion = getRegion(); + if (region.empty()) + return emitOpError("region cannot be empty"); + // Verify single entry point. + Block &entryBlock = region.front(); + if (entryBlock.empty()) + return emitOpError("region must contain a structured block"); + // Verify single exit point. + bool hasTerminator = false; + for (Block &block : region) { + if (isa(block.back())) { + if (hasTerminator) { + return emitOpError("region must have exactly one terminator"); + } + hasTerminator = true; + } + } + if (!hasTerminator) { + return emitOpError("region must be terminated with omp.terminator"); + } + auto walkResult = region.walk([&](Operation *op) -> WalkResult { + // No implicit barrier at end + if (isa(op)) { + return emitOpError( + "explicit barriers are not allowed in workdistribute region"); + } + // Check for invalid nested constructs + if (isa(op)) { + return emitOpError( + "nested parallel constructs not allowed in workdistribute"); + } + if (isa(op)) { + return emitOpError( + "nested teams constructs not allowed in workdistribute"); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 5088f2dfa7d7a..986c3844d0bb9 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3017,3 +3017,110 @@ func.func @invalid_allocate_allocator(%arg0 : memref) -> () { return } + +// ----- +func.func @invalid_workdistribute_empty_region() -> () { + omp.teams { + // expected-error @below {{region cannot be empty}} + omp.workdistribute { + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_no_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_wrong_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + func.return + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_multiple_terminators() -> () { + omp.teams { + // expected-error @below {{region must have exactly one terminator}} + omp.workdistribute { + %cond = arith.constant true + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + omp.terminator + ^bb2: + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_with_barrier() -> () { + omp.teams { + // expected-error @below {{explicit barriers are not allowed in workdistribute region}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + omp.barrier + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_nested_parallel() -> () { + omp.teams { + // expected-error @below {{nested parallel constructs not allowed in workdistribute}} + omp.workdistribute { + omp.parallel { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +// Test: nested teams not allowed in workdistribute +func.func @invalid_workdistribute_nested_teams() -> () { + omp.teams { + // expected-error @below {{nested teams constructs not allowed in workdistribute}} + omp.workdistribute { + omp.teams { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute() -> () { +// expected-error @below {{workdistribute must be nested under teams}} + omp.workdistribute { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 8c846cde1a3ca..3c2e0a3b7cc15 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3238,3 +3238,15 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { return } +// CHECK-LABEL: func.func @omp_workdistribute +func.func @omp_workdistribute() { + // CHECK: omp.teams + omp.teams { + // CHECK: omp.workdistribute + omp.workdistribute { + omp.terminator + } + omp.terminator + } + return +}