Skip to content

[MLIR][Vector] Generalize DropUnitDimFromElementwiseOps to non leading / trailing dimensions. #98455

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1622,7 +1622,33 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
}
};

/// For vectors with either leading or trailing unit dim, replaces:
// Helper function dropping unit non-scalable dimension from a VectorType
// keeping at least 1 dimension to avoid generating 0-D vectors. Scalable unit
// dimensions are not dropped. Folding such dimensions would require "shifting"
// the scalable flag onto some other fixed-width dim (e.g. vector<[1]x4xf32> ->
// vector<[4]xf32>). This could be implemented in the future.
static VectorType dropNonScalableUnitDimFromType(VectorType inVecTy) {
auto inVecShape = inVecTy.getShape();
SmallVector<int64_t> newShape;
SmallVector<bool> newScalableDims;
for (auto [dim, isScalable] :
llvm::zip_equal(inVecShape, inVecTy.getScalableDims())) {
if (dim == 1 && !isScalable)
continue;

newShape.push_back(dim);
newScalableDims.push_back(isScalable);
}
// All dims have been dropped, return vector<1xeType>.
if (newShape.empty()) {
newShape.push_back(1);
newScalableDims.push_back(false);
}

return VectorType::get(newShape, inVecTy.getElementType(), newScalableDims);
}

/// For vectors with at least one unit dim, replaces:
/// elementwise(a, b)
/// with:
/// sc_a = shape_cast(a)
Expand All @@ -1634,20 +1660,16 @@ struct ChainedReduction final : OpRewritePattern<vector::ReductionOp> {
/// required to be rank > 1.
///
/// Ex:
/// ```
/// %mul = arith.mulf %B_row, %A_row : vector<1x[4]xf32>
/// %cast = vector.shape_cast %mul : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// gets converted to:
///
/// ```
/// %B_row_sc = vector.shape_cast %B_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %A_row_sc = vector.shape_cast %A_row : vector<1x[4]xf32> to vector<[4]xf32>
/// %mul = arith.mulf %B_row_sc, %A_row_sc : vector<[4]xf32>
/// %cast_new = vector.shape_cast %mul : vector<[4]xf32> to vector<1x[4]xf32>
/// %cast = vector.shape_cast %cast_new : vector<1x[4]xf32> to vector<[4]xf32>
/// ```
///
/// Patterns for folding shape_casts should instantly eliminate `%cast_new` and
/// `%cast`.
Expand All @@ -1672,37 +1694,26 @@ struct DropUnitDimFromElementwiseOps final
if (sourceVectorType.getRank() < 2)
return failure();

bool hasTrailingDimUnitFixed =
((sourceVectorType.getShape().back() == 1) &&
(!sourceVectorType.getScalableDims().back()));
bool hasLeadingDimUnitFixed =
((sourceVectorType.getShape().front() == 1) &&
(!sourceVectorType.getScalableDims().front()));
if (!hasLeadingDimUnitFixed && !hasTrailingDimUnitFixed)
return failure();

// Drop leading/trailing unit dim by applying vector.shape_cast to all
// operands
int64_t dim = hasLeadingDimUnitFixed ? 0 : sourceVectorType.getRank() - 1;

SmallVector<Value> newOperands;
auto loc = op->getLoc();
for (auto operand : op->getOperands()) {
auto opVectorType = cast<VectorType>(operand.getType());
VectorType newVType = VectorType::Builder(opVectorType).dropDim(dim);
auto newVType = dropNonScalableUnitDimFromType(opVectorType);
if (newVType == opVectorType)
return rewriter.notifyMatchFailure(op, "No unit dimension to remove.");

auto opSC = rewriter.create<vector::ShapeCastOp>(loc, newVType, operand);
newOperands.push_back(opSC);
}

VectorType newResultVectorType =
VectorType::Builder(resultVectorType).dropDim(dim);
// Create an updated elementwise Op without leading/trailing unit dim
dropNonScalableUnitDimFromType(resultVectorType);
// Create an updated elementwise Op without unit dim.
Operation *elementwiseOp =
rewriter.create(loc, op->getName().getIdentifier(), newOperands,
newResultVectorType, op->getAttrs());

// Restore the leading/trailing unit dim by applying vector.shape_cast
// to the result
// Restore the unit dim by applying vector.shape_cast to the result.
rewriter.replaceOpWithNewOp<ShapeCastOp>(op, resultVectorType,
elementwiseOp->getResult(0));

Expand Down
51 changes: 51 additions & 0 deletions mlir/test/Dialect/Vector/vector-transfer-flatten.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,57 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>,

// -----

func.func @fold_inner_unit_dim(%arg0 : vector<8x1x3xf128>,
%arg1 : vector<1x8x3xf128>) -> vector<8x3xf128> {
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x3xf128> to vector<8x1x3xf128>
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x3xf128>
%res = vector.shape_cast %mul : vector<8x1x3xf128> to vector<8x3xf128>
return %res : vector<8x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x3xf128>) -> vector<8x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x3xf128> to vector<8x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x3xf128>

// -----

func.func @fold_inner_unit_dim_scalable(%arg0 : vector<8x1x[1]x3xf128>,
%arg1 : vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
%sc_arg1 = vector.shape_cast %arg1 : vector<1x8x[1]x3xf128> to vector<8x1x[1]x3xf128>
%mul = arith.mulf %arg0, %sc_arg1 : vector<8x1x[1]x3xf128>
%res = vector.shape_cast %mul : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
return %res : vector<8x[1]x3xf128>
}

// CHECK-LABEL: func.func @fold_inner_unit_dim_scalable(
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x1x[1]x3xf128>,
// CHECK-SAME: %[[VAL_1:.*]]: vector<1x8x[1]x3xf128>) -> vector<8x[1]x3xf128> {
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<8x1x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_3:.*]] = vector.shape_cast %[[VAL_1]] : vector<1x8x[1]x3xf128> to vector<8x[1]x3xf128>
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[1]x3xf128>
// CHECK: return %[[VAL_4]] : vector<8x[1]x3xf128>

// -----

func.func @fold_all_unit_dims(%arg0: vector<1x1xf32>) -> vector<1xf32> {
%0 = arith.mulf %arg0, %arg0 : vector<1x1xf32>
%res = vector.shape_cast %0 : vector<1x1xf32> to vector<1xf32>
return %res : vector<1xf32>
}

// CHECK-LABEL: func.func @fold_all_unit_dims(
// CHECK-SAME: %[[VAL_0:.*]]: vector<1x1xf32>) -> vector<1xf32>
// CHECK: %[[VAL_1:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[VAL_2:.*]] = vector.shape_cast %[[VAL_0]] : vector<1x1xf32> to vector<1xf32>
// CHECK: %[[VAL_3:.*]] = arith.mulf %[[VAL_1]], %[[VAL_2]] : vector<1xf32>
// CHECK: return %[[VAL_3]] : vector<1xf32>

// -----

func.func @negative_out_of_bound_transfer_read(
%arg : memref<?x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> {
%c0 = arith.constant 0 : index
Expand Down
Loading