Skip to content

[MLIR][Linalg] More Linalg named ops #90236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
I32EnumAttrCase<"abs", 2>,
I32EnumAttrCase<"ceil", 3>,
I32EnumAttrCase<"floor", 4>,
I32EnumAttrCase<"negf", 5>
I32EnumAttrCase<"negf", 5>,
I32EnumAttrCase<"reciprocal", 6>,
I32EnumAttrCase<"round", 7>,
I32EnumAttrCase<"sqrt", 8>,
I32EnumAttrCase<"rsqrt", 9>,
I32EnumAttrCase<"square", 10>,
I32EnumAttrCase<"tanh", 11>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
Expand Down
261 changes: 260 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: reciprocal
cpp_class_name: ReciprocalOp
doc: |-
Applies reciprocal(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: reciprocal
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: round
cpp_class_name: RoundOp
doc: |-
Applies round(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: round
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: sqrt
cpp_class_name: SqrtOp
doc: |-
Applies sqrt(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: sqrt
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: rsqrt
cpp_class_name: RsqrtOp
doc: |-
Applies rsqrt(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: rsqrt
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: square
cpp_class_name: SquareOp
doc: |-
Applies square(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: square
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: tanh
cpp_class_name: TanhOp
doc: |-
Applies tanh(x) elementwise.

No numeric casting is performed on the input operand.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: I
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: unary
fn_name: tanh
operands:
- !ScalarExpression
scalar_arg: I
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: elemwise_binary
cpp_class_name: ElemwiseBinaryOp
Expand Down Expand Up @@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata

This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
Expand Down Expand Up @@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: min
cpp_class_name: MinOp
doc: |-
Takes the min (signed) between two inputs, elementwise.

The shapes and element types must be identical. The appropriate casts,
broadcasts and reductions should be done previously to calling this op.

This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: lhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: rhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: binary
fn_name: min_signed
operands:
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
Expand Down
16 changes: 16 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,22 @@ class RegionBuilderHelper {
return builder.create<math::FloorOp>(arg.getLoc(), arg);
case UnaryFn::negf:
return builder.create<arith::NegFOp>(arg.getLoc(), arg);
case UnaryFn::reciprocal: {
Attribute oneAttr = builder.getOneAttr(arg.getType());
auto one = builder.create<arith::ConstantOp>(arg.getLoc(),
::cast<TypedAttr>(oneAttr));
return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
}
case UnaryFn::round:
return builder.create<math::RoundOp>(arg.getLoc(), arg);
case UnaryFn::sqrt:
return builder.create<math::SqrtOp>(arg.getLoc(), arg);
case UnaryFn::rsqrt:
return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
case UnaryFn::square:
return builder.create<arith::MulFOp>(arg.getLoc(), arg, arg);
case UnaryFn::tanh:
return builder.create<math::TanhOp>(arg.getLoc(), arg);
}
llvm_unreachable("unsupported unary function");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,11 @@ class UnaryFn:
ceil = UnaryFnType("ceil")
floor = UnaryFnType("floor")
negf = UnaryFnType("negf")
round = UnaryFnType("round")
sqrt = UnaryFnType("sqrt")
rsqrt = UnaryFnType("rsqrt")
square = UnaryFnType("square")
tanh = UnaryFnType("tanh")


class BinaryFnType:
Expand Down
Loading