Skip to content

[mlir][linalg] regionBuilder for transpose, broadcast #69742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 20, 2023

Conversation

makslevental
Copy link
Contributor

Currently, linalg.transpose and linalg.broadcast can't be emitted through either the C API or the python bindings (which of course go through the C API). See https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10.

The reason is even though they're named ops, there is no opdsl @linalg_structured_op for them and thus while they can be instantiated they cannot be passed to mlirLinalgFillBuiltinNamedOpRegion. I believe the issue is they both take a IndexAttrDef but IndexAttrDef cannot represent dynamic rank. Note, if I'm mistaken and there is a way to write the @linalg_structured_op let me know.

The solution here simply implements the regionBuilder interface which is then picked up by LinalgDialect::addNamedOpBuilders.

Extension classes are added "by hand" that mirror the API of the @linalg_structured_ops. Note, the extension classes are added to to dialects/linalg/__init__.py instead of dialects/linalg/opdsl/ops/core_named_ops.py in order that they're not confused for opdsl generators/emitters.

@llvmbot
Copy link
Member

llvmbot commented Oct 20, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Maksim Levental (makslevental)

Changes

Currently, linalg.transpose and linalg.broadcast can't be emitted through either the C API or the python bindings (which of course go through the C API). See https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10.

The reason is even though they're named ops, there is no opdsl @<!-- -->linalg_structured_op for them and thus while they can be instantiated they cannot be passed to mlirLinalgFillBuiltinNamedOpRegion. I believe the issue is they both take a IndexAttrDef but IndexAttrDef cannot represent dynamic rank. Note, if I'm mistaken and there is a way to write the @<!-- -->linalg_structured_op let me know.

The solution here simply implements the regionBuilder interface which is then picked up by LinalgDialect::addNamedOpBuilders.

Extension classes are added "by hand" that mirror the API of the @<!-- -->linalg_structured_ops. Note, the extension classes are added to to dialects/linalg/__init__.py instead of dialects/linalg/opdsl/ops/core_named_ops.py in order that they're not confused for opdsl generators/emitters.


Full diff: https://github.com/llvm/llvm-project/pull/69742.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td (+14-2)
  • (modified) mlir/python/mlir/dialects/linalg/init.py (+50)
  • (modified) mlir/test/python/dialects/linalg/ops.py (+79)
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..751edd022883011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+        mlir::ArrayRef<mlir::NamedAttribute>) {
+      OpBuilder::InsertionGuard guard(b);
+      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+    }
+
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-      return nullptr;
+      return regionBuilder;
     }
 
     static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
     // Implement functions necessary for DestinationStyleOpInterface.
     MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
 
+    static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+        mlir::ArrayRef<mlir::NamedAttribute>) {
+      OpBuilder::InsertionGuard guard(b);
+      b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+    }
+
     static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
         mlir::ArrayRef<mlir::NamedAttribute>)>
       getRegionBuilder() {
-      return nullptr;
+      return regionBuilder;
     }
   }];
 
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1353870ec7257a9..8006950145e08e1 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,3 +55,53 @@
 #     TODO: guard against surprises and fail create Runtime Custom Ops with
 #     the same name as existing Core Named Ops.
 from .opdsl.ops.core_named_ops import *
+from .opdsl.lang.emitter import isa
+
+from ...ir import *
+from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+
+
+def transpose(
+    input: Union[Operation, OpView, Sequence[Value]],
+    *,
+    outs: List[Union[Operation, OpView, Sequence[Value]]],
+    permutation: Union[DenseI64ArrayAttr, List[int]],
+):
+
+    input = _get_op_result_or_value(input)
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+    op = TransposeOp(
+        result=result_types,
+        input=input,
+        init=init,
+        permutation=permutation,
+    )
+    fill_builtin_region(op.operation)
+    return op
+
+
+def broadcast(
+    input: Union[Operation, OpView, Sequence[Value]],
+    *,
+    outs: List[Union[Operation, OpView, Sequence[Value]]],
+    dimensions: Union[DenseI64ArrayAttr, List[int]],
+):
+
+    input = _get_op_result_or_value(input)
+    if len(outs) > 1:
+        raise ValueError(f"{outs=} must have length 1.")
+    init = _get_op_result_or_value(outs[0])
+    result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+    op = BroadcastOp(
+        result=result_types,
+        input=input,
+        init=init,
+        dimensions=dimensions,
+    )
+    fill_builtin_region(op.operation)
+    return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b728e0083781492..b147551c2e73dbd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
                 return linalg.matmul(lhs, rhs, outs=init)
 
     print(module)
+
+
+# CHECK-LABEL: TEST: testIdentityRegionOps
+@run
+def testIdentityRegionOps():
+    with Context(), Location.unknown():
+        module = Module.create()
+        f32 = F32Type.get()
+        with InsertionPoint(module.body):
+            # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
+            # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
+            op1 = tensor.EmptyOp([1, 13], f32)
+            op2 = tensor.EmptyOp([13, 1], f32)
+            # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
+            op3 = linalg.TransposeOp(
+                result=[RankedTensorType.get((13, 1), f32)],
+                input=op1,
+                init=op2,
+                permutation=[1, 0],
+            )
+            linalg.fill_builtin_region(op3.operation)
+
+            # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
+            op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+            # CHECK:         func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
+            @func.FuncOp.from_py_func(
+                MemRefType.get((1, 13), f32),
+                MemRefType.get((13, 1), f32),
+            )
+            def transpose_op(op1, op2):
+                # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
+                op3 = linalg.TransposeOp(
+                    result=[],
+                    input=op1,
+                    init=op2,
+                    permutation=[1, 0],
+                )
+                linalg.fill_builtin_region(op3.operation)
+                # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
+                op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+            # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
+            # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
+            op1 = tensor.EmptyOp([16], f32)
+            op2 = tensor.EmptyOp([16, 64], f32)
+            # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
+            op3 = linalg.BroadcastOp(
+                result=[RankedTensorType.get((16, 64), f32)],
+                input=op1,
+                init=op2,
+                dimensions=[1],
+            )
+            linalg.fill_builtin_region(op3.operation)
+
+            # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
+            op4 = tensor.EmptyOp([64], f32)
+            # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
+            op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
+
+            # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
+            @func.FuncOp.from_py_func(
+                MemRefType.get((16,), f32),
+                MemRefType.get((16, 64), f32),
+                MemRefType.get((64,), f32),
+            )
+            def broadcast_op(op1, op2, op3):
+                # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
+                op4 = linalg.BroadcastOp(
+                    result=[],
+                    input=op1,
+                    init=op2,
+                    dimensions=[1],
+                )
+                linalg.fill_builtin_region(op4.operation)
+                # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
+                op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
+
+    print(module)

@github-actions
Copy link

github-actions bot commented Oct 20, 2023

✅ With the latest revision this PR passed the Python code formatter.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants