-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][linalg] regionBuilder for transpose, broadcast #69742
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir Author: Maksim Levental (makslevental) ChangesCurrently, The reason is even though they're named ops, there is no opdsl The solution here simply implements the Extension classes are added "by hand" that mirror the API of the Full diff: https://github.com/llvm/llvm-project/pull/69742.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
index 21a5e5cc47aeb5c..751edd022883011 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
+ static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
+ mlir::ArrayRef<mlir::NamedAttribute>) {
+ OpBuilder::InsertionGuard guard(b);
+ b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
+ }
+
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
- return nullptr;
+ return regionBuilder;
}
}];
diff --git a/mlir/python/mlir/dialects/linalg/__init__.py b/mlir/python/mlir/dialects/linalg/__init__.py
index 1353870ec7257a9..8006950145e08e1 100644
--- a/mlir/python/mlir/dialects/linalg/__init__.py
+++ b/mlir/python/mlir/dialects/linalg/__init__.py
@@ -55,3 +55,53 @@
# TODO: guard against surprises and fail create Runtime Custom Ops with
# the same name as existing Core Named Ops.
from .opdsl.ops.core_named_ops import *
+from .opdsl.lang.emitter import isa
+
+from ...ir import *
+from .._ods_common import get_op_result_or_value as _get_op_result_or_value
+
+
+def transpose(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ permutation: Union[DenseI64ArrayAttr, List[int]],
+):
+
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = TransposeOp(
+ result=result_types,
+ input=input,
+ init=init,
+ permutation=permutation,
+ )
+ fill_builtin_region(op.operation)
+ return op
+
+
+def broadcast(
+ input: Union[Operation, OpView, Sequence[Value]],
+ *,
+ outs: List[Union[Operation, OpView, Sequence[Value]]],
+ dimensions: Union[DenseI64ArrayAttr, List[int]],
+):
+
+ input = _get_op_result_or_value(input)
+ if len(outs) > 1:
+ raise ValueError(f"{outs=} must have length 1.")
+ init = _get_op_result_or_value(outs[0])
+ result_types = [init.type] if isa(RankedTensorType, init.type) else []
+
+ op = BroadcastOp(
+ result=result_types,
+ input=input,
+ init=init,
+ dimensions=dimensions,
+ )
+ fill_builtin_region(op.operation)
+ return op
diff --git a/mlir/test/python/dialects/linalg/ops.py b/mlir/test/python/dialects/linalg/ops.py
index b728e0083781492..b147551c2e73dbd 100644
--- a/mlir/test/python/dialects/linalg/ops.py
+++ b/mlir/test/python/dialects/linalg/ops.py
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
return linalg.matmul(lhs, rhs, outs=init)
print(module)
+
+
+# CHECK-LABEL: TEST: testIdentityRegionOps
+@run
+def testIdentityRegionOps():
+ with Context(), Location.unknown():
+ module = Module.create()
+ f32 = F32Type.get()
+ with InsertionPoint(module.body):
+ # CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
+ # CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
+ op1 = tensor.EmptyOp([1, 13], f32)
+ op2 = tensor.EmptyOp([13, 1], f32)
+ # CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[RankedTensorType.get((13, 1), f32)],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((1, 13), f32),
+ MemRefType.get((13, 1), f32),
+ )
+ def transpose_op(op1, op2):
+ # CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
+ op3 = linalg.TransposeOp(
+ result=[],
+ input=op1,
+ init=op2,
+ permutation=[1, 0],
+ )
+ linalg.fill_builtin_region(op3.operation)
+ # CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
+ op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
+
+ # CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
+ # CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
+ op1 = tensor.EmptyOp([16], f32)
+ op2 = tensor.EmptyOp([16, 64], f32)
+ # CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
+ op3 = linalg.BroadcastOp(
+ result=[RankedTensorType.get((16, 64), f32)],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op3.operation)
+
+ # CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
+ op4 = tensor.EmptyOp([64], f32)
+ # CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
+
+ # CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
+ @func.FuncOp.from_py_func(
+ MemRefType.get((16,), f32),
+ MemRefType.get((16, 64), f32),
+ MemRefType.get((64,), f32),
+ )
+ def broadcast_op(op1, op2, op3):
+ # CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
+ op4 = linalg.BroadcastOp(
+ result=[],
+ input=op1,
+ init=op2,
+ dimensions=[1],
+ )
+ linalg.fill_builtin_region(op4.operation)
+ # CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
+ op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
+
+ print(module)
|
✅ With the latest revision this PR passed the Python code formatter. |
442585c
to
55f260e
Compare
Currently,
linalg.transpose
andlinalg.broadcast
can't be emitted through either the C API or the python bindings (which of course go through the C API). See https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10.The reason is even though they're named ops, there is no opdsl
@linalg_structured_op
for them and thus while they can be instantiated they cannot be passed tomlirLinalgFillBuiltinNamedOpRegion
. I believe the issue is they both take aIndexAttrDef
butIndexAttrDef
cannot represent dynamic rank. Note, if I'm mistaken and there is a way to write the@linalg_structured_op
let me know.The solution here simply implements the
regionBuilder
interface which is then picked up byLinalgDialect::addNamedOpBuilders
.Extension classes are added "by hand" that mirror the API of the
@linalg_structured_op
s. Note, the extension classes are added to todialects/linalg/__init__.py
instead ofdialects/linalg/opdsl/ops/core_named_ops.py
in order that they're not confused for opdsl generators/emitters.