Skip to content

Commit 4cec3b3

Browse files
authored
[MLIR][Linalg] More Linalg named ops (#90236)
Adding `min` that was already implemented but not exposed. Adding a few additional unary ops: * Reciprocal as `arith.div(1,arg)` * Round as `math.round(arg)` * Sqrt as `math.sqrt(arg)` * Rsqrt as `math.rsqrt(arg)` * Square as `math.powf(arg, 2)` * TanH as `math.tanh(arg)` All with the agreed semantics at the round table: no implicit broadcast/type cast.
1 parent dc6ce60 commit 4cec3b3

File tree

8 files changed

+853
-3
lines changed

8 files changed

+853
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
2222
I32EnumAttrCase<"abs", 2>,
2323
I32EnumAttrCase<"ceil", 3>,
2424
I32EnumAttrCase<"floor", 4>,
25-
I32EnumAttrCase<"negf", 5>
25+
I32EnumAttrCase<"negf", 5>,
26+
I32EnumAttrCase<"reciprocal", 6>,
27+
I32EnumAttrCase<"round", 7>,
28+
I32EnumAttrCase<"sqrt", 8>,
29+
I32EnumAttrCase<"rsqrt", 9>,
30+
I32EnumAttrCase<"square", 10>,
31+
I32EnumAttrCase<"tanh", 11>
2632
]> {
2733
let genSpecializedAttr = 0;
2834
let cppNamespace = "::mlir::linalg";

mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml

Lines changed: 260 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
304304
- !ScalarExpression
305305
scalar_arg: I
306306
--- !LinalgOpConfig
307+
metadata: !LinalgOpMetadata
308+
name: reciprocal
309+
cpp_class_name: ReciprocalOp
310+
doc: |-
311+
Applies reciprocal(x) elementwise.
312+
313+
No numeric casting is performed on the input operand.
314+
structured_op: !LinalgStructuredOpConfig
315+
args:
316+
- !LinalgOperandDefConfig
317+
name: I
318+
kind: input_tensor
319+
type_var: T1
320+
shape_map: affine_map<() -> ()>
321+
- !LinalgOperandDefConfig
322+
name: O
323+
kind: output_tensor
324+
type_var: T1
325+
shape_map: affine_map<() -> ()>
326+
indexing_maps: !LinalgIndexingMapsConfig
327+
static_indexing_maps:
328+
- affine_map<() -> ()>
329+
- affine_map<() -> ()>
330+
iterator_types: []
331+
assignments:
332+
- !ScalarAssign
333+
arg: O
334+
value: !ScalarExpression
335+
scalar_fn:
336+
kind: unary
337+
fn_name: reciprocal
338+
operands:
339+
- !ScalarExpression
340+
scalar_arg: I
341+
--- !LinalgOpConfig
342+
metadata: !LinalgOpMetadata
343+
name: round
344+
cpp_class_name: RoundOp
345+
doc: |-
346+
Applies round(x) elementwise.
347+
348+
No numeric casting is performed on the input operand.
349+
structured_op: !LinalgStructuredOpConfig
350+
args:
351+
- !LinalgOperandDefConfig
352+
name: I
353+
kind: input_tensor
354+
type_var: T1
355+
shape_map: affine_map<() -> ()>
356+
- !LinalgOperandDefConfig
357+
name: O
358+
kind: output_tensor
359+
type_var: T1
360+
shape_map: affine_map<() -> ()>
361+
indexing_maps: !LinalgIndexingMapsConfig
362+
static_indexing_maps:
363+
- affine_map<() -> ()>
364+
- affine_map<() -> ()>
365+
iterator_types: []
366+
assignments:
367+
- !ScalarAssign
368+
arg: O
369+
value: !ScalarExpression
370+
scalar_fn:
371+
kind: unary
372+
fn_name: round
373+
operands:
374+
- !ScalarExpression
375+
scalar_arg: I
376+
--- !LinalgOpConfig
377+
metadata: !LinalgOpMetadata
378+
name: sqrt
379+
cpp_class_name: SqrtOp
380+
doc: |-
381+
Applies sqrt(x) elementwise.
382+
383+
No numeric casting is performed on the input operand.
384+
structured_op: !LinalgStructuredOpConfig
385+
args:
386+
- !LinalgOperandDefConfig
387+
name: I
388+
kind: input_tensor
389+
type_var: T1
390+
shape_map: affine_map<() -> ()>
391+
- !LinalgOperandDefConfig
392+
name: O
393+
kind: output_tensor
394+
type_var: T1
395+
shape_map: affine_map<() -> ()>
396+
indexing_maps: !LinalgIndexingMapsConfig
397+
static_indexing_maps:
398+
- affine_map<() -> ()>
399+
- affine_map<() -> ()>
400+
iterator_types: []
401+
assignments:
402+
- !ScalarAssign
403+
arg: O
404+
value: !ScalarExpression
405+
scalar_fn:
406+
kind: unary
407+
fn_name: sqrt
408+
operands:
409+
- !ScalarExpression
410+
scalar_arg: I
411+
--- !LinalgOpConfig
412+
metadata: !LinalgOpMetadata
413+
name: rsqrt
414+
cpp_class_name: RsqrtOp
415+
doc: |-
416+
Applies rsqrt(x) elementwise.
417+
418+
No numeric casting is performed on the input operand.
419+
structured_op: !LinalgStructuredOpConfig
420+
args:
421+
- !LinalgOperandDefConfig
422+
name: I
423+
kind: input_tensor
424+
type_var: T1
425+
shape_map: affine_map<() -> ()>
426+
- !LinalgOperandDefConfig
427+
name: O
428+
kind: output_tensor
429+
type_var: T1
430+
shape_map: affine_map<() -> ()>
431+
indexing_maps: !LinalgIndexingMapsConfig
432+
static_indexing_maps:
433+
- affine_map<() -> ()>
434+
- affine_map<() -> ()>
435+
iterator_types: []
436+
assignments:
437+
- !ScalarAssign
438+
arg: O
439+
value: !ScalarExpression
440+
scalar_fn:
441+
kind: unary
442+
fn_name: rsqrt
443+
operands:
444+
- !ScalarExpression
445+
scalar_arg: I
446+
--- !LinalgOpConfig
447+
metadata: !LinalgOpMetadata
448+
name: square
449+
cpp_class_name: SquareOp
450+
doc: |-
451+
Applies square(x) elementwise.
452+
453+
No numeric casting is performed on the input operand.
454+
structured_op: !LinalgStructuredOpConfig
455+
args:
456+
- !LinalgOperandDefConfig
457+
name: I
458+
kind: input_tensor
459+
type_var: T1
460+
shape_map: affine_map<() -> ()>
461+
- !LinalgOperandDefConfig
462+
name: O
463+
kind: output_tensor
464+
type_var: T1
465+
shape_map: affine_map<() -> ()>
466+
indexing_maps: !LinalgIndexingMapsConfig
467+
static_indexing_maps:
468+
- affine_map<() -> ()>
469+
- affine_map<() -> ()>
470+
iterator_types: []
471+
assignments:
472+
- !ScalarAssign
473+
arg: O
474+
value: !ScalarExpression
475+
scalar_fn:
476+
kind: unary
477+
fn_name: square
478+
operands:
479+
- !ScalarExpression
480+
scalar_arg: I
481+
--- !LinalgOpConfig
482+
metadata: !LinalgOpMetadata
483+
name: tanh
484+
cpp_class_name: TanhOp
485+
doc: |-
486+
Applies tanh(x) elementwise.
487+
488+
No numeric casting is performed on the input operand.
489+
structured_op: !LinalgStructuredOpConfig
490+
args:
491+
- !LinalgOperandDefConfig
492+
name: I
493+
kind: input_tensor
494+
type_var: T1
495+
shape_map: affine_map<() -> ()>
496+
- !LinalgOperandDefConfig
497+
name: O
498+
kind: output_tensor
499+
type_var: T1
500+
shape_map: affine_map<() -> ()>
501+
indexing_maps: !LinalgIndexingMapsConfig
502+
static_indexing_maps:
503+
- affine_map<() -> ()>
504+
- affine_map<() -> ()>
505+
iterator_types: []
506+
assignments:
507+
- !ScalarAssign
508+
arg: O
509+
value: !ScalarExpression
510+
scalar_fn:
511+
kind: unary
512+
fn_name: tanh
513+
operands:
514+
- !ScalarExpression
515+
scalar_arg: I
516+
--- !LinalgOpConfig
307517
metadata: !LinalgOpMetadata
308518
name: elemwise_binary
309519
cpp_class_name: ElemwiseBinaryOp
@@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata
625835
626836
This means reduction/broadcast/element cast semantics is explicit. Further
627837
passes can take that into account when lowering this code. For example,
628-
a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
838+
a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
629839
`linalg.generic` with different affine maps for the two operands.
630840
structured_op: !LinalgStructuredOpConfig
631841
args:
@@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
663873
- !ScalarExpression
664874
scalar_arg: rhs
665875
--- !LinalgOpConfig
876+
metadata: !LinalgOpMetadata
877+
name: min
878+
cpp_class_name: MinOp
879+
doc: |-
880+
Takes the min (signed) between two inputs, elementwise.
881+
882+
The shapes and element types must be identical. The appropriate casts,
883+
broadcasts and reductions should be done previously to calling this op.
884+
885+
This means reduction/broadcast/element cast semantics is explicit. Further
886+
passes can take that into account when lowering this code. For example,
887+
a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
888+
`linalg.generic` with different affine maps for the two operands.
889+
structured_op: !LinalgStructuredOpConfig
890+
args:
891+
- !LinalgOperandDefConfig
892+
name: lhs
893+
kind: input_tensor
894+
type_var: T1
895+
shape_map: affine_map<() -> ()>
896+
- !LinalgOperandDefConfig
897+
name: rhs
898+
kind: input_tensor
899+
type_var: T1
900+
shape_map: affine_map<() -> ()>
901+
- !LinalgOperandDefConfig
902+
name: O
903+
kind: output_tensor
904+
type_var: T1
905+
shape_map: affine_map<() -> ()>
906+
indexing_maps: !LinalgIndexingMapsConfig
907+
static_indexing_maps:
908+
- affine_map<() -> ()>
909+
- affine_map<() -> ()>
910+
- affine_map<() -> ()>
911+
iterator_types: []
912+
assignments:
913+
- !ScalarAssign
914+
arg: O
915+
value: !ScalarExpression
916+
scalar_fn:
917+
kind: binary
918+
fn_name: min_signed
919+
operands:
920+
- !ScalarExpression
921+
scalar_arg: lhs
922+
- !ScalarExpression
923+
scalar_arg: rhs
924+
--- !LinalgOpConfig
666925
metadata: !LinalgOpMetadata
667926
name: matmul
668927
cpp_class_name: MatmulOp

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ class RegionBuilderHelper {
395395
return builder.create<math::FloorOp>(arg.getLoc(), arg);
396396
case UnaryFn::negf:
397397
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
398+
case UnaryFn::reciprocal: {
399+
Attribute oneAttr = builder.getOneAttr(arg.getType());
400+
auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
401+
::cast<TypedAttr>(oneAttr));
402+
return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
403+
}
404+
case UnaryFn::round:
405+
return builder.create<math::RoundOp>(arg.getLoc(), arg);
406+
case UnaryFn::sqrt:
407+
return builder.create<math::SqrtOp>(arg.getLoc(), arg);
408+
case UnaryFn::rsqrt:
409+
return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
410+
case UnaryFn::square:
411+
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
412+
case UnaryFn::tanh:
413+
return builder.create<math::TanhOp>(arg.getLoc(), arg);
398414
}
399415
llvm_unreachable("unsupported unary function");
400416
}

mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,11 @@ class UnaryFn:
291291
ceil = UnaryFnType("ceil")
292292
floor = UnaryFnType("floor")
293293
negf = UnaryFnType("negf")
294+
round = UnaryFnType("round")
295+
sqrt = UnaryFnType("sqrt")
296+
rsqrt = UnaryFnType("rsqrt")
297+
square = UnaryFnType("square")
298+
tanh = UnaryFnType("tanh")
294299

295300

296301
class BinaryFnType:

0 commit comments

Comments
 (0)