diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h index e620067c15be9..ecd829ed14add 100644 --- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h @@ -213,6 +213,14 @@ scf::ForOp fuseIndependentSiblingForLoops(scf::ForOp target, scf::ForOp source, FailureOr normalizeForallOp(RewriterBase &rewriter, scf::ForallOp forallOp); +/// Check if the provided loops are perfectly nested for-loops. Perfect nesting +/// means: +/// 1. All loops are scf.for operations +/// 2. Each outer loop's region iter args match the inner loop's init args +/// 3. Each outer loop's yields match the inner loop's results +/// 4. Each region iter arg and result has exactly one use +bool isPerfectlyNestedForLoops(MutableArrayRef loops); + } // namespace mlir #endif // MLIR_DIALECT_SCF_UTILS_UTILS_H_ diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 250c413eff9e5..834c02126fa53 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -1916,63 +1916,6 @@ static FailureOr getConsumerFromLoopUses(RewriterBase &rewriter, return failure(); } -/// Check that the loop is perfectly nested. -/// The loops are expected to be ordered from outer most to inner most. -/// For example: -/// ``` -/// %0 = scf.for() -/// %1 = scf.for() -/// %2 = scf.for() -/// %3 = ... -/// yield %3 -/// yield %2 -/// yield %1 -/// ``` -/// Here loops should be [%0, %1]. -static bool -isPerfectlyNestedForLoops(MutableArrayRef loops) { - assert(!loops.empty() && "unexpected empty loop nest"); - if (loops.size() == 1) { - return isa_and_nonnull(loops.front().getOperation()); - } - for (auto [outerLoop, innerLoop] : - llvm::zip_equal(loops.drop_back(), loops.drop_front())) { - auto outerFor = dyn_cast_or_null(outerLoop.getOperation()); - auto innerFor = dyn_cast_or_null(innerLoop.getOperation()); - if (!outerFor || !innerFor) { - return false; - } - auto outerBBArgs = outerFor.getRegionIterArgs(); - auto innerIterArgs = innerFor.getInitArgs(); - if (outerBBArgs.size() != innerIterArgs.size()) { - return false; - } - - for (auto [outerBBArg, innerIterArg] : - llvm::zip_equal(outerBBArgs, innerIterArgs)) { - if (!llvm::hasSingleElement(outerBBArg.getUses()) || - innerIterArg != outerBBArg) { - return false; - } - } - - ValueRange outerYields = - cast(outerFor.getBody()->getTerminator())->getOperands(); - ValueRange innerResults = innerFor.getResults(); - if (outerYields.size() != innerResults.size()) { - return false; - } - for (auto [outerYield, innerResult] : - llvm::zip_equal(outerYields, innerResults)) { - if (!llvm::hasSingleElement(innerResult.getUses()) || - outerYield != innerResult) { - return false; - } - } - } - return true; -} - /// Fetch the untiled consumer of the outermost scf.for's result which is /// yielded by a tensor.insert_slice from the innermost scf.for. This function /// makes the following assumptions : diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 49102583ec5e7..684dff8121de6 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -1512,3 +1512,41 @@ FailureOr mlir::normalizeForallOp(RewriterBase &rewriter, rewriter.replaceOp(forallOp, normalizedForallOp); return normalizedForallOp; } + +bool mlir::isPerfectlyNestedForLoops( + MutableArrayRef loops) { + assert(!loops.empty() && "unexpected empty loop nest"); + if (loops.size() == 1) + return isa_and_nonnull(loops.front().getOperation()); + for (auto [outerLoop, innerLoop] : + llvm::zip_equal(loops.drop_back(), loops.drop_front())) { + auto outerFor = dyn_cast_or_null(outerLoop.getOperation()); + auto innerFor = dyn_cast_or_null(innerLoop.getOperation()); + if (!outerFor || !innerFor) + return false; + auto outerBBArgs = outerFor.getRegionIterArgs(); + auto innerIterArgs = innerFor.getInitArgs(); + if (outerBBArgs.size() != innerIterArgs.size()) + return false; + + for (auto [outerBBArg, innerIterArg] : + llvm::zip_equal(outerBBArgs, innerIterArgs)) { + if (!llvm::hasSingleElement(outerBBArg.getUses()) || + innerIterArg != outerBBArg) + return false; + } + + ValueRange outerYields = + cast(outerFor.getBody()->getTerminator())->getOperands(); + ValueRange innerResults = innerFor.getResults(); + if (outerYields.size() != innerResults.size()) + return false; + for (auto [outerYield, innerResult] : + llvm::zip_equal(outerYields, innerResults)) { + if (!llvm::hasSingleElement(innerResult.getUses()) || + outerYield != innerResult) + return false; + } + } + return true; +}