Skip to content

[MLIR][Linalg] More Linalg named ops #90236

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Apr 28, 2024
Merged

Conversation

rengolin
Copy link
Member

Adding min that was already implemented but not exposed.

Adding a few additional unary ops:

  • Reciprocal as arith.div(1,arg)
  • Round as math.round(arg)
  • Sqrt as math.sqrt(arg)
  • Rsqrt as math.rsqrt(arg)
  • Square as math.powf(arg, 2)
  • TanH as math.tanh(arg)

All with the agreed semantics at the round table: no implicit broadcast/type cast.

Small update on builder to get a numeric attribute (and update the zero/one functions to use it).

This is just the first step. Soon after this we'll add explicit type casts, select, clamp, square difference, generic pow.

After that, we'll discuss the semantics of linalg.softmax and how it should be lowered to existing named ops.

@rengolin rengolin requested review from jpienaar and joker-eph April 26, 2024 17:13
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir labels Apr 26, 2024
@llvmbot
Copy link
Member

llvmbot commented Apr 26, 2024

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Renato Golin (rengolin)

Changes

Adding min that was already implemented but not exposed.

Adding a few additional unary ops:

  • Reciprocal as arith.div(1,arg)
  • Round as math.round(arg)
  • Sqrt as math.sqrt(arg)
  • Rsqrt as math.rsqrt(arg)
  • Square as math.powf(arg, 2)
  • TanH as math.tanh(arg)

All with the agreed semantics at the round table: no implicit broadcast/type cast.

Small update on builder to get a numeric attribute (and update the zero/one functions to use it).

This is just the first step. Soon after this we'll add explicit type casts, select, clamp, square difference, generic pow.

After that, we'll discuss the semantics of linalg.softmax and how it should be lowered to existing named ops.


Patch is 35.95 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/90236.diff

10 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td (+7-1)
  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml (+260-1)
  • (modified) mlir/include/mlir/IR/Builders.h (+5-1)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+20)
  • (modified) mlir/lib/IR/Builders.cpp (+9-19)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py (+5)
  • (modified) mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py (+79)
  • (modified) mlir/test/Dialect/Linalg/generalize-named-ops.mlir (+155)
  • (modified) mlir/test/Dialect/Linalg/named-ops-fail.mlir (+112)
  • (modified) mlir/test/Dialect/Linalg/named-ops.mlir (+220)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
index 59f909aed8f61a..7a350d2c014262 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
@@ -22,7 +22,13 @@ def UnaryFn : I32EnumAttr<"UnaryFn", "", [
   I32EnumAttrCase<"abs", 2>,
   I32EnumAttrCase<"ceil", 3>,
   I32EnumAttrCase<"floor", 4>,
-  I32EnumAttrCase<"negf", 5>
+  I32EnumAttrCase<"negf", 5>,
+  I32EnumAttrCase<"reciprocal", 6>,
+  I32EnumAttrCase<"round", 7>,
+  I32EnumAttrCase<"sqrt", 8>,
+  I32EnumAttrCase<"rsqrt", 9>,
+  I32EnumAttrCase<"square", 10>,
+  I32EnumAttrCase<"tanh", 11>
 ]> {
   let genSpecializedAttr = 0;
   let cppNamespace = "::mlir::linalg";
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
index 1ff6c4086cf357..b7567577347587 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml
@@ -304,6 +304,216 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: I
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: reciprocal
+  cpp_class_name: ReciprocalOp
+  doc: |-
+    Applies reciprocal(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: reciprocal
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: round
+  cpp_class_name: RoundOp
+  doc: |-
+    Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: round
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: sqrt
+  cpp_class_name: SqrtOp
+  doc: |-
+    Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: sqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: rsqrt
+  cpp_class_name: RsqrtOp
+  doc: |-
+    Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: rsqrt
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: square
+  cpp_class_name: SquareOp
+  doc: |-
+    Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: square
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: tanh
+  cpp_class_name: TanhOp
+  doc: |-
+    Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: I
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: unary
+        fn_name: tanh
+        operands:
+        - !ScalarExpression
+          scalar_arg: I
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: elemwise_binary
   cpp_class_name: ElemwiseBinaryOp
@@ -625,7 +835,7 @@ metadata: !LinalgOpMetadata
 
     This means reduction/broadcast/element cast semantics is explicit. Further
     passes can take that into account when lowering this code. For example,
-    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    a `linalg.broadcast` + `linalg.max` sequence can be lowered to a
     `linalg.generic` with different affine maps for the two operands.
 structured_op: !LinalgStructuredOpConfig
   args:
@@ -663,6 +873,55 @@ structured_op: !LinalgStructuredOpConfig
         - !ScalarExpression
           scalar_arg: rhs
 --- !LinalgOpConfig
+metadata: !LinalgOpMetadata
+  name: min
+  cpp_class_name: MinOp
+  doc: |-
+    Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.min` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+structured_op: !LinalgStructuredOpConfig
+  args:
+  - !LinalgOperandDefConfig
+    name: lhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: rhs
+    kind: input_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  - !LinalgOperandDefConfig
+    name: O
+    kind: output_tensor
+    type_var: T1
+    shape_map: affine_map<() -> ()>
+  indexing_maps: !LinalgIndexingMapsConfig
+    static_indexing_maps:
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+    - affine_map<() -> ()>
+  iterator_types: []
+  assignments:
+  - !ScalarAssign
+    arg: O
+    value: !ScalarExpression
+      scalar_fn:
+        kind: binary
+        fn_name: min_signed
+        operands:
+        - !ScalarExpression
+          scalar_arg: lhs
+        - !ScalarExpression
+          scalar_arg: rhs
+--- !LinalgOpConfig
 metadata: !LinalgOpMetadata
   name: matmul
   cpp_class_name: MatmulOp
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 0d5fa719d0dee2..308c2cd38196cb 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -115,12 +115,16 @@ class Builder {
   ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
 
   // Returns a 0-valued attribute of the given `type`. This function only
-  // supports boolean, integer, and 16-/32-/64-bit float types, and vector or
+  // supports integer, and 16-/32-/64-bit float types, and vector or
   // ranked tensor of them. Returns null attribute otherwise.
   TypedAttr getZeroAttr(Type type);
   // Returns a 1-valued attribute of the given `type`.
   // Type constraints are the same as `getZeroAttr`.
   TypedAttr getOneAttr(Type type);
+  // Returns a numeric attribute of the given `type`.
+  // Type constraints are the same as `getZeroAttr`.
+  // Non float types are converted before returning the attribute.
+  TypedAttr getNumberAttr(double value, Type type);
 
   // Convenience methods for fixed types.
   FloatAttr getF16FloatAttr(float value);
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..6ee5864656cf64 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -395,6 +395,26 @@ class RegionBuilderHelper {
       return builder.create<math::FloorOp>(arg.getLoc(), arg);
     case UnaryFn::negf:
       return builder.create<arith::NegFOp>(arg.getLoc(), arg);
+    case UnaryFn::reciprocal:
+    {
+      Attribute oneAttr = builder.getNumberAttr(1.0, arg.getType());
+      auto one = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(oneAttr));
+      return builder.create<arith::DivFOp>(arg.getLoc(), one, arg);
+    }
+    case UnaryFn::round:
+      return builder.create<math::RoundOp>(arg.getLoc(), arg);
+    case UnaryFn::sqrt:
+      return builder.create<math::SqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::rsqrt:
+      return builder.create<math::RsqrtOp>(arg.getLoc(), arg);
+    case UnaryFn::square:
+    {
+      Attribute twoAttr = builder.getNumberAttr(2.0, arg.getType());
+      auto two = builder.create<arith::ConstantOp>(arg.getLoc(), ::cast<TypedAttr>(twoAttr));
+      return builder.create<math::PowFOp>(arg.getLoc(), arg, two);
+    }
+    case UnaryFn::tanh:
+      return builder.create<math::TanhOp>(arg.getLoc(), arg);
     }
     llvm_unreachable("unsupported unary function");
   }
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d49f69a7b7ae6b..ae8348c8454e9e 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -329,34 +329,24 @@ ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
 }
 
 TypedAttr Builder::getZeroAttr(Type type) {
-  if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 0.0);
-  if (llvm::isa<IndexType>(type))
-    return getIndexAttr(0);
-  if (llvm::dyn_cast<IntegerType>(type))
-    return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 0));
-  if (llvm::isa<RankedTensorType, VectorType>(type)) {
-    auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getZeroAttr(vtType.getElementType());
-    if (!element)
-      return {};
-    return DenseElementsAttr::get(vtType, element);
-  }
-  return {};
+  return getNumberAttr(0.0, type);
 }
 
 TypedAttr Builder::getOneAttr(Type type) {
+  return getNumberAttr(1.0, type);
+}
+
+TypedAttr Builder::getNumberAttr(double value, Type type) {
   if (llvm::isa<FloatType>(type))
-    return getFloatAttr(type, 1.0);
+    return getFloatAttr(type, value);
   if (llvm::isa<IndexType>(type))
-    return getIndexAttr(1);
+    return getIndexAttr(static_cast<int64_t>(value));
   if (llvm::dyn_cast<IntegerType>(type))
     return getIntegerAttr(type,
-                          APInt(llvm::cast<IntegerType>(type).getWidth(), 1));
+                          APInt(llvm::cast<IntegerType>(type).getWidth(), static_cast<int64_t>(value)));
   if (llvm::isa<RankedTensorType, VectorType>(type)) {
     auto vtType = llvm::cast<ShapedType>(type);
-    auto element = getOneAttr(vtType.getElementType());
+    auto element = getNumberAttr(value, vtType.getElementType());
     if (!element)
       return {};
     return DenseElementsAttr::get(vtType, element);
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
index 23d6d26b7e294c..f7bc81bd2f6833 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
@@ -291,6 +291,11 @@ class UnaryFn:
     ceil = UnaryFnType("ceil")
     floor = UnaryFnType("floor")
     negf = UnaryFnType("negf")
+    round = UnaryFnType("round")
+    sqrt = UnaryFnType("sqrt")
+    rsqrt = UnaryFnType("rsqrt")
+    square = UnaryFnType("square")
+    tanh = UnaryFnType("tanh")
 
 
 class BinaryFnType:
diff --git a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
index 5b05364f6d35f3..c97ac448006505 100644
--- a/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
+++ b/mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
@@ -108,6 +108,66 @@ def negf(
     O[None] = UnaryFn.negf(I[None])
 
 
+@linalg_structured_op
+def round(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies round(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.round(I[None])
+
+
+@linalg_structured_op
+def sqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies sqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.sqrt(I[None])
+
+
+@linalg_structured_op
+def rsqrt(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies rsqrt(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.rsqrt(I[None])
+
+
+@linalg_structured_op
+def square(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies square(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.square(I[None])
+
+
+@linalg_structured_op
+def tanh(
+    I=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Applies tanh(x) elementwise.
+
+    No numeric casting is performed on the input operand.
+    """
+    O[None] = UnaryFn.tanh(I[None])
+
+
 @linalg_structured_op
 def elemwise_binary(
     lhs=TensorDef(T1),
@@ -239,6 +299,25 @@ def max(
     O[None] = BinaryFn.max_signed(lhs[None], rhs[None])
 
 
+@linalg_structured_op
+def min(
+    lhs=TensorDef(T1),
+    rhs=TensorDef(T1),
+    O=TensorDef(T1, output=True),
+):
+    """Takes the min (signed) between two inputs, elementwise.
+
+    The shapes and element types must be identical. The appropriate casts,
+    broadcasts and reductions should be done previously to calling this op.
+
+    This means reduction/broadcast/element cast semantics is explicit. Further
+    passes can take that into account when lowering this code. For example,
+    a `linalg.broadcast` + `linalg.div` sequence can be lowered to a
+    `linalg.generic` with different affine maps for the two operands.
+    """
+    O[None] = BinaryFn.min_signed(lhs[None], rhs[None])
+
+
 @linalg_structured_op
 def matmul(
     A=TensorDef(T1, S.M, S.K),
diff --git a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
index e852824cdb7367..160d1e9536bb5b 100644
--- a/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
+++ b/mlir/test/Dialect/Linalg/generalize-named-ops.mlir
@@ -565,6 +565,136 @@ func.func @generalize_negf(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>)
 
 // -----
 
+func.func @generalize_reciprocal(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.reciprocal ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_reciprocal
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[one:.+]] = arith.constant 1.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[reciprocal:.+]] = arith.divf %[[one]], %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[reciprocal]] : f32
+
+// -----
+
+func.func @generalize_round(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.round ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_round
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[round:.+]] = math.round %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[round]] : f32
+
+// -----
+
+func.func @generalize_sqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.sqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_sqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[sqrt:.+]] = math.sqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[sqrt]] : f32
+
+// -----
+
+func.func @generalize_rsqrt(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.rsqrt ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_rsqrt
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
+// CHECK-SAME:  ins(%[[LHS]] : memref<7x14x21xf32>) outs(%[[OUT]] : memref<7x14x21xf32>)
+
+// CHECK:         ^{{.+}}(%[[BBARG0:.+]]: f32, %[[BBARG1:.+]]: f32)
+// CHECK-NEXT:      %[[rsqrt:.+]] = math.rsqrt %[[BBARG0]] : f32
+// CHECK-NEXT:      linalg.yield %[[rsqrt]] : f32
+
+// -----
+
+func.func @generalize_square(%arg: memref<7x14x21xf32>, %out: memref<7x14x21xf32>) {
+  linalg.square ins(%arg : memref<7x14x21xf32>) outs(%out : memref<7x14x21xf32>)
+  return
+}
+
+// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
+
+// CHECK: func @generalize_square
+// CHECK-SAME: (%[[ARG:.+]]: memref<7x14x21xf32>, %[[OUT:.+]]: memref<7x14x21xf32>)
+
+// CHECK: %[[two:.+]] = arith.constant 2.000000e+00 : f32
+
+// CHECK: linalg.generic
+// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]]]
+// CHECK-SAME: iterator_types = ["...
[truncated]

Copy link

github-actions bot commented Apr 26, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@stellaraccident
Copy link
Contributor

I wouldn't object at some point to someone working out how to split the implementation C++ file as we add more variants. These big op libraries bring a fair amount of single threaded compile time.

@rengolin rengolin merged commit 4cec3b3 into llvm:main Apr 28, 2024
@rengolin rengolin deleted the linalg-named-ops branch April 28, 2024 14:25
@rengolin rengolin restored the linalg-named-ops branch April 28, 2024 14:25
rengolin added a commit that referenced this pull request May 14, 2024
Following #90236, adding `select` to linalg as `arith.select`. No
implicit type casting.
OpDSL doesn't expose a type restriction for bool, but I saw no reason in
adding it (put a separate symbolic type and check the semantics in the
builder).

---------

Co-authored-by: Renato Golin <[email protected]>
Co-authored-by: Maksim Levental <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:linalg mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants