Skip to content

[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. #92934

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jun 20, 2024
55 changes: 29 additions & 26 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1607,7 +1607,27 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};

/// For vectors with either leading or trailing unit dim, replaces:
// Scalable unit dimensions are not supported. Folding such dimensions would
// require "shifting" the scalable flag onto some other fixed-width dim (e.g.
// vector<[1]x4xf32> -> vector<[4]xf32>). This could be implemented in the
// future.
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
auto inVecShape = inVecTy.getShape();
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dim, isScalable] :
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
if (dim == 1 && !isScalable)
continue;

newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}

return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
}

/// For vectors with at least an unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
Expand All @@ -1619,20 +1639,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// required to be rank > 1.
///
/// Ex:
/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// gets converted to:
///
/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
Expand All @@ -1652,42 +1668,29 @@ struct DropUnitDimFromElementwiseOps final
// guaranteed to have identical shapes (with some exceptions such as
// `arith.select`) and it suffices to only check one of them.
auto sourceVectorType = dyn_cast<VectorType>(op->getOperand(0).getType());
if (!sourceVectorType)
return failure();
if (sourceVectorType.getRank() < 2)
return failure();

bool hasTrailingDimUnitFixed =
((sourceVectorType.getShape().back() == 1) &&
(!sourceVectorType.getScalableDims().back()));
bool hasLeadingDimUnitFixed =
((sourceVectorType.getShape().front() == 1) &&
(!sourceVectorType.getScalableDims().front()));
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
if (!sourceVectorType || sourceVectorType.getRank() < 2)
return failure();

// Drop leading/trailing unit dim by applying vector.shape_cast to all
// operands
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;

SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
if (newVType == opVectorType)
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");

auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}

VectorType newResultVectorType =
VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
dropNonScalableUnitDimFromType(resultVectorType);
// Create an updated elementwise Op without unit dim.
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());

// Restore the leading/trailing unit dim by applying vector.shape_cast
// to the result
// Restore the unit dim by applying vector.shape_cast to the result.
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));

Expand Down
36 changes: 36 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,6 +499,42 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,

// -----

func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
return %res : vector<8x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x3xf128>

// -----

func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
return %res : vector<8x[1]x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>

// -----

func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
Expand Down
Loading