Skip to content

Commit 7842e9e

Browse files
authored
[mlir][Vector] Lower vector.to_elements to LLVM (#145766)
Only elements with at least one use are lowered to `llvm.extractelement` op.
1 parent abf8e25 commit 7842e9e

File tree

2 files changed

+73
-2
lines changed

2 files changed

+73
-2
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,38 @@ struct VectorFromElementsLowering
19851985
}
19861986
};
19871987

1988+
/// Conversion pattern for a `vector.to_elements`.
1989+
struct VectorToElementsLowering
1990+
: public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1991+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1992+
1993+
LogicalResult
1994+
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1995+
ConversionPatternRewriter &rewriter) const override {
1996+
Location loc = toElementsOp.getLoc();
1997+
auto idxType = typeConverter->convertType(rewriter.getIndexType());
1998+
Value source = adaptor.getSource();
1999+
2000+
SmallVector<Value> results(toElementsOp->getNumResults());
2001+
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
2002+
// Create an extractelement operation only for results that are not dead.
2003+
if (element.use_empty())
2004+
continue;
2005+
2006+
auto constIdx = rewriter.create<LLVM::ConstantOp>(
2007+
loc, idxType, rewriter.getIntegerAttr(idxType, idx));
2008+
auto llvmType = typeConverter->convertType(element.getType());
2009+
2010+
Value result = rewriter.create<LLVM::ExtractElementOp>(loc, llvmType,
2011+
source, constIdx);
2012+
results[idx] = result;
2013+
}
2014+
2015+
rewriter.replaceOp(toElementsOp, results);
2016+
return success();
2017+
}
2018+
};
2019+
19882020
/// Conversion pattern for vector.step.
19892021
struct VectorScalableStepOpLowering
19902022
: public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2067,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
20352067
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
20362068
MaskedReductionOpConversion, VectorInterleaveOpLowering,
20372069
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2038-
VectorScalableStepOpLowering>(converter);
2070+
VectorToElementsLowering, VectorScalableStepOpLowering>(
2071+
converter);
20392072
}
20402073

20412074
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
18751875
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
18761876
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
18771877
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
1878-
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
1878+
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
18791879
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
18801880
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
18811881
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,44 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
24212421

24222422
// -----
24232423

2424+
//===----------------------------------------------------------------------===//
2425+
// vector.to_elements
2426+
//===----------------------------------------------------------------------===//
2427+
2428+
// CHECK-LABEL: func.func @to_elements_no_dead_elements
2429+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
2430+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
2431+
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
2432+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
2433+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
2434+
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
2435+
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
2436+
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
2437+
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
2438+
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
2439+
func.func @to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
2440+
%0:4 = vector.to_elements %a : vector<4xf32>
2441+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
2442+
}
2443+
2444+
// -----
2445+
2446+
// CHECK-LABEL: func.func @to_elements_dead_elements
2447+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
2448+
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
2449+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
2450+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
2451+
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
2452+
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
2453+
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
2454+
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
2455+
func.func @to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
2456+
%0:4 = vector.to_elements %a : vector<4xf32>
2457+
return %0#1, %0#3 : f32, f32
2458+
}
2459+
2460+
// -----
2461+
24242462
//===----------------------------------------------------------------------===//
24252463
// vector.step
24262464
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)