diff --git a/mlir/docs/Bindings/Python.md b/mlir/docs/Bindings/Python.md index bef9e7f54948d..8067c5cc217a7 100644 --- a/mlir/docs/Bindings/Python.md +++ b/mlir/docs/Bindings/Python.md @@ -216,13 +216,26 @@ added to an attached operation, they need to be re-parented to the containing module). Due to the validity and parenting accounting needs, `PyOperation` is the owner -for regions and blocks and needs to be a top-level type that we can count on not -aliasing. This let's us do things like selectively invalidating instances when -mutations occur without worrying that there is some alias to the same operation -in the hierarchy. Operations are also the only entity that are allowed to be in -a detached state, and they are interned at the context level so that there is -never more than one Python `mlir.ir.Operation` object for a unique -`MlirOperation`, regardless of how it is obtained. +for regions and blocks. Operations are also the only entities which are allowed to be in +a detached state. + +**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`. +This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op` +and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2` +will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any +operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**. +For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is +transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2` +become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to +`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the +purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that + +1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.; +2. Second, Transform the AST/erase operations/etc. via a single root object; +3. End. + +Ideally this should be done in a function body so that "End" corresponds to the end of the function and there are no +risks of Python wrapper objects leaking/living longer than necessary. The C/C++ API allows for Region/Block to also be detached, but it simplifies the ownership model a lot to eliminate that possibility in this API, allowing the @@ -238,11 +251,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on how hard it is to guarantee how mutations interact with their Python peer objects. We can cross that bridge easily when we get there. -Module, when used purely from the Python API, can't alias anyway, so we can use -it as a top-level ref type without a live-list for interning. If the API ever -changes such that this cannot be guaranteed (i.e. by letting you marshal a -native-defined Module in), then there would need to be a live table for it too. - ## User-level API ### Context Management @@ -1229,4 +1237,4 @@ The exceptions to the free-threading compatibility: - Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`). - Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`). - Usage of `mlir.dialects.transform.interpreter` is unsafe. -- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. \ No newline at end of file +- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe. diff --git a/mlir/include/mlir-c/IR.h b/mlir/include/mlir-c/IR.h index 71c7d4378677f..e97369778b377 100644 --- a/mlir/include/mlir-c/IR.h +++ b/mlir/include/mlir-c/IR.h @@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module); /// The returned module is null when the input operation was not a ModuleOp. MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op); +/// Checks if two modules are equal. +MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs); + //===----------------------------------------------------------------------===// // Operation state. //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 15889ddabd2c4..8ab8901cdc41f 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails. See also: https://mlir.llvm.org/docs/LangRef/ )"; +static const char kModuleCAPICreate[] = + R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr). +Note this returns a new object BUT _clear_mlir_module(module) must be called to +prevent double-frees (of the underlying mlir::Module). +)"; + static const char kOperationCreateDocstring[] = R"(Creates a new operation. @@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() { return getLiveContexts().size(); } -size_t PyMlirContext::getLiveOperationCount() { - nb::ft_lock_guard lock(liveOperationsMutex); - return liveOperations.size(); -} - -std::vector PyMlirContext::getLiveOperationObjects() { - std::vector liveObjects; - nb::ft_lock_guard lock(liveOperationsMutex); - for (auto &entry : liveOperations) - liveObjects.push_back(entry.second.second); - return liveObjects; -} - -size_t PyMlirContext::clearLiveOperations() { - - LiveOperationMap operations; - { - nb::ft_lock_guard lock(liveOperationsMutex); - std::swap(operations, liveOperations); - } - for (auto &op : operations) - op.second.second->setInvalid(); - size_t numInvalidated = operations.size(); - return numInvalidated; -} - -void PyMlirContext::clearOperation(MlirOperation op) { - PyOperation *pyOp; - { - nb::ft_lock_guard lock(liveOperationsMutex); - auto it = liveOperations.find(op.ptr); - if (it == liveOperations.end()) { - return; - } - pyOp = it->second.second; - liveOperations.erase(it); - } - pyOp->setInvalid(); -} - -void PyMlirContext::clearOperationsInside(PyOperationBase &op) { - using callBackData = struct { - PyOperation &rootOp; - bool rootSeen; - }; - callBackData data{op.getOperation(), false}; - // Mark all ops below the op that the passmanager will be rooted - // at (but not op itself - note the preorder) as invalid. - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - callBackData *data = static_cast(userData); - if (LLVM_LIKELY(data->rootSeen)) - data->rootOp.getOperation().getContext()->clearOperation(op); - else - data->rootSeen = true; - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - static_cast(&data), MlirWalkPreOrder); -} -void PyMlirContext::clearOperationsInside(MlirOperation op) { - PyOperationRef opRef = PyOperation::forOperation(getRef(), op); - clearOperationsInside(opRef->getOperation()); -} - -void PyMlirContext::clearOperationAndInside(PyOperationBase &op) { - MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op, - void *userData) { - PyMlirContextRef &contextRef = *static_cast(userData); - contextRef->clearOperation(op); - return MlirWalkResult::MlirWalkResultAdvance; - }; - mlirOperationWalk(op.getOperation(), invalidatingCallback, - &op.getOperation().getContext(), MlirWalkPreOrder); -} - -size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); } - nb::object PyMlirContext::contextEnter(nb::object context) { return PyThreadContextEntry::pushContext(context); } @@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() { PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module) : BaseContextObject(std::move(contextRef)), module(module) {} -PyModule::~PyModule() { - nb::gil_scoped_acquire acquire; - auto &liveModules = getContext()->liveModules; - assert(liveModules.count(module.ptr) == 1 && - "destroying module not in live map"); - liveModules.erase(module.ptr); - mlirModuleDestroy(module); -} +PyModule::~PyModule() { mlirModuleDestroy(module); } PyModuleRef PyModule::forModule(MlirModule module) { MlirContext context = mlirModuleGetContext(module); PyMlirContextRef contextRef = PyMlirContext::forContext(context); - nb::gil_scoped_acquire acquire; - auto &liveModules = contextRef->liveModules; - auto it = liveModules.find(module.ptr); - if (it == liveModules.end()) { - // Create. - PyModule *unownedModule = new PyModule(std::move(contextRef), module); - // Note that the default return value policy on cast is automatic_reference, - // which does not take ownership (delete will not be called). - // Just be explicit. - nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); - unownedModule->handle = pyRef; - liveModules[module.ptr] = - std::make_pair(unownedModule->handle, unownedModule); - return PyModuleRef(unownedModule, std::move(pyRef)); - } - // Use existing. - PyModule *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyModuleRef(existing, std::move(pyRef)); + // Create. + PyModule *unownedModule = new PyModule(std::move(contextRef), module); + // Note that the default return value policy on cast is `automatic_reference`, + // which means "does not take ownership, does not call delete/dtor". + // We use `take_ownership`, which means "Python will call the C++ destructor + // and delete operator when the Python wrapper is garbage collected", because + // MlirModule actually wraps OwningOpRef (see mlirModuleCreateParse + // etc). + nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership); + unownedModule->handle = pyRef; + return PyModuleRef(unownedModule, std::move(pyRef)); } nb::object PyModule::createFromCapsule(nb::object capsule) { @@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() { // If the operation has already been invalidated there is nothing to do. if (!valid) return; - - // Otherwise, invalidate the operation and remove it from live map when it is - // attached. - if (isAttached()) { - getContext()->clearOperation(*this); - } else { - // And destroy it when it is detached, i.e. owned by Python, in which case - // all nested operations must be invalidated at removed from the live map as - // well. + // Otherwise, invalidate the operation when it is attached. + if (isAttached()) + setInvalid(); + else { + // And destroy it when it is detached, i.e. owned by Python. erase(); } } @@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef, PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - auto it = liveOperations.find(operation.ptr); - if (it == liveOperations.end()) { - // Create. - PyOperationRef result = createInstance(std::move(contextRef), operation, - std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(result.getObject(), result.get()); - return result; - } - // Use existing. - PyOperation *existing = it->second.second; - nb::object pyRef = nb::borrow(it->second.first); - return PyOperationRef(existing, std::move(pyRef)); + return createInstance(std::move(contextRef), operation, + std::move(parentKeepAlive)); } PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef, MlirOperation operation, nb::object parentKeepAlive) { - nb::ft_lock_guard lock(contextRef->liveOperationsMutex); - auto &liveOperations = contextRef->liveOperations; - assert(liveOperations.count(operation.ptr) == 0 && - "cannot create detached operation that already exists"); - (void)liveOperations; PyOperationRef created = createInstance(std::move(contextRef), operation, std::move(parentKeepAlive)); - liveOperations[operation.ptr] = - std::make_pair(created.getObject(), created.get()); created->attached = false; return created; } @@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() { void PyOperation::erase() { checkValid(); - getContext()->clearOperationAndInside(*this); + setInvalid(); mlirOperationDestroy(operation); } @@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) { PyMlirContextRef ref = PyMlirContext::forContext(self.get()); return ref.releaseObject(); }) - .def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount) - .def("_get_live_operation_objects", - &PyMlirContext::getLiveOperationObjects) - .def("_clear_live_operations", &PyMlirContext::clearLiveOperations) - .def("_clear_live_operations_inside", - nb::overload_cast( - &PyMlirContext::clearOperationsInside)) - .def("_get_live_module_count", &PyMlirContext::getLiveModuleCount) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule) .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule) .def("__enter__", &PyMlirContext::contextEnter) @@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) { //---------------------------------------------------------------------------- nb::class_(m, "Module", nb::is_weak_referenceable()) .def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule) - .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule) + .def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule, + kModuleCAPICreate) + .def("_clear_mlir_module", &PyModule::clearMlirModule) .def_static( "parse", [](const std::string &moduleAsm, DefaultingPyMlirContext context) { @@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) { // Defer to the operation's __str__. return self.attr("operation").attr("__str__")(); }, - kOperationStrDunderDocstring); + kOperationStrDunderDocstring) + .def( + "__eq__", + [](PyModule &self, PyModule &other) { + return mlirModuleEqual(self.get(), other.get()); + }, + "other"_a); //---------------------------------------------------------------------------- // Mapping of Operation. @@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) { }) .def("__eq__", [](PyOperationBase &self, PyOperationBase &other) { - return &self.getOperation() == &other.getOperation(); + return mlirOperationEqual(self.getOperation().get(), + other.getOperation().get()); }) .def("__eq__", [](PyOperationBase &self, nb::object other) { return false; }) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 6617b41cc916c..0cc0459ebc9a0 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -218,40 +218,6 @@ class PyMlirContext { /// Gets the count of live context objects. Used for testing. static size_t getLiveCount(); - /// Get a list of Python objects which are still in the live context map. - std::vector getLiveOperationObjects(); - - /// Gets the count of live operations associated with this context. - /// Used for testing. - size_t getLiveOperationCount(); - - /// Clears the live operations map, returning the number of entries which were - /// invalidated. To be used as a safety mechanism so that API end-users can't - /// corrupt by holding references they shouldn't have accessed in the first - /// place. - size_t clearLiveOperations(); - - /// Removes an operation from the live operations map and sets it invalid. - /// This is useful for when some non-bindings code destroys the operation and - /// the bindings need to made aware. For example, in the case when pass - /// manager is run. - /// - /// Note that this does *NOT* clear the nested operations. - void clearOperation(MlirOperation op); - - /// Clears all operations nested inside the given op using - /// `clearOperation(MlirOperation)`. - void clearOperationsInside(PyOperationBase &op); - void clearOperationsInside(MlirOperation op); - - /// Clears the operaiton _and_ all operations inside using - /// `clearOperation(MlirOperation)`. - void clearOperationAndInside(PyOperationBase &op); - - /// Gets the count of live modules associated with this context. - /// Used for testing. - size_t getLiveModuleCount(); - /// Enter and exit the context manager. static nanobind::object contextEnter(nanobind::object context); void contextExit(const nanobind::object &excType, @@ -278,25 +244,6 @@ class PyMlirContext { static nanobind::ft_mutex live_contexts_mutex; static LiveContextMap &getLiveContexts(); - // Interns all live modules associated with this context. Modules tracked - // in this map are valid. When a module is invalidated, it is removed - // from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveModuleMap = - llvm::DenseMap>; - LiveModuleMap liveModules; - - // Interns all live operations associated with this context. Operations - // tracked in this map are valid. When an operation is invalidated, it is - // removed from this map, and while it still exists as an instance, any - // attempt to access it will raise an error. - using LiveOperationMap = - llvm::DenseMap>; - nanobind::ft_mutex liveOperationsMutex; - - // Guarded by liveOperationsMutex in free-threading mode. - LiveOperationMap liveOperations; - bool emitErrorDiagnostics = false; MlirContext context; @@ -548,8 +495,8 @@ class PyModule; using PyModuleRef = PyObjectRef; class PyModule : public BaseContextObject { public: - /// Returns a PyModule reference for the given MlirModule. This may return - /// a pre-existing or new object. + /// Returns a PyModule reference for the given MlirModule. This always returns + /// a new object. static PyModuleRef forModule(MlirModule module); PyModule(PyModule &) = delete; PyModule(PyMlirContext &&) = delete; @@ -570,11 +517,12 @@ class PyModule : public BaseContextObject { nanobind::object getCapsule(); /// Creates a PyModule from the MlirModule wrapped by a capsule. - /// Note that PyModule instances are uniqued, so the returned object - /// may be a pre-existing object. Ownership of the underlying MlirModule - /// is taken by calling this function. + /// Note this returns a new object BUT clearMlirModule() must be called to + /// prevent double-frees (of the underlying mlir::Module). static nanobind::object createFromCapsule(nanobind::object capsule); + void clearMlirModule() { module = {nullptr}; } + private: PyModule(PyMlirContextRef contextRef, MlirModule module); MlirModule module; diff --git a/mlir/lib/Bindings/Python/Pass.cpp b/mlir/lib/Bindings/Python/Pass.cpp index 20017e25b69bb..817479ee2421b 100644 --- a/mlir/lib/Bindings/Python/Pass.cpp +++ b/mlir/lib/Bindings/Python/Pass.cpp @@ -159,11 +159,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { "ValueError if the pipeline can't be parsed.") .def( "run", - [](PyPassManager &passManager, PyOperationBase &op, - bool invalidateOps) { - if (invalidateOps) { - op.getOperation().getContext()->clearOperationsInside(op); - } + [](PyPassManager &passManager, PyOperationBase &op) { // Actually run the pass manager. PyMlirContext::ErrorCapture errors(op.getOperation().getContext()); MlirLogicalResult status = mlirPassManagerRunOnOp( @@ -172,7 +168,7 @@ void mlir::python::populatePassManagerSubmodule(nb::module_ &m) { throw MLIRError("Failure while executing pass pipeline", errors.take()); }, - "operation"_a, "invalidate_ops"_a = true, + "operation"_a, "Run the pass manager on the provided operation, raising an " "MLIRError on failure.") .def( diff --git a/mlir/lib/Bindings/Python/TransformInterpreter.cpp b/mlir/lib/Bindings/Python/TransformInterpreter.cpp index f9b0fed62778f..920bca886f617 100644 --- a/mlir/lib/Bindings/Python/TransformInterpreter.cpp +++ b/mlir/lib/Bindings/Python/TransformInterpreter.cpp @@ -67,7 +67,6 @@ static void populateTransformInterpreterSubmodule(nb::module_ &m) { // root. This is awkward, but we don't have access to PyMlirContext // object here otherwise. nb::object obj = nb::cast(payloadRoot); - obj.attr("context").attr("_clear_live_operations_inside")(payloadRoot); MlirLogicalResult result = mlirTransformApplyNamedSequence( payloadRoot, transformRoot, transformModule, options.options); diff --git a/mlir/lib/CAPI/IR/IR.cpp b/mlir/lib/CAPI/IR/IR.cpp index 8491553dab76f..c7069f0017b5d 100644 --- a/mlir/lib/CAPI/IR/IR.cpp +++ b/mlir/lib/CAPI/IR/IR.cpp @@ -465,6 +465,10 @@ MlirModule mlirModuleFromOperation(MlirOperation op) { return wrap(dyn_cast(unwrap(op))); } +bool mlirModuleEqual(MlirModule lhs, MlirModule rhs) { + return unwrap(lhs) == unwrap(rhs); +} + //===----------------------------------------------------------------------===// // Operation state API. //===----------------------------------------------------------------------===// diff --git a/mlir/test/python/ir/module.py b/mlir/test/python/ir/module.py index 6065e59fd6ed9..a552eaa662af4 100644 --- a/mlir/test/python/ir/module.py +++ b/mlir/test/python/ir/module.py @@ -121,27 +121,17 @@ def testRoundtripBinary(): def testModuleOperation(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 op1 = module.operation - assert ctx._get_live_operation_count() == 1 - live_ops = ctx._get_live_operation_objects() - assert len(live_ops) == 1 - assert live_ops[0] is op1 - live_ops = None # CHECK: module @successfulParse print(op1) # Ensure that operations are the same on multiple calls. op2 = module.operation - assert ctx._get_live_operation_count() == 1 - assert op1 is op2 + assert not op1 is op2 + assert op1 == op2 # Test live operation clearing. op1 = module.operation - assert ctx._get_live_operation_count() == 1 - num_invalidated = ctx._clear_live_operations() - assert num_invalidated == 1 - assert ctx._get_live_operation_count() == 0 op1 = None gc.collect() op1 = module.operation @@ -155,9 +145,6 @@ def testModuleOperation(): op1 = None op2 = None gc.collect() - print("LIVE OPERATIONS:", ctx._get_live_operation_count()) - assert ctx._get_live_operation_count() == 0 - assert ctx._get_live_module_count() == 0 # CHECK-LABEL: TEST: testModuleCapsule @@ -165,16 +152,17 @@ def testModuleOperation(): def testModuleCapsule(): ctx = Context() module = Module.parse(r"""module @successfulParse {}""", ctx) - assert ctx._get_live_module_count() == 1 # CHECK: "mlir.ir.Module._CAPIPtr" module_capsule = module._CAPIPtr print(module_capsule) module_dup = Module._CAPICreate(module_capsule) - assert module is module_dup + assert not module is module_dup + assert module == module_dup + module._clear_mlir_module() + assert not module == module_dup assert module_dup.context is ctx # Gc and verify destructed. module = None module_capsule = None module_dup = None gc.collect() - assert ctx._get_live_module_count() == 0 diff --git a/mlir/test/python/ir/operation.py b/mlir/test/python/ir/operation.py index bf16e3f75d60d..94f39c0fbd077 100644 --- a/mlir/test/python/ir/operation.py +++ b/mlir/test/python/ir/operation.py @@ -907,7 +907,13 @@ def testCapsuleConversions(): m_capsule = m._CAPIPtr assert '"mlir.ir.Operation._CAPIPtr"' in repr(m_capsule) m2 = Operation._CAPICreate(m_capsule) - assert m2 is m + assert not m2 is m + assert m2 == m + # Gc and verify destructed. + m = None + m_capsule = None + m2 = None + gc.collect() # CHECK-LABEL: TEST: testOperationErase diff --git a/mlir/test/python/ir/symbol_table.py b/mlir/test/python/ir/symbol_table.py index 8b6d7ea5a197d..7afd539271d21 100644 --- a/mlir/test/python/ir/symbol_table.py +++ b/mlir/test/python/ir/symbol_table.py @@ -56,14 +56,6 @@ def testSymbolTableInsert(): print(m1) assert "bar" not in symbol_table - try: - print(bar) - except RuntimeError as e: - if "the operation has been invalidated" not in str(e): - raise - else: - assert False, "expected RuntimeError due to invalidated operation" - qux = m2.body.operations[0] m1.body.append(qux) symbol_table.insert(qux) diff --git a/mlir/test/python/pass_manager.py b/mlir/test/python/pass_manager.py index e26d42bb32913..0896cd9784641 100644 --- a/mlir/test/python/pass_manager.py +++ b/mlir/test/python/pass_manager.py @@ -176,14 +176,6 @@ def testRunPipelineError(): @run def testPostPassOpInvalidation(): with Context() as ctx: - log_op_count = lambda: log("live ops:", ctx._get_live_operation_count()) - - # CHECK: invalidate_ops=False - log("invalidate_ops=False") - - # CHECK: live ops: 0 - log_op_count() - module = ModuleOp.parse( """ module { @@ -196,9 +188,6 @@ def testPostPassOpInvalidation(): """ ) - # CHECK: live ops: 1 - log_op_count() - outer_const_op = module.body.operations[0] # CHECK: %[[VAL0:.*]] = arith.constant 10 : i64 log(outer_const_op) @@ -214,12 +203,7 @@ def testPostPassOpInvalidation(): # CHECK: %[[VAL1]] = arith.constant 10 : i64 log(inner_const_op) - # CHECK: live ops: 4 - log_op_count() - - PassManager.parse("builtin.module(canonicalize)").run( - module, invalidate_ops=False - ) + PassManager.parse("builtin.module(canonicalize)").run(module) # CHECK: func.func @foo() { # CHECK: return # CHECK: } @@ -233,9 +217,6 @@ def testPostPassOpInvalidation(): # CHECK: invalidate_ops=True log("invalidate_ops=True") - # CHECK: live ops: 4 - log_op_count() - module = ModuleOp.parse( """ module { @@ -247,36 +228,9 @@ def testPostPassOpInvalidation(): } """ ) - outer_const_op = module.body.operations[0] - func_op = module.body.operations[1] - inner_const_op = func_op.body.blocks[0].operations[0] - - # CHECK: live ops: 4 - log_op_count() PassManager.parse("builtin.module(canonicalize)").run(module) - # CHECK: live ops: 1 - log_op_count() - - try: - log(func_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(outer_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - - try: - log(inner_const_op) - except RuntimeError as e: - # CHECK: the operation has been invalidated - log(e) - # CHECK: func.func @foo() { # CHECK: return # CHECK: }