@@ -22,27 +22,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
22
22
SmallVector<int64_t > innerAxis = layoutCache.getInnerAxis ();
23
23
SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes ();
24
24
ss << " [" ;
25
- for (size_t i = 0 ; i < outerAxis.size (); ++i) {
26
- if (i != 0 ) {
27
- ss << " , " ;
28
- }
29
- ss << outerAxis[i];
30
- }
31
- for (size_t i = 0 ; i < innerAxis.size (); ++i) {
32
- ss << (i == 0 ? " ; " : " , " );
33
- ss << innerAxis[i];
25
+ llvm::interleaveComma (outerAxis, ss);
26
+ if (!innerAxis.empty ()) {
27
+ ss << " ; " ;
28
+ llvm::interleaveComma (innerAxis, ss);
34
29
}
35
30
ss << " ]" ;
36
31
if (!tileSizes.empty ()) {
37
32
ss << " ; {" ;
38
- for (size_t i = 0 ; i < tileSizes.size (); ++i) {
39
- if (i != 0 ) {
40
- ss << " , " ;
41
- }
42
- if (getConstantIntValue (tileSizes[i]).has_value ()) {
43
- ss << *getConstantIntValue (tileSizes[i]);
44
- }
45
- }
33
+ llvm::interleaveComma (tileSizes, ss);
46
34
ss << " }" ;
47
35
}
48
36
return ss;
@@ -58,11 +46,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
58
46
const OperatorLayout &opLayout) {
59
47
for (auto &&[idx, layoutCache] :
60
48
llvm::enumerate (opLayout.getSupportedInputLayouts ())) {
61
- ss << " input " << idx << " 's layoutCache : " << layoutCache << " \n " ;
49
+ ss << " input " << idx << " 's layout : " << layoutCache << " \n " ;
62
50
}
63
51
for (auto &&[idx, layoutCache] :
64
52
llvm::enumerate (opLayout.getSupportedOutputLayouts ())) {
65
- ss << " output " << idx << " 's layoutCache : " << layoutCache << " \n " ;
53
+ ss << " output " << idx << " 's layout : " << layoutCache << " \n " ;
66
54
}
67
55
return ss;
68
56
}
@@ -156,15 +144,15 @@ inferTargetLayout(TensorLayout layoutBase,
156
144
}
157
145
158
146
static size_t getTargetInputIdx (ArrayRef<TensorLayout> curInputLayouts) {
159
- for (auto i = 0 ; i < curInputLayouts.size (); ++i) {
147
+ for (size_t i = 0 ; i < curInputLayouts.size (); ++i) {
160
148
if (!curInputLayouts[i].isPlainLayout ()) {
161
149
return i;
162
150
}
163
151
}
164
152
return 0 ;
165
153
}
166
154
167
- static bool supportedContractionOpList (linalg::LinalgOp &linalgOp) {
155
+ static bool supportedContractionNamedOpList (linalg::LinalgOp &linalgOp) {
168
156
if (isa<linalg::MatmulOp, linalg::MatmulTransposeAOp,
169
157
linalg::MatmulTransposeBOp, linalg::BatchMatmulOp,
170
158
linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp>(
@@ -211,7 +199,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
211
199
// ------ Get Current Op's Suggested Layout & Do Propagation ------
212
200
IRRewriter rewriter (linalgOp);
213
201
// TODO: extend to packed/vnni matmul ops
214
- if (supportedContractionOpList (linalgOp)) {
202
+ if (supportedContractionNamedOpList (linalgOp)) {
215
203
// get input and output rank
216
204
auto ARank = cast<ShapedType>(linalgOp.getDpsInputs ()[0 ].getType ())
217
205
.getShape ()
@@ -253,7 +241,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
253
241
rewriter.getIndexAttr (iin)});
254
242
OperatorLayout suggestedLayout ({ALayout, BLayout}, {CLayout});
255
243
layoutCache[linalgOp] = suggestedLayout;
256
- } else if (!mlir::linalg::isaContractionOpInterface (linalgOp)) {
244
+ } else if (!mlir::linalg::isaContractionOpInterface (linalgOp) &&
245
+ !supportedContractionNamedOpList (linalgOp)) {
257
246
SmallVector<TensorLayout> inputLayouts, outputLayouts;
258
247
size_t targetIdx = getTargetInputIdx (curInputLayouts);
259
248
// TODO(yifei): wisely choose the input format basis
@@ -345,11 +334,12 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
345
334
346
335
namespace utils {
347
336
bool isPackableNamedOp (Operation *op) {
348
- if ((isa<linalg::LinalgOp>(op) &&
349
- !mlir::linalg::isaContractionOpInterface (
350
- dyn_cast<linalg::LinalgOp>(op)) &&
351
- !isa<linalgx::Mm4DVnniOp>(op)) ||
352
- isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op))
337
+ if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
338
+ if (!supportedContractionNamedOpList (linalgOp)) {
339
+ return true ;
340
+ }
341
+ } else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::PadOp>(
342
+ op))
353
343
return true ;
354
344
return false ;
355
345
}
0 commit comments