Skip to content

Commit 24cf476

Browse files
authored
[mlir] Add support for vector.store sub-byte emulation. (#70293)
1 parent d065965 commit 24cf476

File tree

2 files changed

+144
-1
lines changed

2 files changed

+144
-1
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,71 @@ using namespace mlir;
3434

3535
namespace {
3636

37+
//===----------------------------------------------------------------------===//
38+
// ConvertVectorStore
39+
//===----------------------------------------------------------------------===//
40+
41+
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
42+
using OpConversionPattern::OpConversionPattern;
43+
44+
LogicalResult
45+
matchAndRewrite(vector::StoreOp op, OpAdaptor adaptor,
46+
ConversionPatternRewriter &rewriter) const override {
47+
48+
auto loc = op.getLoc();
49+
auto convertedType = cast<MemRefType>(adaptor.getBase().getType());
50+
Type oldElementType = op.getValueToStore().getType().getElementType();
51+
Type newElementType = convertedType.getElementType();
52+
int srcBits = oldElementType.getIntOrFloatBitWidth();
53+
int dstBits = newElementType.getIntOrFloatBitWidth();
54+
55+
if (dstBits % srcBits != 0) {
56+
return rewriter.notifyMatchFailure(
57+
op, "only dstBits % srcBits == 0 supported");
58+
}
59+
int scale = dstBits / srcBits;
60+
61+
// Adjust the number of elements to store when emulating narrow types.
62+
// Here only the 1-D vector store is considered, and the N-D memref types
63+
// should be linearized.
64+
// For example, to emulate i4 to i8, the following op:
65+
//
66+
// vector.store %arg1, %0[%arg2, %arg3] : memref<4x8xi4>, vector<8xi4>
67+
//
68+
// can be replaced with
69+
//
70+
// %bitcast = vector.bitcast %arg1 : vector<8xi4> to vector<4xi8>
71+
// vector.store %bitcast, %alloc[%linear_index] : memref<16xi8>,
72+
// vector<4xi8>
73+
74+
auto origElements = op.getValueToStore().getType().getNumElements();
75+
if (origElements % scale != 0)
76+
return failure();
77+
78+
auto stridedMetadata =
79+
rewriter.create<memref::ExtractStridedMetadataOp>(loc, op.getBase());
80+
81+
OpFoldResult linearizedIndices;
82+
std::tie(std::ignore, linearizedIndices) =
83+
memref::getLinearizedMemRefOffsetAndSize(
84+
rewriter, loc, srcBits, dstBits,
85+
stridedMetadata.getConstifiedMixedOffset(),
86+
stridedMetadata.getConstifiedMixedSizes(),
87+
stridedMetadata.getConstifiedMixedStrides(),
88+
getAsOpFoldResult(adaptor.getIndices()));
89+
90+
auto numElements = origElements / scale;
91+
auto bitCast = rewriter.create<vector::BitCastOp>(
92+
loc, VectorType::get(numElements, newElementType),
93+
op.getValueToStore());
94+
95+
rewriter.replaceOpWithNewOp<vector::StoreOp>(
96+
op, bitCast.getResult(), adaptor.getBase(),
97+
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices));
98+
return success();
99+
}
100+
};
101+
37102
//===----------------------------------------------------------------------===//
38103
// ConvertVectorLoad
39104
//===----------------------------------------------------------------------===//
@@ -755,7 +820,7 @@ void vector::populateVectorNarrowTypeEmulationPatterns(
755820
RewritePatternSet &patterns) {
756821

757822
// Populate `vector.*` conversion patterns.
758-
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad,
823+
patterns.add<ConvertVectorLoad, ConvertVectorMaskedLoad, ConvertVectorStore,
759824
ConvertVectorTransferRead>(typeConverter, patterns.getContext());
760825
}
761826

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,3 +350,81 @@ func.func @vector_extract_cst_maskedload_i4() -> vector<8x8x16xi4> {
350350
// CHECK32-SAME: memref<128xi32>, vector<2xi1>, vector<2xi32> into vector<2xi32>
351351
// CHECK32: %[[BITCAST:.+]] = vector.bitcast %[[LOAD]] : vector<2xi32> to vector<16xi4>
352352
// CHECK32: %[[SELECT:.+]] = arith.select %[[ORIG_EXT2]], %[[BITCAST]], %[[PASSTHRU]] : vector<16xi1>, vector<16xi4>
353+
354+
// -----
355+
356+
func.func @vector_store_i8(%arg0: vector<8xi8>, %arg1: index, %arg2: index) {
357+
%0 = memref.alloc() : memref<4x8xi8>
358+
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi8>, vector<8xi8>
359+
return
360+
}
361+
362+
// Expect no conversions, i8 is supported.
363+
// CHECK: func @vector_store_i8
364+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<4x8xi8>
365+
// CHECK: vector.store %[[ARG0]], %[[ALLOC:.+]][%[[ARG1]], %[[ARG2]]] : memref<4x8xi8>, vector<8xi8>
366+
367+
// CHECK32-DAG: affine_map<()[s0, s1] -> (s0 * 2 + s1 floordiv 4)>
368+
// CHECK32: func @vector_store_i8
369+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<8xi32>
370+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
371+
// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi8> to vector<2xi32>
372+
// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<8xi32>, vector<2xi32
373+
374+
// -----
375+
376+
func.func @vector_store_i4(%arg0: vector<8xi4>, %arg1: index, %arg2: index) {
377+
%0 = memref.alloc() : memref<4x8xi4>
378+
vector.store %arg0, %0[%arg1, %arg2] :memref<4x8xi4>, vector<8xi4>
379+
return
380+
}
381+
382+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
383+
// CHECK: func @vector_store_i4
384+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<16xi8>
385+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
386+
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
387+
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<16xi8>, vector<4xi8>
388+
389+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
390+
// CHECK32: func @vector_store_i4
391+
// CHECK32: %[[ALLOC:.+]] = memref.alloc() : memref<4xi32>
392+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
393+
// CHECK32: %[[VEC_I32:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
394+
// CHECK32: vector.store %[[VEC_I32:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<4xi32>, vector<1xi32>
395+
396+
// -----
397+
398+
func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
399+
%0 = memref.alloc(%arg1, %arg2) : memref<?x?xi4>
400+
vector.store %arg0, %0[%arg3, %arg4] : memref<?x?xi4>, vector<8xi4>
401+
return
402+
}
403+
404+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
405+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
406+
// CHECK: func @vector_store_i4_dynamic
407+
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
408+
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
409+
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
410+
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
411+
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
412+
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
413+
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
414+
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
415+
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
416+
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
417+
418+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
419+
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
420+
// CHECK32: func @vector_store_i4_dynamic
421+
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
422+
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
423+
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
424+
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
425+
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
426+
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
427+
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
428+
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
429+
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
430+
// CHECK32: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi32>, vector<1xi32>

0 commit comments

Comments
 (0)