Skip to content

Commit 9f69638

Browse files
author
Ferdinand Lemaire
committed
Refactor unfusing and remove unrelated changes
1 parent df8d212 commit 9f69638

File tree

2 files changed

+26
-49
lines changed

2 files changed

+26
-49
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4096,39 +4096,38 @@ structured_op: !LinalgStructuredOpConfig
40964096
name: I
40974097
kind: input_tensor
40984098
type_var: T1
4099-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2
4100-
* s3 + s4 * s5, s6 * s7 + s8 * s9)>
4099+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 *
4100+
s2 + s3 * s4, s5 * s6 + s7 * s8)>
41014101
- !LinalgOperandDefConfig
41024102
name: K
41034103
kind: input_tensor
41044104
type_var: T2
4105-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s1, s4, s8)>
4105+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s9, s3, s7)>
41064106
- !LinalgOperandDefConfig
41074107
name: O
41084108
kind: output_tensor
41094109
type_var: U
4110-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2,
4111-
s6)>
4110+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1, s5)>
41124111
- !LinalgOperandDefConfig
41134112
name: strides
41144113
kind: index_attr
4115-
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3,
4116-
s7)>
4114+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
4115+
s6)>
41174116
default_indices:
41184117
- 1
41194118
- 1
41204119
- !LinalgOperandDefConfig
41214120
name: dilations
41224121
kind: index_attr
4123-
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5,
4124-
s9)>
4122+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
4123+
s8)>
41254124
default_indices:
41264125
- 1
41274126
- 1
41284127
indexing_maps: !LinalgIndexingMapsConfig
41294128
static_indexing_maps:
41304129
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
4131-
-> (d0, d3, d1 * s3 + d4 * s5, d2 * s7 + d5 * s9)>
4130+
-> (d0, d3, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8)>
41324131
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
41334132
-> (d3, d4, d5)>
41344133
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]

mlir/lib/Dialect/Linalg/Transforms/Unfuse.cpp

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -651,12 +651,9 @@ struct GlobalAveragePool2DLowering : OpRewritePattern<GlobalAveragePool2DOp> {
651651
}
652652
};
653653

654-
/// Torch MLIR does a similar lowering for their Linear operator to lin alg
655-
/// here we implement the same so we can run tests using the unfused version
656-
struct LinearLowering : OpRewritePattern<LinearOp> {
657-
using OpRewritePattern<LinearOp>::OpRewritePattern;
658-
LogicalResult matchAndRewrite(LinearOp op,
659-
PatternRewriter &rewriter) const override {
654+
template <class Linear>
655+
static Value unfuseLinear(Linear &op, PatternRewriter &rewriter) {
656+
660657
Location loc = op.getLoc();
661658
Value input = op.getOperand(0);
662659
Value weights = op.getOperand(1);
@@ -691,48 +688,29 @@ struct LinearLowering : OpRewritePattern<LinearOp> {
691688
->getResult(0);
692689

693690
// Create the matmul operation that does the multiplcation and addition
694-
rewriter.replaceOpWithNewOp<MatmulOp>(op, output.getType(),
695-
ValueRange{input, transposeWeightsOp},
696-
broadcastBiasOp);
697-
691+
auto newOp = rewriter.create<MatmulOp>(loc, outputType, ValueRange{op.getOperand(0), transposeWeightsOp},
692+
broadcastBiasOp).getResult(0);
693+
return newOp;
694+
}
695+
/// Torch MLIR does a similar lowering for their Linear operator to lin alg
696+
/// here we implement the same so we can run tests using the unfused version
697+
struct LinearLowering : OpRewritePattern<LinearOp> {
698+
using OpRewritePattern<LinearOp>::OpRewritePattern;
699+
LogicalResult matchAndRewrite(LinearOp op,
700+
PatternRewriter &rewriter) const override {
701+
Value matmul = unfuseLinear<LinearOp>(op, rewriter);
702+
rewriter.replaceOp(op, matmul);
698703
return success();
699704
}
700705
};
701706

707+
702708
struct LinearReluLowering : OpRewritePattern<LinearReluOp> {
703709
using OpRewritePattern<LinearReluOp>::OpRewritePattern;
704710
LogicalResult matchAndRewrite(LinearReluOp op,
705711
PatternRewriter &rewriter) const override {
706-
Location loc = op.getLoc();
707-
Value weights = op.getOperand(1);
708-
Value bias = op.getOperand(2);
709-
710-
auto weightsType = weights.getType().cast<RankedTensorType>();
711-
auto biasType = bias.getType().cast<RankedTensorType>();
712-
auto outputType = op->getResult(0).getType().cast<RankedTensorType>();
713-
714-
// Create a linalg op that transposes the weights tensor
715-
// The transposedWeights is simply used to describe the output shape.
716-
llvm::ArrayRef<int64_t> weightsShape = weightsType.getShape();
717-
Value transposedWeights = rewriter.create<tensor::EmptyOp>(
718-
loc,
719-
ArrayRef<int64_t>{weightsShape[1], weightsShape[0]},
720-
weightsType.getElementType());
721-
Value transposeWeightsOp =
722-
rewriter.create<Transpose2DOp>(loc, weights, transposedWeights)
723-
->getResult(0);
724-
725-
// Create a linalg op that broadcasts the 1D bias values across
726-
// the 2nd dimension
727-
Value broadcastedBias = rewriter.create<tensor::EmptyOp>(
728-
loc, outputType.getShape(), biasType.getElementType());
729-
Value broadcastBiasOp =
730-
rewriter.create<Broadcast1DTo2DOp>(loc, bias, broadcastedBias)
731-
->getResult(0);
732712

733-
auto linearResult = rewriter.create<MatmulOp>(loc,
734-
outputType, ValueRange{op.getOperand(0), transposeWeightsOp},
735-
broadcastBiasOp).getResult(0);
713+
Value linearResult = unfuseLinear<LinearReluOp>(op, rewriter);
736714

737715
rewriter.replaceOpWithNewOp<Relu2DNchwOp>(
738716
op,

0 commit comments

Comments
 (0)