-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][spirv][vector] Use adaptor.getElements() in FromElements lowering. #156972
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ing. Signed-off-by: hanhanW <hanhan0912@gmail.com>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-spirv Author: Han-Chung Wang (hanhanW) ChangesFull diff: https://github.com/llvm/llvm-project/pull/156972.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 036cbad0bcfe8..c861935b4bc18 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -278,7 +278,7 @@ struct VectorFromElementsOpConvert final
Type resultType = getTypeConverter()->convertType(op.getType());
if (!resultType)
return failure();
- OperandRange elements = op.getElements();
+ ValueRange elements = adaptor.getElements();
if (isa<spirv::ScalarType>(resultType)) {
// In the case with a single scalar operand / single-element result,
// pass through the scalar.
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 4b56897821dbb..c3688e0657d4b 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -281,33 +281,46 @@ func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
// -----
-// CHECK-LABEL: @from_elements_0d
+// CHECK-LABEL: @from_elements_0d_f32
// CHECK-SAME: %[[ARG0:.+]]: f32
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: return %[[RETVAL]]
-func.func @from_elements_0d(%arg0 : f32) -> vector<f32> {
+func.func @from_elements_0d_f32(%arg0 : f32) -> vector<f32> {
%0 = vector.from_elements %arg0 : vector<f32>
return %0: vector<f32>
}
-// CHECK-LABEL: @from_elements_1x
+// CHECK-LABEL: @from_elements_1xf32
// CHECK-SAME: %[[ARG0:.+]]: f32
// CHECK: %[[RETVAL:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
// CHECK: return %[[RETVAL]]
-func.func @from_elements_1x(%arg0 : f32) -> vector<1xf32> {
+func.func @from_elements_1xf32(%arg0 : f32) -> vector<1xf32> {
%0 = vector.from_elements %arg0 : vector<1xf32>
return %0: vector<1xf32>
}
-// CHECK-LABEL: @from_elements_3x
+// CHECK-LABEL: @from_elements_3xf32
// CHECK-SAME: %[[ARG0:.+]]: f32, %[[ARG1:.+]]: f32, %[[ARG2:.+]]: f32
// CHECK: %[[RETVAL:.+]] = spirv.CompositeConstruct %[[ARG0]], %[[ARG1]], %[[ARG2]] : (f32, f32, f32) -> vector<3xf32>
// CHECK: return %[[RETVAL]]
-func.func @from_elements_3x(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
+func.func @from_elements_3xf32(%arg0 : f32, %arg1 : f32, %arg2 : f32) -> vector<3xf32> {
%0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xf32>
return %0: vector<3xf32>
}
+func.func @from_elements_3xi8(%arg0 : i8, %arg1 : i8, %arg2 : i8) -> vector<3xi8> {
+ %0 = vector.from_elements %arg0, %arg1, %arg2 : vector<3xi8>
+ return %0: vector<3xi8>
+}
+// CHECK-LABEL: @from_elements_3xi8
+// CHECK-SAME: %[[ARG0:.+]]: i8, %[[ARG1:.+]]: i8, %[[ARG2:.+]]: i8
+// CHECK-DAG: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] : i8 to i32
+// CHECK-DAG: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32
+// CHECK-DAG: %[[CAST2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : i8 to i32
+// CHECK: %[[VAL:.+]] = spirv.CompositeConstruct %[[CAST0]], %[[CAST1]], %[[CAST2]] : (i32, i32, i32) -> vector<3xi32>
+// CHECK: %[[RETVAL:.*]] = builtin.unrealized_conversion_cast %[[VAL]] : vector<3xi32> to vector<3xi8>
+// CHECK: return %[[RETVAL]]
+
// -----
// CHECK-LABEL: @insert
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @hanhanW , this looks like a better fix than what I was looking at.
I saw on LowerGPUOpsToNVVMOps
and GPUToLLVMConversion
that these passes will first run the vector::populateVectorFromElementsLoweringPatterns(patterns);
patterns which will also flatten N-dimensional vector.from_elements
operations.
I agree that using the adaptor
makes sense here, but I wonder if one should also apply these patterns before lowering to SPIR-V as n-dimensional vectors are illegal in SPIR-V. See for example: https://github.com/llvm/llvm-project/pull/155499/files
EDIT: Probably makes sense to have both. This change makes sure that the types are valid in SPIR-V while the draft PR makes sure that the shapes are valid.
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/116/builds/17908 Here is the relevant piece of the build log for the reference
|
seems unrelated |
Yes, I think they are different issues. The PR fixes a bug in dialect conversion. Whether running the unrolling patterns is a different problem. The PR "fixes" the IREE SPIR-V lit tests just because there are no n-D vectors in the tests. The real fix would be adding the unrolling patterns and lowering them to corresponding SPIR-V ops. We'll need both, IMO. Whether running the pattern in SPIR-V conversion is unknown to me, because people tend to decouple such patterns from dialect conversion, see #151175 (comment). They are not added to LLVM conversion, and I think it is better to follow it for consistency. On IREE side, we've been running a lot of patterns in the final conversion, and we can add it to ConvertToSPIRVPass for now. |
We can add it to the --test-convert-to-spirv pass in mlir |
No description provided.