Skip to content

Conversation

yangtetris
Copy link
Contributor

@yangtetris yangtetris commented Aug 21, 2025

This PR is a follow-up to #151175 that supported lowering multi-dimensional vector.from_elements op to LLVM by introducing a unrolling pattern.

Changes

Add vector.shape_cast based flattening pattern for vector.from_elements

This change introduces a new linearization pattern that uses vector.shape_cast to flatten multi-dimensional vector.from_elements operations. This provides an alternative approach to the unrolling-based method introduced in #151175.

Example:

// Before
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector<2x2xf32>

// After
%flat = vector.from_elements %e0, %e1, %e2, %e3 : vector<4xf32>
%result = vector.shape_cast %flat : vector<4xf32> to vector<2x2xf32>

@llvmbot
Copy link
Member

llvmbot commented Aug 21, 2025

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Author: Yang Bai (yangtetris)

Changes

This PR is a follow-up to #151175 that supported lowering multi-dimensional vector.from_elements op to LLVM by introducing a unrolling pattern.

Changes

1. Add vector.shape_cast based flattening pattern for vector.from_elements

This change introduces a new linearization pattern that uses vector.shape_cast to flatten multi-dimensional vector.from_elements operations. This provides an alternative approach to the unrolling-based method introduced in #151175.

Example:

// Before
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector&lt;2x2xf32&gt;

// After
%flat = vector.from_elements %e0, %e1, %e2, %e3 : vector&lt;4xf32&gt;
%result = vector.shape_cast %flat : vector&lt;4xf32&gt; to vector&lt;2x2xf32&gt;

2. Integrate UnrollFromElements pattern into SPIR-V pipeline

The UnrollFromElements pattern from #151175 is now integrated into spirv::unrollVectorsInFuncBodies to ensure proper handling of multi-dimensional vectors in SPIR-V lowering. Additional tests are included to verify the pattern works correctly in the SPIR-V context.

Example:

// Before
%v = vector.from_elements %e0, %e1, %e2, %e3 : vector&lt;2x2xf32&gt;

// After test-spirv-vector-unrolling
%vec_1d_0 = vector.from_elements %e0, %e1 : vector&lt;2xf32&gt;
%vec_1d_1 = vector.from_elements %e2, %e3 : vector&lt;2xf32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/154664.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+3)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp (+38-1)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir (+12)
  • (modified) mlir/test/Conversion/ConvertToSPIRV/vector.mlir (+11)
  • (modified) mlir/test/Dialect/Vector/linearize.mlir (+14)
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 49f4ce8de7c76..7815c3b377316 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -1496,6 +1496,9 @@ LogicalResult mlir::spirv::unrollVectorsInFuncBodies(Operation *op) {
     auto options = vector::UnrollVectorOptions().setNativeShapeFn(
         [](auto op) { return mlir::spirv::getNativeVectorShape(op); });
     populateVectorUnrollPatterns(patterns, options);
+    // The unroll pattern for vector.from_elements op doesn't belong to the
+    // above set.
+    vector::populateVectorFromElementsLoweringPatterns(patterns);
     if (failed(applyPatternsGreedily(op, std::move(patterns))))
       return failure();
   }
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
index 491b448e9e1e9..2cb6d47f37128 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorLinearize.cpp
@@ -762,6 +762,42 @@ struct LinearizeVectorStore final
   }
 };
 
+/// This pattern linearizes `vector.from_elements` operations by converting
+/// the result type to a 1-D vector while preserving all element values.
+/// The transformation creates a linearized `vector.from_elements` followed by
+/// a `vector.shape_cast` to restore the original multidimensional shape.
+///
+/// Example:
+///
+///     %0 = vector.from_elements %a, %b, %c, %d : vector<2x2xf32>
+///
+///   is converted to:
+///
+///     %0 = vector.from_elements %a, %b, %c, %d : vector<4xf32>
+///     %1 = vector.shape_cast %0 : vector<4xf32> to vector<2x2xf32>
+///
+struct LinearizeVectorFromElements final
+    : public OpConversionPattern<vector::FromElementsOp> {
+  using OpConversionPattern::OpConversionPattern;
+  LinearizeVectorFromElements(const TypeConverter &typeConverter,
+                              MLIRContext *context, PatternBenefit benefit = 1)
+      : OpConversionPattern(typeConverter, context, benefit) {}
+  LogicalResult
+  matchAndRewrite(vector::FromElementsOp fromElementsOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType dstTy =
+        getTypeConverter()->convertType<VectorType>(fromElementsOp.getType());
+    assert(dstTy && "vector type destination expected.");
+
+    auto elements = fromElementsOp.getElements();
+    assert(elements.size() == static_cast<size_t>(dstTy.getNumElements()) &&
+           "expected same number of elements");
+    rewriter.replaceOpWithNewOp<vector::FromElementsOp>(fromElementsOp, dstTy,
+                                                        elements);
+    return success();
+  }
+};
+
 } // namespace
 
 /// This method defines the set of operations that are linearizable, and hence
@@ -854,7 +890,8 @@ void mlir::vector::populateVectorLinearizeBasePatterns(
   patterns
       .add<LinearizeConstantLike, LinearizeVectorizable, LinearizeVectorBitCast,
            LinearizeVectorSplat, LinearizeVectorCreateMask, LinearizeVectorLoad,
-           LinearizeVectorStore>(typeConverter, patterns.getContext());
+           LinearizeVectorStore, LinearizeVectorFromElements>(
+          typeConverter, patterns.getContext());
 }
 
 void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
index d68ba44ee8840..48c6dd5dc8c31 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector-unroll.mlir
@@ -100,3 +100,15 @@ func.func @transpose(%arg0 : vector<2x3xi32>) -> (vector<3x2xi32>) {
   %0 = vector.transpose %arg0, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
   return %0 : vector<3x2xi32>
 }
+
+// -----
+
+// CHECK-LABEL: @from_elements_2d
+// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32, %[[ARG3:.*]]: f32)
+func.func @from_elements_2d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+  // CHECK: %[[VEC0:.*]] = vector.from_elements %[[ARG0]], %[[ARG1]] : vector<2xf32>
+  // CHECK: %[[VEC1:.*]] = vector.from_elements %[[ARG2]], %[[ARG3]] : vector<2xf32>
+  // CHECK: return %[[VEC0]], %[[VEC1]] : vector<2xf32>, vector<2xf32>
+  %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+  return %0 : vector<2x2xf32>
+}
diff --git a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
index eb9feaad15c5b..ff66fe576fc2a 100644
--- a/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
+++ b/mlir/test/Conversion/ConvertToSPIRV/vector.mlir
@@ -356,3 +356,14 @@ func.func @shape_cast_size1_vector(%arg0 : vector<f32>) -> vector<1xf32> {
   %1 = vector.shape_cast %arg0 : vector<f32> to vector<1xf32>
   return %1 : vector<1xf32>
 }
+
+// -----
+
+// CHECK-LABEL: func @from_elements
+//  CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: f32, %[[ARG2:.*]]: f32)
+//       CHECK:   %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
+//       CHECK:   spirv.ReturnValue %[[RETVAL]]
+func.func @from_elements(%arg0: f32, %arg1: f32, %arg2: f32) -> vector<3xf32> {
+  %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
+  return %0 : vector<3xf32>
+}
diff --git a/mlir/test/Dialect/Vector/linearize.mlir b/mlir/test/Dialect/Vector/linearize.mlir
index 2e630bf93622e..5e8bfd0698b33 100644
--- a/mlir/test/Dialect/Vector/linearize.mlir
+++ b/mlir/test/Dialect/Vector/linearize.mlir
@@ -524,3 +524,17 @@ func.func @linearize_vector_store_scalable(%arg0: memref<2x8xf32>, %arg1: vector
   vector.store %arg1, %arg0[%c0, %c0] : memref<2x8xf32>, vector<1x[4]xf32>
   return
 }
+
+// -----
+
+// Test pattern LinearizeVectorFromElements.
+
+// CHECK-LABEL: test_vector_from_elements
+// CHECK-SAME: %[[ARG_0:.*]]: f32, %[[ARG_1:.*]]: f32, %[[ARG_2:.*]]: f32, %[[ARG_3:.*]]: f32
+func.func @test_vector_from_elements(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x2xf32> {
+  // CHECK: %[[FROM_ELEMENTS:.*]] = vector.from_elements %[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]] : vector<4xf32>
+  // CHECK: %[[CAST:.*]] = vector.shape_cast %[[FROM_ELEMENTS]] : vector<4xf32> to vector<2x2xf32>
+  // CHECK: return %[[CAST]] : vector<2x2xf32>
+  %1 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x2xf32>
+  return %1 : vector<2x2xf32>
+}

@yangtetris yangtetris marked this pull request as draft August 21, 2025 04:39
@yangtetris yangtetris force-pushed the mlir/vecor-from-elements-flattening branch from 3f76a9c to 9471daa Compare August 21, 2025 06:40
Copy link
Contributor

@newling newling left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. FYI we've been discussing converting these patterns to greedy patterns #146030

yangtetris and others added 2 commits August 22, 2025 15:01
Co-authored-by: James Newling <james.newling@gmail.com>
@yangtetris
Copy link
Contributor Author

LGTM. FYI we've been discussing converting these patterns to greedy patterns #146030

Thanks. Actually, I was confused when I first saw that a type converter was used there.

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks!

@kuhar kuhar requested a review from Groverkss August 23, 2025 06:03
@yangtetris yangtetris marked this pull request as ready for review August 26, 2025 05:02
Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@dcaballe dcaballe merged commit 5fdd3a1 into llvm:main Aug 28, 2025
11 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants