Skip to content

Commit f2d5e8b

Browse files
committed
linting
1 parent 66ecff4 commit f2d5e8b

File tree

2 files changed

+122
-3
lines changed

2 files changed

+122
-3
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -429,7 +429,7 @@ namespace {
429429
struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
430430
using OpConversionPattern::OpConversionPattern;
431431

432-
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
432+
ConvertVectorStore(MLIRContext *context, bool useAtomicWrites)
433433
: OpConversionPattern<vector::StoreOp>(context),
434434
useAtomicWrites_(useAtomicWrites) {}
435435

@@ -583,8 +583,8 @@ struct ConvertVectorStore final : OpConversionPattern<vector::StoreOp> {
583583
extractSliceIntoByte(rewriter, loc, valueToStore, 0,
584584
frontSubWidthStoreElem, *foldedNumFrontPadElems);
585585

586-
atomicStore(rewriter, loc, memrefBase, currentDestIndex,
587-
cast<VectorValue>(value), frontMask.getResult());
586+
subEmulatedWidthStore(rewriter, loc, memrefBase, currentDestIndex,
587+
cast<VectorValue>(value), frontMask.getResult());
588588
}
589589

590590
if (currentSourceIndex >= origElements) {
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
// RUN: mlir-opt --test-emulate-narrow-int="arith-compute-bitwidth=1 memref-load-bitwidth=8 atomic-store=false" --cse --split-input-file %s | FileCheck %s
2+
3+
// TODO: remove memref.alloc() in the tests to eliminate noises.
4+
// memref.alloc exists here because sub-byte vector data types such as i2
5+
// are currently not supported as input arguments.
6+
7+
func.func @vector_store_i2_const_index_two_rmw(%arg0: vector<3xi2>) {
8+
%0 = memref.alloc() : memref<3x3xi2>
9+
%c0 = arith.constant 0 : index
10+
%c2 = arith.constant 2 : index
11+
vector.store %arg0, %0[%c2, %c0] :memref<3x3xi2>, vector<3xi2>
12+
return
13+
}
14+
// Load from bit [12:18), byte [1:2] of total 3 bytes, both bytes needs rmw.
15+
16+
// CHECK: func @vector_store_i2_const_index_two_rmw(
17+
// CHECK-SAME: %[[ARG0:.+]]: vector<3xi2>)
18+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<3xi8>
19+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
20+
21+
// Part 1 RMW sequence
22+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, true, true]>
23+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
24+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
25+
// CHECK-SAME: {offsets = [0], sizes = [2], strides = [1]} : vector<3xi2> to vector<2xi2>
26+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
27+
// CHECK-SAME: {offsets = [2], strides = [1]} : vector<2xi2> into vector<4xi2>
28+
// CHECK: %[[LOAD:.+]] = vector.load
29+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
30+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
31+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
32+
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
33+
34+
// Part 2 RMW sequence
35+
// CHECK: %[[OFFSET:.+]] = arith.addi %[[C1]], %[[C1]] : index
36+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
37+
// CHECK-SAME: {offsets = [2], sizes = [1], strides = [1]} : vector<3xi2> to vector<1xi2>
38+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]], %[[CST0]]
39+
// CHECK-SAME: {offsets = [0], strides = [1]} : vector<1xi2> into vector<4xi2>
40+
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, false, false, false]> : vector<4xi1>
41+
// CHECK: %[[LOAD2:.+]] = vector.load
42+
// CHECK: %[[UPCAST2:.+]] = vector.bitcast %[[LOAD2]] : vector<1xi8> to vector<4xi2>
43+
// CHECK: %[[SELECT2:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST2]]
44+
// CHECK: %[[DOWNCAST2:.+]] = vector.bitcast %[[SELECT2]]
45+
// CHECK: vector.store %[[DOWNCAST2]], %[[ALLOC]][%[[OFFSET]]]
46+
47+
48+
// -----
49+
50+
func.func @vector_store_i2_rmw(%arg0: vector<7xi2>) {
51+
%0 = memref.alloc() : memref<3x7xi2>
52+
%c0 = arith.constant 0 : index
53+
%c1 = arith.constant 1 : index
54+
vector.store %arg0, %0[%c1, %c0] :memref<3x7xi2>, vector<7xi2>
55+
return
56+
}
57+
58+
// CHECK: func @vector_store_i2_rmw(
59+
// CHECK-SAME: %[[ARG0:.+]]:
60+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<6xi8>
61+
// CHECK: %[[C1:.+]] = arith.constant 1 : index
62+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, false, false, true]>
63+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
64+
// CHECK: %[[EXTRACT:.+]] = vector.extract_strided_slice %[[ARG0]]
65+
// CHECK-SAME: {offsets = [0], sizes = [1], strides = [1]}
66+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[EXTRACT]], %[[CST0]]
67+
// CHECK-SAME: {offsets = [3], strides = [1]}
68+
// First sub-width RMW:
69+
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C1]]]
70+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
71+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
72+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
73+
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C1]]]
74+
75+
// Full-width store:
76+
// CHECK: %[[INDEX:.+]] = arith.addi %[[C1]], %[[C1]]
77+
// CHECK: %[[EXTRACT1:.+]] = vector.extract_strided_slice %[[ARG0]]
78+
// CHECK-SAME: {offsets = [1], sizes = [4], strides = [1]}
79+
// CHECK: %[[BITCAST:.+]] = vector.bitcast %[[EXTRACT1]]
80+
// CHECK: vector.store %[[BITCAST]], %[[ALLOC]][%[[INDEX]]]
81+
82+
// Second sub-width RMW:
83+
// CHECK: %[[INDEX2:.+]] = arith.addi %[[INDEX]], %[[C1]]
84+
// CHECK: %[[EXTRACT2:.+]] = vector.extract_strided_slice %[[ARG0]]
85+
// CHECK-SAME: {offsets = [5], sizes = [2], strides = [1]}
86+
// CHECK: %[[INSERT2:.+]] = vector.insert_strided_slice %[[EXTRACT2]]
87+
// CHECK-SAME: {offsets = [0], strides = [1]}
88+
// CHECK: %[[CST1:.+]] = arith.constant dense<[true, true, false, false]>
89+
// CHECK: %[[LOAD1:.+]] = vector.load %[[ALLOC]][%[[INDEX2]]]
90+
// CHECK: %[[UPCAST1:.+]] = vector.bitcast %[[LOAD1]]
91+
// CHECK: %[[SELECT1:.+]] = arith.select %[[CST1]], %[[INSERT2]], %[[UPCAST1]]
92+
// CHECK: %[[DOWNCAST1:.+]] = vector.bitcast %[[SELECT1]]
93+
// CHECK: vector.store %[[DOWNCAST1]], %[[ALLOC]][%[[INDEX2]]]
94+
95+
// -----
96+
97+
func.func @vector_store_i2_single_rmw(%arg0: vector<1xi2>) {
98+
%0 = memref.alloc() : memref<4x1xi2>
99+
%c0 = arith.constant 0 : index
100+
%c1 = arith.constant 1 : index
101+
vector.store %arg0, %0[%c1, %c0] :memref<4x1xi2>, vector<1xi2>
102+
return
103+
}
104+
105+
// in this test, only emit 1 rmw store
106+
// CHECK: func @vector_store_i2_single_rmw(
107+
// CHECK-SAME: %[[ARG0:.+]]: vector<1xi2>)
108+
// CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<1xi8>
109+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
110+
// CHECK: %[[CST:.+]] = arith.constant dense<[false, true, false, false]>
111+
// CHECK: %[[CST0:.+]] = arith.constant dense<0> : vector<4xi2>
112+
// CHECK: %[[INSERT:.+]] = vector.insert_strided_slice %[[ARG0]], %[[CST0]]
113+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<1xi2> into vector<4xi2>
114+
// CHECK: %[[LOAD:.+]] = vector.load %[[ALLOC]][%[[C0]]] : memref<1xi8>, vector<1xi8>
115+
// CHECK: %[[UPCAST:.+]] = vector.bitcast %[[LOAD]] : vector<1xi8> to vector<4xi2>
116+
// CHECK: %[[SELECT:.+]] = arith.select %[[CST]], %[[INSERT]], %[[UPCAST]]
117+
// CHECK: %[[DOWNCAST:.+]] = vector.bitcast %[[SELECT]]
118+
// CHECK: vector.store %[[DOWNCAST]], %[[ALLOC]][%[[C0]]]
119+

0 commit comments

Comments
 (0)