Skip to content

[MLIR][Tensor] Remove tensor.dim canonicalization patterns registered on tensor.expand_shape/tensor.collapse_shape #134219

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 1 addition & 86 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1986,90 +1986,6 @@ struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
}
};

struct FoldDimOfExpandShape : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto expandShapeOp = dimOp.getSource().getDefiningOp<ExpandShapeOp>();
if (!expandShapeOp)
return failure();

// Only constant dimension values are supported.
std::optional<int64_t> dim = dimOp.getConstantIndex();
if (!dim.has_value())
return failure();

// Skip static dims. These are folded to constant ops.
RankedTensorType resultType = expandShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();

// Find reassociation group that contains this result dimension.
int64_t srcDim = expandShapeOp.getCorrespondingSourceDim(*dim);

// `dim` is the only dynamic dimension in `group`. (Otherwise, the
// ExpandShapeOp would be ambiguous.)
int64_t product = 1;
ReassociationIndices grp = expandShapeOp.getReassociationIndices()[srcDim];
for (int64_t d : grp) {
if (d != dim) {
assert(!resultType.isDynamicDim(d) && "expected static dim");
product *= resultType.getDimSize(d);
}
}

// result dim size = src dim size / (product(other dims in reassoc group))
Value srcDimSz =
rewriter.create<DimOp>(dimOp.getLoc(), expandShapeOp.getSrc(), srcDim);
AffineExpr expr;
bindSymbols(dimOp.getContext(), expr);
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(
dimOp, expr.floorDiv(product), srcDimSz);
return success();
}
};

struct FoldDimOfCollapseShape : public OpRewritePattern<DimOp> {
using OpRewritePattern<DimOp>::OpRewritePattern;

LogicalResult matchAndRewrite(DimOp dimOp,
PatternRewriter &rewriter) const override {
auto collapseShapeOp = dimOp.getSource().getDefiningOp<CollapseShapeOp>();
if (!collapseShapeOp)
return failure();

// Only constant dimension values are supported.
std::optional<int64_t> dim = dimOp.getConstantIndex();
if (!dim.has_value() ||
dim.value() >= collapseShapeOp.getResultType().getRank())
return failure();

// Skip static dims. These are folded to constant ops.
RankedTensorType resultType = collapseShapeOp.getResultType();
if (!resultType.isDynamicDim(*dim))
return failure();

// Get reassociation group of the result dimension.
ReassociationIndices group =
collapseShapeOp.getReassociationIndices()[*dim];

// result dim size = product(dims in reassoc group)
SmallVector<Value> srcDimSizes;
SmallVector<AffineExpr> syms;
AffineExpr product;
for (const auto &it : llvm::enumerate(group)) {
srcDimSizes.push_back(rewriter.create<DimOp>(
dimOp.getLoc(), collapseShapeOp.getSrc(), it.value()));
syms.push_back(rewriter.getAffineSymbolExpr(it.index()));
product = product ? product * syms.back() : syms.back();
}
rewriter.replaceOpWithNewOp<affine::AffineApplyOp>(dimOp, product,
srcDimSizes);
return success();
}
};

/// Fold/sink a producer `tensor.cast` with a consumer `tensor.expand_shape` by
/// matching constant output_shape operands of the expand. This makes the
/// `tensor.expand_shape` more static and creates a consumer cast that can be
Expand Down Expand Up @@ -2158,8 +2074,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
ComposeExpandOfCollapseOp<ExpandShapeOp, CollapseShapeOp>,
ConvertToStaticExpandShape, FoldReshapeWithConstant<ExpandShapeOp>,
FoldReshapeWithSplat<ExpandShapeOp>,
FoldReshapeWithFromElements<ExpandShapeOp>, FoldDimOfExpandShape,
FoldDimOfCollapseShape>(context);
FoldReshapeWithFromElements<ExpandShapeOp>>(context);
}

void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand Down
29 changes: 6 additions & 23 deletions mlir/test/Dialect/Linalg/drop-unit-extent-dims.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,18 @@ func.func @drop_one_trip_loops(%arg0 : tensor<?x1x?xf32>, %arg1 : f32, %shape: t
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-LABEL: func @drop_one_trip_loops
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2]]
// CHECK: tensor.collapse_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP3]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C0]]
// CHECK: %[[VAL_1:.*]] = affine.apply #[[$MAP4]]()[%[[DIM]], %[[C1]]]
// CHECK: %[[DIM_1:.*]] = tensor.dim %{{.*}}, %[[C2]]
// CHECK: %[[VAL_2:.*]] = affine.apply #[[$MAP4]]()[%[[DIM_1]], %[[C1]]]
// CHECK: %[[DIM_2:.*]] = tensor.dim %{{.*}}, %[[C2]]
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[VAL_1]], 1, %[[VAL_2]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1], [2, 3], [4]] output_shape [%[[DIM]], 1, %[[DIM_1]], 1, %[[DIM_2]]] : tensor<?x?x?xf32> into tensor<?x1x?x1x?xf32>

// CHECK-SLICES-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1, d2) -> (d0, d2)>
// CHECK-SLICES-DAG: #[[$MAP2:.*]] = affine_map<(d0, d1, d2) -> ()>
Expand Down Expand Up @@ -79,18 +75,15 @@ func.func @drop_one_trip_loops_all_ones(%arg0 : tensor<1x1x1xf32>, %arg1 : f32,
}
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0) -> ()>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<(d0) -> (d0)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2, s3, s4] -> ((((s0 * s1) * s2) * s3) * s4)>
// CHECK-LABEL: func @drop_one_trip_loops_all_ones
// CHECK: %[[C2:.*]] = arith.constant 2 : index
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: tensor.collapse_shape %{{.*}} []
// CHECK: tensor.collapse_shape %{{.*}} {{\[}}[0, 1, 2, 3, 4]]
// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[$MAP1]], #[[$MAP1]], #[[$MAP2]]]
// CHECK-SAME: iterator_types = ["parallel"]
// CHECK: %[[DIM:.*]] = tensor.dim %{{.*}}, %[[C2]] : tensor<1x1x?x1x1xf32>
// CHECK: %[[SZ:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[C1]], %[[DIM]], %[[C1]], %[[C1]]]
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[SZ]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>
// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %{{.*}} {{\[\[}}0, 1, 2, 3, 4]] output_shape [1, 1, %[[DIM]], 1, 1] : tensor<?xf32> into tensor<1x1x?x1x1xf32>

// -----

Expand Down Expand Up @@ -406,7 +399,6 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * s2)>
// CHECK: func @unit_dim_for_reduction
// CHECK-SAME: %[[ARG0:.+]]: tensor<1x?x1x?xf32>
// CHECK: %[[C1:.+]] = arith.constant 1 : index
Expand All @@ -422,8 +414,7 @@ func.func @unit_dim_for_reduction(%arg0: tensor<1x?x1x?xf32>) -> tensor<1x?xf32>
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[DIM_0:.*]] = tensor.dim %[[ARG0]], %[[C1]] : tensor<1x?x1x?xf32>
// CHECK: %[[VAL_3:.*]] = affine.apply #[[$MAP3]]()[%[[C1]], %[[DIM_0]], %[[C1]]]
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_3]]] : tensor<?xf32> into tensor<1x?xf32>
// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[GENERIC]] {{\[\[}}0, 1]] output_shape [1, %[[DIM_0]]] : tensor<?xf32> into tensor<1x?xf32>
// CHECK: return %[[EXPANDED]] : tensor<1x?xf32>

// -----
Expand Down Expand Up @@ -482,10 +473,8 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-DAG: #[[MAP2:.+]] = affine_map<(d0, d1) -> (d0)>
// CHECK-DAG: #[[MAP3:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: func @unit_dim_for_reduction_inner
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x1x?x1xf32>
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[CST:.*]] = arith.constant 1.000000e+00 : f32
// CHECK: %[[C2:.*]] = arith.constant 2 : index
Expand All @@ -499,8 +488,7 @@ func.func @unit_dim_for_reduction_inner(%arg0: tensor<?x1x?x1xf32>) -> tensor<?x
// CHECK-SAME: ins(%[[RESHAPE]] : tensor<?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?xf32>)
// CHECK: %[[DIM_0:.+]] = tensor.dim %[[ARG0]], %[[C0]] : tensor<?x1x?x1xf32>
// CHECK: %[[VAL_3:.+]] = affine.apply #[[$MAP3]]()[%[[DIM_0]], %[[C1]]]
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[VAL_3]], 1] : tensor<?xf32> into tensor<?x1xf32>
// CHECK: %[[RESULT_RESHAPE:.+]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [%[[DIM_0]], 1] : tensor<?xf32> into tensor<?x1xf32>
// CHECK: return %[[RESULT_RESHAPE]]

// -----
Expand Down Expand Up @@ -1017,7 +1005,6 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
return %0 : tensor<1x?xf32>
}

// CHECK-DAG: #[[$MAP:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK-DAG: #[[$MAP1:.+]] = affine_map<()[s0] -> (s0 + 11)>
// CHECK-LABEL: func @drop_unit_pad_dynamic_dims
// CHECK: %[[C1:.*]] = arith.constant 1 : index
Expand All @@ -1027,8 +1014,7 @@ func.func @drop_unit_pad_dynamic_dims(%arg0: tensor<1x?xf32>) -> tensor<1x?xf32>
// CHECK: %[[PADDED:.+]] = tensor.pad %[[COLLAPSE]] low[5] high[6]
// CHECK: } : tensor<?xf32> to tensor<?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %{{.*}}, %[[C1]] : tensor<1x?xf32>
// CHECK: %[[VAL_0:.+]] = affine.apply #[[$MAP]]()[%[[C1]], %[[DIM]]]
// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[VAL_0]]]
// CHECK: %[[VAL_1:.+]] = affine.apply #[[$MAP1]]()[%[[DIM]]]
// CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[PADDED]] {{\[\[}}0, 1]] output_shape [1, %[[VAL_1]]] : tensor<?xf32> into tensor<1x?xf32>

// CHECK-SLICES: #[[$MAP:.+]] = affine_map<()[s0] -> (s0 + 11)>
Expand Down Expand Up @@ -1090,20 +1076,17 @@ func.func @drop_known_unit_constant_low_high(%arg0: tensor<1x383x128xf32>) -> te

// -----

// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * s1)>
// CHECK: #[[$MAP1:.+]] = affine_map<(d0) -> (0, d0)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0) -> ()>

// CHECK-LABEL: func @drop_unit_dim_corresponding_to_dynamic_dim
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x?x?x1xf32>,
// CHECK-SAME: %[[ARG1:.*]]: index) -> tensor<?x1x61x1xf32> {
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : index
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
// CHECK: %[[VAL_2:.*]] = arith.constant dense<1.000000e+00> : tensor<f32>
// CHECK: %[[VAL_3:.*]] = tensor.collapse_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] : tensor<1x?x?x1xf32> into tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
// CHECK: %[[VAL_5:.*]] = affine.apply #[[$MAP0]](){{\[}}%[[ARG1]], %[[VAL_1]]]
// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[VAL_5]]) : tensor<?x61xf32>
// CHECK: %[[VAL_6:.*]] = tensor.empty(%[[ARG1]]) : tensor<?x61xf32>
// CHECK: %[[VAL_7:.*]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel"]} ins(%[[VAL_3]], %[[VAL_2]], %[[VAL_4]] : tensor<?x?xf32>, tensor<f32>, tensor<?x61xf32>) outs(%[[VAL_6]] : tensor<?x61xf32>) {
// CHECK: ^bb0(%[[VAL_8:.*]]: f32, %[[VAL_9:.*]]: f32, %[[VAL_10:.*]]: f32, %[[VAL_11:.*]]: f32):
// CHECK: %[[VAL_12:.*]] = arith.mulf %[[VAL_8]], %[[VAL_9]] : f32
Expand Down
10 changes: 5 additions & 5 deletions mlir/test/Dialect/Linalg/rank-reduce-contraction-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,13 @@ func.func @singleton_batch_vecmat(%arg0 : tensor<1x?xf32>, %arg1 : tensor<1x?x?x
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<1x?x?xf32>
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_RHS:.*]] = tensor.collapse_shape %[[RHS]] {{\[}}[0, 1], [2]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[COLLAPSED_RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
%1 = linalg.batch_vecmat ins(%arg0, %arg1 : tensor<1x?xf32>, tensor<1x?x?xf32>)
Expand Down Expand Up @@ -134,7 +134,7 @@ func.func @matmul_to_matvec_tensor(%arg0: tensor<?x?xf32>, %arg1: tensor<?x1xf32
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[MATMUL:.+]] = linalg.matvec
// CHECK-SAME: ins(%[[LHS]], %[[COLLAPSED_RHS]] : tensor<?x?xf32>, tensor<?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[INIT]], %[[C0]]
// CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[MATMUL]] {{\[}}[0, 1]] output_shape [%[[DIM0]], 1]
// CHECK-NEXT: return %[[RES]]
%0 = linalg.matmul ins(%arg0, %arg1: tensor<?x?xf32>, tensor<?x1xf32>) outs(%arg2: tensor<?x1xf32>) -> tensor<?x1xf32>
Expand Down Expand Up @@ -171,12 +171,12 @@ func.func @matmul_to_vecmat_tensor(%arg0: tensor<1x?xf32>, %arg1: tensor<?x?xf32
// CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: tensor<?x?xf32>
// CHECK-SAME: %[[INIT:[a-zA-Z0-9]+]]: tensor<1x?xf32>
// CHECK-DAG: %[[C1:.*]] = arith.constant 1
// CHECK-DAG: %[[C0:.*]] = arith.constant 0
// CHECK-NEXT: %[[COLLAPSED_LHS:.*]] = tensor.collapse_shape %[[LHS]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[COLLAPSED_INIT:.*]] = tensor.collapse_shape %[[INIT]] {{\[}}[0, 1]]
// CHECK-NEXT: %[[RESULT:.*]] = linalg.vecmat
// CHECK-SAME: ins(%[[COLLAPSED_LHS]], %[[RHS]] : tensor<?xf32>, tensor<?x?xf32>) outs(%[[COLLAPSED_INIT]] : tensor<?xf32>)
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[INIT]], %[[C1]]
// CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[COLLAPSED_INIT]], %[[C0]]
// CHECK-NEXT: %[[RES:.*]] = tensor.expand_shape %[[RESULT]] {{\[}}[0, 1]] output_shape [1, %[[DIM1]]]
// CHECK-NEXT: return %[[RES]]
%0 = linalg.matmul ins(%arg0, %arg1: tensor<1x?xf32>, tensor<?x?xf32>) outs(%arg2: tensor<1x?xf32>) -> tensor<1x?xf32>
Expand Down
28 changes: 10 additions & 18 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1105,15 +1105,13 @@ func.func @compose_expand_of_collapse_last_two_dims(%arg0: tensor<?x64x1xf32>) -
%expanded = tensor.expand_shape %collapsed [[0, 1]] output_shape [%div, 384] : tensor<?xf32> into tensor<?x384xf32>
return %expanded : tensor<?x384xf32>
}
// CHECK: #[[$MAP:.*]] = affine_map<()[s0] -> (s0 * 64)>
// CHECK-LABEL: @compose_expand_of_collapse_last_two_dims
// CHECK-SAME: %[[ARG0:.+]]: tensor<?x64x1xf32>
// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
// CHECK: %[[CONSTANT384:.+]] = arith.constant 384 : index
// CHECK: %[[CONSTANT0:.+]] = arith.constant 0 : index
// CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
// CHECK: %[[DIM:.+]] = tensor.dim %[[ARG0]], %[[CONSTANT0]] : tensor<?x64x1xf32>
// CHECK: %[[AFFAPPLY:.+]] = affine.apply #[[$MAP]]()[%[[DIM]]]
// CHECK: %[[DIVUI:.+]] = arith.divui %[[AFFAPPLY]], %[[CONSTANT384]] : index
// CHECK: %[[DIM:.+]] = tensor.dim %[[COLLAPSE]], %[[CONSTANT0]] : tensor<?xf32>
// CHECK: %[[DIVUI:.+]] = arith.divui %[[DIM]], %[[CONSTANT384]] : index
Comment on lines -1108 to +1114
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph For an example of non-local effects, the previous canonicalization pattern for collapse_shape, was not moving collapse_shape to a more canonical form. It's actually not touching the collapse_shape at all, and is doing things around it.

It's very hard to say what the canonical form of where tensor.dim should be here is. There is no reason for the tensor.dim, when propagated up to be more canonical. It can possibly introduce 2 tensor.dim operations, if there were 2 dynamic dimensions before the collapse shape. If there is a cost to a tensor.dim operation, you now pay the cost twice. We cannot really say if one form is more canonical than the other.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this current example, simplified we have:

func.func @dim_expand_shape(%arg0: tensor<?x64x1xf32>) -> index {
  %c0 = arith.constant 0 : index
  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<?x64x1xf32> into tensor<?xf32>
  %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
  return %dim : index
}

If I want to know something about the dim, it's just the first dimension of the collapsed_shape and I have now to reason about it.

After canonicalization we have:

#map = affine_map<()[s0] -> (s0 * 64)>
module {
  func.func @dim_expand_shape(%arg0: tensor<?x64x1xf32>) -> index {
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x64x1xf32>
    %0 = affine.apply #map()[%dim]
    return %0 : index
  }

The dimension computation is made pretty explicit in terms of the input (the calculation is explicit) and the tensor.collapse_shape even disappears.

There is no reason for the tensor.dim, when propagated up to be more canonical.

It's not always obvious what the "best" canonical form should be, that does not mean we can't pick one though: often it resolves just about picking one convention and sticking to it to be consistent (I don't know if this is the right convention here, just saying).

It can possibly introduce 2 tensor.dim operations, if there were 2 dynamic dimensions before the collapse shape.

Here is the example you describe I believe:

func.func @dim_expand_shape(%arg0: tensor<?x?x64x1xf32>) -> index {
  %c0 = arith.constant 0 : index
  %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2, 3]] : tensor<?x?x64x1xf32> into tensor<?xf32>
  %dim = tensor.dim %collapsed, %c0 : tensor<?xf32>
  return %dim : index
}

Becomes:

#map = affine_map<()[s0, s1] -> ((s0 * s1) * 64)>
module {
  func.func @dim_expand_shape(%arg0: tensor<?x?x64x1xf32>) -> index {
    %c1 = arith.constant 1 : index
    %c0 = arith.constant 0 : index
    %dim = tensor.dim %arg0, %c0 : tensor<?x?x64x1xf32>
    %dim_0 = tensor.dim %arg0, %c1 : tensor<?x?x64x1xf32>
    %0 = affine.apply #map()[%dim, %dim_0]
    return %0 : index
  }
}

As mentioned elsewhere, the cost of "dim" does not seem meaningful to me here. The question is rather about whether this IR is more amenable to reasoning on the values, or equivalence between values in further analysis. Any argument long this line of reasoning can be very valuable to make a statement about whether it is clearly belonging / not belonging to canonicalization.

// CHECK: %[[RESULT:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1]] output_shape [%[[DIVUI]], 384] : tensor<?xf32> into tensor<?x384xf32>
// CHECK: return %[[RESULT]]

Expand Down Expand Up @@ -2137,13 +2135,12 @@ func.func @empty_tensor_canonicalize(%i : index) {

// -----

// CHECK: #[[$map:.*]] = affine_map<()[s0] -> (s0 floordiv 40)>
// CHECK-LABEL: func @dim_of_expand_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?xf32>
// CHECK: %[[c1:.*]] = arith.constant 1 : index
// CHECK: %[[dim:.*]] = tensor.dim %[[t]], %[[c1]] : tensor<?x?xf32>
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim]]]
// CHECK: return %[[apply]]
// CHECK: %[[c2:.*]] = arith.constant 2 : index
// CHECK: %[[expanded:.*]] = tensor.expand_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4, 5]] output_shape [%arg1, 1, %arg2, 5, 1, 8] : tensor<?x?xf32> into tensor<?x1x?x5x1x8xf32>
// CHECK: %[[dim:.*]] = tensor.dim %[[expanded]], %[[c2]] : tensor<?x1x?x5x1x8xf32>
// CHECK: return %[[dim]]
Comment on lines -2140 to +2143
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joker-eph For an example of the old pattern being wrong, notice that we already know the output shape. The result should just be %arg2 here! But it's instead it was doing some magic with affine.apply, which is wrong.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could implement a folder for this for tensor.dim, but that doesn't change the fact that this pattern is wrong, and doesn't even belong as a canonicalization to expand_shape.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could implement a folder for this for tensor.dim, but that doesn't change the fact that this pattern is wrong, and doesn't even belong as a canonicalization to expand_shape.

func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) -> index {
%c2 = arith.constant 2 : index
%0 = tensor.expand_shape %t [[0], [1, 2, 3, 4, 5]] output_shape [%sz0, 1, %sz1, 5, 1, 8]
Expand All @@ -2154,17 +2151,12 @@ func.func @dim_of_expand_shape(%t: tensor<?x?xf32>, %sz0: index, %sz1: index) ->

// -----

// CHECK: #[[$map:.*]] = affine_map<()[s0, s1, s2] -> (((s0 * s1) * s2) * 7)>
// CHECK-LABEL: func @dim_of_collapse_shape(
// CHECK-SAME: %[[t:.*]]: tensor<?x?x?x7x?xf32>
// CHECK-DAG: %[[c1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[c2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[c4:.*]] = arith.constant 4 : index
// CHECK-DAG: %[[dim1:.*]] = tensor.dim %[[t]], %[[c1]]
// CHECK-DAG: %[[dim2:.*]] = tensor.dim %[[t]], %[[c2]]
// CHECK-DAG: %[[dim4:.*]] = tensor.dim %[[t]], %[[c4]]
// CHECK: %[[apply:.*]] = affine.apply #[[$map]]()[%[[dim1]], %[[dim2]], %[[dim4]]]
// CHECK: return %[[apply]]
// CHECK-DAG: %[[collapsed:.*]] = tensor.collapse_shape %[[t]] {{\[\[}}0], [1, 2, 3, 4]] : tensor<?x?x?x7x?xf32> into tensor<?x?xf32>
// CHECK-DAG: %[[dim:.*]] = tensor.dim %[[collapsed]], %[[c1]]
// CHECK: return %[[dim]]
func.func @dim_of_collapse_shape(%t: tensor<?x?x?x7x?xf32>) -> index {
%c1 = arith.constant 1 : index
%0 = tensor.collapse_shape %t [[0], [1, 2, 3, 4]]
Expand Down