Skip to content

Commit 53e8ff1

Browse files
authored
[MLIR] Fixing the memref linearization size computation for non-packed memref (#138922)
Credit to @krzysz00 who discovered this subtle bug in `MemRefUtils`. The problem is in `getLinearizedMemRefOffsetAndSize()` utility. In particular, how this subroutine computes the linearized size of a memref is incorrect when given a non-packed memref. ### Background As context, in a packed memref of `memref<8x8xf32>`, we'd compute the size by multiplying the size of dimensions together. This is implemented by composing an affine_map of `affine_map<()[s0, s1] -> (s0 * s1)>` and then computing the result of size via `%size = affine.apply #map()[%c8, %c8]`. However, this is wrong for a non-packed memref of `memref<8x8xf32, strided<[1024, 1]>>`. Since the previous computed multiplication map will only consider the dimension sizes, it'd continue to conclude that the size of the non-packed memref to be 64. ### Solution This PR come up with a fix such that the linearized size computation take strides into consideration. It computes the maximum of (dim size * dim stride) for each dimension. We'd compute the size via the affine_map of `affine_map<()[stride0, size0, stride1] -> ((stride0 * size0), 1 * size1)>` and then computing the size via `%size = affine.max #map()[%stride0, %size0, %size1]`. In particular for the new non-packed memref, the size will be derived as max(1024\*8, 1\*8) = 8192 (rather than the wrong size 64 computed by packed memref equation).
1 parent 9692dff commit 53e8ff1

File tree

5 files changed

+51
-82
lines changed

5 files changed

+51
-82
lines changed

mlir/lib/Dialect/AMDGPU/Transforms/TransferReadToLoad.cpp

Lines changed: 4 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -162,60 +162,20 @@ struct TransferReadLowering final : OpRewritePattern<vector::TransferReadOp> {
162162
stridedMetadata.getConstifiedMixedStrides();
163163
SmallVector<OpFoldResult> sizes = stridedMetadata.getConstifiedMixedSizes();
164164
OpFoldResult offset = stridedMetadata.getConstifiedMixedOffset();
165+
memref::LinearizedMemRefInfo linearizedInfo;
165166
OpFoldResult linearizedIndices;
166-
std::tie(std::ignore, linearizedIndices) =
167+
std::tie(linearizedInfo, linearizedIndices) =
167168
memref::getLinearizedMemRefOffsetAndSize(rewriter, loc, elementBitWidth,
168169
elementBitWidth, offset, sizes,
169170
strides, indices);
170171

171-
// TODO(jerryyin): Fix the getLinearizedMemRefOffsetAndSize() function
172-
// Note below doesn't give the correct result for the linearized size.
173-
// Value totalSize = getValueOrCreateConstantIndexOp(
174-
// rewriter, loc, linearizedInfo.linearizedSize);
175-
// It computes the multiplied sizes of all dimensions instead of taking
176-
// the maximum of each dimension size * stride.
177-
SmallVector<AffineExpr> productExpressions;
178-
unsigned sourceRank = cast<ShapedType>(src.getType()).getRank();
179-
180-
SmallVector<AffineExpr> symbols(2 * sourceRank);
181-
SmallVector<Value> offsetValues;
182-
bindSymbolsList(rewriter.getContext(), MutableArrayRef{symbols});
183-
184-
size_t symbolIndex = 0;
185-
for (size_t i = 0; i < sourceRank; ++i) {
186-
AffineExpr strideExpr, sizeExpr;
187-
OpFoldResult stride = strides[i];
188-
OpFoldResult size = sizes[i];
189-
if (auto constantStride = getConstantIntValue(stride)) {
190-
strideExpr = rewriter.getAffineConstantExpr(*constantStride);
191-
} else {
192-
strideExpr = symbols[symbolIndex++];
193-
offsetValues.push_back(
194-
getValueOrCreateConstantIndexOp(rewriter, loc, stride));
195-
}
196-
197-
if (auto constantSize = getConstantIntValue(size)) {
198-
sizeExpr = rewriter.getAffineConstantExpr(*constantSize);
199-
} else {
200-
sizeExpr = symbols[symbolIndex++];
201-
offsetValues.push_back(
202-
getValueOrCreateConstantIndexOp(rewriter, loc, size));
203-
}
204-
205-
productExpressions.push_back(strideExpr * sizeExpr);
206-
}
207-
208-
AffineMap maxMap = AffineMap::get(
209-
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
210-
rewriter.getContext());
211-
Value totalSize =
212-
rewriter.create<affine::AffineMaxOp>(loc, maxMap, offsetValues);
213-
214172
// delta = bufferSize - linearizedOffset
215173
Value vectorSizeOffset =
216174
rewriter.create<arith::ConstantIndexOp>(loc, vectorSize);
217175
Value linearIndex =
218176
getValueOrCreateConstantIndexOp(rewriter, loc, linearizedIndices);
177+
Value totalSize = getValueOrCreateConstantIndexOp(
178+
rewriter, loc, linearizedInfo.linearizedSize);
219179
Value delta = rewriter.create<arith::SubIOp>(loc, totalSize, linearIndex);
220180

221181
// 1) check if delta < vectorSize

mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
6666
SmallVector<AffineExpr> symbols(2 * sourceRank);
6767
bindSymbolsList(builder.getContext(), MutableArrayRef{symbols});
6868
AffineExpr addMulMap = builder.getAffineConstantExpr(0);
69-
AffineExpr mulMap = builder.getAffineConstantExpr(1);
7069

7170
SmallVector<OpFoldResult> offsetValues(2 * sourceRank);
7271

@@ -75,18 +74,28 @@ std::pair<LinearizedMemRefInfo, OpFoldResult> getLinearizedMemRefOffsetAndSize(
7574
addMulMap = addMulMap + symbols[offsetIdx] * symbols[offsetIdx + 1];
7675
offsetValues[offsetIdx] = indicesVec[i];
7776
offsetValues[offsetIdx + 1] = strides[i];
78-
79-
mulMap = mulMap * symbols[i];
8077
}
81-
8278
// Adjust linearizedIndices and size by the scale factor (dstBits / srcBits).
8379
int64_t scaler = dstBits / srcBits;
84-
mulMap = mulMap.floorDiv(scaler);
85-
8680
OpFoldResult linearizedIndices = affine::makeComposedFoldedAffineApply(
8781
builder, loc, addMulMap.floorDiv(scaler), offsetValues);
82+
83+
size_t symbolIndex = 0;
84+
SmallVector<OpFoldResult> values;
85+
SmallVector<AffineExpr> productExpressions;
86+
for (unsigned i = 0; i < sourceRank; ++i) {
87+
AffineExpr strideExpr = symbols[symbolIndex++];
88+
values.push_back(strides[i]);
89+
AffineExpr sizeExpr = symbols[symbolIndex++];
90+
values.push_back(sizes[i]);
91+
92+
productExpressions.push_back((strideExpr * sizeExpr).floorDiv(scaler));
93+
}
94+
AffineMap maxMap = AffineMap::get(
95+
/*dimCount=*/0, /*symbolCount=*/symbolIndex, productExpressions,
96+
builder.getContext());
8897
OpFoldResult linearizedSize =
89-
affine::makeComposedFoldedAffineApply(builder, loc, mulMap, sizes);
98+
affine::makeComposedFoldedAffineMax(builder, loc, maxMap, values);
9099

91100
// Adjust baseOffset by the scale factor (dstBits / srcBits).
92101
AffineExpr s0;

mlir/test/Dialect/AMDGPU/transfer-read-to-load.mlir

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ func.func @transfer_to_maskedload_fatrawbuffer_f16(%mem : memref<8x8xf16, #amdgp
5252

5353
// -----
5454

55-
// CHECK: #map = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
56-
// CHECK: #map1 = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
57-
// CHECK-LABEL: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
55+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1 + s2)>
56+
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> (s0 * s1, s2)>
57+
// CHECK: func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(
5858
// CHECK-SAME: %[[ARG0:.*]]: memref<?x?xi8, #amdgpu.address_space<fat_raw_buffer>>
5959
// CHECK-SAME: %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
6060
// CHECK-SAME: %[[ARG3:.*]]: vector<4xi1>
@@ -64,14 +64,14 @@ func.func @transfer_to_maskedload_fatrawbuffer_dynamic_i8(%mem : memref<?x?xi8,
6464
return %res : vector<4xi8>
6565
}
6666

67-
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
68-
// CHECK: %[[C0:.*]] = arith.constant 0 : index
69-
// CHECK: %[[C4:.*]] = arith.constant 4 : index
70-
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
71-
// CHECK: %[[LINEAR:.*]] = affine.apply #map()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
72-
// CHECK: %[[SIZE:.*]] = affine.max #map1()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
73-
// CHECK: %[[IF:.*]] = scf.if
74-
// CHECK: return
67+
// CHECK: %[[CST:.*]] = arith.constant dense<0> : vector<4xi8>
68+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
69+
// CHECK: %[[C4:.*]] = arith.constant 4 : index
70+
// CHECK: %[[BASE:.*]], %[[OFFSET:.*]], %[[SIZES:.*]]:2, %[[STRIDES:.*]]:2 = memref.extract_strided_metadata %[[ARG0]]
71+
// CHECK-DAG: %[[SIZE:.*]] = affine.max #[[MAP1]]()[%[[STRIDES]]#0, %[[SIZES]]#0, %[[SIZES]]#1]
72+
// CHECK-DAG: %[[LINEAR:.*]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[STRIDES]]#0, %[[ARG2]]]
73+
// CHECK: %[[IF:.*]] = scf.if
74+
// CHECK: return
7575

7676
// -----
7777

mlir/test/Dialect/MemRef/emulate-narrow-type.mlir

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -104,15 +104,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
104104
%1 = memref.load %0[%arg2, %arg3] : memref<?x?xi4>
105105
return %1 : i4
106106
}
107-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
107+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
108108
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
109109
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
110110
// CHECK: func @memref_load_i4_dynamic(
111111
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
112112
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
113113
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
114114
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
115-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
115+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
116116
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
117117
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
118118
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -122,15 +122,15 @@ func.func @memref_load_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %a
122122
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[SHIFTRT]] : i8 to i4
123123
// CHECK: return %[[TRUNC]]
124124

125-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
125+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
126126
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
127127
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
128128
// CHECK32: func @memref_load_i4_dynamic(
129129
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: index
130130
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
131131
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
132132
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
133-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
133+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
134134
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]])
135135
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
136136
// CHECK32: %[[LOAD:.+]] = memref.load %[[ALLOC]][%[[INDEX]]]
@@ -399,7 +399,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
399399
memref.store %arg4, %0[%arg2, %arg3] : memref<?x?xi4>
400400
return
401401
}
402-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
402+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
403403
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
404404
// CHECK-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 2) * 8)>
405405
// CHECK: func @memref_store_i4_dynamic(
@@ -408,7 +408,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
408408
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
409409
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
410410
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
411-
// CHECK-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
411+
// CHECK-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
412412
// CHECK-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
413413
// CHECK-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i8
414414
// CHECK-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
@@ -423,7 +423,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
423423
// CHECK: %[[WRITE_RMW:.+]] = memref.atomic_rmw ori %[[SHIFTED_VAL]], %[[ALLOC]][%[[INDEX]]] : (i8, memref<?xi8>) -> i8
424424
// CHECK: return
425425

426-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
426+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
427427
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
428428
// CHECK32-DAG: #[[MAP2:.+]] = affine_map<()[s0, s1, s2] -> ((s0 * s1) * 4 + s2 * 4 - ((s2 + s0 * s1) floordiv 8) * 32)>
429429
// CHECK32: func @memref_store_i4_dynamic(
@@ -432,7 +432,7 @@ func.func @memref_store_i4_dynamic(%arg0: index, %arg1 : index, %arg2 : index, %
432432
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
433433
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
434434
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: i4
435-
// CHECK32-DAG: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
435+
// CHECK32-DAG: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
436436
// CHECK32-DAG: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
437437
// CHECK32-DAG: %[[EXTUI:.+]] = arith.extui %[[ARG4]] : i4 to i32
438438
// CHECK32-DAG: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]

mlir/test/Dialect/Vector/vector-emulate-narrow-type.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -58,27 +58,27 @@ func.func @vector_load_i4_dynamic(%arg0 : index, %arg1 : index, %arg2 : index, %
5858
%1 = vector.load %0[%arg2, %arg3] : memref<?x?xi4>, vector<8xi4>
5959
return %1 : vector<8xi4>
6060
}
61-
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
61+
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
6262
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
6363
// CHECK: func.func @vector_load_i4_dynamic(
6464
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
6565
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
6666
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
6767
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
68-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
68+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
6969
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
7070
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
7171
// CHECK: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi8>, vector<4xi8>
7272
// CHECK: %[[VEC_I4:.+]] = vector.bitcast %[[VEC]] : vector<4xi8> to vector<8xi4>
7373

74-
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
74+
// CHECK32-DAG: #[[MAP0:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
7575
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
7676
// CHECK32: func.func @vector_load_i4_dynamic(
7777
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: index
7878
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: index
7979
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9_]+]]: index
8080
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9_]+]]: index
81-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP0]]()[%[[ARG0]], %[[ARG1]]]
81+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP0]]()[%[[ARG1]], %[[ARG0]]]
8282
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
8383
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG2]], %[[ARG1]], %[[ARG3]]]
8484
// CHECK32: %[[VEC:.+]] = vector.load %[[ALLOC]][%[[INDEX]]] : memref<?xi32>, vector<1xi32>
@@ -450,29 +450,29 @@ func.func @vector_store_i4_dynamic(%arg0: vector<8xi4>, %arg1: index, %arg2: ind
450450
return
451451
}
452452

453-
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
453+
// CHECK-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2, s0 floordiv 2)>
454454
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 2)>
455455
// CHECK: func @vector_store_i4_dynamic
456456
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
457457
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
458458
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
459459
// CHECK-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
460460
// CHECK-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
461-
// CHECK: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
461+
// CHECK: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
462462
// CHECK: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi8>
463463
// CHECK: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
464464
// CHECK: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<4xi8>
465465
// CHECK: vector.store %[[VEC_I8:.+]], %[[ALLOC:.+]][%[[INDEX:.+]]] : memref<?xi8>, vector<4xi8>
466466

467-
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8)>
467+
// CHECK32-DAG: #[[MAP:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 8, s0 floordiv 8)>
468468
// CHECK32-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1, s2] -> ((s2 + s0 * s1) floordiv 8)>
469469
// CHECK32: func @vector_store_i4_dynamic
470470
// CHECK32-SAME: %[[ARG0:[a-zA-Z0-9]+]]: vector<8xi4>
471471
// CHECK32-SAME: %[[ARG1:[a-zA-Z0-9]+]]: index
472472
// CHECK32-SAME: %[[ARG2:[a-zA-Z0-9]+]]: index
473473
// CHECK32-SAME: %[[ARG3:[a-zA-Z0-9]+]]: index
474474
// CHECK32-SAME: %[[ARG4:[a-zA-Z0-9]+]]: index
475-
// CHECK32: %[[SIZE:.+]] = affine.apply #[[MAP]]()[%[[ARG1]], %[[ARG2]]]
475+
// CHECK32: %[[SIZE:.+]] = affine.max #[[MAP]]()[%[[ARG2]], %[[ARG1]]]
476476
// CHECK32: %[[ALLOC:.+]] = memref.alloc(%[[SIZE]]) : memref<?xi32>
477477
// CHECK32: %[[INDEX:.+]] = affine.apply #[[MAP1]]()[%[[ARG3]], %[[ARG2]], %[[ARG4]]]
478478
// CHECK32: %[[VEC_I8:.+]] = vector.bitcast %[[ARG0]] : vector<8xi4> to vector<1xi32>
@@ -537,7 +537,7 @@ func.func @vector_maskedstore_i4(
537537
// CHECK: #[[$ATTR_10:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
538538
// CHECK: #[[$ATTR_11:.+]] = affine_map<()[s0] -> (s0 ceildiv 2)>
539539

540-
// CHECK-LABEL: func.func @vector_maskedstore_i4(
540+
// CHECK: func.func @vector_maskedstore_i4(
541541
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
542542
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
543543
// CHECK-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -557,7 +557,7 @@ func.func @vector_maskedstore_i4(
557557
// CHECK32: #[[$ATTR_17:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
558558
// CHECK32: #[[$ATTR_18:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
559559

560-
// CHECK32-LABEL: func.func @vector_maskedstore_i4(
560+
// CHECK32: func.func @vector_maskedstore_i4(
561561
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
562562
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
563563
// CHECK32-SAME: %[[NUM_EL_TO_STORE:[a-zA-Z0-9]+]]: index,
@@ -623,7 +623,7 @@ func.func @vector_maskedstore_i4_constant_mask(
623623
}
624624

625625
// CHECK: #[[$ATTR_12:.+]] = affine_map<()[s0, s1] -> (s0 * 4 + s1 floordiv 2)>
626-
// CHECK-LABEL: func.func @vector_maskedstore_i4_constant_mask(
626+
// CHECK: func.func @vector_maskedstore_i4_constant_mask(
627627
// CHECK-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
628628
// CHECK-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
629629
// CHECK-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {
@@ -639,7 +639,7 @@ func.func @vector_maskedstore_i4_constant_mask(
639639
// CHECK: vector.maskedstore %[[ALLOC]]{{\[}}%[[LIDX]]], %[[NEW_MASK]], %[[NEW_VAL]] : memref<12xi8>, vector<4xi1>, vector<4xi8>
640640

641641
// CHECK32: #[[$ATTR_20:.+]] = affine_map<()[s0, s1] -> (s0 + s1 floordiv 8)>
642-
// CHECK32-LABEL: func.func @vector_maskedstore_i4_constant_mask(
642+
// CHECK32: func.func @vector_maskedstore_i4_constant_mask(
643643
// CHECK32-SAME: %[[IDX_1:[a-zA-Z0-9]+]]: index,
644644
// CHECK32-SAME: %[[IDX_2:[a-zA-Z0-9]+]]: index,
645645
// CHECK32-SAME: %[[VAL_TO_STORE:[a-zA-Z0-9]+]]: vector<8xi4>) {

0 commit comments

Comments
 (0)