Skip to content

[mlir][MemRef] Extend memref.subview sub-byte type emulation support. #94045

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 67 additions & 29 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,14 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//

/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
/// type. The result MemRefType of the old op must have a rank and stride of 1,
/// with static offset and size. The number of bits in the offset must evenly
/// divide the bitwidth of the new converted type.
template <typename MemRefOpTy>
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
typename MemRefOpTy::Adaptor adaptor,
MemRefOpTy op, MemRefType newTy) {
static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
"Expected only memref::SubViewOp or memref::ReinterpretCastOp");

/// Converts a memref::ReinterpretCastOp to the converted type. The result
/// MemRefType of the old op must have a rank and stride of 1, with static
/// offset and size. The number of bits in the offset must evenly divide the
/// bitwidth of the new converted type.
static LogicalResult
convertCastingOp(ConversionPatternRewriter &rewriter,
memref::ReinterpretCastOp::Adaptor adaptor,
memref::ReinterpretCastOp op, MemRefType newTy) {
auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
Expand All @@ -67,24 +63,22 @@ static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
[](int64_t size) { return size == ShapedType::kDynamic; }) ||
offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "dynamic size or offset is not supported");
op, "dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(), "offset not multiple of elementsPerByte is not "
"supported");
op, "offset not multiple of elementsPerByte is not supported");
}

SmallVector<int64_t> size;
if (sizes.size())
size.push_back(ceilDiv(sizes[0], elementsPerByte));
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
*adaptor.getODSOperands(0).begin(),
offset, size, op.getStaticStrides());
rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, newTy, adaptor.getSource(), offset, size, op.getStaticStrides());
return success();
}

Expand Down Expand Up @@ -402,29 +396,73 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {

/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
/// folded away before running narrow type emulation, and this pattern would
/// never run. This pattern is mostly used for testing pruposes.
/// folded away before running narrow type emulation, and this pattern should
/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
MemRefType newTy = dyn_cast<MemRefType>(
getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
subViewOp->getLoc(),
llvm::formatv("failed to convert memref type: {0}",
subViewOp.getType()));
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
Location loc = subViewOp.getLoc();
Type convertedElementType = newTy.getElementType();
Type oldElementType = subViewOp.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0)
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
subViewOp, "only dstBits % srcBits == 0 supported");

// Only support stride of 1.
if (llvm::any_of(subViewOp.getStaticStrides(),
[](int64_t stride) { return stride != 1; })) {
return rewriter.notifyMatchFailure(subViewOp->getLoc(),
"stride != 1 is not supported");
}

return convertCastingOp(rewriter, adaptor, op, newTy);
if (!memref::isStaticShapeAndContiguousRowMajor(subViewOp.getType())) {
return rewriter.notifyMatchFailure(
subViewOp, "the result memref type is not contiguous");
}

auto sizes = subViewOp.getStaticSizes();
int64_t lastOffset = subViewOp.getStaticOffsets().back();
// Only support static sizes and offsets.
if (llvm::any_of(
sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
lastOffset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
subViewOp->getLoc(), "dynamic size or offset is not supported");
}

// Transform the offsets, sizes and strides according to the emulation.
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, subViewOp.getViewSource());

OpFoldResult linearizedIndices;
auto strides = stridedMetadata.getConstifiedMixedStrides();
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
subViewOp.getMixedSizes(), strides,
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
rewriter));

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
subViewOp, newTy, adaptor.getSource(), linearizedIndices,
linearizedInfo.linearizedSize, strides.back());
return success();
}
};

Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
AffineExpr mulMap = builder.getAffineConstantExpr(1);

SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
SmallVector<OpFoldResult> sizeValues(sourceRank);

for (unsigned i = 0; i < sourceRank; ++i) {
unsigned offsetIdx = 2 * i;
Expand All @@ -79,8 +78,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
mulMap = mulMap * symbols[i];
}

// Adjust linearizedIndices, size and offset by the scale factor (dstBits /
// srcBits).
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
addMulMap = addMulMap.floorDiv(scaler);
mulMap = mulMap.floorDiv(scaler);
Expand Down
39 changes: 37 additions & 2 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --split-input-file %s | FileCheck %s --check-prefix=CHECK32
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=8" --cse --verify-diagnostics --split-input-file %s | FileCheck %s
// RUN: mlir-opt --test-emulate-narrow-int="memref-load-bitwidth=32" --cse --verify-diagnostics --split-input-file %s | FileCheck %s --check-prefix=CHECK32

// Expect no conversions.
func.func @memref_i8() -> i8 {
Expand Down Expand Up @@ -177,6 +177,41 @@ func.func @memref_strided_i4(%idx : index) -> i4 {

// -----

func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<512x64x8x16xi4>
%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
%ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
return %ld : i4
}

// CHECK-LABEL: func.func @memref_subview_dynamic_offset_i4(
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
// CHECK: %[[IDX:.*]] = affine.apply
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
// CHECK: memref.load %[[SUBVIEW]]

// CHECK32-LABEL: func.func @memref_subview_dynamic_offset_i4(
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
// CHECK32: %[[IDX:.*]] = affine.apply
// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
// CHECK32: memref.load %[[SUBVIEW]]

// -----


func.func @negative_memref_subview_non_contiguous(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<40x40xi4>
// expected-error @+1 {{failed to legalize operation 'memref.subview' that was explicitly marked illegal}}
%subview = memref.subview %arr[%idx, 0] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>
%ld = memref.load %subview[%c0, %c0] : memref<4x8xi4, strided<[40, 1], offset:?>>
return %ld : i4
}

// -----

func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
Expand Down
Loading