Skip to content

[mlir] Add narrow type emulation for memref.reinterpret_cast #73144

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Nov 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 95 additions & 41 deletions mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,77 @@
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/DialectConversion.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/MathExtras.h"
#include <cassert>
#include <type_traits>

using namespace mlir;

//===----------------------------------------------------------------------===//
// Utility functions
//===----------------------------------------------------------------------===//

/// Converts a memref::SubViewOp or memref::ReinterpretCastOp to the converted
/// type. The result MemRefType of the old op must have a rank and stride of 1,
/// with static offset and size. The number of bits in the offset must evenly
/// divide the bitwidth of the new converted type.
template <typename MemRefOpTy>
static LogicalResult convertCastingOp(ConversionPatternRewriter &rewriter,
typename MemRefOpTy::Adaptor adaptor,
MemRefOpTy op, MemRefType newTy) {
static_assert(std::is_same<MemRefOpTy, memref::SubViewOp>() ||
std::is_same<MemRefOpTy, memref::ReinterpretCastOp>(),
"Expected only memref::SubViewOp or memref::ReinterpretCastOp");

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(op,
"only dstBits % srcBits == 0 supported");
}

// Only support stride of 1.
if (llvm::any_of(op.getStaticStrides(),
[](int64_t stride) { return stride != 1; })) {
return rewriter.notifyMatchFailure(op->getLoc(),
"stride != 1 is not supported");
}

auto sizes = op.getStaticSizes();
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (llvm::any_of(sizes,
[](int64_t size) { return size == ShapedType::kDynamic; }) ||
offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(), "offset not multiple of elementsPerByte is not "
"supported");
}

SmallVector<int64_t> size;
if (sizes.size())
size.push_back(ceilDiv(sizes[0], elementsPerByte));
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<MemRefOpTy>(op, newTy,
*adaptor.getODSOperands(0).begin(),
offset, size, op.getStaticStrides());
return success();
}

/// When data is loaded/stored in `targetBits` granularity, but is used in
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
/// treated as an array of elements of width `sourceBits`.
Expand Down Expand Up @@ -211,6 +270,37 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefReinterpretCast
//===----------------------------------------------------------------------===//

/// Output types should be at most one dimensional, so only the 0 or 1
/// dimensional cases are supported.
struct ConvertMemRefReinterpretCast final
: OpConversionPattern<memref::ReinterpretCastOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(memref::ReinterpretCastOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType newTy =
dyn_cast<MemRefType>(getTypeConverter()->convertType(op.getType()));
if (!newTy) {
return rewriter.notifyMatchFailure(
op->getLoc(),
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

// Only support for 0 or 1 dimensional cases.
if (op.getType().getRank() > 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}

return convertCastingOp(rewriter, adaptor, op, newTy);
}
};

//===----------------------------------------------------------------------===//
// ConvertMemRefSubview
//===----------------------------------------------------------------------===//
Expand All @@ -233,50 +323,13 @@ struct ConvertMemRefSubview final : OpConversionPattern<memref::SubViewOp> {
llvm::formatv("failed to convert memref type: {0}", op.getType()));
}

auto convertedElementType = newTy.getElementType();
auto oldElementType = op.getType().getElementType();
int srcBits = oldElementType.getIntOrFloatBitWidth();
int dstBits = convertedElementType.getIntOrFloatBitWidth();
if (dstBits % srcBits != 0) {
return rewriter.notifyMatchFailure(
op, "only dstBits % srcBits == 0 supported");
}

// Only support offset for 1-D subview.
if (op.getType().getRank() != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with rank > 1 is not supported");
}

// Only support stride of 1.
if (op.getStaticStride(0) != 1) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with stride != 1 is not supported");
}

int64_t size = op.getStaticSize(0);
int64_t offset = op.getStaticOffset(0);
// Only support static sizes and offsets.
if (size == ShapedType::kDynamic || offset == ShapedType::kDynamic) {
return rewriter.notifyMatchFailure(
op->getLoc(), "subview with dynamic size or offset is not supported");
}

int elementsPerByte = dstBits / srcBits;
if (offset % elementsPerByte != 0) {
return rewriter.notifyMatchFailure(
op->getLoc(),
"subview with offset not multiple of elementsPerByte is not "
"supported");
}

size = ceilDiv(size, elementsPerByte);
offset = offset / elementsPerByte;

rewriter.replaceOpWithNewOp<memref::SubViewOp>(
op, newTy, *adaptor.getODSOperands(0).begin(), offset, size,
op.getStaticStrides());
return success();
return convertCastingOp(rewriter, adaptor, op, newTy);
}
};

Expand All @@ -291,9 +344,10 @@ void memref::populateMemRefNarrowTypeEmulationPatterns(
RewritePatternSet &patterns) {

// Populate `memref.*` conversion patterns.
patterns.add<ConvertMemRefAlloc, ConvertMemRefLoad,
ConvertMemRefAssumeAlignment, ConvertMemRefSubview>(
typeConverter, patterns.getContext());
patterns
.add<ConvertMemRefAlloc, ConvertMemRefLoad, ConvertMemRefAssumeAlignment,
ConvertMemRefSubview, ConvertMemRefReinterpretCast>(
typeConverter, patterns.getContext());
memref::populateResolveExtractStridedMetadataPatterns(patterns);
}

Expand Down
58 changes: 58 additions & 0 deletions mlir/test/Dialect/MemRef/emulate-narrow-type.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -174,3 +174,61 @@ func.func @memref_strided_i4(%idx : index) -> i4 {
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<16xi32>
// CHECK32: %[[SUBVIEW:.+]] = memref.subview %[[ALLOC]][4] [4] [1] : memref<16xi32> to memref<4xi32, strided<[1], offset: 4>>
// CHECK32: %[[LOAD:.+]] = memref.load %[[SUBVIEW]]

// -----

func.func @reinterpret_cast_memref_load_0D() -> i4 {
%0 = memref.alloc() : memref<5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [0], sizes: [], strides: [] : memref<5xi4> to memref<i4>
%1 = memref.load %reinterpret_cast_0[] : memref<i4>
return %1 : i4
}
// CHECK-LABEL: func @reinterpret_cast_memref_load_0D()
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<3xi8> to memref<i8>
// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i8>
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i8 to i4
// CHECK: return %[[TRUNC]]

// CHECK32-LABEL: func @reinterpret_cast_memref_load_0D()
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<1xi32>
// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [0], sizes: [], strides: [] : memref<1xi32> to memref<i32>
// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][] : memref<i32>
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[LOAD]] : i32 to i4
// CHECK32: return %[[TRUNC]]

// -----

func.func @reinterpret_cast_memref_load_1D(%arg0: index) -> i4 {
%0 = memref.alloc() : memref<5x5xi4>
%reinterpret_cast_0 = memref.reinterpret_cast %0 to offset: [8], sizes: [25], strides: [1] : memref<5x5xi4> to memref<25xi4, strided<[1], offset:8>>
%1 = memref.load %reinterpret_cast_0[%arg0] : memref<25xi4, strided<[1], offset:8>>
return %1 : i4
}
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 2)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 2) * 8)>
// CHECK: func @reinterpret_cast_memref_load_1D(
// CHECK-SAME: %[[ARG0:.+]]: index
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<13xi8>
// CHECK: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [4], sizes: [13], strides: [1] : memref<13xi8> to memref<13xi8, strided<[1], offset: 4>>
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<13xi8, strided<[1], offset: 4>>
// CHECK: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i8
// CHECK: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i8
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i8 to i4
// CHECK: return %[[TRUNC]]

// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0] -> (s0 floordiv 8)>
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 8) * 32)>
// CHECK32: func @reinterpret_cast_memref_load_1D(
// CHECK32-SAME: %[[ARG0:.+]]: index
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
// CHECK32: %[[RE_CAST:.+]] = memref.reinterpret_cast %[[ALLOC]] to offset: [1], sizes: [4], strides: [1] : memref<4xi32> to memref<4xi32, strided<[1], offset: 1>>
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG0]]]
// CHECK32: %[[LOAD:.+]] = memref.load %[[RE_CAST]][%[[INDEX]]] : memref<4xi32, strided<[1], offset: 1>>
// CHECK32: %[[OFFSET:.+]] = affine.apply #[[MAP1]]()[%[[ARG0]]]
// CHECK32: %[[CAST:.+]] = arith.index_cast %[[OFFSET]] : index to i32
// CHECK32: %[[SHR:.+]] = arith.shrsi %[[LOAD]], %[[CAST]] : i32
// CHECK32: %[[TRUNC:.+]] = arith.trunci %[[SHR]] : i32 to i4
// CHECK32: return %[[TRUNC]]