Skip to content

Commit b27c49d

Browse files
committed
fixup! [mlir][Vector] Update patterns for flattening vector.xfer Ops (2/N)
Refactor to use makeComposedFoldedAffineApply
1 parent ee5e355 commit b27c49d

File tree

3 files changed

+28
-25
lines changed

3 files changed

+28
-25
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ class FlattenContiguousRowMajorTransferReadPattern
544544
auto loc = transferReadOp.getLoc();
545545
Value vector = transferReadOp.getVector();
546546
VectorType vectorType = cast<VectorType>(vector.getType());
547-
Value source = transferReadOp.getSource();
547+
auto source = transferReadOp.getSource();
548548
MemRefType sourceType = dyn_cast<MemRefType>(source.getType());
549549

550550
// 0. Check pre-conditions
@@ -602,26 +602,30 @@ class FlattenContiguousRowMajorTransferReadPattern
602602
//
603603
// For this example:
604604
// %2 = vector.transfer_read %arg4[%c0, %arg0, %c0] (...) :
605-
// memref<1x43x2xi32>, vector<1x2xi32>
605+
// memref<1x43x2xi32>, vector<1x2xi32>
606606
// which would be collapsed to:
607607
// %1 = vector.transfer_read %collapse_shape[%c0, %offset] (...) :
608-
// memref<1x86xi32>, vector<2xi32>
608+
// memref<1x86xi32>, vector<2xi32>
609609
// one would get the following offset:
610610
// %offset = %arg0 * 43
611+
AffineExpr offsetE, idx;
612+
bindSymbols(rewriter.getContext(), offsetE, idx);
613+
611614
int64_t outputRank = transferReadOp.getIndices().size();
612-
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
615+
OpFoldResult offset =
616+
rewriter.create<arith::ConstantIndexOp>(loc, 0).getResult();
613617
for (int64_t i = firstDimToCollapse; i < outputRank; ++i) {
614-
Value dimIdx = rewriter.create<arith::ConstantIndexOp>(loc, i);
615-
auto sourceDimSize =
616-
rewriter.create<memref::DimOp>(loc, source, dimIdx);
617-
618-
offset = rewriter.create<arith::AddIOp>(
619-
loc,
620-
rewriter.create<arith::MulIOp>(loc, transferReadOp.getIndices()[i],
621-
sourceDimSize),
622-
offset);
618+
int64_t dim = dyn_cast<ShapedType>(source.getType()).getDimSize(i);
619+
offset = affine::makeComposedFoldedAffineApply(
620+
rewriter, loc, offsetE + dim * idx,
621+
{offset, transferReadOp.getIndices()[i]});
622+
}
623+
if (offset.is<Value>()) {
624+
collapsedIndices.push_back(offset.get<Value>());
625+
} else {
626+
collapsedIndices.push_back(rewriter.create<arith::ConstantIndexOp>(
627+
loc, *getConstantIntValue(offset)));
623628
}
624-
collapsedIndices.push_back(offset);
625629
}
626630

627631
// 3. Create new vector.transfer_read that reads from the collapsed memref

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,19 @@ func.func @transfer_read_dims_mismatch_non_zero_indices(
5555
return
5656
}
5757

58+
// CHECK: #[[$ATTR_0:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 * 43)>
59+
5860
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_zero_indices(
5961
// CHECK-SAME: %[[VAL_0:.*]]: index, %[[VAL_1:.*]]: index,
6062
// CHECK-SAME: %[[VAL_2:.*]]: memref<1x43x4x6xi32>,
6163
// CHECK-SAME: %[[VAL_3:.*]]: memref<1x2x6xi32>) {
62-
// CHECK: %[[VAL_4:.*]] = arith.constant 43 : index
63-
// CHECK: %[[VAL_5:.*]] = arith.constant 4 : index
64-
// CHECK: %[[VAL_6:.*]] = arith.constant 0 : i32
65-
// CHECK: %[[VAL_7:.*]] = arith.constant 0 : index
66-
// CHECK: %[[VAL_8:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
67-
// CHECK: %[[VAL_9:.*]] = arith.muli %[[VAL_0]], %[[VAL_4]] : index
68-
// CHECK: %[[VAL_10:.*]] = arith.muli %[[VAL_1]], %[[VAL_5]] : index
69-
// CHECK: %[[VAL_11:.*]] = arith.addi %[[VAL_10]], %[[VAL_9]] : index
70-
// CHECK: %[[VAL_12:.*]] = vector.transfer_read %[[VAL_8]]{{\[}}%[[VAL_7]], %[[VAL_11]]], %[[VAL_6]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
71-
// CHECK: %[[VAL_13:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
72-
// CHECK: vector.transfer_write %[[VAL_12]], %[[VAL_13]]{{\[}}%[[VAL_7]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
64+
// CHECK: %[[VAL_4:.*]] = arith.constant 0 : i32
65+
// CHECK: %[[VAL_5:.*]] = arith.constant 0 : index
66+
// CHECK: %[[VAL_6:.*]] = memref.collapse_shape %[[VAL_2]] {{\[\[}}0], [1, 2, 3]] : memref<1x43x4x6xi32> into memref<1x1032xi32>
67+
// CHECK: %[[VAL_7:.*]] = affine.apply #[[$ATTR_0]](){{\[}}%[[VAL_1]], %[[VAL_0]]]
68+
// CHECK: %[[VAL_8:.*]] = vector.transfer_read %[[VAL_6]]{{\[}}%[[VAL_5]], %[[VAL_7]]], %[[VAL_4]] {in_bounds = [true]} : memref<1x1032xi32>, vector<12xi32>
69+
// CHECK: %[[VAL_9:.*]] = memref.collapse_shape %[[VAL_3]] {{\[\[}}0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32>
70+
// CHECK: vector.transfer_write %[[VAL_8]], %[[VAL_9]]{{\[}}%[[VAL_5]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32>
7371

7472
// -----
7573

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,7 @@ struct TestFlattenVectorTransferPatterns
454454
}
455455
void getDependentDialects(DialectRegistry &registry) const override {
456456
registry.insert<memref::MemRefDialect>();
457+
registry.insert<affine::AffineDialect>();
457458
}
458459
void runOnOperation() override {
459460
RewritePatternSet patterns(&getContext());

0 commit comments

Comments
 (0)