Skip to content

Commit 3bdcd27

Browse files
committed
[MLIR] Introduce a SelectOpInterface
This commit introduces a `SelectOpInterface` that can be used to handle select-like operations generically. Select operations are similar to control flow operations, as they forward operands depending on conditions. This is the reason why it was placed to the already existing control flow interfaces.
1 parent 065d2d9 commit 3bdcd27

File tree

7 files changed

+78
-8
lines changed

7 files changed

+78
-8
lines changed

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/OpDefinition.h"
1515
#include "mlir/IR/OpImplementation.h"
1616
#include "mlir/Interfaces/CastInterfaces.h"
17+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
1718
#include "mlir/Interfaces/InferIntRangeInterface.h"
1819
#include "mlir/Interfaces/InferTypeOpInterface.h"
1920
#include "mlir/Interfaces/SideEffectInterfaces.h"

mlir/include/mlir/Dialect/Arith/IR/ArithOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
include "mlir/Dialect/Arith/IR/ArithBase.td"
1313
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
1414
include "mlir/Interfaces/CastInterfaces.td"
15+
include "mlir/Interfaces/ControlFlowInterfaces.td"
1516
include "mlir/Interfaces/InferIntRangeInterface.td"
1617
include "mlir/Interfaces/InferTypeOpInterface.td"
1718
include "mlir/Interfaces/SideEffectInterfaces.td"
@@ -1578,6 +1579,7 @@ def SelectOp : Arith_Op<"select", [Pure,
15781579
AllTypesMatch<["true_value", "false_value", "result"]>,
15791580
BooleanConditionOrMatchingShape<"condition", "result">,
15801581
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRangesFromOptional"]>,
1582+
DeclareOpInterfaceMethods<SelectOpInterface>,
15811583
] # ElementwiseMappable.traits> {
15821584
let summary = "select operation";
15831585
let description = [{

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,8 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector",
835835
def LLVM_SelectOp
836836
: LLVM_Op<"select",
837837
[Pure, AllTypesMatch<["trueValue", "falseValue", "res"]>,
838-
DeclareOpInterfaceMethods<FastmathFlagsInterface>]>,
838+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
839+
DeclareOpInterfaceMethods<SelectOpInterface>]>,
839840
LLVM_Builder<
840841
"$res = builder.CreateSelect($condition, $trueValue, $falseValue);"> {
841842
let arguments = (ins LLVM_ScalarOrVectorOf<I1>:$condition,

mlir/include/mlir/Interfaces/ControlFlowInterfaces.td

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,27 @@ def RegionBranchTerminatorOpInterface :
343343
}];
344344
}
345345

346+
def SelectOpInterface : OpInterface<"SelectOpInterface"> {
347+
let description = [{
348+
This interface provides information for select-like operations, i.e.,
349+
operations that forward specific operands to the output, depending on a
350+
condition.
351+
}];
352+
let cppNamespace = "::mlir";
353+
354+
let methods = [
355+
InterfaceMethod<[{
356+
Returns the operand that would be chosen for a false condition.
357+
}], "::mlir::Value", "getFalseValue", (ins)>,
358+
InterfaceMethod<[{
359+
Returns the operand that would be chosen for a true condition.
360+
}], "::mlir::Value", "getTrueValue", (ins)>,
361+
InterfaceMethod<[{
362+
Returns the condition operand.
363+
}], "::mlir::Value", "getCondition", (ins)>
364+
];
365+
}
366+
346367
//===----------------------------------------------------------------------===//
347368
// ControlFlow Traits
348369
//===----------------------------------------------------------------------===//

mlir/lib/Analysis/SliceWalk.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,11 @@ getBlockPredecessorOperands(BlockArgument blockArg) {
104104

105105
std::optional<SmallVector<Value>>
106106
mlir::getControlFlowPredecessors(Value value) {
107-
SmallVector<Value> result;
108107
if (OpResult opResult = dyn_cast<OpResult>(value)) {
109-
auto regionOp = dyn_cast<RegionBranchOpInterface>(opResult.getOwner());
108+
if (auto selectOp = opResult.getDefiningOp<SelectOpInterface>())
109+
return SmallVector<Value>(
110+
{selectOp.getTrueValue(), selectOp.getFalseValue()});
111+
auto regionOp = opResult.getDefiningOp<RegionBranchOpInterface>();
110112
// If the interface is not implemented, there are no control flow
111113
// predecessors to work with.
112114
if (!regionOp)

mlir/lib/Dialect/LLVMIR/Transforms/InlinerInterfaceImpl.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -235,11 +235,6 @@ getUnderlyingObjectSet(Value pointerValue) {
235235
if (auto addrCast = val.getDefiningOp<LLVM::AddrSpaceCastOp>())
236236
return WalkContinuation::advanceTo(addrCast.getOperand());
237237

238-
// TODO: Add a SelectLikeOpInterface and use it in the slicing utility.
239-
if (auto selectOp = val.getDefiningOp<LLVM::SelectOp>())
240-
return WalkContinuation::advanceTo(
241-
{selectOp.getTrueValue(), selectOp.getFalseValue()});
242-
243238
// Attempt to advance to control flow predecessors.
244239
std::optional<SmallVector<Value>> controlFlowPredecessors =
245240
getControlFlowPredecessors(val);

mlir/test/Dialect/LLVMIR/inlining-alias-scopes.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,3 +508,51 @@ llvm.func @noalias_with_region(%arg0: !llvm.ptr) {
508508
llvm.call @region(%arg0) : (!llvm.ptr) -> ()
509509
llvm.return
510510
}
511+
512+
// -----
513+
514+
// CHECK-DAG: #[[DOMAIN:.*]] = #llvm.alias_scope_domain<{{.*}}>
515+
// CHECK-DAG: #[[$ARG_SCOPE:.*]] = #llvm.alias_scope<id = {{.*}}, domain = #[[DOMAIN]]{{(,.*)?}}>
516+
517+
llvm.func @foo(%arg: i32)
518+
519+
llvm.func @func(%arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
520+
%cond = llvm.load %arg1 : !llvm.ptr -> i1
521+
%1 = llvm.getelementptr inbounds %arg0[1] : (!llvm.ptr) -> !llvm.ptr, f32
522+
%selected = llvm.select %cond, %arg0, %1 : i1, !llvm.ptr
523+
%2 = llvm.load %selected : !llvm.ptr -> i32
524+
llvm.call @foo(%2) : (i32) -> ()
525+
llvm.return
526+
}
527+
528+
// CHECK-LABEL: llvm.func @selects
529+
// CHECK: llvm.load
530+
// CHECK-NOT: alias_scopes
531+
// CHECK-SAME: noalias_scopes = [#[[$ARG_SCOPE]]]
532+
// CHECK: llvm.load
533+
// CHECK-SAME: alias_scopes = [#[[$ARG_SCOPE]]]
534+
llvm.func @selects(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
535+
llvm.call @func(%arg0, %arg1) : (!llvm.ptr, !llvm.ptr) -> ()
536+
llvm.return
537+
}
538+
539+
// -----
540+
541+
llvm.func @foo(%arg: i32)
542+
543+
llvm.func @func(%cond: i1, %arg0: !llvm.ptr {llvm.noalias}, %arg1: !llvm.ptr) {
544+
%selected = llvm.select %cond, %arg0, %arg1 : i1, !llvm.ptr
545+
%2 = llvm.load %selected : !llvm.ptr -> i32
546+
llvm.call @foo(%2) : (i32) -> ()
547+
llvm.return
548+
}
549+
550+
// CHECK-LABEL: llvm.func @multi_ptr_select
551+
// CHECK: llvm.load
552+
// CHECK-NOT: alias_scopes
553+
// CHECK-NOT: noalias_scopes
554+
// CHECK: llvm.call @foo
555+
llvm.func @multi_ptr_select(%cond: i1, %arg0: !llvm.ptr, %arg1: !llvm.ptr) {
556+
llvm.call @func(%cond, %arg0, %arg1) : (i1, !llvm.ptr, !llvm.ptr) -> ()
557+
llvm.return
558+
}

0 commit comments

Comments
 (0)