Skip to content

Commit fbbb43f

Browse files
committed
Fixup : comment hook
1 parent 7b09906 commit fbbb43f

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1607,7 +1607,10 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
16071607
}
16081608
};
16091609

1610-
VectorType dropNonScalableUnitDimType(VectorType inVecTy) {
1610+
/// Drop unit dimension of the input VectorType. Scalable dimensions cannot be
1611+
/// folded as we do not want to merge them through a shape_cast and create ill
1612+
/// shaped scalable sizes.
1613+
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
16111614
auto newVecBuilder = VectorType::Builder(inVecTy);
16121615
auto inVecShape = inVecTy.getShape();
16131616
SmallVector<int64_t> newShape;
@@ -1676,7 +1679,7 @@ struct DropUnitDimFromElementwiseOps final
16761679
auto loc = op->getLoc();
16771680
for (auto operand : op->getOperands()) {
16781681
auto opVectorType = cast<VectorType>(operand.getType());
1679-
auto newVType = dropNonScalableUnitDimType(opVectorType);
1682+
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
16801683
if (newVType == opVectorType)
16811684
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");
16821685

@@ -1685,7 +1688,7 @@ struct DropUnitDimFromElementwiseOps final
16851688
}
16861689

16871690
VectorType newResultVectorType =
1688-
dropNonScalableUnitDimType(resultVectorType);
1691+
dropNonScalableUnitDimFromType(resultVectorType);
16891692
// Create an updated elementwise Op without unit dim
16901693
Operation *elementwiseOp =
16911694
rewriter.create(loc, op->getName().getIdentifier(), newOperands,

mlir/test/Dialect/Vector/vector-transfer-flatten.mlir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -499,15 +499,15 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,
499499

500500
// -----
501501

502-
func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
502+
func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
503503
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
504504
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
505505
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
506506
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
507507
return %res : vector<8x3xf128>
508508
}
509509

510-
// CHECK-LABEL: func.func @fold_unit_inner_dim(
510+
// CHECK-LABEL: func.func @fold_inner_unit_dim(
511511
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
512512
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
513513
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
@@ -517,15 +517,15 @@ func.func @fold_unit_inner_dim(%arg0 : vector<8x1x3xf128>,
517517

518518
// -----
519519

520-
func.func @fold_unit_inner_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
520+
func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
521521
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
522522
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
523523
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
524524
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
525525
return %res : vector<8x[1]x3xf128>
526526
}
527527

528-
// CHECK-LABEL: func.func @fold_unit_inner_dim_scalable(
528+
// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
529529
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
530530
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
531531
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>

0 commit comments

Comments
 (0)