Skip to content

Commit d8dc1c2

Browse files
committed
[MLIR][Linalg] Add max named op to linalg
I've been trying to come up with a simple and clean implementation for ReLU. TOSA uses `clamp` which is probably the goal, but that means table-gen to make it efficient (attributes, only lower `min` or `max`). For now, `max` is a reasonable named op despite ReLU, so we can start using it for tiling and fusion, and upon success, we create a more complete op `clamp` that doesn't need a whole tensor filled with zeroes or ones to implement the different activation functions. As with other named ops, we start "requiring" type casts and broadcasts, and zero filled constant tensors to a more complex pattern-matcher, and can slowly simplify with attributes or structured matchers (ex. PDL) in the future. Differential Revision: https://reviews.llvm.org/D154703
1 parent 9d5cfed commit d8dc1c2

File tree

5 files changed

+143
-0
lines changed

5 files changed

+143
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,55 @@ structured_op: !LinalgStructuredOpConfig
613613
- !ScalarExpression
614614
scalar_arg: rhs
615615
--- !LinalgOpConfig
616+
metadata: !LinalgOpMetadata
617+
name: max
618+
cpp_class_name: MaxOp
619+
doc: |-
620+
Takes the max (signed) between the input and a constant.
621+
622+
The shapes and element types must be identical. The appropriate casts,
623+
broadcasts and reductions should be done previously to calling this op.
624+
625+
This means reduction/broadcast/element cast semantics is explicit. Further
626+
passes can take that into account when lowering this code. For example,
627+
a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
628+
`linalg.generic` with different affine maps for the two operands.
629+
structured_op: !LinalgStructuredOpConfig
630+
args:
631+
- !LinalgOperandDefConfig
632+
name: lhs
633+
kind: input_tensor
634+
type_var: T
635+
shape_map: affine_map<() -> ()>
636+
- !LinalgOperandDefConfig
637+
name: rhs
638+
kind: input_tensor
639+
type_var: T
640+
shape_map: affine_map<() -> ()>
641+
- !LinalgOperandDefConfig
642+
name: out
643+
kind: output_tensor
644+
type_var: T
645+
shape_map: affine_map<() -> ()>
646+
indexing_maps: !LinalgIndexingMapsConfig
647+
static_indexing_maps:
648+
- affine_map<() -> ()>
649+
- affine_map<() -> ()>
650+
- affine_map<() -> ()>
651+
iterator_types: []
652+
assignments:
653+
- !ScalarAssign
654+
arg: out
655+
value: !ScalarExpression
656+
scalar_fn:
657+
kind: binary
658+
fn_name: max_signed
659+
operands:
660+
- !ScalarExpression
661+
scalar_arg: lhs
662+
- !ScalarExpression
663+
scalar_arg: rhs
664+
--- !LinalgOpConfig
616665
metadata: !LinalgOpMetadata
617666
name: matmul
618667
cpp_class_name: MatmulOp

mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,25 @@ def div_unsigned(
219219
O[None] = lhs[None] / rhs[None]
220220

221221

222+
@linalg_structured_op
223+
def max(
224+
lhs=TensorDef(T1),
225+
rhs=TensorDef(T1),
226+
O=TensorDef(T1, output=True),
227+
):
228+
"""Takes the max (signed) between two inputs, elementwise.
229+
230+
The shapes and element types must be identical. The appropriate casts,
231+
broadcasts and reductions should be done previously to calling this op.
232+
233+
This means reduction/broadcast/element cast semantics is explicit. Further
234+
passes can take that into account when lowering this code. For example,
235+
a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
236+
`linalg.generic` with different affine maps for the two operands.
237+
"""
238+
O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
239+
240+
222241
@linalg_structured_op
223242
def matmul(
224243
A=TensorDef(T1, S.M, S.K),

mlir/test/Dialect/Linalg/generalize-named-ops.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,3 +537,28 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
537537
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
538538
// CHECK-NEXT: %[[negf:.+]] = arith.negf %[[BBARG0]] : f32
539539
// CHECK-NEXT: linalg.yield %[[negf]] : f32
540+
541+
// -----
542+
543+
func.func @generalize_max(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
544+
%out: memref<7x14x21xf32>) {
545+
linalg.max ins(%lhs, %rhs : memref<7x14x21xf32>, memref<7x14x21xf32>)
546+
outs(%out : memref<7x14x21xf32>)
547+
return
548+
}
549+
550+
// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
551+
552+
// CHECK: func @generalize_max
553+
// CHECK-SAME: (%[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
554+
// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)
555+
556+
// CHECK: linalg.generic
557+
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]]]
558+
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
559+
// CHECK-SAME: ins(%[[LHS]], %[[RHS]] : memref<7x14x21xf32>, memref<7x14x21xf32>)
560+
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)
561+
562+
// CHECK: ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32)
563+
// CHECK-NEXT: %[[max:.+]] = arith.maxf %[[BBARG0]], %[[BBARG1]] : f32
564+
// CHECK-NEXT: linalg.yield %[[max]] : f32

mlir/test/Dialect/Linalg/named-ops-fail.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,19 @@ func.func @negf_broadcast(%arg: memref<8x16xf32>, %out: memref<4x8x16xf32>) {
173173
linalg.negf ins(%arg : memref<8x16xf32>) outs(%out: memref<4x8x16xf32>)
174174
return
175175
}
176+
177+
// -----
178+
179+
func.func @max_type_cast(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>) {
180+
// CHECK: op requires the same type for all operands and results
181+
linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf16>) outs(%arg2: memref<4x8x16xf32>)
182+
return
183+
}
184+
185+
// -----
186+
187+
func.func @max_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
188+
// CHECK: op expected operand rank (2) to match the result rank of indexing_map #0 (3)
189+
linalg.max ins(%arg0, %arg1 : memref<8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
190+
return
191+
}

mlir/test/Dialect/Linalg/named-ops.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1540,3 +1540,37 @@ func.func @negf_tensor(%arg0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
15401540
%1 = linalg.negf ins(%arg0 : tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
15411541
return %1 : tensor<4x8x16xf32>
15421542
}
1543+
1544+
// -----
1545+
1546+
// CHECK-LABEL: func @max_dynamic
1547+
func.func @max_dynamic(%arg0: memref<?x?x?xf32>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>) {
1548+
// CHECK: linalg.max
1549+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<?x?x?xf32>, memref<?x?x?xf32>)
1550+
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
1551+
linalg.max ins(%arg0, %arg1 : memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg2: memref<?x?x?xf32>)
1552+
return
1553+
}
1554+
1555+
// -----
1556+
1557+
// CHECK-LABEL: func @max_static
1558+
func.func @max_static(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>) {
1559+
// CHECK: linalg.max
1560+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : memref<4x8x16xf32>, memref<4x8x16xf32>)
1561+
// CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
1562+
linalg.max ins(%arg0, %arg1 : memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg2: memref<4x8x16xf32>)
1563+
return
1564+
}
1565+
1566+
// -----
1567+
1568+
// CHECK-LABEL: func @max_tensor
1569+
func.func @max_tensor(%arg0: tensor<4x8x16xf32>, %arg1: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
1570+
%0 = tensor.empty() : tensor<4x8x16xf32>
1571+
// CHECK: linalg.max
1572+
// CHECK-SAME: ins(%{{.+}}, %{{.+}} : tensor<4x8x16xf32>, tensor<4x8x16xf32>)
1573+
// CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
1574+
%1 = linalg.max ins(%arg0, %arg1 : tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
1575+
return %1 : tensor<4x8x16xf32>
1576+
}

0 commit comments

Comments
 (0)