Skip to content

Commit b6b4362

Browse files
add foldConstantOp fold function and apply it to extractOp and insertOp.
1 parent 1b1270f commit b6b4362

File tree

3 files changed

+81
-2
lines changed

3 files changed

+81
-2
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1977,6 +1977,46 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19771977
return fromElementsOp.getElements()[flatIndex];
19781978
}
19791979

1980+
// If the dynamic operands of `extractOp` or `insertOp` is result of
1981+
// `constantOp`, then fold it.
1982+
template <typename T>
1983+
static void foldConstantOp(T op, SmallVectorImpl<Value> &operands) {
1984+
auto staticPosition = op.getStaticPosition().vec();
1985+
OperandRange dynamicPosition = op.getDynamicPosition();
1986+
1987+
// If the dynamic operands is empty, it is returned directly.
1988+
if (!dynamicPosition.size())
1989+
return;
1990+
unsigned index = 0;
1991+
1992+
// `opChange` is a flog. If it is true, it means to update `op` in place.
1993+
bool opChange = false;
1994+
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
1995+
if (!ShapedType::isDynamic(staticPosition[i]))
1996+
continue;
1997+
Value position = dynamicPosition[index++];
1998+
1999+
// If it is a block parameter, proceed to the next iteration.
2000+
if (!position.getDefiningOp()) {
2001+
operands.push_back(position);
2002+
continue;
2003+
}
2004+
2005+
if (auto constantOp =
2006+
mlir::dyn_cast<arith::ConstantIndexOp>(position.getDefiningOp())) {
2007+
opChange = true;
2008+
staticPosition[i] = constantOp.value();
2009+
continue;
2010+
}
2011+
operands.push_back(position);
2012+
}
2013+
2014+
if (opChange) {
2015+
op.setStaticPosition(staticPosition);
2016+
op.getOperation()->setOperands(operands);
2017+
}
2018+
}
2019+
19802020
OpFoldResult ExtractOp::fold(FoldAdaptor) {
19812021
// Fold "vector.extract %v[] : vector<2x2xf32> from vector<2x2xf32>" to %v.
19822022
// Note: Do not fold "vector.extract %v[] : f32 from vector<f32>" (type
@@ -1999,6 +2039,8 @@ OpFoldResult ExtractOp::fold(FoldAdaptor) {
19992039
return val;
20002040
if (auto val = foldScalarExtractFromFromElements(*this))
20012041
return val;
2042+
SmallVector<Value> operands = {getVector()};
2043+
foldConstantOp(*this, operands);
20022044
return OpFoldResult();
20032045
}
20042046

@@ -3028,6 +3070,8 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
30283070
// (type mismatch).
30293071
if (getNumIndices() == 0 && getSourceType() == getType())
30303072
return getSource();
3073+
SmallVector<Value> operands = {getSource(), getDest()};
3074+
foldConstantOp(*this, operands);
30313075
return {};
30323076
}
30333077

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4115,3 +4115,39 @@ func.func @step_scalable() -> vector<[4]xindex> {
41154115
%0 = vector.step : vector<[4]xindex>
41164116
return %0 : vector<[4]xindex>
41174117
}
4118+
4119+
// -----
4120+
4121+
// CHECK-LABEL: @extract_arith_constnt
4122+
func.func @extract_arith_constnt() -> i32 {
4123+
%v = arith.constant dense<0> : vector<32x1xi32>
4124+
%c_0 = arith.constant 0 : index
4125+
%elem = vector.extract %v[%c_0, %c_0] : i32 from vector<32x1xi32>
4126+
return %elem : i32
4127+
}
4128+
4129+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4130+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4131+
// CHECK: %[[VAL_2:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4132+
// CHECK: %[[VAL_3:.*]] = llvm.mlir.constant(0 : i64) : i64
4133+
// CHECK: %{{.*}} = llvm.extractelement %[[VAL_2]]{{\[}}%[[VAL_3]] : i64] : vector<1xi32>
4134+
4135+
// -----
4136+
4137+
// CHECK-LABEL: @insert_arith_constnt()
4138+
4139+
func.func @insert_arith_constnt() -> vector<32x1xi32> {
4140+
%v = arith.constant dense<0> : vector<32x1xi32>
4141+
%c_0 = arith.constant 0 : index
4142+
%c_1 = arith.constant 1 : i32
4143+
%v_1 = vector.insert %c_1, %v[%c_0, %c_0] : i32 into vector<32x1xi32>
4144+
return %v_1 : vector<32x1xi32>
4145+
}
4146+
4147+
// CHECK: %[[VAL_0:.*]] = arith.constant dense<0> : vector<32x1xi32>
4148+
// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4149+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
4150+
// CHECK: %[[VAL_3:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>
4151+
// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(0 : i64) : i64
4152+
// CHECK: %[[VAL_5:.*]] = llvm.insertelement %[[VAL_2]], %[[VAL_3]]{{\[}}%[[VAL_4]] : i64] : vector<1xi32>
4153+
// CHECK: %{{.*}} = llvm.insertvalue %[[VAL_5]], %[[VAL_1]][0] : !llvm.array<32 x vector<1xi32>>

mlir/test/Dialect/Vector/vector-warp-distribute.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {
778778

779779
// CHECK-PROP-LABEL: func.func @vector_extract_1d(
780780
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
781-
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
782781
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
783782
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
784783
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
785784
// CHECK-PROP: }
786-
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
785+
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
787786
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
788787
// CHECK-PROP: return %[[SHUFFLED]] : f32
789788
func.func @vector_extract_1d(%laneid: index) -> (f32) {

0 commit comments

Comments
 (0)