Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 23 additions & 13 deletions mlir/docs/Bindings/Python.md
Original file line number Diff line number Diff line change
Expand Up @@ -216,13 +216,28 @@ added to an attached operation, they need to be re-parented to the containing
module).

Due to the validity and parenting accounting needs, `PyOperation` is the owner
for regions and blocks and needs to be a top-level type that we can count on not
aliasing. This let's us do things like selectively invalidating instances when
mutations occur without worrying that there is some alias to the same operation
in the hierarchy. Operations are also the only entity that are allowed to be in
a detached state, and they are interned at the context level so that there is
never more than one Python `mlir.ir.Operation` object for a unique
`MlirOperation`, regardless of how it is obtained.
for regions and blocks. Operations are also the only entities which are allowed to be in
a detached state.

**Note**: Multiple `PyOperation` objects (i.e., the Python objects themselves) can alias a single `mlir::Operation`.
This means, for example, if you have `py_op1` and `py_op2` which wrap the same `mlir::Operation op`
and you somehow transform `op` (e.g., you run a pass on `op`) then walking the MLIR AST via either/or `py_op1`, `py_op2`
will reflect the same MLIR AST. This is perfectly safe and supported. What is not supported is invalidating any
operation while there exist multiple Python objects wrapping that operation **and then manipulating those wrappers**.
For example if `py_op1` and `py_op2` wrap the same operation under a root `py_op3` and then `py_op3` is
transformed such that the operation referenced (by `py_op1`, `py_op2`) is erased. Then `py_op1`, `py_op2`
become "undefined" in a sense; manipulating them in any way is "formally forbidden". Note, this also applies to
`SymbolTable` mutation, which is considered a transformation of the root `SymbolTable`-supporting operation for the
purposes of the discussion here. Metaphorically, one can think of this similarly to how STL container iterators are invalidated once the container itself is changed. The "best practices" recommendation is to structure your code such that

1. First, query/manipulate various Python wrapper objects `py_op1`, `py_op2`, `py_op3`, etc.;
2. Second, transform the AST/erase operations/etc. via a single root object;
3. Invalidate all queried nodes (e.g., using `op._set_invalid()`).

Ideally this should be done in a function body so that step (3) corresponds to the end of the function and there are no
risks of Python wrapper objects leaking/living longer than necessary. In summary, you should scope your changes based on
nesting i.e., change leaf nodes first before going up in hierarchy, and only in very rare cases query nested ops post
modifying a parent op.

The C/C++ API allows for Region/Block to also be detached, but it simplifies the
ownership model a lot to eliminate that possibility in this API, allowing the
Expand All @@ -238,11 +253,6 @@ blocks. We may end up needing an op-local one at some point TBD, depending on
how hard it is to guarantee how mutations interact with their Python peer
objects. We can cross that bridge easily when we get there.

Module, when used purely from the Python API, can't alias anyway, so we can use
it as a top-level ref type without a live-list for interning. If the API ever
changes such that this cannot be guaranteed (i.e. by letting you marshal a
native-defined Module in), then there would need to be a live table for it too.
Comment on lines -241 to -244
Copy link
Contributor Author

@makslevental makslevental Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This isn't even valid/true anymore - we do have (before this PR) a liveModules live-list and Module._CAPICreate is exactly for "letting you marshal a native-defined Module in".


## User-level API

### Context Management
Expand Down Expand Up @@ -1229,4 +1239,4 @@ The exceptions to the free-threading compatibility:
- Usage of `Location.emit_error` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `Module.dump` is unsafe (due to thread-unsafe `llvm::raw_ostream`).
- Usage of `mlir.dialects.transform.interpreter` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
- Usage of `mlir.dialects.gpu` and `gpu-module-to-binary` is unsafe.
3 changes: 3 additions & 0 deletions mlir/include/mlir-c/IR.h
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,9 @@ MLIR_CAPI_EXPORTED MlirOperation mlirModuleGetOperation(MlirModule module);
/// The returned module is null when the input operation was not a ModuleOp.
MLIR_CAPI_EXPORTED MlirModule mlirModuleFromOperation(MlirOperation op);

/// Checks if two modules are equal.
MLIR_CAPI_EXPORTED bool mlirModuleEqual(MlirModule lhs, MlirModule rhs);

//===----------------------------------------------------------------------===//
// Operation state.
//===----------------------------------------------------------------------===//
Expand Down
196 changes: 46 additions & 150 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,12 @@ Returns a new MlirModule or raises an MLIRError if the parsing fails.
See also: https://mlir.llvm.org/docs/LangRef/
)";

static const char kModuleCAPICreate[] =
R"(Creates a Module from a MlirModule wrapped by a capsule (i.e. module._CAPIPtr).
Note this returns a new object BUT _clear_mlir_module(module) must be called to
prevent double-frees (of the underlying mlir::Module).
)";

static const char kOperationCreateDocstring[] =
R"(Creates a new operation.

Expand Down Expand Up @@ -702,84 +708,6 @@ size_t PyMlirContext::getLiveCount() {
return getLiveContexts().size();
}

size_t PyMlirContext::getLiveOperationCount() {
nb::ft_lock_guard lock(liveOperationsMutex);
return liveOperations.size();
}

std::vector<PyOperation *> PyMlirContext::getLiveOperationObjects() {
std::vector<PyOperation *> liveObjects;
nb::ft_lock_guard lock(liveOperationsMutex);
for (auto &entry : liveOperations)
liveObjects.push_back(entry.second.second);
return liveObjects;
}

size_t PyMlirContext::clearLiveOperations() {

LiveOperationMap operations;
{
nb::ft_lock_guard lock(liveOperationsMutex);
std::swap(operations, liveOperations);
}
for (auto &op : operations)
op.second.second->setInvalid();
size_t numInvalidated = operations.size();
return numInvalidated;
}

void PyMlirContext::clearOperation(MlirOperation op) {
PyOperation *pyOp;
{
nb::ft_lock_guard lock(liveOperationsMutex);
auto it = liveOperations.find(op.ptr);
if (it == liveOperations.end()) {
return;
}
pyOp = it->second.second;
liveOperations.erase(it);
}
pyOp->setInvalid();
}

void PyMlirContext::clearOperationsInside(PyOperationBase &op) {
using callBackData = struct {
PyOperation &rootOp;
bool rootSeen;
};
callBackData data{op.getOperation(), false};
// Mark all ops below the op that the passmanager will be rooted
// at (but not op itself - note the preorder) as invalid.
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
callBackData *data = static_cast<callBackData *>(userData);
if (LLVM_LIKELY(data->rootSeen))
data->rootOp.getOperation().getContext()->clearOperation(op);
else
data->rootSeen = true;
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
static_cast<void *>(&data), MlirWalkPreOrder);
}
void PyMlirContext::clearOperationsInside(MlirOperation op) {
PyOperationRef opRef = PyOperation::forOperation(getRef(), op);
clearOperationsInside(opRef->getOperation());
}

void PyMlirContext::clearOperationAndInside(PyOperationBase &op) {
MlirOperationWalkCallback invalidatingCallback = [](MlirOperation op,
void *userData) {
PyMlirContextRef &contextRef = *static_cast<PyMlirContextRef *>(userData);
contextRef->clearOperation(op);
return MlirWalkResult::MlirWalkResultAdvance;
};
mlirOperationWalk(op.getOperation(), invalidatingCallback,
&op.getOperation().getContext(), MlirWalkPreOrder);
}

size_t PyMlirContext::getLiveModuleCount() { return liveModules.size(); }

nb::object PyMlirContext::contextEnter(nb::object context) {
return PyThreadContextEntry::pushContext(context);
}
Expand Down Expand Up @@ -1151,38 +1079,23 @@ PyLocation &DefaultingPyLocation::resolve() {
PyModule::PyModule(PyMlirContextRef contextRef, MlirModule module)
: BaseContextObject(std::move(contextRef)), module(module) {}

PyModule::~PyModule() {
nb::gil_scoped_acquire acquire;
auto &liveModules = getContext()->liveModules;
assert(liveModules.count(module.ptr) == 1 &&
"destroying module not in live map");
liveModules.erase(module.ptr);
mlirModuleDestroy(module);
}
PyModule::~PyModule() { mlirModuleDestroy(module); }

PyModuleRef PyModule::forModule(MlirModule module) {
MlirContext context = mlirModuleGetContext(module);
PyMlirContextRef contextRef = PyMlirContext::forContext(context);

nb::gil_scoped_acquire acquire;
auto &liveModules = contextRef->liveModules;
auto it = liveModules.find(module.ptr);
if (it == liveModules.end()) {
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is automatic_reference,
// which does not take ownership (delete will not be called).
// Just be explicit.
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
liveModules[module.ptr] =
std::make_pair(unownedModule->handle, unownedModule);
return PyModuleRef(unownedModule, std::move(pyRef));
}
// Use existing.
PyModule *existing = it->second.second;
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyModuleRef(existing, std::move(pyRef));
// Create.
PyModule *unownedModule = new PyModule(std::move(contextRef), module);
// Note that the default return value policy on cast is `automatic_reference`,
// which means "does not take ownership, does not call delete/dtor".
// We use `take_ownership`, which means "Python will call the C++ destructor
// and delete operator when the Python wrapper is garbage collected", because
// MlirModule actually wraps OwningOpRef<ModuleOp> (see mlirModuleCreateParse
// etc).
nb::object pyRef = nb::cast(unownedModule, nb::rv_policy::take_ownership);
unownedModule->handle = pyRef;
return PyModuleRef(unownedModule, std::move(pyRef));
}

nb::object PyModule::createFromCapsule(nb::object capsule) {
Expand All @@ -1207,15 +1120,11 @@ PyOperation::~PyOperation() {
// If the operation has already been invalidated there is nothing to do.
if (!valid)
return;

// Otherwise, invalidate the operation and remove it from live map when it is
// attached.
if (isAttached()) {
getContext()->clearOperation(*this);
} else {
// And destroy it when it is detached, i.e. owned by Python, in which case
// all nested operations must be invalidated at removed from the live map as
// well.
// Otherwise, invalidate the operation when it is attached.
if (isAttached())
setInvalid();
else {
// And destroy it when it is detached, i.e. owned by Python.
erase();
}
}
Expand Down Expand Up @@ -1252,35 +1161,15 @@ PyOperationRef PyOperation::createInstance(PyMlirContextRef contextRef,
PyOperationRef PyOperation::forOperation(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
auto it = liveOperations.find(operation.ptr);
if (it == liveOperations.end()) {
// Create.
PyOperationRef result = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
liveOperations[operation.ptr] =
std::make_pair(result.getObject(), result.get());
return result;
}
// Use existing.
PyOperation *existing = it->second.second;
nb::object pyRef = nb::borrow<nb::object>(it->second.first);
return PyOperationRef(existing, std::move(pyRef));
return createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
}

PyOperationRef PyOperation::createDetached(PyMlirContextRef contextRef,
MlirOperation operation,
nb::object parentKeepAlive) {
nb::ft_lock_guard lock(contextRef->liveOperationsMutex);
auto &liveOperations = contextRef->liveOperations;
assert(liveOperations.count(operation.ptr) == 0 &&
"cannot create detached operation that already exists");
(void)liveOperations;
PyOperationRef created = createInstance(std::move(contextRef), operation,
std::move(parentKeepAlive));
liveOperations[operation.ptr] =
std::make_pair(created.getObject(), created.get());
created->attached = false;
return created;
}
Expand Down Expand Up @@ -1652,7 +1541,7 @@ nb::object PyOperation::createOpView() {

void PyOperation::erase() {
checkValid();
getContext()->clearOperationAndInside(*this);
setInvalid();
mlirOperationDestroy(operation);
}

Expand Down Expand Up @@ -3023,14 +2912,6 @@ void mlir::python::populateIRCore(nb::module_ &m) {
PyMlirContextRef ref = PyMlirContext::forContext(self.get());
return ref.releaseObject();
})
.def("_get_live_operation_count", &PyMlirContext::getLiveOperationCount)
.def("_get_live_operation_objects",
&PyMlirContext::getLiveOperationObjects)
.def("_clear_live_operations", &PyMlirContext::clearLiveOperations)
.def("_clear_live_operations_inside",
nb::overload_cast<MlirOperation>(
&PyMlirContext::clearOperationsInside))
.def("_get_live_module_count", &PyMlirContext::getLiveModuleCount)
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyMlirContext::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyMlirContext::createFromCapsule)
.def("__enter__", &PyMlirContext::contextEnter)
Expand Down Expand Up @@ -3348,7 +3229,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
//----------------------------------------------------------------------------
nb::class_<PyModule>(m, "Module", nb::is_weak_referenceable())
.def_prop_ro(MLIR_PYTHON_CAPI_PTR_ATTR, &PyModule::getCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule)
.def(MLIR_PYTHON_CAPI_FACTORY_ATTR, &PyModule::createFromCapsule,
kModuleCAPICreate)
.def("_clear_mlir_module", &PyModule::clearMlirModule)
.def_static(
"parse",
[](const std::string &moduleAsm, DefaultingPyMlirContext context) {
Expand Down Expand Up @@ -3428,7 +3311,13 @@ void mlir::python::populateIRCore(nb::module_ &m) {
// Defer to the operation's __str__.
return self.attr("operation").attr("__str__")();
},
kOperationStrDunderDocstring);
kOperationStrDunderDocstring)
.def(
"__eq__",
[](PyModule &self, PyModule &other) {
return mlirModuleEqual(self.get(), other.get());
},
"other"_a);

//----------------------------------------------------------------------------
// Mapping of Operation.
Expand All @@ -3440,7 +3329,8 @@ void mlir::python::populateIRCore(nb::module_ &m) {
})
.def("__eq__",
[](PyOperationBase &self, PyOperationBase &other) {
return &self.getOperation() == &other.getOperation();
return mlirOperationEqual(self.getOperation().get(),
other.getOperation().get());
})
.def("__eq__",
[](PyOperationBase &self, nb::object other) { return false; })
Expand Down Expand Up @@ -3655,7 +3545,9 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.");
"Returns the list of Operation successors.")
.def("_set_invalid", &PyOperation::setInvalid,
"Invalidate the operation.");

auto opViewClass =
nb::class_<PyOpView, PyOperationBase>(m, "OpView")
Expand Down Expand Up @@ -3699,7 +3591,11 @@ void mlir::python::populateIRCore(nb::module_ &m) {
[](PyOperationBase &self) {
return PyOpSuccessors(self.getOperation().getRef());
},
"Returns the list of Operation successors.");
"Returns the list of Operation successors.")
.def(
"_set_invalid",
[](PyOpView &self) { self.getOperation().setInvalid(); },
"Invalidate the operation.");
opViewClass.attr("_ODS_REGIONS") = nb::make_tuple(0, true);
opViewClass.attr("_ODS_OPERAND_SEGMENTS") = nb::none();
opViewClass.attr("_ODS_RESULT_SEGMENTS") = nb::none();
Expand Down
Loading