-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Lower vector.to_elements
to LLVM
#145766
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Vector] Lower vector.to_elements
to LLVM
#145766
Conversation
@llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesOnly elements with at least one use are lowered to Full diff: https://github.com/llvm/llvm-project/pull/145766.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
index d53d11f87efe8..f1543200fb56f 100644
--- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
+++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
}
};
+/// Conversion pattern for a `vector.to_elements`.
+struct VectorToElementsLowering
+ : public ConvertOpToLLVMPattern<vector::ToElementsOp> {
+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = toElementsOp.getLoc();
+ auto idxType = typeConverter->convertType(rewriter.getIndexType());
+ Value source = adaptor.getSource();
+
+ SmallVector<Value> results(toElementsOp->getNumResults());
+ for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
+ // Create an extractelement operation only for results that are not dead.
+ if (!element.use_empty()) {
+ auto constIdx = rewriter.create<LLVM::ConstantOp>(
+ loc, idxType, rewriter.getIntegerAttr(idxType, idx));
+ auto llvmType = typeConverter->convertType(element.getType());
+
+ Value result = rewriter.create<LLVM::ExtractElementOp>(
+ loc, llvmType, source, constIdx);
+ results[idx] = result;
+ }
+ }
+
+ rewriter.replaceOp(toElementsOp, results);
+ return success();
+ }
+};
+
/// Conversion pattern for vector.step.
struct VectorScalableStepOpLowering
: public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
MaskedReductionOpConversion, VectorInterleaveOpLowering,
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
- VectorScalableStepOpLowering>(converter);
+ VectorToElementsLowering, VectorScalableStepOpLowering>(
+ converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
index 3df14528bac39..8f73e79d7bfc2 100644
--- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
+++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
-// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
+// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
// -----
+// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+ // CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ // CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ // CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
+func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
+}
+
+// -----
+
+// CHECK-LABEL: func.func @vector_to_elements_dead_elements
+ // CHECK-SAME: %[[A:.*]]: vector<4xf32>)
+ // CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
+ // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
+ // CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
+ // CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
+ // CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
+ // CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
+ // CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
+func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
+ %0:4 = vector.to_elements %a : vector<4xf32>
+ return %0#1, %0#3 : f32, f32
+}
+
+// -----
+
//===----------------------------------------------------------------------===//
// vector.step
//===----------------------------------------------------------------------===//
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG % minor nits, thanks!
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements | ||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | ||
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 | ||
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32> | ||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | ||
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | ||
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 | ||
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32> | ||
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | ||
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | ||
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 | ||
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { | ||
%0:4 = vector.to_elements %a : vector<4xf32> | ||
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 | ||
} | ||
|
||
// ----- | ||
|
||
// CHECK-LABEL: func.func @vector_to_elements_dead_elements | ||
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | ||
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64 | ||
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | ||
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | ||
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64 | ||
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | ||
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | ||
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 | ||
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { | ||
%0:4 = vector.to_elements %a : vector<4xf32> | ||
return %0#1, %0#3 : f32, f32 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nit] Add block comment, remove vector
from function names (for consistency)
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements | |
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | |
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 | |
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32> | |
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | |
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | |
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 | |
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32> | |
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | |
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | |
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 | |
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { | |
%0:4 = vector.to_elements %a : vector<4xf32> | |
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 | |
} | |
// ----- | |
// CHECK-LABEL: func.func @vector_to_elements_dead_elements | |
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | |
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64 | |
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | |
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | |
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64 | |
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | |
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | |
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 | |
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { | |
%0:4 = vector.to_elements %a : vector<4xf32> | |
return %0#1, %0#3 : f32, f32 | |
} | |
//===----------------------------------------------------------------------===// | |
// vector.to_elements | |
//===----------------------------------------------------------------------===// | |
// CHECK-LABEL: func.func @to_elements_no_dead_elements | |
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | |
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64 | |
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32> | |
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | |
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | |
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64 | |
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32> | |
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | |
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | |
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32 | |
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) { | |
%0:4 = vector.to_elements %a : vector<4xf32> | |
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32 | |
} | |
// ----- | |
// CHECK-LABEL: func.func @to_elements_dead_elements | |
// CHECK-SAME: %[[A:.*]]: vector<4xf32>) | |
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64 | |
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64 | |
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32> | |
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64 | |
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64 | |
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32> | |
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32 | |
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) { | |
%0:4 = vector.to_elements %a : vector<4xf32> | |
return %0#1, %0#3 : f32, f32 | |
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dcaballe do you plan to add SPIR-V lowering on your own? If not, could you open an issue and tag me so that I can assign it to someone?
SmallVector<Value> results(toElementsOp->getNumResults()); | ||
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) { | ||
// Create an extractelement operation only for results that are not dead. | ||
if (!element.use_empty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: you can invert this check and continue
instead to reduce nesting
Only elements with at least one use are lowered to `llvm.extractelement` op.
2c40823
to
cc50307
Compare
Only elements with at least one use are lowered to `llvm.extractelement` op.
Only elements with at least one use are lowered to
llvm.extractelement
op.