Skip to content

[mlir][MemRef] Extend memref.subview sub-byte type emulation support #89488

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 92 additions & 70 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include "mlir/Dialect/Arith/Transforms/Passes.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
Expand All @@ -24,7 +23,6 @@
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>

Expand All @@ -34,62 +32,6 @@ using namespace mlir;
// Utility functions
//===----------------------------------------------------------------------===//

/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
/// type. The result MemRefType of the old op must have a rank and stride of 1,
/// with static offset and size. The number of bits in the offset must evenly
/// divide the bitwidth of the new converted type.
template <typename MemRefOpTy>
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
typename MemRefOpTy::Adaptor adaptor,
MemRefOpTy op, MemRefType newTy) {
static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
"Expected only memref::SubViewOp or memref::ReinterpretCastOp");

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(op,
"only dstBits % srcBits == 0 supported");
}

// Only support stride of 1.
if (llvm::any_of(op.getStaticStrides(),
[](int64_t stride) { return stride != 1; })) {
return rewriter.notifyMatchFailure(op->getLoc(),
"stride != 1 is not supported");
}

auto sizes = op.getStaticSizes();
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (llvm::any_of(sizes,
[](int64_t size) { return size == ShapedType::kDynamic; }) ||
offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(), "offset not multiple of elementsPerByte is not "
"supported");
}

SmallVector<int64_t> size;
if (sizes.size())
size.push_back(ceilDiv(sizes[0], elementsPerByte));
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
*adaptor.getODSOperands(0).begin(),
offset, size, op.getStaticStrides());
return success();
}

/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
Expand Down Expand Up @@ -337,7 +279,48 @@ struct ConvertMemRefReinterpretCast final
op->getLoc(), "subview with rank > 1 is not supported");
}

return convertCastingOp(rewriter, adaptor, op, newTy);
Type convertedElementType = newTy.getElementType();
Type oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

// Only support stride of 1.
if (llvm::any_of(op.getStaticStrides(),
[](int64_t stride) { return stride != 1; })) {
return rewriter.notifyMatchFailure(op->getLoc(),
"stride != 1 is not supported");
}

auto sizes = op.getStaticSizes();
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (llvm::any_of(
sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(), "offset not multiple of elementsPerByte is not "
"supported");
}

SmallVector<int64_t> size;
if (sizes.size())
size.push_back(ceilDiv(sizes[0], elementsPerByte));
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::ReinterpretCastOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
op.getStaticStrides());
return success();
}
};

Expand Down Expand Up @@ -404,29 +387,68 @@ struct ConvertMemrefStore final : OpConversionPattern<memref::StoreOp> {

/// Emulating narrow ints on subview have limited support, supporting only
/// static offset and size and stride of 1. Ideally, the subview should be
/// folded away before running narrow type emulation, and this pattern would
/// never run. This pattern is mostly used for testing pruposes.
/// folded away before running narrow type emulation, and this pattern should
/// only run for cases that can't be folded.
struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::SubViewOp op, OpAdaptor adaptor,
matchAndRewrite(memref::SubViewOp subViewOp, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
MemRefType newTy = dyn_cast<MemRefType>(
getTypeConverter()->convertType(subViewOp.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
subViewOp->getLoc(),
llvm::formatv("failed to convert memref type: {0}",
subViewOp.getType()));
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
Location loc = subViewOp.getLoc();
Type convertedElementType = newTy.getElementType();
Type oldElementType = subViewOp.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0)
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
subViewOp, "only dstBits % srcBits == 0 supported");

// Only support stride of 1.
if (llvm::any_of(subViewOp.getStaticStrides(),
[](int64_t stride) { return stride != 1; })) {
return rewriter.notifyMatchFailure(subViewOp->getLoc(),
"stride != 1 is not supported");
}

auto sizes = subViewOp.getStaticSizes();
int64_t lastOffset = subViewOp.getStaticOffsets().back();
// Only support static sizes and offsets.
if (llvm::any_of(
sizes, [](int64_t size) { return size == ShapedType::kDynamic; }) ||
lastOffset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
subViewOp->getLoc(), "dynamic size or offset is not supported");
}

return convertCastingOp(rewriter, adaptor, op, newTy);
// Transform the offsets, sizes and strides according to the emulation.
auto stridedMetadata = rewriter.create<memref::ExtractStridedMetadataOp>(
loc, subViewOp.getViewSource());

OpFoldResult linearizedIndices;
auto strides = stridedMetadata.getConstifiedMixedStrides();
memref::LinearizedMemRefInfo linearizedInfo;
std::tie(linearizedInfo, linearizedIndices) =
memref::getLinearizedMemRefOffsetAndSize(
rewriter, loc, srcBits, dstBits,
stridedMetadata.getConstifiedMixedOffset(),
subViewOp.getMixedSizes(), strides,
getMixedValues(adaptor.getStaticOffsets(), adaptor.getOffsets(),
rewriter));

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is wrong. You can only linearize a subiew that is "contiguous". So you have to check that the subview is contiguous in memory.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Lines 417-421 check that the subview has only unit strides. Is there anything else needed? I'm using strides.back() (L450) because I know that all the strides are one and therefore the stride of the new subview would be one.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think just checking the strides of the subview are enough... you have to check the strides of the memref type of the result. Those strides need to be contiguous?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let me add a test to better understand what happens...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, I tried with a few tests like these:

func.func @memref_subview_dynamic_offset_i4_1(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>                                                                                                                 
                                                                            to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                          
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>                                                                                                      
  return %ld : i4                                                                                                                                                                                                    
}                                                                                                                                                                                                                    
                                                                                                                                                                                                                     
func.func @memref_subview_dynamic_offset_i4_2(%idx : index) -> i4 {                                                                                                                                                  
  %c0 = arith.constant 0 : index                                                                                                                                                                                     
  %arr = memref.alloc() : memref<512x64x8x16xi4>                                                                                                                                                                     
  %subview = memref.subview %arr[%idx, 0, 0, 0] [16, 4, 8, 16] [1, 16, 1, 1] : memref<512x64x8x16xi4>                                                                                                                
                                                                            to memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                         
  %ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x4x8x16xi4, strided<[8192, 2048, 16, 1], offset: ?>>                                                                                                     
  return %ld : i4                                                                                                                                                                                                    
} 

and all of them seem to be covered with the existing rules. Can you think of any other example?

you have to check the strides of the memref type of the result. Those strides need to be contiguous?

Note that the type of the new subview, newTy is coming from the emulation converter (L398), where we check that the original memref has input strides and then we linearize the shape, also resulting in a memref with a single unit stride. Not sure what else I can check

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 4] [1, 1, 1, 4] : memref<512x64x8x16xi4>                                                                                                                 
                                                                            to memref<16x64x8x4xi4, strided<[8192, 128, 16, 4], offset: ?>>      

The result subview here is contiguous... If you already checked for it (or if it is somehow already enforced), I might have missed it... but take a subview of this form

%subview = memref.subview %arr[%idx0, %idx] [4, 8] [1, 1] : memref<40x40xi4> to memref<4x8xi4, strided<[40, 1], offset:?>>

Here the result type is not contiguous. You cannot represent this as a linearized type. Is this handled?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can check it using isStaticShapeAndContiguousRowMajor in this case. It ensures that the memref type is contiguous.

subViewOp, newTy, adaptor.getSource(), linearizedIndices,
linearizedInfo.linearizedSize, strides.back());
return success();
}
};

Expand Down
4 changes: 1 addition & 3 deletions mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
AffineExpr mulMap = builder.getAffineConstantExpr(1);

SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
SmallVector<OpFoldResult> sizeValues(sourceRank);

for (unsigned i = 0; i < sourceRank; ++i) {
unsigned offsetIdx = 2 * i;
Expand All @@ -78,8 +77,7 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
mulMap = mulMap * symbols[i];
}

// Adjust linearizedIndices, size and offset by the scale factor (dstBits /
// srcBits).
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
int64_t scaler = dstBits / srcBits;
addMulMap = addMulMap.floorDiv(scaler);
mulMap = mulMap.floorDiv(scaler);
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,29 @@ func.func @memref_strided_i4(%idx : index) -> i4 {

// -----

func.func @memref_subview_dynamic_offset_i4(%idx : index) -> i4 {
%c0 = arith.constant 0 : index
%arr = memref.alloc() : memref<512x64x8x16xi4>
%subview = memref.subview %arr[%idx, 0, 0, 0] [16, 64, 8, 16] [1, 1, 1, 1] : memref<512x64x8x16xi4>
to memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
%ld = memref.load %subview[%c0, %c0, %c0, %c0] : memref<16x64x8x16xi4, strided<[8192, 128, 16, 1], offset: ?>>
return %ld : i4
}

// CHECK-LABEL: func.func @memref_subview_dynamic_offset_i4(
// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<2097152xi8>
// CHECK: %[[IDX:.*]] = affine.apply
// CHECK: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [65536] [1] : memref<2097152xi8> to memref<65536xi8, strided<[1], offset: ?>>
// CHECK: memref.load %[[SUBVIEW]]

// CHECK32-LABEL: func.func @memref_subview_dynamic_offset_i4(
// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<524288xi32>
// CHECK32: %[[IDX:.*]] = affine.apply
// CHECK32: %[[SUBVIEW:.*]] = memref.subview %[[ALLOC]][%[[IDX]]] [16384] [1] : memref<524288xi32> to memref<16384xi32, strided<[1], offset: ?>>
// CHECK32: memref.load %[[SUBVIEW]]

// -----

func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
Expand Down