diff --git a/llvm/include/llvm/Frontend/OpenMP/OMP.td b/llvm/include/llvm/Frontend/OpenMP/OMP.td index 202f684d808bc..6eaf19eed40c8 100644 --- a/llvm/include/llvm/Frontend/OpenMP/OMP.td +++ b/llvm/include/llvm/Frontend/OpenMP/OMP.td @@ -1322,6 +1322,17 @@ def OMP_EndWorkshare : Directive<[Spelling<"end workshare">]> { let category = OMP_Workshare.category; let languages = [L_Fortran]; } +def OMP_Workdistribute : Directive<[Spelling<"workdistribute">]> { + let association = AS_Block; + let category = CA_Executable; + let languages = [L_Fortran]; +} +def OMP_EndWorkdistribute : Directive<[Spelling<"end workdistribute">]> { + let leafConstructs = OMP_Workdistribute.leafConstructs; + let association = OMP_Workdistribute.association; + let category = OMP_Workdistribute.category; + let languages = [L_Fortran]; +} //===----------------------------------------------------------------------===// // Definitions of OpenMP compound directives @@ -2482,6 +2493,35 @@ def OMP_TargetTeamsDistributeSimd let leafConstructs = [OMP_Target, OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TargetTeamsWorkdistribute : Directive<[Spelling<"target teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Target, OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_target_teams_loop : Directive<[Spelling<"target teams loop">]> { let allowedClauses = [ VersionedClause, @@ -2723,6 +2763,25 @@ def OMP_TeamsDistributeSimd : Directive<[Spelling<"teams distribute simd">]> { let leafConstructs = [OMP_Teams, OMP_Distribute, OMP_Simd]; let category = CA_Executable; } +def OMP_TeamsWorkdistribute : Directive<[Spelling<"teams workdistribute">]> { + let allowedClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let allowedOnceClauses = [ + VersionedClause, + VersionedClause, + VersionedClause, + VersionedClause, + ]; + let leafConstructs = [OMP_Teams, OMP_Workdistribute]; + let category = CA_Executable; + let languages = [L_Fortran]; +} def OMP_teams_loop : Directive<[Spelling<"teams loop">]> { let allowedClauses = [ VersionedClause, diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td index c956d69781b3d..2548a8ab4aac6 100644 --- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td +++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td @@ -2209,4 +2209,27 @@ def TargetFreeMemOp : OpenMP_Op<"target_freemem", let assemblyFormat = "$device `,` $heapref attr-dict `:` type($device) `,` qualified(type($heapref))"; } +//===----------------------------------------------------------------------===// +// workdistribute Construct +//===----------------------------------------------------------------------===// + +def WorkdistributeOp : OpenMP_Op<"workdistribute"> { + let summary = "workdistribute directive"; + let description = [{ + workdistribute divides execution of the enclosed structured block into + separate units of work, each executed only once by each + initial thread in the league. + ``` + !$omp target teams + !$omp workdistribute + y = a * x + y + !$omp end workdistribute + !$omp end target teams + ``` + }]; + let regions = (region AnyRegion:$region); + let hasVerifier = 1; + let assemblyFormat = "$region attr-dict"; +} + #endif // OPENMP_OPS diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index fa94219016c1e..6e43f28e8d93d 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -3975,6 +3975,58 @@ llvm::LogicalResult omp::TargetAllocMemOp::verify() { return mlir::success(); } +//===----------------------------------------------------------------------===// +// WorkdistributeOp +//===----------------------------------------------------------------------===// + +LogicalResult WorkdistributeOp::verify() { + // Check that region exists and is not empty + Region ®ion = getRegion(); + if (region.empty()) + return emitOpError("region cannot be empty"); + // Verify single entry point. + Block &entryBlock = region.front(); + if (entryBlock.empty()) + return emitOpError("region must contain a structured block"); + // Verify single exit point. + bool hasTerminator = false; + for (Block &block : region) { + if (isa(block.back())) { + if (hasTerminator) { + return emitOpError("region must have exactly one terminator"); + } + hasTerminator = true; + } + } + if (!hasTerminator) { + return emitOpError("region must be terminated with omp.terminator"); + } + auto walkResult = region.walk([&](Operation *op) -> WalkResult { + // No implicit barrier at end + if (isa(op)) { + return emitOpError( + "explicit barriers are not allowed in workdistribute region"); + } + // Check for invalid nested constructs + if (isa(op)) { + return emitOpError( + "nested parallel constructs not allowed in workdistribute"); + } + if (isa(op)) { + return emitOpError( + "nested teams constructs not allowed in workdistribute"); + } + return WalkResult::advance(); + }); + if (walkResult.wasInterrupted()) + return failure(); + + Operation *parentOp = (*this)->getParentOp(); + if (!llvm::dyn_cast(parentOp)) + return emitOpError("workdistribute must be nested under teams"); + return success(); +} + #define GET_ATTRDEF_CLASSES #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc" diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir index 5088f2dfa7d7a..986c3844d0bb9 100644 --- a/mlir/test/Dialect/OpenMP/invalid.mlir +++ b/mlir/test/Dialect/OpenMP/invalid.mlir @@ -3017,3 +3017,110 @@ func.func @invalid_allocate_allocator(%arg0 : memref) -> () { return } + +// ----- +func.func @invalid_workdistribute_empty_region() -> () { + omp.teams { + // expected-error @below {{region cannot be empty}} + omp.workdistribute { + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_no_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_wrong_terminator() -> () { + omp.teams { + // expected-error @below {{region must be terminated with omp.terminator}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + func.return + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_multiple_terminators() -> () { + omp.teams { + // expected-error @below {{region must have exactly one terminator}} + omp.workdistribute { + %cond = arith.constant true + cf.cond_br %cond, ^bb1, ^bb2 + ^bb1: + omp.terminator + ^bb2: + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_with_barrier() -> () { + omp.teams { + // expected-error @below {{explicit barriers are not allowed in workdistribute region}} + omp.workdistribute { + %c0 = arith.constant 0 : i32 + omp.barrier + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute_nested_parallel() -> () { + omp.teams { + // expected-error @below {{nested parallel constructs not allowed in workdistribute}} + omp.workdistribute { + omp.parallel { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +// Test: nested teams not allowed in workdistribute +func.func @invalid_workdistribute_nested_teams() -> () { + omp.teams { + // expected-error @below {{nested teams constructs not allowed in workdistribute}} + omp.workdistribute { + omp.teams { + omp.terminator + } + omp.terminator + } + omp.terminator + } + return +} + +// ----- +func.func @invalid_workdistribute() -> () { +// expected-error @below {{workdistribute must be nested under teams}} + omp.workdistribute { + omp.terminator + } + return +} diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir index 8c846cde1a3ca..3c2e0a3b7cc15 100644 --- a/mlir/test/Dialect/OpenMP/ops.mlir +++ b/mlir/test/Dialect/OpenMP/ops.mlir @@ -3238,3 +3238,15 @@ func.func @omp_allocate_dir(%arg0 : memref, %arg1 : memref) -> () { return } +// CHECK-LABEL: func.func @omp_workdistribute +func.func @omp_workdistribute() { + // CHECK: omp.teams + omp.teams { + // CHECK: omp.workdistribute + omp.workdistribute { + omp.terminator + } + omp.terminator + } + return +}