Skip to content

Commit ea6a60a

Browse files
committed
[mlir][vector] Add folder for ExtractStridedSliceOp
Add folder for the case where ExtractStridedSliceOp source comes from a chain of InsertStridedSliceOp. Also add a folder for the trivial case where the ExtractStridedSliceOp is a no-op. Differential Revision: https://reviews.llvm.org/D89850
1 parent bfb04ae commit ea6a60a

File tree

3 files changed

+165
-0
lines changed

3 files changed

+165
-0
lines changed

mlir/include/mlir/Dialect/Vector/VectorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,7 @@ def Vector_ExtractStridedSliceOp :
10161016
void getOffsets(SmallVectorImpl<int64_t> &results);
10171017
}];
10181018
let hasCanonicalizer = 1;
1019+
let hasFolder = 1;
10191020
let assemblyFormat = "$vector attr-dict `:` type($vector) `to` type(results)";
10201021
}
10211022

mlir/lib/Dialect/Vector/VectorOps.cpp

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1629,6 +1629,81 @@ static LogicalResult verify(ExtractStridedSliceOp op) {
16291629
return success();
16301630
}
16311631

1632+
// When the source of ExtractStrided comes from a chain of InsertStrided ops try
1633+
// to use the source o the InsertStrided ops if we can detect that the extracted
1634+
// vector is a subset of one of the vector inserted.
1635+
static LogicalResult
1636+
foldExtractStridedOpFromInsertChain(ExtractStridedSliceOp op) {
1637+
// Helper to extract integer out of ArrayAttr.
1638+
auto getElement = [](ArrayAttr array, int idx) {
1639+
return array[idx].cast<IntegerAttr>().getInt();
1640+
};
1641+
ArrayAttr extractOffsets = op.offsets();
1642+
ArrayAttr extractStrides = op.strides();
1643+
ArrayAttr extractSizes = op.sizes();
1644+
auto insertOp = op.vector().getDefiningOp<InsertStridedSliceOp>();
1645+
while (insertOp) {
1646+
if (op.getVectorType().getRank() !=
1647+
insertOp.getSourceVectorType().getRank())
1648+
return failure();
1649+
ArrayAttr insertOffsets = insertOp.offsets();
1650+
ArrayAttr insertStrides = insertOp.strides();
1651+
// If the rank of extract is greater than the rank of insert, we are likely
1652+
// extracting a partial chunk of the vector inserted.
1653+
if (extractOffsets.size() > insertOffsets.size())
1654+
return failure();
1655+
bool patialoverlap = false;
1656+
bool disjoint = false;
1657+
SmallVector<int64_t, 4> offsetDiffs;
1658+
for (unsigned dim = 0, e = extractOffsets.size(); dim < e; ++dim) {
1659+
if (getElement(extractStrides, dim) != getElement(insertStrides, dim))
1660+
return failure();
1661+
int64_t start = getElement(insertOffsets, dim);
1662+
int64_t end = start + insertOp.getSourceVectorType().getDimSize(dim);
1663+
int64_t offset = getElement(extractOffsets, dim);
1664+
int64_t size = getElement(extractSizes, dim);
1665+
// Check if the start of the extract offset is in the interval inserted.
1666+
if (start <= offset && offset < end) {
1667+
// If the extract interval overlaps but is not fully included we may
1668+
// have a partial overlap that will prevent any folding.
1669+
if (offset + size > end)
1670+
patialoverlap = true;
1671+
offsetDiffs.push_back(offset - start);
1672+
continue;
1673+
}
1674+
disjoint = true;
1675+
break;
1676+
}
1677+
// The extract element chunk is a subset of the insert element.
1678+
if (!disjoint && !patialoverlap) {
1679+
op.setOperand(insertOp.source());
1680+
// OpBuilder is only used as a helper to build an I64ArrayAttr.
1681+
OpBuilder b(op.getContext());
1682+
op.setAttr(ExtractStridedSliceOp::getOffsetsAttrName(),
1683+
b.getI64ArrayAttr(offsetDiffs));
1684+
return success();
1685+
}
1686+
// If the chunk extracted is disjoint from the chunk inserted, keep looking
1687+
// in the insert chain.
1688+
if (disjoint)
1689+
insertOp = insertOp.dest().getDefiningOp<InsertStridedSliceOp>();
1690+
else {
1691+
// The extracted vector partially overlap the inserted vector, we cannot
1692+
// fold.
1693+
return failure();
1694+
}
1695+
}
1696+
return failure();
1697+
}
1698+
1699+
OpFoldResult ExtractStridedSliceOp::fold(ArrayRef<Attribute> operands) {
1700+
if (getVectorType() == getResult().getType())
1701+
return vector();
1702+
if (succeeded(foldExtractStridedOpFromInsertChain(*this)))
1703+
return getResult();
1704+
return {};
1705+
}
1706+
16321707
void ExtractStridedSliceOp::getOffsets(SmallVectorImpl<int64_t> &results) {
16331708
populateFromInt64AttrArray(offsets(), results);
16341709
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,95 @@ func @extract_strided_slice_of_constant_mask() -> (vector<2x1xi1>) {
9090

9191
// -----
9292

93+
// CHECK-LABEL: extract_strided_fold
94+
// CHECK-SAME: (%[[ARG:.*]]: vector<4x3xi1>)
95+
// CHECK-NEXT: return %[[ARG]] : vector<4x3xi1>
96+
func @extract_strided_fold(%arg : vector<4x3xi1>) -> (vector<4x3xi1>) {
97+
%0 = vector.extract_strided_slice %arg
98+
{offsets = [0, 0], sizes = [4, 3], strides = [1, 1]}
99+
: vector<4x3xi1> to vector<4x3xi1>
100+
return %0 : vector<4x3xi1>
101+
}
102+
103+
// -----
104+
105+
// CHECK-LABEL: extract_strided_fold_insert
106+
// CHECK-SAME: (%[[ARG:.*]]: vector<4x4xf32>
107+
// CHECK-NEXT: return %[[ARG]] : vector<4x4xf32>
108+
func @extract_strided_fold_insert(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
109+
-> (vector<4x4xf32>) {
110+
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
111+
: vector<4x4xf32> into vector<8x16xf32>
112+
%1 = vector.extract_strided_slice %0
113+
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
114+
: vector<8x16xf32> to vector<4x4xf32>
115+
return %1 : vector<4x4xf32>
116+
}
117+
118+
// -----
119+
120+
// Case where the vector inserted is a subset of the vector extracted.
121+
// CHECK-LABEL: extract_strided_fold_insert
122+
// CHECK-SAME: (%[[ARG0:.*]]: vector<6x4xf32>
123+
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG0]]
124+
// CHECK-SAME: {offsets = [0, 0], sizes = [4, 4], strides = [1, 1]}
125+
// CHECK-SAME: : vector<6x4xf32> to vector<4x4xf32>
126+
// CHECK-NEXT: return %[[EXT]] : vector<4x4xf32>
127+
func @extract_strided_fold_insert(%a: vector<6x4xf32>, %b: vector<8x16xf32>)
128+
-> (vector<4x4xf32>) {
129+
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
130+
: vector<6x4xf32> into vector<8x16xf32>
131+
%1 = vector.extract_strided_slice %0
132+
{offsets = [2, 2], sizes = [4, 4], strides = [1, 1]}
133+
: vector<8x16xf32> to vector<4x4xf32>
134+
return %1 : vector<4x4xf32>
135+
}
136+
137+
// -----
138+
139+
// Negative test where the extract is not a subset of the element inserted.
140+
// CHECK-LABEL: extract_strided_fold_negative
141+
// CHECK-SAME: (%[[ARG0:.*]]: vector<4x4xf32>, %[[ARG1:.*]]: vector<8x16xf32>
142+
// CHECK: %[[INS:.*]] = vector.insert_strided_slice %[[ARG0]], %[[ARG1]]
143+
// CHECK-SAME: {offsets = [2, 2], strides = [1, 1]}
144+
// CHECK-SAME: : vector<4x4xf32> into vector<8x16xf32>
145+
// CHECK: %[[EXT:.*]] = vector.extract_strided_slice %[[INS]]
146+
// CHECK-SAME: {offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
147+
// CHECK-SAME: : vector<8x16xf32> to vector<6x4xf32>
148+
// CHECK-NEXT: return %[[EXT]] : vector<6x4xf32>
149+
func @extract_strided_fold_negative(%a: vector<4x4xf32>, %b: vector<8x16xf32>)
150+
-> (vector<6x4xf32>) {
151+
%0 = vector.insert_strided_slice %a, %b {offsets = [2, 2], strides = [1, 1]}
152+
: vector<4x4xf32> into vector<8x16xf32>
153+
%1 = vector.extract_strided_slice %0
154+
{offsets = [2, 2], sizes = [6, 4], strides = [1, 1]}
155+
: vector<8x16xf32> to vector<6x4xf32>
156+
return %1 : vector<6x4xf32>
157+
}
158+
159+
// -----
160+
161+
// Case where we need to go through 2 level of insert element.
162+
// CHECK-LABEL: extract_strided_fold_insert
163+
// CHECK-SAME: (%[[ARG0:.*]]: vector<2x4xf32>, %[[ARG1:.*]]: vector<1x4xf32>,
164+
// CHECK-NEXT: %[[EXT:.*]] = vector.extract_strided_slice %[[ARG1]]
165+
// CHECK-SAME: {offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
166+
// CHECK-SAME: : vector<1x4xf32> to vector<1x1xf32>
167+
// CHECK-NEXT: return %[[EXT]] : vector<1x1xf32>
168+
func @extract_strided_fold_insert(%a: vector<2x4xf32>, %b: vector<1x4xf32>,
169+
%c : vector<1x4xf32>) -> (vector<1x1xf32>) {
170+
%0 = vector.insert_strided_slice %b, %a {offsets = [0, 0], strides = [1, 1]}
171+
: vector<1x4xf32> into vector<2x4xf32>
172+
%1 = vector.insert_strided_slice %c, %0 {offsets = [1, 0], strides = [1, 1]}
173+
: vector<1x4xf32> into vector<2x4xf32>
174+
%2 = vector.extract_strided_slice %1
175+
{offsets = [0, 1], sizes = [1, 1], strides = [1, 1]}
176+
: vector<2x4xf32> to vector<1x1xf32>
177+
return %2 : vector<1x1xf32>
178+
}
179+
180+
// -----
181+
93182
// CHECK-LABEL: transpose_1D_identity
94183
// CHECK-SAME: ([[ARG:%.*]]: vector<4xf32>)
95184
func @transpose_1D_identity(%arg : vector<4xf32>) -> vector<4xf32> {

0 commit comments

Comments
 (0)