Skip to content

Commit c60506c

Browse files
author
Jerry Wu
committed
Refactor and fix tests
1 parent 376f8d5 commit c60506c

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -553,17 +553,26 @@ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
553553
ControlPropagationFn controlFn;
554554
};
555555

556+
/// Project dimsPos to the inner-most non-unit dim pos with reassocIndices.
557+
///
558+
/// For example, given dimsPos [0, 2], reassocIndices [[0, 1], [2, 3]], and
559+
/// targetShape [16, 16, 32, 1], it returns [1, 2]. Because for pos 0, the
560+
/// inner-most projected dim in pos [0, 1] is 1. And for pos 2, the inner-most
561+
/// non-unit projected dims in pos [2, 3] is 2.
562+
///
563+
/// If all candidates in a reassociation are unit dims, it chooses the
564+
/// inner-most dim pos.
556565
static SmallVector<int64_t>
557566
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
558567
ArrayRef<ReassociationIndices> reassocIndices,
559-
ArrayRef<int64_t> baseShape) {
568+
ArrayRef<int64_t> targetShape) {
560569
SmallVector<int64_t> projectedDimsPos;
561570
for (auto pos : dimsPos) {
562571
// In the case all dims are unit, this will return the inner-most one.
563572
int64_t projectedPos = reassocIndices[pos].back();
564573
for (auto it = reassocIndices[pos].rbegin();
565574
it != reassocIndices[pos].rend(); ++it) {
566-
int64_t dim = baseShape[*it];
575+
int64_t dim = targetShape[*it];
567576
if (dim > 1 || ShapedType::isDynamic(dim)) {
568577
projectedPos = *it;
569578
break;
@@ -574,32 +583,36 @@ projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
574583
return projectedDimsPos;
575584
}
576585

577-
static bool
578-
isProjectedDimsDivisibleByTileSizes(ArrayRef<int64_t> projectedDimsPos,
579-
ArrayRef<int64_t> targetShape,
580-
ArrayRef<int64_t> tileSizes) {
581-
for (auto [projectedPos, tileSize] :
582-
llvm::zip_equal(projectedDimsPos, tileSizes)) {
583-
int64_t dim = targetShape[projectedPos];
586+
/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
587+
static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
588+
ArrayRef<int64_t> shape,
589+
ArrayRef<int64_t> tileSizes) {
590+
for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
591+
int64_t dim = shape[pos];
584592
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
585593
return false;
586594
}
587595
return true;
588596
}
589597

598+
/// Permutate the reassociation indices and reindex them in the sequence order.
599+
/// Returns the next dim pos in the sequence.
600+
///
601+
/// For example, given reassocIndices [[0, 1], [2]] and permutation [1, 0], it
602+
/// applies the permutation to get [[2], [0, 1]] and reindexes the indices into
603+
/// [[0], [1, 2]].
590604
static int64_t applyPermutationAndReindexReassoc(
591-
SmallVector<ReassociationIndices> &reassociationIndices,
592-
ArrayRef<int64_t> dimsPerm) {
593-
applyPermutationToVector<ReassociationIndices>(reassociationIndices,
594-
dimsPerm);
595-
int64_t lastPos = 0;
596-
for (ReassociationIndices &indices : reassociationIndices) {
605+
SmallVector<ReassociationIndices> &reassocIndices,
606+
ArrayRef<int64_t> permutation) {
607+
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
608+
int64_t nextPos = 0;
609+
for (ReassociationIndices &indices : reassocIndices) {
597610
for (auto &index : indices) {
598-
index = lastPos;
599-
lastPos += 1;
611+
index = nextPos;
612+
nextPos += 1;
600613
}
601614
}
602-
return lastPos;
615+
return nextPos;
603616
}
604617

605618
/// Bubble up pack op through collapse shape op when the packed dims can be
@@ -614,7 +627,7 @@ static int64_t applyPermutationAndReindexReassoc(
614627
/// inner_dims_pos = [0, 1] inner_tiles = [8, 1] into %empty
615628
/// : tensor<?x4xf32> -> tensor<?x4x8x1xf32>
616629
///
617-
/// Can be transformed into:
630+
/// can be transformed into:
618631
///
619632
/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
620633
/// inner_dims_pos = [1, 2] inner_tiles = [8, 1] into %empty
@@ -632,11 +645,18 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
632645
ArrayRef<int64_t> srcShape = collapseOp.getSrcType().getShape();
633646
SmallVector<ReassociationIndices> reassocIndices =
634647
collapseOp.getReassociationIndices();
648+
// Project inner tile pos to the dim pos before collapsing. For example, if
649+
// dims [x, y] is collapsed into [z], packing on dim z can be projected back
650+
// to pack on dim y.
651+
//
652+
// Project to inner-most non-unit dims to increase the chance that they can be
653+
// divided by the inner tile sizes. This is correct because for [..., x, 1],
654+
// packing on dim 1 is equivalent to packing on dim x.
635655
SmallVector<int64_t> projectedInnerDimsPos =
636656
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, srcShape);
637657

638-
if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
639-
innerTileSizes)) {
658+
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, srcShape,
659+
innerTileSizes)) {
640660
return failure();
641661
}
642662
// Expand the outer dims permutation with the associated source dims for the
@@ -658,14 +678,14 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
658678

659679
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
660680
// First apply the permutation on the reassociations of the outer dims.
661-
// For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
681+
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
662682
// -> [[0], [1, 2]]
663-
int64_t lastPos =
683+
int64_t nextPos =
664684
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
665685
// Then add direct mapping for the inner tile dims.
666686
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
667-
newReassocIndices.push_back({lastPos});
668-
lastPos += 1;
687+
newReassocIndices.push_back({nextPos});
688+
nextPos += 1;
669689
}
670690

671691
auto newCollapseOp = rewriter.create<tensor::CollapseShapeOp>(
@@ -723,7 +743,7 @@ class BubbleUpPackOpThroughReshapeOp final
723743
/// %expanded = tensor.expand_shape %unpack [[0, 1], [2]]
724744
/// : tensor<?x256xf32> into tensor<?x256x256xf32>
725745
///
726-
/// Can be transformed into:
746+
/// can be transformed into:
727747
///
728748
/// %expanded = tensor.expand_shape %ain [[0, 1], [2], [3], [4]]
729749
/// : tensor<?x32x8x8xf32> into tensor<?x32x32x8x8xf32>
@@ -741,11 +761,18 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
741761
ArrayRef<int64_t> dstShape = expandOp.getType().getShape();
742762
SmallVector<ReassociationIndices> reassocIndices =
743763
expandOp.getReassociationIndices();
764+
// Project inner tile pos to the dim pos after expanding. For example, if dims
765+
// [z] is expanded into [x, y], unpacking on dim z can be projected to unpack
766+
// on dim y.
767+
//
768+
// Project to inner-most non-unit dims to increase the chance that they can be
769+
// divided by the inner tile sizes. This is correct because for [..., x, 1],
770+
// unpacking on dim 1 is equivalent to unpacking on dim x.
744771
SmallVector<int64_t> projectedInnerDimsPos =
745772
projectToInnerMostNonUnitDimsPos(innerDimsPos, reassocIndices, dstShape);
746773

747-
if (!isProjectedDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
748-
innerTileSizes)) {
774+
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, dstShape,
775+
innerTileSizes)) {
749776
return failure();
750777
}
751778
// Expand the outer dims permutation with the associated expanded dims for the
@@ -760,14 +787,14 @@ pushDownUnPackOpThroughExpandShape(tensor::UnPackOp unPackOp,
760787

761788
SmallVector<ReassociationIndices> newReassocIndices = reassocIndices;
762789
// First apply the permutation on the reassociations of the outer dims.
763-
// For example given the permutation [1, 0], the reassociations: [[0, 1], [2]]
790+
// For example given the permutation [1, 0], the reassociations [[0, 1], [2]]
764791
// -> [[0], [1, 2]]
765-
int64_t lastPos =
792+
int64_t nextPos =
766793
applyPermutationAndReindexReassoc(newReassocIndices, outerDimsPerm);
767794
// Then add direct mapping for the inner tile dims.
768795
for (size_t i = 0; i < innerDimsPos.size(); ++i) {
769-
newReassocIndices.push_back({lastPos});
770-
lastPos += 1;
796+
newReassocIndices.push_back({nextPos});
797+
nextPos += 1;
771798
}
772799

773800
RankedTensorType newExpandType =

mlir/test/Dialect/Linalg/data-layout-propagation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -998,7 +998,7 @@ func.func @push_down_permuted_unpack_through_expand(%5: tensor<4x32x384x8x8xf32>
998998
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]
999999
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2, 3], [4], [5]] : tensor<4x32x384x8x8xf32> into tensor<4x32x12x32x8x8xf32>
10001000
// CHECK: %[[EMPTY:.+]] = tensor.empty() : tensor<4x12x256x256xf32>
1001-
// CHECK: %[[UNPACL:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
1001+
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[EXPANDED]] outer_dims_perm = [0, 3, 1, 2] inner_dims_pos = [3, 2] inner_tiles = [8, 8] into %[[EMPTY]] : tensor<4x32x12x32x8x8xf32> -> tensor<4x12x256x256xf32>
10021002
// CHECK: return %[[UNPACK]] : tensor<4x12x256x256xf32>
10031003

10041004
// -----

0 commit comments

Comments
 (0)