Skip to content

Commit 3ff8708

Browse files
committed
[mlir][arith] Add narrowing patterns for other insertion ops
Allow to commute extension ops over `vector.insertelement` and `vector.insert_strided_slice`. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149509
1 parent 0f1a8b4 commit 3ff8708

File tree

2 files changed

+180
-15
lines changed

2 files changed

+180
-15
lines changed

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -306,27 +306,35 @@ struct ExtensionOverExtractStridedSlice final
306306
}
307307
};
308308

309-
struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
310-
using NarrowingPattern::NarrowingPattern;
311-
312-
LogicalResult matchAndRewrite(vector::InsertOp op,
313-
PatternRewriter &rewriter) const override {
309+
/// Base pattern for `vector.insert` narrowing patterns.
310+
template <typename InsertionOp>
311+
struct ExtensionOverInsertionPattern : NarrowingPattern<InsertionOp> {
312+
using NarrowingPattern<InsertionOp>::NarrowingPattern;
313+
314+
/// Derived classes must provide a function to create the matching insertion
315+
/// op based on the original op and new arguments.
316+
virtual InsertionOp createInsertionOp(PatternRewriter &rewriter,
317+
InsertionOp origInsert,
318+
Value narrowValue,
319+
Value narrowDest) const = 0;
320+
321+
LogicalResult matchAndRewrite(InsertionOp op,
322+
PatternRewriter &rewriter) const final {
314323
FailureOr<ExtensionOp> ext =
315324
ExtensionOp::from(op.getSource().getDefiningOp());
316325
if (failed(ext))
317326
return failure();
318327

319-
FailureOr<vector::InsertOp> newInsert =
320-
createNarrowInsert(op, rewriter, *ext);
328+
FailureOr<InsertionOp> newInsert = createNarrowInsert(op, rewriter, *ext);
321329
if (failed(newInsert))
322330
return failure();
323331
ext->recreateAndReplace(rewriter, op, *newInsert);
324332
return success();
325333
}
326334

327-
FailureOr<vector::InsertOp> createNarrowInsert(vector::InsertOp op,
328-
PatternRewriter &rewriter,
329-
ExtensionOp insValue) const {
335+
FailureOr<InsertionOp> createNarrowInsert(InsertionOp op,
336+
PatternRewriter &rewriter,
337+
ExtensionOp insValue) const {
330338
// Calculate the operand and result bitwidths. We can only apply narrowing
331339
// when the inserted source value and destination vector require fewer bits
332340
// than the result. Because the source and destination may have different
@@ -337,6 +345,8 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
337345
if (failed(origBitsRequired))
338346
return failure();
339347

348+
// TODO: We could relax this check by disregarding bitwidth requirements of
349+
// elements that we know will be replaced by the insertion.
340350
FailureOr<unsigned> destBitsRequired =
341351
calculateBitsRequired(op.getDest(), insValue.getKind());
342352
if (failed(destBitsRequired) || *destBitsRequired >= *origBitsRequired)
@@ -352,12 +362,13 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
352362
// both the source and the destination values.
353363
unsigned newInsertionBits =
354364
std::max(*destBitsRequired, *insertedBitsRequired);
355-
FailureOr<Type> newVecTy = getNarrowType(newInsertionBits, op.getType());
365+
FailureOr<Type> newVecTy =
366+
this->getNarrowType(newInsertionBits, op.getType());
356367
if (failed(newVecTy) || *newVecTy == op.getType())
357368
return failure();
358369

359370
FailureOr<Type> newInsertedValueTy =
360-
getNarrowType(newInsertionBits, insValue.getType());
371+
this->getNarrowType(newInsertionBits, insValue.getType());
361372
if (failed(newInsertedValueTy))
362373
return failure();
363374

@@ -366,8 +377,47 @@ struct ExtensionOverInsert final : NarrowingPattern<vector::InsertOp> {
366377
loc, *newInsertedValueTy, insValue.getResult());
367378
Value narrowDest =
368379
rewriter.createOrFold<arith::TruncIOp>(loc, *newVecTy, op.getDest());
369-
return rewriter.create<vector::InsertOp>(loc, narrowValue, narrowDest,
370-
op.getPosition());
380+
return createInsertionOp(rewriter, op, narrowValue, narrowDest);
381+
}
382+
};
383+
384+
struct ExtensionOverInsert final
385+
: ExtensionOverInsertionPattern<vector::InsertOp> {
386+
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
387+
388+
vector::InsertOp createInsertionOp(PatternRewriter &rewriter,
389+
vector::InsertOp origInsert,
390+
Value narrowValue,
391+
Value narrowDest) const override {
392+
return rewriter.create<vector::InsertOp>(
393+
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
394+
}
395+
};
396+
397+
struct ExtensionOverInsertElement final
398+
: ExtensionOverInsertionPattern<vector::InsertElementOp> {
399+
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
400+
401+
vector::InsertElementOp createInsertionOp(PatternRewriter &rewriter,
402+
vector::InsertElementOp origInsert,
403+
Value narrowValue,
404+
Value narrowDest) const override {
405+
return rewriter.create<vector::InsertElementOp>(
406+
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getPosition());
407+
}
408+
};
409+
410+
struct ExtensionOverInsertStridedSlice final
411+
: ExtensionOverInsertionPattern<vector::InsertStridedSliceOp> {
412+
using ExtensionOverInsertionPattern::ExtensionOverInsertionPattern;
413+
414+
vector::InsertStridedSliceOp
415+
createInsertionOp(PatternRewriter &rewriter,
416+
vector::InsertStridedSliceOp origInsert, Value narrowValue,
417+
Value narrowDest) const override {
418+
return rewriter.create<vector::InsertStridedSliceOp>(
419+
origInsert.getLoc(), narrowValue, narrowDest, origInsert.getOffsets(),
420+
origInsert.getStrides());
371421
}
372422
};
373423

@@ -400,7 +450,8 @@ void populateArithIntNarrowingPatterns(
400450
// Add commute patterns with a higher benefit. This is to expose more
401451
// optimization opportunities to narrowing patterns.
402452
patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
403-
ExtensionOverExtractStridedSlice, ExtensionOverInsert>(
453+
ExtensionOverExtractStridedSlice, ExtensionOverInsert,
454+
ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
404455
patterns.getContext(), options, PatternBenefit(2));
405456

406457
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);

mlir/test/Dialect/Arith/int-narrowing.mlir

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,3 +328,117 @@ func.func @extui_over_insert_3xi16_cst_i16(%a: i8) -> vector<3xi32> {
328328
%e = vector.insert %d, %cst [1] : i32 into vector<3xi32>
329329
return %e : vector<3xi32>
330330
}
331+
332+
// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16
333+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
334+
// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
335+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
336+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
337+
func.func @extsi_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
338+
%c = arith.extsi %a : vector<3xi16> to vector<3xi32>
339+
%d = arith.extsi %b : i16 to i32
340+
%e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
341+
return %e : vector<3xi32>
342+
}
343+
344+
// CHECK-LABEL: func.func @extui_over_insertelement_3xi16
345+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: i16, %[[POS:.+]]: i32)
346+
// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[ARG1]], %[[ARG0]][%[[POS]] : i32] : vector<3xi16>
347+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
348+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
349+
func.func @extui_over_insertelement_3xi16(%a: vector<3xi16>, %b: i16, %pos: i32) -> vector<3xi32> {
350+
%c = arith.extui %a : vector<3xi16> to vector<3xi32>
351+
%d = arith.extui %b : i16 to i32
352+
%e = vector.insertelement %d, %c[%pos : i32] : vector<3xi32>
353+
return %e : vector<3xi32>
354+
}
355+
356+
// CHECK-LABEL: func.func @extsi_over_insertelement_3xi16_cst_i16
357+
// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
358+
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[-1, 128, 0]> : vector<3xi16>
359+
// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : i8 to i32
360+
// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
361+
// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
362+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
363+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
364+
func.func @extsi_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
365+
%cst = arith.constant dense<[-1, 128, 0]> : vector<3xi32>
366+
%d = arith.extsi %a : i8 to i32
367+
%e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
368+
return %e : vector<3xi32>
369+
}
370+
371+
// CHECK-LABEL: func.func @extui_over_insertelement_3xi16_cst_i16
372+
// CHECK-SAME: (%[[ARG:.+]]: i8, %[[POS:.+]]: i32)
373+
// CHECK-NEXT: %[[CST:.+]] = arith.constant dense<[1, 256, 0]> : vector<3xi16>
374+
// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : i8 to i32
375+
// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : i32 to i16
376+
// CHECK-NEXT: %[[INS:.+]] = vector.insertelement %[[SRCT]], %[[CST]][%[[POS]] : i32] : vector<3xi16>
377+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
378+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
379+
func.func @extui_over_insertelement_3xi16_cst_i16(%a: i8, %pos: i32) -> vector<3xi32> {
380+
%cst = arith.constant dense<[1, 256, 0]> : vector<3xi32>
381+
%d = arith.extui %a : i8 to i32
382+
%e = vector.insertelement %d, %cst[%pos : i32] : vector<3xi32>
383+
return %e : vector<3xi32>
384+
}
385+
386+
// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_1d
387+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
388+
// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
389+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
390+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<3xi16> to vector<3xi32>
391+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
392+
func.func @extsi_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
393+
%c = arith.extsi %a : vector<3xi16> to vector<3xi32>
394+
%d = arith.extsi %b : vector<2xi16> to vector<2xi32>
395+
%e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
396+
return %e : vector<3xi32>
397+
}
398+
399+
// CHECK-LABEL: func.func @extui_over_insert_strided_slice_1d
400+
// CHECK-SAME: (%[[ARG0:.+]]: vector<3xi16>, %[[ARG1:.+]]: vector<2xi16>)
401+
// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[ARG1]], %[[ARG0]]
402+
// CHECK-SAME: {offsets = [1], strides = [1]} : vector<2xi16> into vector<3xi16>
403+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<3xi16> to vector<3xi32>
404+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
405+
func.func @extui_over_insert_strided_slice_1d(%a: vector<3xi16>, %b: vector<2xi16>) -> vector<3xi32> {
406+
%c = arith.extui %a : vector<3xi16> to vector<3xi32>
407+
%d = arith.extui %b : vector<2xi16> to vector<2xi32>
408+
%e = vector.insert_strided_slice %d, %c {offsets = [1], strides = [1]} : vector<2xi32> into vector<3xi32>
409+
return %e : vector<3xi32>
410+
}
411+
412+
// CHECK-LABEL: func.func @extsi_over_insert_strided_slice_cst_2d
413+
// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>)
414+
// CHECK-NEXT: %[[CST:.+]] = arith.constant
415+
// CHECK-SAME{LITERAL}: dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi16>
416+
// CHECK-NEXT: %[[SRCE:.+]] = arith.extsi %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
417+
// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
418+
// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
419+
// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
420+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
421+
// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
422+
func.func @extsi_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
423+
%cst = arith.constant dense<[[-1, 128, 0], [-129, 42, 1337]]> : vector<2x3xi32>
424+
%d = arith.extsi %a : vector<1x2xi8> to vector<1x2xi32>
425+
%e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
426+
return %e : vector<2x3xi32>
427+
}
428+
429+
// CHECK-LABEL: func.func @extui_over_insert_strided_slice_cst_2d
430+
// CHECK-SAME: (%[[ARG:.+]]: vector<1x2xi8>)
431+
// CHECK-NEXT: %[[CST:.+]] = arith.constant
432+
// CHECK-SAME{LITERAL}: dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi16>
433+
// CHECK-NEXT: %[[SRCE:.+]] = arith.extui %[[ARG]] : vector<1x2xi8> to vector<1x2xi32>
434+
// CHECK-NEXT: %[[SRCT:.+]] = arith.trunci %[[SRCE]] : vector<1x2xi32> to vector<1x2xi16>
435+
// CHECK-NEXT: %[[INS:.+]] = vector.insert_strided_slice %[[SRCT]], %[[CST]]
436+
// CHECK-SAME: {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi16> into vector<2x3xi16>
437+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[INS]] : vector<2x3xi16> to vector<2x3xi32>
438+
// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
439+
func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<2x3xi32> {
440+
%cst = arith.constant dense<[[1, 128, 0], [256, 42, 1337]]> : vector<2x3xi32>
441+
%d = arith.extui %a : vector<1x2xi8> to vector<1x2xi32>
442+
%e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
443+
return %e : vector<2x3xi32>
444+
}

0 commit comments

Comments
 (0)