Skip to content

Commit 8ee9261

Browse files
committed
[mlir][Vector] Add vector.to_elements lowering to LLVM
Only elements with at least one use are lowered to `llvm.extractelement` op.
1 parent 8a65196 commit 8ee9261

File tree

2 files changed

+68
-2
lines changed

2 files changed

+68
-2
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1985,6 +1985,37 @@ struct VectorFromElementsLowering
19851985
}
19861986
};
19871987

1988+
/// Conversion pattern for a `vector.to_elements`.
1989+
struct VectorToElementsLowering
1990+
: public ConvertOpToLLVMPattern<vector::ToElementsOp> {
1991+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1992+
1993+
LogicalResult
1994+
matchAndRewrite(vector::ToElementsOp toElementsOp, OpAdaptor adaptor,
1995+
ConversionPatternRewriter &rewriter) const override {
1996+
Location loc = toElementsOp.getLoc();
1997+
auto idxType = typeConverter->convertType(rewriter.getIndexType());
1998+
Value source = adaptor.getSource();
1999+
2000+
SmallVector<Value> results(toElementsOp->getNumResults());
2001+
for (auto [idx, element] : llvm::enumerate(toElementsOp.getElements())) {
2002+
// Create an extractelement operation only for results that are not dead.
2003+
if (!element.use_empty()) {
2004+
auto constIdx = rewriter.create<LLVM::ConstantOp>(
2005+
loc, idxType, rewriter.getIntegerAttr(idxType, idx));
2006+
auto llvmType = typeConverter->convertType(element.getType());
2007+
2008+
Value result = rewriter.create<LLVM::ExtractElementOp>(
2009+
loc, llvmType, source, constIdx);
2010+
results[idx] = result;
2011+
}
2012+
}
2013+
2014+
rewriter.replaceOp(toElementsOp, results);
2015+
return success();
2016+
}
2017+
};
2018+
19882019
/// Conversion pattern for vector.step.
19892020
struct VectorScalableStepOpLowering
19902021
: public ConvertOpToLLVMPattern<vector::StepOp> {
@@ -2035,7 +2066,8 @@ void mlir::populateVectorToLLVMConversionPatterns(
20352066
VectorScalableInsertOpLowering, VectorScalableExtractOpLowering,
20362067
MaskedReductionOpConversion, VectorInterleaveOpLowering,
20372068
VectorDeinterleaveOpLowering, VectorFromElementsLowering,
2038-
VectorScalableStepOpLowering>(converter);
2069+
VectorToElementsLowering, VectorScalableStepOpLowering>(
2070+
converter);
20392071
}
20402072

20412073
void mlir::populateVectorToLLVMMatrixConversionPatterns(

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1875,7 +1875,7 @@ func.func @store_0d(%memref : memref<200x100xf32>, %i : index, %j : index) {
18751875
// CHECK: %[[CAST_MEMREF:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<200x100xf32> to !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
18761876
// CHECK: %[[CST:.*]] = arith.constant dense<1.100000e+01> : vector<f32>
18771877
// CHECK: %[[VAL:.*]] = builtin.unrealized_conversion_cast %[[CST]] : vector<f32> to vector<1xf32>
1878-
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
1878+
// CHECK: %[[REF:.*]] = llvm.extractvalue %[[CAST_MEMREF]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
18791879
// CHECK: %[[C100:.*]] = llvm.mlir.constant(100 : index) : i64
18801880
// CHECK: %[[MUL:.*]] = llvm.mul %[[I]], %[[C100]] : i64
18811881
// CHECK: %[[ADD:.*]] = llvm.add %[[MUL]], %[[J]] : i64
@@ -2421,6 +2421,40 @@ func.func @from_elements_0d(%arg0: f32) -> vector<f32> {
24212421

24222422
// -----
24232423

2424+
// CHECK-LABEL: func.func @vector_to_elements_no_dead_elements
2425+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
2426+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
2427+
// CHECK: %[[ELEM0:.*]] = llvm.extractelement %[[A]][%[[C0]] : i64] : vector<4xf32>
2428+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
2429+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
2430+
// CHECK: %[[C2:.*]] = llvm.mlir.constant(2 : i64) : i64
2431+
// CHECK: %[[ELEM2:.*]] = llvm.extractelement %[[A]][%[[C2]] : i64] : vector<4xf32>
2432+
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
2433+
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
2434+
// CHECK: return %[[ELEM0]], %[[ELEM1]], %[[ELEM2]], %[[ELEM3]] : f32, f32, f32, f32
2435+
func.func @vector_to_elements_no_dead_elements(%a: vector<4xf32>) -> (f32, f32, f32, f32) {
2436+
%0:4 = vector.to_elements %a : vector<4xf32>
2437+
return %0#0, %0#1, %0#2, %0#3 : f32, f32, f32, f32
2438+
}
2439+
2440+
// -----
2441+
2442+
// CHECK-LABEL: func.func @vector_to_elements_dead_elements
2443+
// CHECK-SAME: %[[A:.*]]: vector<4xf32>)
2444+
// CHECK-NOT: llvm.mlir.constant(0 : i64) : i64
2445+
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i64) : i64
2446+
// CHECK: %[[ELEM1:.*]] = llvm.extractelement %[[A]][%[[C1]] : i64] : vector<4xf32>
2447+
// CHECK-NOT: llvm.mlir.constant(2 : i64) : i64
2448+
// CHECK: %[[C3:.*]] = llvm.mlir.constant(3 : i64) : i64
2449+
// CHECK: %[[ELEM3:.*]] = llvm.extractelement %[[A]][%[[C3]] : i64] : vector<4xf32>
2450+
// CHECK: return %[[ELEM1]], %[[ELEM3]] : f32, f32
2451+
func.func @vector_to_elements_dead_elements(%a: vector<4xf32>) -> (f32, f32) {
2452+
%0:4 = vector.to_elements %a : vector<4xf32>
2453+
return %0#1, %0#3 : f32, f32
2454+
}
2455+
2456+
// -----
2457+
24242458
//===----------------------------------------------------------------------===//
24252459
// vector.step
24262460
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)