Skip to content

Commit 8fc433f

Browse files
committed
[mlir][MemRef] Move narrow type emulation common methods to MemRefUtils.
It also unifies the computation of StridedLayoutAttr. If the stride is static known value, we can just use it. Differential Revision: https://reviews.llvm.org/D155017
1 parent a48f32d commit 8fc433f

File tree

11 files changed

+162
-215
lines changed

11 files changed

+162
-215
lines changed

mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#ifndef MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
1717
#define MLIR_DIALECT_MEMREF_UTILS_MEMREFUTILS_H
1818

19+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
20+
1921
namespace mlir {
2022

2123
class MemRefType;
@@ -26,6 +28,37 @@ namespace memref {
2628
/// contiguous chunk of memory.
2729
bool isStaticShapeAndContiguousRowMajor(MemRefType type);
2830

31+
/// Returns the flattened 1-D memref and linearized offset for narrow type
32+
/// emulation.
33+
///
34+
/// The emulation only works on 1D memref types. To make this work on N-D
35+
/// memref, we need to linearize the offset.
36+
///
37+
/// For example, to emulate i4 to i8, the following op:
38+
///
39+
/// %0 = memref.load %arg0[%v0, %v1] :
40+
/// memref<?x?xi4, strided<[?, ?], offset: ?>>
41+
///
42+
/// can be replaced with
43+
///
44+
/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
45+
///
46+
/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
47+
/// %linearized_size = %size0 * %size1
48+
/// %scaled_linear_offset = %linearized_offset / 8 * 4
49+
/// %scaled_base_offset = %offset / 8 * 4
50+
///
51+
/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
52+
/// sizes = [%linearized_size], strides = [%stride#1]
53+
///
54+
/// %new_load = memref.load %linearized[%scaled_linear_offset] :
55+
/// memref<?xi8, strided<[?], offset: ?>>
56+
std::pair<Value, Value>
57+
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
58+
int dstBits, SmallVector<Value> indices,
59+
memref::ExtractStridedMetadataOp stridedMetadata,
60+
OpBuilder &builder);
61+
2962
} // namespace memref
3063
} // namespace mlir
3164

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1212
#include "mlir/Dialect/Func/IR/FuncOps.h"
1313
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14-
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1514
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
1615
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1716
#include "mlir/IR/Matchers.h"

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlir/Dialect/Arith/IR/Arith.h"
1010
#include "mlir/Dialect/Arith/Utils/Utils.h"
1111
#include "mlir/Dialect/MemRef/IR/MemRef.h"
12-
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1312
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1413
#include "mlir/IR/AffineMap.h"
1514
#include "mlir/IR/Builders.h"

mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp

Lines changed: 8 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1616
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1717
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
18+
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
1819
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1920
#include "mlir/Transforms/DialectConversion.h"
2021
#include "llvm/Support/FormatVariadic.h"
@@ -27,102 +28,6 @@ using namespace mlir;
2728
// Utility functions
2829
//===----------------------------------------------------------------------===//
2930

30-
/// The emulation only works on 1D memref types.
31-
/// To make this work on N-D memref, we need to linearize the offset.
32-
///
33-
/// For example, to emulate i4 to i8, the following op:
34-
///
35-
/// %0 = memref.load %arg0[%v0, %v1] :
36-
/// memref<?x?xi4, strided<[?, ?], offset: ?>>
37-
///
38-
/// can be replaced with
39-
///
40-
/// %b, %offset, %sizes:2, %strides:2 = memref.extract_strided_metadata %0
41-
///
42-
/// %linearized_offset = %v0 * %stride#0 + %v1 * %stride#1
43-
/// %linearized_size = %size0 * %size1
44-
/// %scaled_linear_offset = %linearized_offset / 8 * 4
45-
/// %scaled_base_offset = %offset / 8 * 4
46-
///
47-
/// %linearized = memref.reinterpret_cast %b, offset = [%scaled_base_offset],
48-
/// sizes = [%linearized_size], strides = [%stride#1]
49-
///
50-
/// %new_load = memref.load %linearized[%scaled_linear_offset] :
51-
/// memref<?xi8, strided<[?], offset: ?>>
52-
53-
static Value
54-
linearizeMemrefLoad(Location loc, MemRefType sourceType, int srcBits,
55-
int dstBits, SmallVector<Value> indices,
56-
memref::ExtractStridedMetadataOp stridedMetadata,
57-
OpBuilder &builder) {
58-
auto srcElementType = sourceType.getElementType();
59-
unsigned sourceRank = indices.size();
60-
61-
Value baseBuffer = stridedMetadata.getBaseBuffer();
62-
SmallVector<Value> baseSizes = stridedMetadata.getSizes();
63-
SmallVector<Value> baseStrides = stridedMetadata.getStrides();
64-
Value baseOffset = stridedMetadata.getOffset();
65-
assert(indices.size() == baseStrides.size());
66-
67-
// Create the affine symbols and values for linearization.
68-
SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
69-
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
70-
symbols[0] = builder.getAffineSymbolExpr(0);
71-
AffineExpr addMulMap = symbols.front();
72-
AffineExpr mulMap = symbols.front();
73-
74-
SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
75-
offsetValues[0] = builder.getIndexAttr(0);
76-
SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
77-
sizeValues[0] = builder.getIndexAttr(1);
78-
79-
for (unsigned i = 0; i < sourceRank; ++i) {
80-
unsigned offsetIdx = 2 * i + 1;
81-
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
82-
offsetValues[offsetIdx] = indices[i];
83-
offsetValues[offsetIdx + 1] = baseStrides[i];
84-
85-
unsigned sizeIdx = i + 1;
86-
mulMap = mulMap * symbols[sizeIdx];
87-
sizeValues[sizeIdx] = baseSizes[i];
88-
}
89-
90-
// Adjust linearizedOffset by the scale factor (dstBits / srcBits).
91-
OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
92-
AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
93-
offsetValues.back() = scaler;
94-
95-
OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
96-
builder, loc, scaledAddMulMap, offsetValues);
97-
OpFoldResult linearizedSize =
98-
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
99-
100-
// Adjust baseOffset by the scale factor (dstBits / srcBits).
101-
AffineExpr s0, s1;
102-
bindSymbols(builder.getContext(), s0, s1);
103-
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
104-
builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
105-
106-
// Flatten n-D MemRef to 1-D MemRef.
107-
auto layoutAttr = StridedLayoutAttr::get(
108-
sourceType.getContext(), ShapedType::kDynamic, {ShapedType::kDynamic});
109-
int64_t staticShape = sourceType.hasStaticShape()
110-
? sourceType.getNumElements()
111-
: ShapedType::kDynamic;
112-
auto flattenMemrefType = MemRefType::get(
113-
staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
114-
115-
auto reinterpret = builder.create<memref::ReinterpretCastOp>(
116-
loc, flattenMemrefType, baseBuffer,
117-
getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
118-
getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
119-
baseStrides.back());
120-
121-
return builder.create<memref::LoadOp>(
122-
loc, srcElementType, reinterpret.getResult(),
123-
getValueOrCreateConstantIndexOp(builder, loc, linearizedOffset));
124-
}
125-
12631
/// When data is loaded/stored in `targetBits` granularity, but is used in
12732
/// `sourceBits` granularity (`sourceBits` < `targetBits`), the `targetBits` is
12833
/// treated as an array of elements of width `sourceBits`.
@@ -239,8 +144,13 @@ struct ConvertMemRefLoad final : OpConversionPattern<memref::LoadOp> {
239144

240145
lastIdx = stridedMetadata.getOffset();
241146
} else {
242-
newLoad = linearizeMemrefLoad(loc, sourceType, srcBits, dstBits, indices,
243-
stridedMetadata, rewriter);
147+
auto [reinterpret, linearizedOffset] =
148+
memref::getLinearizeMemRefAndOffset(loc, sourceType, srcBits, dstBits,
149+
adaptor.getIndices(),
150+
stridedMetadata, rewriter);
151+
152+
newLoad = rewriter.create<memref::LoadOp>(loc, srcElementType,
153+
reinterpret, linearizedOffset);
244154

245155
lastIdx = adaptor.getIndices().back();
246156
}

mlir/lib/Dialect/MemRef/Utils/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,4 +6,6 @@ add_mlir_dialect_library(MLIRMemRefUtils
66

77
LINK_LIBS PUBLIC
88
MLIRIR
9+
MLIRAffineDialect
10+
MLIRArithUtils
911
)

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
14+
#include "mlir/Dialect/Affine/IR/AffineOps.h"
15+
#include "mlir/Dialect/Arith/Utils/Utils.h"
1416
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1517

1618
namespace mlir {
@@ -44,5 +46,80 @@ bool isStaticShapeAndContiguousRowMajor(MemRefType type) {
4446
return curDim < 0;
4547
}
4648

49+
std::pair<Value, Value>
50+
getLinearizeMemRefAndOffset(Location loc, MemRefType sourceType, int srcBits,
51+
int dstBits, SmallVector<Value> indices,
52+
memref::ExtractStridedMetadataOp stridedMetadata,
53+
OpBuilder &builder) {
54+
auto srcElementType = sourceType.getElementType();
55+
unsigned sourceRank = indices.size();
56+
57+
Value baseBuffer = stridedMetadata.getBaseBuffer();
58+
SmallVector<Value> baseSizes = stridedMetadata.getSizes();
59+
SmallVector<Value> baseStrides = stridedMetadata.getStrides();
60+
Value baseOffset = stridedMetadata.getOffset();
61+
assert(indices.size() == baseStrides.size());
62+
63+
// Create the affine symbols and values for linearization.
64+
SmallVector<AffineExpr> symbols(2 * sourceRank + 2);
65+
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
66+
symbols[0] = builder.getAffineSymbolExpr(0);
67+
AffineExpr addMulMap = symbols.front();
68+
AffineExpr mulMap = symbols.front();
69+
70+
SmallVector<OpFoldResult> offsetValues(2 * sourceRank + 2);
71+
offsetValues[0] = builder.getIndexAttr(0);
72+
SmallVector<OpFoldResult> sizeValues(sourceRank + 1);
73+
sizeValues[0] = builder.getIndexAttr(1);
74+
75+
for (unsigned i = 0; i < sourceRank; ++i) {
76+
unsigned offsetIdx = 2 * i + 1;
77+
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
78+
offsetValues[offsetIdx] = indices[i];
79+
offsetValues[offsetIdx + 1] = baseStrides[i];
80+
81+
unsigned sizeIdx = i + 1;
82+
mulMap = mulMap * symbols[sizeIdx];
83+
sizeValues[sizeIdx] = baseSizes[i];
84+
}
85+
86+
// Adjust linearizedOffset by the scale factor (dstBits / srcBits).
87+
OpFoldResult scaler = builder.getIndexAttr(dstBits / srcBits);
88+
AffineExpr scaledAddMulMap = addMulMap.floorDiv(symbols.back());
89+
offsetValues.back() = scaler;
90+
91+
OpFoldResult linearizedOffset = affine::makeComposedFoldedAffineApply(
92+
builder, loc, scaledAddMulMap, offsetValues);
93+
OpFoldResult linearizedSize =
94+
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizeValues);
95+
96+
// Adjust baseOffset by the scale factor (dstBits / srcBits).
97+
AffineExpr s0, s1;
98+
bindSymbols(builder.getContext(), s0, s1);
99+
OpFoldResult adjustBaseOffset = affine::makeComposedFoldedAffineApply(
100+
builder, loc, s0.floorDiv(s1), {baseOffset, scaler});
101+
102+
// Flatten n-D MemRef to 1-D MemRef.
103+
std::optional<int64_t> stride =
104+
getConstantIntValue(stridedMetadata.getConstifiedMixedStrides().back());
105+
auto layoutAttr =
106+
StridedLayoutAttr::get(sourceType.getContext(), ShapedType::kDynamic,
107+
{stride ? stride.value() : ShapedType::kDynamic});
108+
int64_t staticShape = sourceType.hasStaticShape()
109+
? sourceType.getNumElements()
110+
: ShapedType::kDynamic;
111+
auto flattenMemrefType = MemRefType::get(
112+
staticShape, srcElementType, layoutAttr, sourceType.getMemorySpace());
113+
114+
auto reinterpret = builder.create<memref::ReinterpretCastOp>(
115+
loc, flattenMemrefType, baseBuffer,
116+
getValueOrCreateConstantIndexOp(builder, loc, adjustBaseOffset),
117+
getValueOrCreateConstantIndexOp(builder, loc, linearizedSize),
118+
baseStrides.back());
119+
120+
return std::make_pair(reinterpret, getValueOrCreateConstantIndexOp(
121+
builder, loc, linearizedOffset));
122+
}
123+
47124
} // namespace memref
48125
} // namespace mlir

mlir/lib/Dialect/Vector/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ add_mlir_dialect_library(MLIRVectorTransforms
3737
MLIRIR
3838
MLIRLinalgDialect
3939
MLIRMemRefDialect
40+
MLIRMemRefUtils
4041
MLIRSCFDialect
4142
MLIRSideEffectInterfaces
4243
MLIRTensorDialect

0 commit comments

Comments
 (0)