diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 22608a16cc1ab..7e5ce26b5f733 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -427,6 +427,21 @@ bool mlir::affine::isValidSymbol(Value value) { return false; } +/// A utility function to check if a value is defined at the top level of +/// `region` or is an argument of `region` or is defined above the region. +static bool isTopLevelValueOrAbove(Value value, Region *region) { + Region *parentRegion = value.getParentRegion(); + do { + if (parentRegion == region) + return true; + Operation *regionOp = region->getParentOp(); + if (regionOp->hasTrait()) + break; + region = region->getParentOp()->getParentRegion(); + } while (region); + return false; +} + /// A value can be used as a symbol for `region` iff it meets one of the /// following conditions: /// *) It is a constant. @@ -445,19 +460,12 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { return false; // A top-level value is a valid symbol. - if (region && ::isTopLevelValue(value, region)) + if (region && isTopLevelValueOrAbove(value, region)) return true; auto *defOp = value.getDefiningOp(); - if (!defOp) { - // A block argument that is not a top-level value is a valid symbol if it - // dominates region's parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait()) - if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentOpRegion); + if (!defOp) return false; - } // Constant operation is ok. Attribute operandCst; @@ -475,12 +483,6 @@ bool mlir::affine::isValidSymbol(Value value, Region *region) { if (auto dimOp = dyn_cast(defOp)) return isDimOpValidSymbol(dimOp, region); - // Check for values dominating `region`'s parent op. - Operation *regionOp = region ? region->getParentOp() : nullptr; - if (regionOp && !regionOp->hasTrait()) - if (auto *parentRegion = region->getParentOp()->getParentRegion()) - return isValidSymbol(value, parentRegion); - return false; }