Skip to content

[mlir][vector] Drop innermost unit dims on transfer_write. #78554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 143 additions & 54 deletions mlir/lib/Dialect/Vector/Transforms/VectorTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,8 +1152,71 @@ struct FoldI1Select : public OpRewritePattern<arith::SelectOp> {
}
};

// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
/// Returns the number of dims can be folded away from transfer ops. It returns
/// a failure if strides and offsets can not be resolved.
static FailureOr<size_t>
getTransferFoldableInnerUnitDims(MemRefType srcType, VectorType vectorType) {
SmallVector<int64_t> srcStrides;
int64_t srcOffset;
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();

// According to vector.transfer_read/write semantics, the vector can be a
// slice. It pads the indices with `1` starting from beginning. Thus, we have
// to offset the check index with `rankDiff` in `srcStrides` and source dim
// sizes.
size_t result = 0;
int rankDiff = srcType.getRank() - vectorType.getRank();
for (int64_t i = 0, e = vectorType.getRank(); i < e; ++i) {
// Check that the inner dim size is 1 for both memref/tensor type and
// vector slice. It can be folded only if they are 1 and the stride is 1.
int dim = vectorType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] == 1 &&
srcType.getDimSize(dim + rankDiff) == 1 &&
vectorType.getDimSize(dim) == 1) {
result++;
} else {
break;
}
}
return result;
}

/// Returns a MemRef type that drops inner `dimsToDrop` dimensions from
/// `srcType`.
static MemRefType getMemRefTypeWithDroppingInnerDims(OpBuilder &builder,
MemRefType srcType,
size_t dimsToDrop) {
MemRefType resultMemrefType;
MemRefLayoutAttrInterface layout = srcType.getLayout();
if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
srcType.getElementType(), nullptr,
srcType.getMemorySpace());
}
MemRefLayoutAttrInterface updatedLayout;
if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
auto strides = llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
strided.getOffset(), strides);
} else {
AffineMap map = srcType.getLayout().getAffineMap();
int numSymbols = map.getNumSymbols();
for (size_t i = 0; i < dimsToDrop; ++i) {
int dim = srcType.getRank() - i - 1;
map = map.replace(builder.getAffineDimExpr(dim),
builder.getAffineConstantExpr(0), map.getNumDims() - 1,
numSymbols);
}
}
return MemRefType::get(srcType.getShape().drop_back(dimsToDrop),
srcType.getElementType(), updatedLayout,
srcType.getMemorySpace());
}

/// Drop inner most contiguous unit dimensions from transfer_read operand.
class DropInnerMostUnitDimsTransferRead
: public OpRewritePattern<vector::TransferReadOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
Expand All @@ -1177,65 +1240,22 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
if (targetType.getRank() <= 1)
return failure();

SmallVector<int64_t> srcStrides;
int64_t srcOffset;
if (failed(getStridesAndOffset(srcType, srcStrides, srcOffset)))
return failure();

// According to vector.transfer_read semantics, the result can be a slice.
// It pads the indices with `1` starting from beginning. Thus, we have to
// offset the check index with `rankDiff` in `srcStrides` and source dim
// sizes.
size_t dimsToDrop = 0;
int rankDiff = srcType.getRank() - targetType.getRank();
for (int64_t i = 0, e = targetType.getRank(); i < e; ++i) {
// Check that the inner dim size is 1 for both memref/tensor type and
// vector slice. It can be folded only if they are 1 and the stride is 1.
int dim = targetType.getRank() - i - 1;
if (srcStrides[dim + rankDiff] == 1 &&
srcType.getDimSize(dim + rankDiff) == 1 &&
targetType.getDimSize(dim) == 1) {
dimsToDrop++;
} else {
break;
}
}
FailureOr<size_t> maybeDimsToDrop =
getTransferFoldableInnerUnitDims(srcType, targetType);
if (failed(maybeDimsToDrop))
return failure();

size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();

auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());

MemRefType resultMemrefType;
MemRefLayoutAttrInterface layout = srcType.getLayout();
if (isa<AffineMapAttr>(layout) && layout.isIdentity()) {
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
nullptr, srcType.getMemorySpace());
} else {
MemRefLayoutAttrInterface updatedLayout;
if (auto strided = dyn_cast<StridedLayoutAttr>(layout)) {
auto strides =
llvm::to_vector(strided.getStrides().drop_back(dimsToDrop));
updatedLayout = StridedLayoutAttr::get(strided.getContext(),
strided.getOffset(), strides);
} else {
AffineMap map = srcType.getLayout().getAffineMap();
int numSymbols = map.getNumSymbols();
for (size_t i = 0; i < dimsToDrop; ++i) {
int dim = srcType.getRank() - i - 1;
map = map.replace(rewriter.getAffineDimExpr(dim),
rewriter.getAffineConstantExpr(0),
map.getNumDims() - 1, numSymbols);
}
}
resultMemrefType = MemRefType::get(
srcType.getShape().drop_back(dimsToDrop), srcType.getElementType(),
updatedLayout, srcType.getMemorySpace());
}

auto loc = readOp.getLoc();
MemRefType resultMemrefType =
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);

Expand All @@ -1261,6 +1281,73 @@ class DropInnerMostUnitDims : public OpRewritePattern<vector::TransferReadOp> {
}
};

/// Drop inner most contiguous unit dimensions from transfer_write operand.
class DropInnerMostUnitDimsTransferWrite
: public OpRewritePattern<vector::TransferWriteOp> {
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
PatternRewriter &rewriter) const override {
// TODO: support 0-d corner case.
if (writeOp.getTransferRank() == 0)
return failure();

// TODO: support mask.
if (writeOp.getMask())
return failure();

auto srcType = dyn_cast<MemRefType>(writeOp.getSource().getType());
if (!srcType || !srcType.hasStaticShape())
return failure();

if (!writeOp.getPermutationMap().isMinorIdentity())
return failure();

auto targetType = writeOp.getVectorType();
if (targetType.getRank() <= 1)
return failure();

FailureOr<size_t> maybeDimsToDrop =
getTransferFoldableInnerUnitDims(srcType, targetType);
if (failed(maybeDimsToDrop))
return failure();

size_t dimsToDrop = maybeDimsToDrop.value();
if (dimsToDrop == 0)
return failure();

auto resultTargetVecType =
VectorType::get(targetType.getShape().drop_back(dimsToDrop),
targetType.getElementType());

auto loc = writeOp.getLoc();
MemRefType resultMemrefType =
getMemRefTypeWithDroppingInnerDims(rewriter, srcType, dimsToDrop);
SmallVector<int64_t> offsets(srcType.getRank(), 0);
SmallVector<int64_t> strides(srcType.getRank(), 1);

ArrayAttr inBoundsAttr =
writeOp.getInBounds()
? rewriter.getArrayAttr(
writeOp.getInBoundsAttr().getValue().drop_back(dimsToDrop))
: ArrayAttr();
Value rankedReducedView = rewriter.create<memref::SubViewOp>(
loc, resultMemrefType, writeOp.getSource(), offsets, srcType.getShape(),
strides);
auto permMap = getTransferMinorIdentityMap(
cast<ShapedType>(rankedReducedView.getType()), resultTargetVecType);

auto shapeCast = rewriter.createOrFold<vector::ShapeCastOp>(
loc, resultTargetVecType, writeOp.getVector());
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
writeOp, shapeCast, rankedReducedView,
writeOp.getIndices().drop_back(dimsToDrop), AffineMapAttr::get(permMap),
// TODO: support mask.
/*mask=*/Value(), inBoundsAttr);
return success();
}
};

/// Canonicalization of a `vector.contraction %a, %b, %c` with row-major matmul
/// semantics to a contraction suitable for MMT (matrix matrix multiplication
/// with the RHS transposed) lowering.
Expand Down Expand Up @@ -1696,7 +1783,9 @@ void mlir::vector::populateVectorReductionToContractPatterns(
void mlir::vector::
populateVectorTransferCollapseInnerMostContiguousDimsPatterns(
RewritePatternSet &patterns, PatternBenefit benefit) {
patterns.add<DropInnerMostUnitDims>(patterns.getContext(), benefit);
patterns.add<DropInnerMostUnitDimsTransferRead,
DropInnerMostUnitDimsTransferWrite>(patterns.getContext(),
benefit);
}

void mlir::vector::populateSinkVectorBroadcastPatterns(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,24 @@ func.func @contiguous_inner_most_dim_out_of_bounds_2d(%arg0: memref<1x1xf32>) ->
// CHECK-NOT: memref.subview
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SRC]]
// CHECK: return %[[READ]] : vector<4x8xf32>

// -----

func.func @drop_inner_most_dim_for_transfer_write(%arg0: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>, %arg1: vector<1x16x16x1xf32>, %arg2: index) {
%c0 = arith.constant 0 : index
vector.transfer_write %arg1, %arg0[%c0, %arg2, %c0, %c0]
{in_bounds = [true, true, true, true]}
: vector<1x16x16x1xf32>, memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>>
return
}
// CHECK: func.func @drop_inner_most_dim_for_transfer_write
// CHECK-SAME: %[[DEST:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[VEC:[a-zA-Z0-9]+]]
// CHECK-SAME: %[[IDX:[a-zA-Z0-9]+]]
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK: %[[SUBVIEW:.+]] = memref.subview %[[DEST]]
// CHECK-SAME: [0, 0, 0, 0] [1, 512, 16, 1] [1, 1, 1, 1]
// CHECK-SAME: memref<1x512x16x1xf32, strided<[8192, 16, 1, 1], offset: ?>> to memref<1x512x16xf32, strided<[8192, 16, 1], offset: ?>>
// CHECK: %[[CAST:.+]] = vector.shape_cast %[[VEC]] : vector<1x16x16x1xf32> to vector<1x16x16xf32>
// CHECK: vector.transfer_write %[[CAST]], %[[SUBVIEW]]
// CHECK-SAME: [%[[C0]], %[[IDX]], %[[C0]]]