Skip to content

Commit df8d212

Browse files
author
Ferdinand Lemaire
committed
Add linearRelu op to linalg structured ops and unfuse pass
1 parent 9eafef6 commit df8d212

File tree

3 files changed

+147
-10
lines changed

3 files changed

+147
-10
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 82 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,7 +2478,7 @@ metadata: !LinalgOpMetadata
24782478
The partial multiplication results are reduced into a 2D output.
24792479
24802480
Numeric casting is performed on the operands to the inner multiply, promoting
2481-
them to the same data type as the accumulator/output."
2481+
them to the same data type as the accumulator/output.
24822482
implements:
24832483
- LinalgContractionOpInterface
24842484
structured_op: !LinalgStructuredOpConfig
@@ -4096,38 +4096,39 @@ structured_op: !LinalgStructuredOpConfig
40964096
name: I
40974097
kind: input_tensor
40984098
type_var: T1
4099-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 *
4100-
s2 + s3 * s4, s5 * s6 + s7 * s8)>
4099+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2
4100+
* s3 + s4 * s5, s6 * s7 + s8 * s9)>
41014101
- !LinalgOperandDefConfig
41024102
name: K
41034103
kind: input_tensor
41044104
type_var: T2
4105-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s9, s3, s7)>
4105+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s1, s4, s8)>
41064106
- !LinalgOperandDefConfig
41074107
name: O
41084108
kind: output_tensor
41094109
type_var: U
4110-
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1, s5)>
4110+
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2,
4111+
s6)>
41114112
- !LinalgOperandDefConfig
41124113
name: strides
41134114
kind: index_attr
4114-
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2,
4115-
s6)>
4115+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3,
4116+
s7)>
41164117
default_indices:
41174118
- 1
41184119
- 1
41194120
- !LinalgOperandDefConfig
41204121
name: dilations
41214122
kind: index_attr
4122-
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4,
4123-
s8)>
4123+
index_attr_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5,
4124+
s9)>
41244125
default_indices:
41254126
- 1
41264127
- 1
41274128
indexing_maps: !LinalgIndexingMapsConfig
41284129
static_indexing_maps:
41294130
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
4130-
-> (d0, d3, d1 * s2 + d4 * s4, d2 * s6 + d5 * s8)>
4131+
-> (d0, d3, d1 * s3 + d4 * s5, d2 * s7 + d5 * s9)>
41314132
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
41324133
-> (d3, d4, d5)>
41334134
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -5766,3 +5767,74 @@ structured_op: !LinalgStructuredOpConfig
57665767
scalar_const: '2.3283063999999999E-10 : f64'
57675768
- !ScalarExpression
57685769
scalar_arg: min
5770+
--- !LinalgOpConfig
5771+
metadata: !LinalgOpMetadata
5772+
name: linear_relu
5773+
cpp_class_name: LinearReluOp
5774+
doc: |-
5775+
Performs a linear/fully-connected + relu operation
5776+
5777+
This is a long description that I'll fill later
5778+
5779+
Layout:
5780+
* I: WH (Input)
5781+
* W: WH (Weights)
5782+
* B: H (Bias)
5783+
structured_op: !LinalgStructuredOpConfig
5784+
args:
5785+
- !LinalgOperandDefConfig
5786+
name: I
5787+
kind: input_tensor
5788+
type_var: T1
5789+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s1)>
5790+
- !LinalgOperandDefConfig
5791+
name: W
5792+
kind: input_tensor
5793+
type_var: T1
5794+
shape_map: affine_map<()[s0, s1, s2] -> (s2, s1)>
5795+
- !LinalgOperandDefConfig
5796+
name: B
5797+
kind: input_tensor
5798+
type_var: T1
5799+
shape_map: affine_map<()[s0, s1, s2] -> (s2)>
5800+
- !LinalgOperandDefConfig
5801+
name: O
5802+
kind: output_tensor
5803+
type_var: T1
5804+
shape_map: affine_map<()[s0, s1, s2] -> (s0, s2)>
5805+
indexing_maps: !LinalgIndexingMapsConfig
5806+
static_indexing_maps:
5807+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
5808+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
5809+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2)>
5810+
- affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
5811+
iterator_types:
5812+
- parallel
5813+
- reduction
5814+
- parallel
5815+
assignments:
5816+
- !ScalarAssign
5817+
arg: O
5818+
value: !ScalarExpression
5819+
scalar_fn:
5820+
kind: binary
5821+
fn_name: add
5822+
operands:
5823+
- !ScalarExpression
5824+
scalar_arg: O
5825+
- !ScalarExpression
5826+
scalar_fn:
5827+
kind: binary
5828+
fn_name: add
5829+
operands:
5830+
- !ScalarExpression
5831+
scalar_fn:
5832+
kind: binary
5833+
fn_name: mul
5834+
operands:
5835+
- !ScalarExpression
5836+
scalar_arg: I
5837+
- !ScalarExpression
5838+
scalar_arg: W
5839+
- !ScalarExpression
5840+
scalar_arg: B

mlir/lib/Dialect/Linalg/Transforms/Unfuse.cpp

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/IR/PatternMatch.h"
2525
#include "mlir/Support/LogicalResult.h"
2626
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
27+
#include "llvm/ADT/ArrayRef.h"
2728
#include "llvm/ADT/SmallVector.h"
2829
#include "llvm/Support/Debug.h"
2930

@@ -698,6 +699,50 @@ struct LinearLowering : OpRewritePattern<LinearOp> {
698699
}
699700
};
700701

702+
struct LinearReluLowering : OpRewritePattern<LinearReluOp> {
703+
using OpRewritePattern<LinearReluOp>::OpRewritePattern;
704+
LogicalResult matchAndRewrite(LinearReluOp op,
705+
PatternRewriter &rewriter) const override {
706+
Location loc = op.getLoc();
707+
Value weights = op.getOperand(1);
708+
Value bias = op.getOperand(2);
709+
710+
auto weightsType = weights.getType().cast<RankedTensorType>();
711+
auto biasType = bias.getType().cast<RankedTensorType>();
712+
auto outputType = op->getResult(0).getType().cast<RankedTensorType>();
713+
714+
// Create a linalg op that transposes the weights tensor
715+
// The transposedWeights is simply used to describe the output shape.
716+
llvm::ArrayRef<int64_t> weightsShape = weightsType.getShape();
717+
Value transposedWeights = rewriter.create<tensor::EmptyOp>(
718+
loc,
719+
ArrayRef<int64_t>{weightsShape[1], weightsShape[0]},
720+
weightsType.getElementType());
721+
Value transposeWeightsOp =
722+
rewriter.create<Transpose2DOp>(loc, weights, transposedWeights)
723+
->getResult(0);
724+
725+
// Create a linalg op that broadcasts the 1D bias values across
726+
// the 2nd dimension
727+
Value broadcastedBias = rewriter.create<tensor::EmptyOp>(
728+
loc, outputType.getShape(), biasType.getElementType());
729+
Value broadcastBiasOp =
730+
rewriter.create<Broadcast1DTo2DOp>(loc, bias, broadcastedBias)
731+
->getResult(0);
732+
733+
auto linearResult = rewriter.create<MatmulOp>(loc,
734+
outputType, ValueRange{op.getOperand(0), transposeWeightsOp},
735+
broadcastBiasOp).getResult(0);
736+
737+
rewriter.replaceOpWithNewOp<Relu2DNchwOp>(
738+
op,
739+
/*resultTensorTypes=*/linearResult.getType(),
740+
/*inputs=*/linearResult,
741+
/*outputs=*/linearResult);
742+
return success();
743+
}
744+
};
745+
701746
struct LinalgUnfusePass : public impl::LinalgUnfuseBase<LinalgUnfusePass> {
702747
void runOnOperation() override {
703748
RewritePatternSet patterns(&getContext());

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1377,3 +1377,23 @@ def fill_rng_2d(min=ScalarDef(F64),
13771377
scaling = (max - min) * inv_range
13781378
O[D.m, D.n] = TypeFn.cast_signed(
13791379
T, (offset + TypeFn.cast_signed(F64, rand2)) * scaling + min)
1380+
1381+
@linalg_structured_op
1382+
def linear_relu(
1383+
I=TensorDef(T1, S.W, S.H),
1384+
W=TensorDef(T1, S.K, S.H),
1385+
B=TensorDef(T1, S.K),
1386+
O=TensorDef(T1, S.W, S.K, output=True)):
1387+
"""Performs a linear/fully-connected + relu operation
1388+
1389+
This is a long description that I'll fill later
1390+
1391+
Layout:
1392+
* I: WH (Input)
1393+
* W: WH (Weights)
1394+
* B: H (Bias)
1395+
"""
1396+
domain(D.W, D.H, D.K)
1397+
# implementation is incorrect the addition of the bias should happen after
1398+
# the multiplication, not on each element
1399+
O[D.W, D.K] += I[D.W, D.H]*W[D.K, D.H] + B[D.K]

0 commit comments

Comments
 (0)