Skip to content

[MLIR][Linalg] Ternary Op & Linalg select #91461

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ def UnaryFnAttr : EnumAttr<Linalg_Dialect, UnaryFn, "unary_fn"> {
def BinaryFnAttr : EnumAttr<Linalg_Dialect, BinaryFn, "binary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
def TernaryFnAttr : EnumAttr<Linalg_Dialect, TernaryFn, "ternary_fn"> {
let assemblyFormat = "`<` $value `>`";
}
def TypeFnAttr : EnumAttr<Linalg_Dialect, TypeFn, "type_fn"> {
let assemblyFormat = "`<` $value `>`";
}
Expand Down
6 changes: 6 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgEnums.td
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def BinaryFn : I32EnumAttr<"BinaryFn", "", [
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def TernaryFn : I32EnumAttr<"TernaryFn", "", [
I32EnumAttrCase<"select", 0>
]> {
let genSpecializedAttr = 0;
let cppNamespace = "::mlir::linalg";
}
def TypeFn : I32EnumAttr<"TypeFn", "", [
I32EnumAttrCase<"cast_signed", 0>,
I32EnumAttrCase<"cast_unsigned", 1>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,63 @@ structured_op: !LinalgStructuredOpConfig
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: select
cpp_class_name: SelectOp
doc: |-
Chooses one value based on a binary condition supplied as its first operand.

The shapes and element types must be identical. The appropriate casts,
broadcasts and reductions should be done previously to calling this op.

This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
structured_op: !LinalgStructuredOpConfig
args:
- !LinalgOperandDefConfig
name: cond
kind: input_tensor
type_var: U
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: lhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: rhs
kind: input_tensor
type_var: T1
shape_map: affine_map<() -> ()>
- !LinalgOperandDefConfig
name: O
kind: output_tensor
type_var: T1
shape_map: affine_map<() -> ()>
indexing_maps: !LinalgIndexingMapsConfig
static_indexing_maps:
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
- affine_map<() -> ()>
iterator_types: []
assignments:
- !ScalarAssign
arg: O
value: !ScalarExpression
scalar_fn:
kind: ternary
fn_name: select
operands:
- !ScalarExpression
scalar_arg: cond
- !ScalarExpression
scalar_arg: lhs
- !ScalarExpression
scalar_arg: rhs
--- !LinalgOpConfig
metadata: !LinalgOpMetadata
name: matmul
cpp_class_name: MatmulOp
Expand Down
19 changes: 19 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,25 @@ class RegionBuilderHelper {
llvm_unreachable("unsupported binary function");
}

// Build the ternary functions defined by OpDSL.
Value buildTernaryFn(TernaryFn ternaryFn, Value arg0, Value arg1,
Value arg2) {
bool headBool =
isInteger(arg0) && arg0.getType().getIntOrFloatBitWidth() == 1;
bool tailFloatingPoint =
isFloatingPoint(arg0) && isFloatingPoint(arg1) && isFloatingPoint(arg2);
bool tailInteger = isInteger(arg0) && isInteger(arg1) && isInteger(arg1);
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPointToEnd(&block);
switch (ternaryFn) {
case TernaryFn::select:
if (!headBool && !(tailFloatingPoint || tailInteger))
llvm_unreachable("unsupported non numeric type");
return builder.create<arith::SelectOp>(arg0.getLoc(), arg0, arg1, arg2);
}
llvm_unreachable("unsupported ternary function");
}

// Build the type functions defined by OpDSL.
Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
switch (typeFn) {
Expand Down
61 changes: 59 additions & 2 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/comprehension.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,8 @@ def __repr__(self):
class FunctionKind(Enum):
UNARY = 0
BINARY = 1
TYPE = 2
TERNARY = 2
TYPE = 3


class UnaryFnType:
Expand Down Expand Up @@ -339,6 +340,33 @@ class BinaryFn:
powf = BinaryFnType("powf")


class TernaryFnType:
"""Ternary function.

A ternary function takes three tensor expressions and returns the
function evaluation result.
"""

def __init__(self, fn_name: str):
self.fn_name = fn_name

def __call__(
self, arg0: TensorExpression, arg1: TensorExpression, arg2: TensorExpression
) -> "TensorFn":
return TensorFn(
FunctionKind.TERNARY, self.fn_name, None, None, [arg0, arg1, arg2]
)

def __repr__(self):
return f"{self.fn_name}"


class TernaryFn:
"""Ternary function namespace."""

select = TernaryFnType("select")


class TypeFnType:
"""Type conversion function.

Expand Down Expand Up @@ -437,7 +465,8 @@ class OperandKind(Enum):
INDEX_ATTR = 3
UNARY_FN_ATTR = 4
BINARY_FN_ATTR = 5
TYPE_FN_ATTR = 6
TERNARY_FN_ATTR = 6
TYPE_FN_ATTR = 7


class OperandDef:
Expand Down Expand Up @@ -489,6 +518,7 @@ def is_attribute(self) -> bool:
self.kind == OperandKind.INDEX_ATTR
or self.kind == OperandKind.UNARY_FN_ATTR
or self.kind == OperandKind.BINARY_FN_ATTR
or self.kind == OperandKind.TERNARY_FN_ATTR
or self.kind == OperandKind.TYPE_FN_ATTR
)

Expand Down Expand Up @@ -670,6 +700,33 @@ def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)


class TernaryFnAttrDef:
"""Ternary function attribute definition.

Ternary function attributes provide a way to make the arithmetic computation
parametrizable. Every attribute specifies a default Ternary function
that may be overwritten at operation instantiation time.
"""

def __init__(self, default: "TernaryFnType"):
if not isinstance(default, TernaryFnType):
raise ValueError(
f"TernaryFnAttrDef requires default of type TernaryFnType "
f"but got {default}"
)
self.operand_def = OperandDef(
OperandKind.TERNARY_FN_ATTR, default_fn=default.fn_name
)

def __call__(self, arg0: TensorExpression, arg1: TensorExpression) -> TensorFn:
return TensorFn(
FunctionKind.TERNARY, None, self.operand_def, None, [arg0, arg1]
)

def __getitem__(self, reduce_dims: Tuple[DimDef]) -> ReduceFnUse:
return ReduceFnUse(None, self, *reduce_dims)


class TypeFnAttrDef:
"""Type conversion function attribute definition.

Expand Down
7 changes: 7 additions & 0 deletions mlir/python/mlir/dialects/linalg/opdsl/lang/emitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def prepare_common_structured_op(
in [
OperandKind.UNARY_FN_ATTR,
OperandKind.BINARY_FN_ATTR,
OperandKind.TERNARY_FN_ATTR,
OperandKind.TYPE_FN_ATTR,
]
]
Expand Down Expand Up @@ -180,6 +181,12 @@ def prepare_common_structured_op(
f"Attribute {fn_attr.name} needs to be of type "
f"BinaryFnType but got {type(attr_val)}"
)
elif attr_kind == OperandKind.TERNARY_FN_ATTR:
if not isinstance(fn, TernaryFnType):
raise ValueError(
f"Attribute {fn_attr.name} needs to be of type "
f"TernaryFnType but got {type(attr_val)}"
)
else:
if not isinstance(fn, TypeFnType):
raise ValueError(
Expand Down
20 changes: 20 additions & 0 deletions mlir/python/mlir/dialects/linalg/opdsl/ops/core_named_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,26 @@ def powf(
O[None] = BinaryFn.powf(lhs[None], rhs[None])


@linalg_structured_op
def select(
cond=TensorDef(U),
lhs=TensorDef(T1),
rhs=TensorDef(T1),
O=TensorDef(T1, output=True),
):
"""Chooses one value based on a binary condition supplied as its first operand.

The shapes and element types must be identical. The appropriate casts,
broadcasts and reductions should be done previously to calling this op.

This means reduction/broadcast/element cast semantics is explicit. Further
passes can take that into account when lowering this code. For example,
a `linalg.broadcast` + `linalg.select` sequence can be lowered to a
`linalg.generic` with different affine maps for the two operands.
"""
O[None] = TernaryFn.select(cond[None], lhs[None], rhs[None])


@linalg_structured_op
def matmul(
A=TensorDef(T1, S.M, S.K),
Expand Down
25 changes: 25 additions & 0 deletions mlir/test/Dialect/Linalg/generalize-named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -791,6 +791,31 @@ func.func @generalize_powf(%lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,

// -----

func.func @generalize_select(%cond: memref<7x14x21xi1>, %lhs: memref<7x14x21xf32>, %rhs: memref<7x14x21xf32>,
%out: memref<7x14x21xf32>) {
linalg.select ins(%cond, %lhs, %rhs: memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
outs(%out: memref<7x14x21xf32>)
return
}

// CHECK: #[[MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)>

// CHECK: func @generalize_select
// CHECK-SAME: (%[[COND:.+]]: memref<7x14x21xi1>, %[[LHS:.+]]: memref<7x14x21xf32>, %[[RHS:.+]]: memref<7x14x21xf32>,
// CHECK-SAME: %[[OUT:.+]]: memref<7x14x21xf32>)

// CHECK: linalg.generic
// CHECK-SAME: indexing_maps = [#[[MAP]], #[[MAP]], #[[MAP]], #[[MAP]]]
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]}
// CHECK-SAME: ins(%[[COND]], %[[LHS]], %[[RHS]] : memref<7x14x21xi1>, memref<7x14x21xf32>, memref<7x14x21xf32>)
// CHECK-SAME: outs(%[[OUT]] : memref<7x14x21xf32>)

// CHECK: ^{{.+}}(%[[BBARG0:.+]]: i1, %[[BBARG1:.+]]: f32, %[[BBARG2:.+]]: f32, %[[BBARG3:.+]]: f32)
// CHECK-NEXT: %[[select:.+]] = arith.select %[[BBARG0]], %[[BBARG1]], %[[BBARG2]] : f32
// CHECK-NEXT: linalg.yield %[[select]] : f32


// -----

// CHECK-LABEL: func @fill_tensor
func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vector<2x4xf32>>) {
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Linalg/named-ops-fail.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -334,3 +334,19 @@ func.func @powf_broadcast(%arg0: memref<8x16xf32>, %arg1: memref<4x8x16xf32>, %a
return
}

// -----

func.func @select_type_cast(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf16>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
// CHECK: op failed to verify that all of {true_value, false_value, result} have same type
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf16>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}

// -----

func.func @select_wrong_condition_type(%arg0: memref<4x8x16xf32>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
// CHECK: op operand #0 must be bool-like, but got 'f32'
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xf32>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}

34 changes: 34 additions & 0 deletions mlir/test/Dialect/Linalg/named-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1924,3 +1924,37 @@ func.func @fill_tensor(%f: f32, %v: vector<2x4xf32>) -> (tensor<f32>, tensor<vec
%1 = linalg.fill ins(%v : vector<2x4xf32>) outs(%e1 : tensor<vector<2x4xf32>>) -> tensor<vector<2x4xf32>>
return %0, %1: tensor<f32>, tensor<vector<2x4xf32>>
}

// -----

// CHECK-LABEL: func @select_dynamic
func.func @select_dynamic(%arg0: memref<?x?x?xi1>, %arg1: memref<?x?x?xf32>, %arg2: memref<?x?x?xf32>, %arg3: memref<?x?x?xf32>) {
// CHECK: linalg.select
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<?x?x?xf32>)
linalg.select ins(%arg0, %arg1, %arg2 : memref<?x?x?xi1>, memref<?x?x?xf32>, memref<?x?x?xf32>) outs(%arg3: memref<?x?x?xf32>)
return
}

// -----

// CHECK-LABEL: func @select_static
func.func @select_static(%arg0: memref<4x8x16xi1>, %arg1: memref<4x8x16xf32>, %arg2: memref<4x8x16xf32>, %arg3: memref<4x8x16xf32>) {
// CHECK: linalg.select
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>)
// CHECK-SAME: outs(%{{.+}} : memref<4x8x16xf32>)
linalg.select ins(%arg0, %arg1, %arg2 : memref<4x8x16xi1>, memref<4x8x16xf32>, memref<4x8x16xf32>) outs(%arg3: memref<4x8x16xf32>)
return
}

// -----

// CHECK-LABEL: func @select_tensor
func.func @select_tensor(%arg0: tensor<4x8x16xi1>, %arg1: tensor<4x8x16xf32>, %arg2: tensor<4x8x16xf32>) -> tensor<4x8x16xf32> {
%0 = tensor.empty() : tensor<4x8x16xf32>
// CHECK: linalg.select
// CHECK-SAME: ins(%{{.+}}, %{{.+}}, %{{.+}} : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>)
// CHECK-SAME: outs(%{{.+}} : tensor<4x8x16xf32>)
%1 = linalg.select ins(%arg0, %arg1, %arg2 : tensor<4x8x16xi1>, tensor<4x8x16xf32>, tensor<4x8x16xf32>) outs(%0: tensor<4x8x16xf32>) -> tensor<4x8x16xf32>
return %1 : tensor<4x8x16xf32>
}
Loading