diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp index b3ef23085c186..068bd4b216bce 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRVPass.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h" #include "mlir/Dialect/UB/IR/UBOps.h" +#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir { @@ -37,6 +38,13 @@ void ConvertVectorToSPIRVPass::runOnOperation() { MLIRContext *context = &getContext(); Operation *op = getOperation(); + { + RewritePatternSet patterns(context); + vector::populateVectorFromElementsLoweringPatterns(patterns); + if (failed(applyPatternsGreedily(op, std::move(patterns)))) + return signalPassFailure(); + } + auto targetAttr = spirv::lookupTargetEnvOrDefault(op); std::unique_ptr target = SPIRVConversionTarget::get(targetAttr); diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index 8918f91ef9145..a39d2509de363 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -308,6 +308,22 @@ func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf return %0: vector<3xf32> } +// CHECK-LABEL: @from_elements_3d +// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32, %[[ARG3:.+]]: f32 +func.func @from_elements_3d(%arg0: f32, %arg1: f32, %arg2: f32, %arg3: f32) -> vector<2x1x2xf32> { + // CHECK-DAG: %[[VEC3D:.+]] = ub.poison : vector<2x1x2xf32> + // CHECK-DAG: %[[VEC2D:.+]] = ub.poison : vector<1x2xf32> + // CHECK: %[[VEC1_0:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]] : (f32, f32) -> vector<2xf32> + // CHECK: %[[VEC1_1:.+]] = vector.insert %[[VEC1_0]], %[[VEC2D]] [0] + // CHECK: %[[VEC1_2:.+]] = vector.insert %[[VEC1_1]], %[[VEC3D]] [0] + // CHECK: %[[VEC2_0:.+]] = spirv.CompositeConstruct %[[ARG2]], %[[ARG3]] : (f32, f32) -> vector<2xf32> + // CHECK: %[[VEC2_1:.+]] = vector.insert %[[VEC2_0]], %[[VEC2D]] [0] + // CHECK: %[[VEC2_2:.+]] = vector.insert %[[VEC2_1]], %[[VEC1_2]] [1] + // CHECK: return %[[VEC2_2]] + %0 = vector.from_elements %arg0, %arg1, %arg2, %arg3 : vector<2x1x2xf32> + return %0 : vector<2x1x2xf32> +} + // ----- // CHECK-LABEL: @insert