Skip to content

Commit 0f9a31e

Browse files
author
Ferdinand Lemaire
committed
Add relu_nc and test for unfusing
1 parent 9f69638 commit 0f9a31e

File tree

4 files changed

+87
-5
lines changed

4 files changed

+87
-5
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2478,7 +2478,7 @@ metadata: !LinalgOpMetadata
24782478
The partial multiplication results are reduced into a 2D output.
24792479
24802480
Numeric casting is performed on the operands to the inner multiply, promoting
2481-
them to the same data type as the accumulator/output.
2481+
them to the same data type as the accumulator/output."
24822482
implements:
24832483
- LinalgContractionOpInterface
24842484
structured_op: !LinalgStructuredOpConfig
@@ -4097,7 +4097,7 @@ structured_op: !LinalgStructuredOpConfig
40974097
kind: input_tensor
40984098
type_var: T1
40994099
shape_map: affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 *
4100-
s2 + s3 * s4, s5 * s6 + s7 * s8)>
4100+
s2 + s3 * s4, s5 * s6 + s7 * s8)>
41014101
- !LinalgOperandDefConfig
41024102
name: K
41034103
kind: input_tensor
@@ -5837,3 +5837,49 @@ structured_op: !LinalgStructuredOpConfig
58375837
scalar_arg: W
58385838
- !ScalarExpression
58395839
scalar_arg: B
5840+
--- !LinalgOpConfig
5841+
metadata: !LinalgOpMetadata
5842+
name: relu_nc
5843+
cpp_class_name: ReluNcOp
5844+
doc: |-
5845+
Applies the ReLU activation function to every value in the tensor.
5846+
5847+
Layout:
5848+
* Input: NC
5849+
structured_op: !LinalgStructuredOpConfig
5850+
args:
5851+
- !LinalgOperandDefConfig
5852+
name: IFM
5853+
kind: input_tensor
5854+
type_var: T1
5855+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
5856+
- !LinalgOperandDefConfig
5857+
name: OFM
5858+
kind: output_tensor
5859+
type_var: T1
5860+
shape_map: affine_map<()[s0, s1] -> (s0, s1)>
5861+
indexing_maps: !LinalgIndexingMapsConfig
5862+
static_indexing_maps:
5863+
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
5864+
- affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
5865+
iterator_types:
5866+
- parallel
5867+
- parallel
5868+
assignments:
5869+
- !ScalarAssign
5870+
arg: OFM
5871+
value: !ScalarExpression
5872+
scalar_fn:
5873+
kind: binary
5874+
fn_name: max_signed
5875+
operands:
5876+
- !ScalarExpression
5877+
scalar_arg: IFM
5878+
- !ScalarExpression
5879+
scalar_fn:
5880+
kind: type
5881+
fn_name: cast_signed
5882+
type_var: T1
5883+
operands:
5884+
- !ScalarExpression
5885+
scalar_const: '0.000000e+00 : f64'

mlir/lib/Dialect/Linalg/Transforms/Unfuse.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ struct LinearReluLowering : OpRewritePattern<LinearReluOp> {
712712

713713
Value linearResult = unfuseLinear<LinearReluOp>(op, rewriter);
714714

715-
rewriter.replaceOpWithNewOp<Relu2DNchwOp>(
715+
rewriter.replaceOpWithNewOp<ReluNcOp>(
716716
op,
717717
/*resultTensorTypes=*/linearResult.getType(),
718718
/*inputs=*/linearResult,
@@ -734,7 +734,8 @@ struct LinalgUnfusePass : public impl::LinalgUnfuseBase<LinalgUnfusePass> {
734734
Conv2DTensorAddLreluAveragePoolLowering,
735735
Conv2DActivationMaxpoolOpLowering<Conv2DLreluMaxpoolOp>,
736736
Conv2DActivationMaxpoolOpLowering<Conv2DReluMaxpoolOp>,
737-
SoftmaxLowering, GlobalAveragePool2DLowering, LinearLowering>(
737+
SoftmaxLowering, GlobalAveragePool2DLowering, LinearLowering,
738+
LinearReluLowering>(
738739
&getContext());
739740

740741
(void)applyPatternsAndFoldGreedily(getOperation().getBody(),

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1396,4 +1396,19 @@ def linear_relu(
13961396
domain(D.W, D.H, D.K)
13971397
# implementation is incorrect the addition of the bias should happen after
13981398
# the multiplication, not on each element
1399-
O[D.W, D.K] += I[D.W, D.H]*W[D.K, D.H] + B[D.K]
1399+
O[D.W, D.K] += I[D.W, D.H]*W[D.K, D.H] + B[D.K]
1400+
1401+
1402+
@linalg_structured_op
1403+
def relu_nc(
1404+
IFM=TensorDef(T1, Batch, S.C ),
1405+
OFM=TensorDef(T1, Batch, S.C, output=True )):
1406+
"""Applies the ReLU activation function to every value in the tensor.
1407+
1408+
Layout:
1409+
* Input: NC
1410+
"""
1411+
domain(D.b, D.c)
1412+
OFM[D.b, D.c] = BinaryFn.max_signed(
1413+
IFM[D.b, D.c], TypeFn.cast_signed(T1, const(0.0))
1414+
)

mlir/test/Dialect/Linalg/unfuse.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,26 @@ func.func @unfuse_linear(%input: tensor<1x2048xf32>, %weights: tensor<1000x2048x
448448
// CHECK: %[[bias2dshape:.+]] = tensor.empty() : tensor<1x1000xf32>
449449
// CHECK: %[[bias2d:.+]] = linalg.broadcast_1d_to_2d ins(%arg2 : tensor<1000xf32>) outs(%2 : tensor<1x1000xf32>) -> tensor<1x1000xf32>
450450
// CHECK: %[[out:.+]] = linalg.matmul ins(%[[input]], %[[tweights]] : tensor<1x2048xf32>, tensor<2048x1000xf32>) outs(%[[bias2d]] : tensor<1x1000xf32>) -> tensor<1x1000xf32
451+
// CHECK: return %[[out]]
452+
453+
return %result : tensor<1x1000xf32>
454+
}
455+
456+
// -----
457+
458+
// CHECK: func.func @unfuse_linearRelu
459+
// CHECK-SAME: %[[input:.+]]: tensor<1x2048xf32>, %[[weights:.+]]: tensor<1000x2048xf32>, %[[bias:.+]]: tensor<1000xf32>
460+
func.func @unfuse_linearRelu(%input: tensor<1x2048xf32>, %weights: tensor<1000x2048xf32>, %bias: tensor<1000xf32>) -> tensor<1x1000xf32> {
461+
%zero = arith.constant 0.0 : f32
462+
%init = tensor.splat %zero : tensor<1x1000xf32>
463+
%result = linalg.linear_relu ins(%input, %weights, %bias: tensor<1x2048xf32>, tensor<1000x2048xf32>, tensor<1000xf32>) outs(%init: tensor<1x1000xf32>) -> tensor<1x1000xf32>
464+
465+
// CHECK: %[[tweightshape:.+]] = tensor.empty() : tensor<2048x1000xf32>
466+
// CHECK: %[[tweights:.+]] = linalg.transpose2d ins(%arg1 : tensor<1000x2048xf32>) outs(%0 : tensor<2048x1000xf32>) -> tensor<2048x1000xf32>
467+
// CHECK: %[[bias2dshape:.+]] = tensor.empty() : tensor<1x1000xf32>
468+
// CHECK: %[[bias2d:.+]] = linalg.broadcast_1d_to_2d ins(%arg2 : tensor<1000xf32>) outs(%2 : tensor<1x1000xf32>) -> tensor<1x1000xf32>
469+
// CHECK: %[[matmul:.+]] = linalg.matmul ins(%[[input]], %[[tweights]] : tensor<1x2048xf32>, tensor<2048x1000xf32>) outs(%[[bias2d]] : tensor<1x1000xf32>) -> tensor<1x1000xf32
470+
// CHECK: %[[out:.*]] = linalg.relu_nc ins(%[[matmul]] : tensor<1x1000xf32>) outs(%[[matmul]] : tensor<1x1000xf32>) -> tensor<1x1000xf32>
451471
// CHECK: return %[[out]]
452472

453473
return %result : tensor<1x1000xf32>

0 commit comments

Comments
 (0)