Skip to content

Commit a969404

Browse files
authored
[mlir][linalg] regionBuilder for transpose, broadcast (llvm#69742)
Currently, `linalg.transpose` and `linalg.broadcast` can't be emitted through either the C API or the python bindings (which of course go through the C API). See https://discourse.llvm.org/t/how-to-build-linalg-transposeop-in-mlir-pybind/73989/10. The reason is even though they're named ops, there is no opdsl `@linalg_structured_op` for them and thus while they can be instantiated they cannot be passed to [`mlirLinalgFillBuiltinNamedOpRegion`](https://github.com/llvm/llvm-project/blob/a7cccb9cbb2b9954684cbea37615303a59719973/mlir/lib/CAPI/Dialect/Linalg.cpp#L18). I believe the issue is they both take a `IndexAttrDef` but `IndexAttrDef` cannot represent dynamic rank. Note, if I'm mistaken and there is a way to write the `@linalg_structured_op` let me know. The solution here simply implements the `regionBuilder` interface which is then picked up by [`LinalgDialect::addNamedOpBuilders`](https://github.com/llvm/llvm-project/blob/7557530f428a2f226d8d925c33d527dfcfdcb0c5/mlir/lib/Dialect/Linalg/IR/LinalgDialect.cpp#L116). Extension classes are added "by hand" that mirror the API of the `@linalg_structured_op`s. Note, the extension classes are added to to `dialects/linalg/__init__.py` instead of `dialects/linalg/opdsl/ops/core_named_ops.py` in order that they're not confused for opdsl generators/emitters.
1 parent 7ba99fd commit a969404

File tree

3 files changed

+141
-2
lines changed

3 files changed

+141
-2
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
442442
// Implement functions necessary for DestinationStyleOpInterface.
443443
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
444444

445+
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
446+
mlir::ArrayRef<mlir::NamedAttribute>) {
447+
OpBuilder::InsertionGuard guard(b);
448+
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
449+
}
450+
445451
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
446452
mlir::ArrayRef<mlir::NamedAttribute>)>
447453
getRegionBuilder() {
448-
return nullptr;
454+
return regionBuilder;
449455
}
450456

451457
static void createRegion(::mlir::OpBuilder &opBuilder,
@@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
510516
// Implement functions necessary for DestinationStyleOpInterface.
511517
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }
512518

519+
static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
520+
mlir::ArrayRef<mlir::NamedAttribute>) {
521+
OpBuilder::InsertionGuard guard(b);
522+
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
523+
}
524+
513525
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
514526
mlir::ArrayRef<mlir::NamedAttribute>)>
515527
getRegionBuilder() {
516-
return nullptr;
528+
return regionBuilder;
517529
}
518530
}];
519531

mlir/python/mlir/dialects/linalg/__init__.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,51 @@
5555
# TODO: guard against surprises and fail create Runtime Custom Ops with
5656
# the same name as existing Core Named Ops.
5757
from .opdsl.ops.core_named_ops import *
58+
from .opdsl.lang.emitter import isa
59+
60+
from ...ir import *
61+
from .._ods_common import get_op_result_or_value as _get_op_result_or_value
62+
63+
64+
def transpose(
65+
input: Union[Operation, OpView, Sequence[Value]],
66+
*,
67+
outs: List[Union[Operation, OpView, Sequence[Value]]],
68+
permutation: Union[DenseI64ArrayAttr, List[int]],
69+
):
70+
input = _get_op_result_or_value(input)
71+
if len(outs) > 1:
72+
raise ValueError(f"{outs=} must have length 1.")
73+
init = _get_op_result_or_value(outs[0])
74+
result_types = [init.type] if isa(RankedTensorType, init.type) else []
75+
76+
op = TransposeOp(
77+
result=result_types,
78+
input=input,
79+
init=init,
80+
permutation=permutation,
81+
)
82+
fill_builtin_region(op.operation)
83+
return op
84+
85+
86+
def broadcast(
87+
input: Union[Operation, OpView, Sequence[Value]],
88+
*,
89+
outs: List[Union[Operation, OpView, Sequence[Value]]],
90+
dimensions: Union[DenseI64ArrayAttr, List[int]],
91+
):
92+
input = _get_op_result_or_value(input)
93+
if len(outs) > 1:
94+
raise ValueError(f"{outs=} must have length 1.")
95+
init = _get_op_result_or_value(outs[0])
96+
result_types = [init.type] if isa(RankedTensorType, init.type) else []
97+
98+
op = BroadcastOp(
99+
result=result_types,
100+
input=input,
101+
init=init,
102+
dimensions=dimensions,
103+
)
104+
fill_builtin_region(op.operation)
105+
return op

mlir/test/python/dialects/linalg/ops.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
157157
return linalg.matmul(lhs, rhs, outs=init)
158158

159159
print(module)
160+
161+
162+
# CHECK-LABEL: TEST: testIdentityRegionOps
163+
@run
164+
def testIdentityRegionOps():
165+
with Context(), Location.unknown():
166+
module = Module.create()
167+
f32 = F32Type.get()
168+
with InsertionPoint(module.body):
169+
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
170+
# CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
171+
op1 = tensor.EmptyOp([1, 13], f32)
172+
op2 = tensor.EmptyOp([13, 1], f32)
173+
# CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
174+
op3 = linalg.TransposeOp(
175+
result=[RankedTensorType.get((13, 1), f32)],
176+
input=op1,
177+
init=op2,
178+
permutation=[1, 0],
179+
)
180+
linalg.fill_builtin_region(op3.operation)
181+
182+
# CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
183+
op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
184+
185+
# CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
186+
@func.FuncOp.from_py_func(
187+
MemRefType.get((1, 13), f32),
188+
MemRefType.get((13, 1), f32),
189+
)
190+
def transpose_op(op1, op2):
191+
# CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
192+
op3 = linalg.TransposeOp(
193+
result=[],
194+
input=op1,
195+
init=op2,
196+
permutation=[1, 0],
197+
)
198+
linalg.fill_builtin_region(op3.operation)
199+
# CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
200+
op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])
201+
202+
# CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
203+
# CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
204+
op1 = tensor.EmptyOp([16], f32)
205+
op2 = tensor.EmptyOp([16, 64], f32)
206+
# CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
207+
op3 = linalg.BroadcastOp(
208+
result=[RankedTensorType.get((16, 64), f32)],
209+
input=op1,
210+
init=op2,
211+
dimensions=[1],
212+
)
213+
linalg.fill_builtin_region(op3.operation)
214+
215+
# CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
216+
op4 = tensor.EmptyOp([64], f32)
217+
# CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
218+
op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])
219+
220+
# CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
221+
@func.FuncOp.from_py_func(
222+
MemRefType.get((16,), f32),
223+
MemRefType.get((16, 64), f32),
224+
MemRefType.get((64,), f32),
225+
)
226+
def broadcast_op(op1, op2, op3):
227+
# CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
228+
op4 = linalg.BroadcastOp(
229+
result=[],
230+
input=op1,
231+
init=op2,
232+
dimensions=[1],
233+
)
234+
linalg.fill_builtin_region(op4.operation)
235+
# CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
236+
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])
237+
238+
print(module)

0 commit comments

Comments
 (0)