Skip to content

[mlir][linalg] regionBuilder for transpose, broadcast #69742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -442,10 +442,16 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute>) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return nullptr;
return regionBuilder;
}

static void createRegion(::mlir::OpBuilder &opBuilder,
Expand Down Expand Up @@ -510,10 +516,16 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
// Implement functions necessary for DestinationStyleOpInterface.
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute>) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
getRegionBuilder() {
return nullptr;
return regionBuilder;
}
}];

Expand Down
48 changes: 48 additions & 0 deletions mlir/python/mlir/dialects/linalg/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,51 @@
# TODO: guard against surprises and fail create Runtime Custom Ops with
# the same name as existing Core Named Ops.
from .opdsl.ops.core_named_ops import *
from .opdsl.lang.emitter import isa

from ...ir import *
from .._ods_common import get_op_result_or_value as _get_op_result_or_value


def transpose(
input: Union[Operation, OpView, Sequence[Value]],
*,
outs: List[Union[Operation, OpView, Sequence[Value]]],
permutation: Union[DenseI64ArrayAttr, List[int]],
):
input = _get_op_result_or_value(input)
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isa(RankedTensorType, init.type) else []

op = TransposeOp(
result=result_types,
input=input,
init=init,
permutation=permutation,
)
fill_builtin_region(op.operation)
return op


def broadcast(
input: Union[Operation, OpView, Sequence[Value]],
*,
outs: List[Union[Operation, OpView, Sequence[Value]]],
dimensions: Union[DenseI64ArrayAttr, List[int]],
):
input = _get_op_result_or_value(input)
if len(outs) > 1:
raise ValueError(f"{outs=} must have length 1.")
init = _get_op_result_or_value(outs[0])
result_types = [init.type] if isa(RankedTensorType, init.type) else []

op = BroadcastOp(
result=result_types,
input=input,
init=init,
dimensions=dimensions,
)
fill_builtin_region(op.operation)
return op
79 changes: 79 additions & 0 deletions mlir/test/python/dialects/linalg/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,82 @@ def pass_an_op_directly(arg0, arg1):
return linalg.matmul(lhs, rhs, outs=init)

print(module)


# CHECK-LABEL: TEST: testIdentityRegionOps
@run
def testIdentityRegionOps():
with Context(), Location.unknown():
module = Module.create()
f32 = F32Type.get()
with InsertionPoint(module.body):
# CHECK: %[[VAL_0:.*]] = tensor.empty() : tensor<1x13xf32>
# CHECK: %[[VAL_1:.*]] = tensor.empty() : tensor<13x1xf32>
op1 = tensor.EmptyOp([1, 13], f32)
op2 = tensor.EmptyOp([13, 1], f32)
# CHECK: %[[VAL_2:.*]] = linalg.transpose ins(%[[VAL_0]] : tensor<1x13xf32>) outs(%[[VAL_1]] : tensor<13x1xf32>) permutation = [1, 0]
op3 = linalg.TransposeOp(
result=[RankedTensorType.get((13, 1), f32)],
input=op1,
init=op2,
permutation=[1, 0],
)
linalg.fill_builtin_region(op3.operation)

# CHECK: %[[VAL_3:.*]] = linalg.transpose ins(%[[VAL_1]] : tensor<13x1xf32>) outs(%[[VAL_0]] : tensor<1x13xf32>) permutation = [1, 0]
op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])

# CHECK: func.func @transpose_op(%[[VAL_4:.*]]: memref<1x13xf32>, %[[VAL_5:.*]]: memref<13x1xf32>)
@func.FuncOp.from_py_func(
MemRefType.get((1, 13), f32),
MemRefType.get((13, 1), f32),
)
def transpose_op(op1, op2):
# CHECK: linalg.transpose ins(%[[VAL_4]] : memref<1x13xf32>) outs(%[[VAL_5]] : memref<13x1xf32>) permutation = [1, 0]
op3 = linalg.TransposeOp(
result=[],
input=op1,
init=op2,
permutation=[1, 0],
)
linalg.fill_builtin_region(op3.operation)
# CHECK: linalg.transpose ins(%[[VAL_5]] : memref<13x1xf32>) outs(%[[VAL_4]] : memref<1x13xf32>) permutation = [1, 0]
op4 = linalg.transpose(op2, outs=[op1], permutation=[1, 0])

# CHECK: %[[VAL_6:.*]] = tensor.empty() : tensor<16xf32>
# CHECK: %[[VAL_7:.*]] = tensor.empty() : tensor<16x64xf32>
op1 = tensor.EmptyOp([16], f32)
op2 = tensor.EmptyOp([16, 64], f32)
# CHECK: %[[VAL_8:.*]] = linalg.broadcast ins(%[[VAL_6]] : tensor<16xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [1]
op3 = linalg.BroadcastOp(
result=[RankedTensorType.get((16, 64), f32)],
input=op1,
init=op2,
dimensions=[1],
)
linalg.fill_builtin_region(op3.operation)

# CHECK: %[[VAL_9:.*]] = tensor.empty() : tensor<64xf32>
op4 = tensor.EmptyOp([64], f32)
# CHECK: %[[VAL_10:.*]] = linalg.broadcast ins(%[[VAL_9]] : tensor<64xf32>) outs(%[[VAL_7]] : tensor<16x64xf32>) dimensions = [0]
op5 = linalg.broadcast(op4, outs=[op2], dimensions=[0])

# CHECK: func.func @broadcast_op(%[[VAL_11:.*]]: memref<16xf32>, %[[VAL_12:.*]]: memref<16x64xf32>, %[[VAL_13:.*]]: memref<64xf32>)
@func.FuncOp.from_py_func(
MemRefType.get((16,), f32),
MemRefType.get((16, 64), f32),
MemRefType.get((64,), f32),
)
def broadcast_op(op1, op2, op3):
# CHECK: linalg.broadcast ins(%[[VAL_11]] : memref<16xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [1]
op4 = linalg.BroadcastOp(
result=[],
input=op1,
init=op2,
dimensions=[1],
)
linalg.fill_builtin_region(op4.operation)
# CHECK: linalg.broadcast ins(%[[VAL_13]] : memref<64xf32>) outs(%[[VAL_12]] : memref<16x64xf32>) dimensions = [0]
op5 = linalg.broadcast(op3, outs=[op2], dimensions=[0])

print(module)