Skip to content

Conversation

matthias-springer
Copy link
Member

Make PyShapedType public, so that downstream projects can define types that implement the ShapedType type interface in Python.

Make `PyShapedType` public, so that downstream projects can define types that implement the `ShapedType` type interface in Python.
@llvmbot llvmbot added the mlir label Aug 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Aug 26, 2024

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Make PyShapedType public, so that downstream projects can define types that implement the ShapedType type interface in Python.


Full diff: https://github.com/llvm/llvm-project/pull/106105.diff

2 Files Affected:

  • (added) mlir/include/mlir/Bindings/Python/IRTypes.h (+31)
  • (modified) mlir/lib/Bindings/Python/IRTypes.cpp (+91-89)
diff --git a/mlir/include/mlir/Bindings/Python/IRTypes.h b/mlir/include/mlir/Bindings/Python/IRTypes.h
new file mode 100644
index 00000000000000..9afad4c23b3f35
--- /dev/null
+++ b/mlir/include/mlir/Bindings/Python/IRTypes.h
@@ -0,0 +1,31 @@
+//===- IRTypes.h - Type Interfaces ----------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_BINDINGS_PYTHON_IRTYPES_H
+#define MLIR_BINDINGS_PYTHON_IRTYPES_H
+
+#include "mlir/Bindings/Python/PybindAdaptors.h"
+
+namespace mlir {
+
+/// Shaped Type Interface - ShapedType
+class PyShapedType : public python::PyConcreteType<PyShapedType> {
+public:
+  static const IsAFunctionTy isaFunction;
+  static constexpr const char *pyClassName = "ShapedType";
+  using PyConcreteType::PyConcreteType;
+
+  static void bindDerived(ClassTy &c);
+
+private:
+  void requireHasRank();
+};
+
+} // namespace mlir
+
+#endif // MLIR_BINDINGS_PYTHON_IRTYPES_H
diff --git a/mlir/lib/Bindings/Python/IRTypes.cpp b/mlir/lib/Bindings/Python/IRTypes.cpp
index c3d42c0ef8e3cb..2ee1d89c38884d 100644
--- a/mlir/lib/Bindings/Python/IRTypes.cpp
+++ b/mlir/lib/Bindings/Python/IRTypes.cpp
@@ -10,6 +10,8 @@
 
 #include "PybindUtils.h"
 
+#include "mlir/Bindings/Python/IRTypes.h"
+
 #include "mlir-c/BuiltinAttributes.h"
 #include "mlir-c/BuiltinTypes.h"
 #include "mlir-c/Support.h"
@@ -418,98 +420,98 @@ class PyComplexType : public PyConcreteType<PyComplexType> {
   }
 };
 
-class PyShapedType : public PyConcreteType<PyShapedType> {
-public:
-  static constexpr IsAFunctionTy isaFunction = mlirTypeIsAShaped;
-  static constexpr const char *pyClassName = "ShapedType";
-  using PyConcreteType::PyConcreteType;
+} // namespace
 
-  static void bindDerived(ClassTy &c) {
-    c.def_property_readonly(
-        "element_type",
-        [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
-        "Returns the element type of the shaped type.");
-    c.def_property_readonly(
-        "has_rank",
-        [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
-        "Returns whether the given shaped type is ranked.");
-    c.def_property_readonly(
-        "rank",
-        [](PyShapedType &self) {
-          self.requireHasRank();
-          return mlirShapedTypeGetRank(self);
-        },
-        "Returns the rank of the given ranked shaped type.");
-    c.def_property_readonly(
-        "has_static_shape",
-        [](PyShapedType &self) -> bool {
-          return mlirShapedTypeHasStaticShape(self);
-        },
-        "Returns whether the given shaped type has a static shape.");
-    c.def(
-        "is_dynamic_dim",
-        [](PyShapedType &self, intptr_t dim) -> bool {
-          self.requireHasRank();
-          return mlirShapedTypeIsDynamicDim(self, dim);
-        },
-        py::arg("dim"),
-        "Returns whether the dim-th dimension of the given shaped type is "
-        "dynamic.");
-    c.def(
-        "get_dim_size",
-        [](PyShapedType &self, intptr_t dim) {
-          self.requireHasRank();
-          return mlirShapedTypeGetDimSize(self, dim);
-        },
-        py::arg("dim"),
-        "Returns the dim-th dimension of the given ranked shaped type.");
-    c.def_static(
-        "is_dynamic_size",
-        [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
-        py::arg("dim_size"),
-        "Returns whether the given dimension size indicates a dynamic "
-        "dimension.");
-    c.def(
-        "is_dynamic_stride_or_offset",
-        [](PyShapedType &self, int64_t val) -> bool {
-          self.requireHasRank();
-          return mlirShapedTypeIsDynamicStrideOrOffset(val);
-        },
-        py::arg("dim_size"),
-        "Returns whether the given value is used as a placeholder for dynamic "
-        "strides and offsets in shaped types.");
-    c.def_property_readonly(
-        "shape",
-        [](PyShapedType &self) {
-          self.requireHasRank();
-
-          std::vector<int64_t> shape;
-          int64_t rank = mlirShapedTypeGetRank(self);
-          shape.reserve(rank);
-          for (int64_t i = 0; i < rank; ++i)
-            shape.push_back(mlirShapedTypeGetDimSize(self, i));
-          return shape;
-        },
-        "Returns the shape of the ranked shaped type as a list of integers.");
-    c.def_static(
-        "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
-        "Returns the value used to indicate dynamic dimensions in shaped "
-        "types.");
-    c.def_static(
-        "get_dynamic_stride_or_offset",
-        []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
-        "Returns the value used to indicate dynamic strides or offsets in "
-        "shaped types.");
-  }
+// Shaped Type Interface - ShapedType
+void mlir::PyShapedType::bindDerived(ClassTy &c) {
+  c.def_property_readonly(
+      "element_type",
+      [](PyShapedType &self) { return mlirShapedTypeGetElementType(self); },
+      "Returns the element type of the shaped type.");
+  c.def_property_readonly(
+      "has_rank",
+      [](PyShapedType &self) -> bool { return mlirShapedTypeHasRank(self); },
+      "Returns whether the given shaped type is ranked.");
+  c.def_property_readonly(
+      "rank",
+      [](PyShapedType &self) {
+        self.requireHasRank();
+        return mlirShapedTypeGetRank(self);
+      },
+      "Returns the rank of the given ranked shaped type.");
+  c.def_property_readonly(
+      "has_static_shape",
+      [](PyShapedType &self) -> bool {
+        return mlirShapedTypeHasStaticShape(self);
+      },
+      "Returns whether the given shaped type has a static shape.");
+  c.def(
+      "is_dynamic_dim",
+      [](PyShapedType &self, intptr_t dim) -> bool {
+        self.requireHasRank();
+        return mlirShapedTypeIsDynamicDim(self, dim);
+      },
+      py::arg("dim"),
+      "Returns whether the dim-th dimension of the given shaped type is "
+      "dynamic.");
+  c.def(
+      "get_dim_size",
+      [](PyShapedType &self, intptr_t dim) {
+        self.requireHasRank();
+        return mlirShapedTypeGetDimSize(self, dim);
+      },
+      py::arg("dim"),
+      "Returns the dim-th dimension of the given ranked shaped type.");
+  c.def_static(
+      "is_dynamic_size",
+      [](int64_t size) -> bool { return mlirShapedTypeIsDynamicSize(size); },
+      py::arg("dim_size"),
+      "Returns whether the given dimension size indicates a dynamic "
+      "dimension.");
+  c.def(
+      "is_dynamic_stride_or_offset",
+      [](PyShapedType &self, int64_t val) -> bool {
+        self.requireHasRank();
+        return mlirShapedTypeIsDynamicStrideOrOffset(val);
+      },
+      py::arg("dim_size"),
+      "Returns whether the given value is used as a placeholder for dynamic "
+      "strides and offsets in shaped types.");
+  c.def_property_readonly(
+      "shape",
+      [](PyShapedType &self) {
+        self.requireHasRank();
+
+        std::vector<int64_t> shape;
+        int64_t rank = mlirShapedTypeGetRank(self);
+        shape.reserve(rank);
+        for (int64_t i = 0; i < rank; ++i)
+          shape.push_back(mlirShapedTypeGetDimSize(self, i));
+        return shape;
+      },
+      "Returns the shape of the ranked shaped type as a list of integers.");
+  c.def_static(
+      "get_dynamic_size", []() { return mlirShapedTypeGetDynamicSize(); },
+      "Returns the value used to indicate dynamic dimensions in shaped "
+      "types.");
+  c.def_static(
+      "get_dynamic_stride_or_offset",
+      []() { return mlirShapedTypeGetDynamicStrideOrOffset(); },
+      "Returns the value used to indicate dynamic strides or offsets in "
+      "shaped types.");
+}
 
-private:
-  void requireHasRank() {
-    if (!mlirShapedTypeHasRank(*this)) {
-      throw py::value_error(
-          "calling this method requires that the type has a rank.");
-    }
+void mlir::PyShapedType::requireHasRank() {
+  if (!mlirShapedTypeHasRank(*this)) {
+    throw py::value_error(
+        "calling this method requires that the type has a rank.");
   }
-};
+}
+
+const mlir::PyShapedType::IsAFunctionTy mlir::PyShapedType::isaFunction =
+    mlirTypeIsAShaped;
+
+namespace {
 
 /// Vector Type subclass - VectorType.
 class PyVectorType : public PyConcreteType<PyVectorType, PyShapedType> {

Copy link
Contributor

@stellaraccident stellaraccident left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems ok to me. Thanks for plussing me in: missed this one going by.

@matthias-springer matthias-springer merged commit 82579f9 into main Aug 26, 2024
10 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/python_shaped_type branch August 26, 2024 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants