Skip to content

Commit 5f5b3bb

Browse files
authored
[mlir][ArmSME] Add rewrites to swap extract of extend (#80407)
In mixed matmul lowering (e.g., i8 to i32) we're seeing the following sequence: %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32> %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32> %lhs = vector.scalable.extract %1[0] : vector<[4]xi32> from vector<[8]xi32> ... (same for rhs) %2 = vector.outerproduct %lhs, %rhs, %acc vector<[4]xi32>, vector<[4]xi32> // x4 chained by accumulator This chain of 4 outer products can be fused into a single 4-way widening variant but the pass doesn't match on the IR, as it expects the source of the inputs to be an extend and it can't look through the extracts. This patch fixes this with two rewrites that swaps extract(extend) into extend(extract). Related to #78975, #79288.
1 parent 8e00fc3 commit 5f5b3bb

File tree

2 files changed

+197
-1
lines changed

2 files changed

+197
-1
lines changed

mlir/lib/Dialect/ArmSME/Transforms/OuterProductFusion.cpp

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,104 @@ class OuterProductFusion2Way
261261
}
262262
};
263263

264+
// Rewrites: vector.extract(arith.extend) -> arith.extend(vector.extract).
265+
//
266+
// This transforms IR like:
267+
// %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
268+
// %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
269+
// Into:
270+
// %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
271+
// %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
272+
//
273+
// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
274+
// pass when the result is the input to an outer product.
275+
struct SwapVectorExtractOfArithExtend
276+
: public OpRewritePattern<vector::ExtractOp> {
277+
using OpRewritePattern::OpRewritePattern;
278+
279+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
280+
PatternRewriter &rewriter) const override {
281+
VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType());
282+
if (!resultType)
283+
return rewriter.notifyMatchFailure(extractOp,
284+
"extracted type is not a vector type");
285+
286+
auto numScalableDims = llvm::count(resultType.getScalableDims(), true);
287+
if (numScalableDims != 1)
288+
return rewriter.notifyMatchFailure(
289+
extractOp, "extracted type is not a 1-D scalable vector type");
290+
291+
auto *extendOp = extractOp.getVector().getDefiningOp();
292+
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
293+
extendOp))
294+
return rewriter.notifyMatchFailure(extractOp,
295+
"extract not from extend op");
296+
297+
auto loc = extractOp.getLoc();
298+
StringAttr extendOpName = extendOp->getName().getIdentifier();
299+
Value extendSource = extendOp->getOperand(0);
300+
301+
// Create new extract from source of extend.
302+
Value newExtract = rewriter.create<vector::ExtractOp>(
303+
loc, extendSource, extractOp.getMixedPosition());
304+
305+
// Extend new extract to original result type.
306+
Operation *newExtend =
307+
rewriter.create(loc, extendOpName, Value(newExtract), resultType);
308+
309+
rewriter.replaceOp(extractOp, newExtend);
310+
311+
return success();
312+
}
313+
};
314+
315+
// Same as above, but for vector.scalable.extract.
316+
//
317+
// This transforms IR like:
318+
// %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
319+
// %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
320+
// Into:
321+
// %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
322+
// %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
323+
//
324+
// This enables outer product fusion in the `-arm-sme-outer-product-fusion`
325+
// pass when the result is the input to an outer product.
326+
struct SwapVectorScalableExtractOfArithExtend
327+
: public OpRewritePattern<vector::ScalableExtractOp> {
328+
using OpRewritePattern::OpRewritePattern;
329+
330+
LogicalResult matchAndRewrite(vector::ScalableExtractOp extractOp,
331+
PatternRewriter &rewriter) const override {
332+
auto *extendOp = extractOp.getSource().getDefiningOp();
333+
if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
334+
extendOp))
335+
return rewriter.notifyMatchFailure(extractOp,
336+
"extract not from extend op");
337+
338+
auto loc = extractOp.getLoc();
339+
VectorType resultType = extractOp.getResultVectorType();
340+
341+
Value extendSource = extendOp->getOperand(0);
342+
StringAttr extendOpName = extendOp->getName().getIdentifier();
343+
VectorType extendSourceVectorType =
344+
cast<VectorType>(extendSource.getType());
345+
346+
// Create new extract from source of extend.
347+
VectorType extractResultVectorType =
348+
resultType.clone(extendSourceVectorType.getElementType());
349+
Value newExtract = rewriter.create<vector::ScalableExtractOp>(
350+
loc, extractResultVectorType, extendSource, extractOp.getPos());
351+
352+
// Extend new extract to original result type.
353+
Operation *newExtend =
354+
rewriter.create(loc, extendOpName, Value(newExtract), resultType);
355+
356+
rewriter.replaceOp(extractOp, newExtend);
357+
358+
return success();
359+
}
360+
};
361+
264362
struct OuterProductFusionPass
265363
: public arm_sme::impl::OuterProductFusionBase<OuterProductFusionPass> {
266364

@@ -278,7 +376,11 @@ struct OuterProductFusionPass
278376

279377
void mlir::arm_sme::populateOuterProductFusionPatterns(
280378
RewritePatternSet &patterns) {
281-
patterns.add<OuterProductFusion2Way>(patterns.getContext());
379+
MLIRContext *context = patterns.getContext();
380+
// Note: High benefit to ensure extract(extend) are swapped first.
381+
patterns.add<SwapVectorExtractOfArithExtend,
382+
SwapVectorScalableExtractOfArithExtend>(context, 1024);
383+
patterns.add<OuterProductFusion2Way>(context);
282384
}
283385

284386
std::unique_ptr<Pass> mlir::arm_sme::createOuterProductFusionPass() {

mlir/test/Dialect/ArmSME/outer-product-fusion.mlir

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,48 @@ func.func @outerproduct_sub_widening_2way_unsigned_i16i16i32(
213213
return %1 : vector<[4]x[4]xi32>
214214
}
215215

216+
/// Tests for related patterns.
217+
218+
// -----
219+
220+
// CHECK-LABEL: @extract_from_arith_ext(
221+
// CHECK-SAME: %[[SRC:.*]]: vector<4x[8]xi8>
222+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][0] : vector<[8]xi8> from vector<4x[8]xi8>
223+
// CHECK: %[[EXTEND:.*]] = arith.extsi %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
224+
// CHECK: return %[[EXTEND]]
225+
func.func @extract_from_arith_ext(%src: vector<4x[8]xi8>) -> vector<[8]xi32> {
226+
%0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
227+
%1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
228+
return %1 : vector<[8]xi32>
229+
}
230+
231+
// -----
232+
233+
// CHECK-LABEL: @non_constant_extract_from_arith_ext(
234+
// CHECK-SAME: %[[SRC:[a-z0-9]+]]: vector<4x[8]xi8>,
235+
// CHECK-SAME: %[[DIM:[a-z0-9]+]]: index
236+
// CHECK: %[[EXTRACT:.*]] = vector.extract %[[SRC]][%[[DIM]]] : vector<[8]xi8> from vector<4x[8]xi8>
237+
// CHECK: %[[EXTEND:.*]] = arith.extui %[[EXTRACT]] : vector<[8]xi8> to vector<[8]xi32>
238+
// CHECK: return %[[EXTEND]]
239+
func.func @non_constant_extract_from_arith_ext(%src: vector<4x[8]xi8>, %dim: index) -> vector<[8]xi32> {
240+
%0 = arith.extui %src : vector<4x[8]xi8> to vector<4x[8]xi32>
241+
%1 = vector.extract %0[%dim] : vector<[8]xi32> from vector<4x[8]xi32>
242+
return %1 : vector<[8]xi32>
243+
}
244+
245+
// -----
246+
247+
// CHECK-LABEL: @scalable_extract_from_arith_ext(
248+
// CHECK-SAME: %[[SRC:.*]]: vector<[8]xf16>
249+
// CHECK: %[[EXTRACT:.*]] = vector.scalable.extract %[[SRC]][0] : vector<[4]xf16> from vector<[8]xf16>
250+
// CHECK: %[[EXTEND:.*]] = arith.extf %[[EXTRACT]] : vector<[4]xf16> to vector<[4]xf32>
251+
// CHECK: return %[[EXTEND]]
252+
func.func @scalable_extract_from_arith_ext(%src: vector<[8]xf16>) -> vector<[4]xf32> {
253+
%0 = arith.extf %src : vector<[8]xf16> to vector<[8]xf32>
254+
%1 = vector.scalable.extract %0[0] : vector<[4]xf32> from vector<[8]xf32>
255+
return %1 : vector<[4]xf32>
256+
}
257+
216258
/// Negative tests
217259

218260
// -----
@@ -362,3 +404,55 @@ func.func @outerproduct_widening_2way__bad_defining_op(
362404

363405
return %1 : vector<[4]x[4]xf32>
364406
}
407+
408+
/// Negative tests for related patterns.
409+
410+
// -----
411+
412+
/// Non-vector extracts should be ignored.
413+
414+
// CHECK-LABEL: @extract_scalar_from_arith_ext
415+
// CHECK-NEXT: arith.extsi
416+
// CHECK-NEXT: vector.extract
417+
func.func @extract_scalar_from_arith_ext(%src: vector<4x[8]xi8>) -> i32 {
418+
%0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
419+
%1 = vector.extract %0[0, 0] : i32 from vector<4x[8]xi32>
420+
return %1 : i32
421+
}
422+
423+
// -----
424+
425+
/// Extracted type should be a 1-D scalable vector type.
426+
427+
// CHECK-LABEL: @extract_fixed_1d_vec_from_arith_ext
428+
// CHECK-NEXT: arith.extsi
429+
// CHECK-NEXT: vector.extract
430+
func.func @extract_fixed_1d_vec_from_arith_ext(%src: vector<4x8xi8>) -> vector<8xi32> {
431+
%0 = arith.extsi %src : vector<4x8xi8> to vector<4x8xi32>
432+
%1 = vector.extract %0[0] : vector<8xi32> from vector<4x8xi32>
433+
return %1 : vector<8xi32>
434+
}
435+
436+
// -----
437+
438+
/// Extract must come from an arith extend.
439+
440+
// CHECK-LABEL: @extract_from_non_arith_ext
441+
// CHECK-NEXT: vector.extract
442+
// CHECK-NEXT: return
443+
func.func @extract_from_non_arith_ext(%src: vector<4x[8]xi32>) -> vector<[8]xi32> {
444+
%0 = vector.extract %src[0] : vector<[8]xi32> from vector<4x[8]xi32>
445+
return %0 : vector<[8]xi32>
446+
}
447+
448+
// -----
449+
450+
/// Scalable extract must come from an arith extend.
451+
452+
// CHECK-LABEL: @scalable_extract_from_non_arith_ext
453+
// CHECK-NEXT: vector.scalable.extract
454+
// CHECK-NEXT: return
455+
func.func @scalable_extract_from_non_arith_ext(%src: vector<[8]xf32>) -> vector<[4]xf32> {
456+
%0 = vector.scalable.extract %src[0] : vector<[4]xf32> from vector<[8]xf32>
457+
return %0 : vector<[4]xf32>
458+
}

0 commit comments

Comments
 (0)