Skip to content

Commit 70343c8

Browse files
authored
[mlir][flang] Added Weighted[Region]BranchOpInterface's. (#142079)
The new interfaces provide getters and setters for the weight information about the branches of BranchOpInterface and RegionBranchOpInterface operations. These interfaces are done the same way as LLVM dialect's BranchWeightOpInterface. The plan is to produce this information in Flang, e.g. mark most probably "cold" code as such and allow LLVM to order basic blocks accordingly. An example of such a code is copy loops generated for arrays repacking - we can mark it as "cold" assuming that the copy will not happen dynamically. If the copy actually happens the overhead of the copy is probably high enough so that we may not care about the little overhead of jumping to the "cold" code and fetching it.
1 parent af65cb6 commit 70343c8

File tree

23 files changed

+461
-135
lines changed

23 files changed

+461
-135
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2323,9 +2323,13 @@ def fir_DoLoopOp : region_Op<"do_loop", [AttrSizedOperandSegments,
23232323
}];
23242324
}
23252325

2326-
def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterface, [
2327-
"getRegionInvocationBounds", "getEntrySuccessorRegions"]>, RecursiveMemoryEffects,
2328-
NoRegionArguments]> {
2326+
def fir_IfOp
2327+
: region_Op<
2328+
"if", [DeclareOpInterfaceMethods<
2329+
RegionBranchOpInterface, ["getRegionInvocationBounds",
2330+
"getEntrySuccessorRegions"]>,
2331+
RecursiveMemoryEffects, NoRegionArguments,
2332+
WeightedRegionBranchOpInterface]> {
23292333
let summary = "if-then-else conditional operation";
23302334
let description = [{
23312335
Used to conditionally execute operations. This operation is the FIR
@@ -2342,7 +2346,8 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
23422346
```
23432347
}];
23442348

2345-
let arguments = (ins I1:$condition);
2349+
let arguments = (ins I1:$condition,
2350+
OptionalAttr<DenseI32ArrayAttr>:$region_weights);
23462351
let results = (outs Variadic<AnyType>:$results);
23472352

23482353
let regions = (region
@@ -2371,6 +2376,11 @@ def fir_IfOp : region_Op<"if", [DeclareOpInterfaceMethods<RegionBranchOpInterfac
23712376

23722377
void resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,
23732378
unsigned resultNum);
2379+
2380+
/// Returns the display name string for the region_weights attribute.
2381+
static constexpr llvm::StringRef getWeightsAttrAssemblyName() {
2382+
return "weights";
2383+
}
23742384
}];
23752385
}
23762386

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4418,6 +4418,19 @@ mlir::ParseResult fir::IfOp::parse(mlir::OpAsmParser &parser,
44184418
parser.resolveOperand(cond, i1Type, result.operands))
44194419
return mlir::failure();
44204420

4421+
if (mlir::succeeded(
4422+
parser.parseOptionalKeyword(getWeightsAttrAssemblyName()))) {
4423+
if (parser.parseLParen())
4424+
return mlir::failure();
4425+
mlir::DenseI32ArrayAttr weights;
4426+
if (parser.parseCustomAttributeWithFallback(weights, mlir::Type{}))
4427+
return mlir::failure();
4428+
if (weights)
4429+
result.addAttribute(getRegionWeightsAttrName(result.name), weights);
4430+
if (parser.parseRParen())
4431+
return mlir::failure();
4432+
}
4433+
44214434
if (parser.parseOptionalArrowTypeList(result.types))
44224435
return mlir::failure();
44234436

@@ -4449,6 +4462,11 @@ llvm::LogicalResult fir::IfOp::verify() {
44494462
void fir::IfOp::print(mlir::OpAsmPrinter &p) {
44504463
bool printBlockTerminators = false;
44514464
p << ' ' << getCondition();
4465+
if (auto weights = getRegionWeightsAttr()) {
4466+
p << ' ' << getWeightsAttrAssemblyName() << '(';
4467+
p.printStrippedAttrOrType(weights);
4468+
p << ')';
4469+
}
44524470
if (!getResults().empty()) {
44534471
p << " -> (" << getResultTypes() << ')';
44544472
printBlockTerminators = true;
@@ -4464,7 +4482,8 @@ void fir::IfOp::print(mlir::OpAsmPrinter &p) {
44644482
p.printRegion(otherReg, /*printEntryBlockArgs=*/false,
44654483
printBlockTerminators);
44664484
}
4467-
p.printOptionalAttrDict((*this)->getAttrs());
4485+
p.printOptionalAttrDict((*this)->getAttrs(),
4486+
/*elideAttrs=*/{getRegionWeightsAttrName()});
44684487
}
44694488

44704489
void fir::IfOp::resultToSourceOps(llvm::SmallVectorImpl<mlir::Value> &results,

flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,9 +212,12 @@ class CfgIfConv : public mlir::OpRewritePattern<fir::IfOp> {
212212
}
213213

214214
rewriter.setInsertionPointToEnd(condBlock);
215-
rewriter.create<mlir::cf::CondBranchOp>(
215+
auto branchOp = rewriter.create<mlir::cf::CondBranchOp>(
216216
loc, ifOp.getCondition(), ifOpBlock, llvm::ArrayRef<mlir::Value>(),
217217
otherwiseBlock, llvm::ArrayRef<mlir::Value>());
218+
llvm::ArrayRef<int32_t> weights = ifOp.getWeights();
219+
if (!weights.empty())
220+
branchOp.setWeights(weights);
218221
rewriter.replaceOp(ifOp, continueBlock->getArguments());
219222
return success();
220223
}

flang/test/Fir/cfg-conversion-if.fir

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
// RUN: fir-opt --split-input-file --cfg-conversion %s | FileCheck %s
2+
3+
func.func private @callee() -> none
4+
5+
// CHECK-LABEL: func.func @if_then(
6+
// CHECK-SAME: %[[ARG0:.*]]: i1) {
7+
// CHECK: cf.cond_br %[[ARG0]] weights([10, 90]), ^bb1, ^bb2
8+
// CHECK: ^bb1:
9+
// CHECK: %[[VAL_0:.*]] = fir.call @callee() : () -> none
10+
// CHECK: cf.br ^bb2
11+
// CHECK: ^bb2:
12+
// CHECK: return
13+
// CHECK: }
14+
func.func @if_then(%cond: i1) {
15+
fir.if %cond weights([10, 90]) {
16+
fir.call @callee() : () -> none
17+
}
18+
return
19+
}
20+
21+
// -----
22+
23+
// CHECK-LABEL: func.func @if_then_else(
24+
// CHECK-SAME: %[[ARG0:.*]]: i1) -> i32 {
25+
// CHECK: %[[VAL_0:.*]] = arith.constant 0 : i32
26+
// CHECK: %[[VAL_1:.*]] = arith.constant 1 : i32
27+
// CHECK: cf.cond_br %[[ARG0]] weights([90, 10]), ^bb1, ^bb2
28+
// CHECK: ^bb1:
29+
// CHECK: cf.br ^bb3(%[[VAL_0]] : i32)
30+
// CHECK: ^bb2:
31+
// CHECK: cf.br ^bb3(%[[VAL_1]] : i32)
32+
// CHECK: ^bb3(%[[VAL_2:.*]]: i32):
33+
// CHECK: cf.br ^bb4
34+
// CHECK: ^bb4:
35+
// CHECK: return %[[VAL_2]] : i32
36+
// CHECK: }
37+
func.func @if_then_else(%cond: i1) -> i32 {
38+
%c0 = arith.constant 0 : i32
39+
%c1 = arith.constant 1 : i32
40+
%result = fir.if %cond weights([90, 10]) -> i32 {
41+
fir.result %c0 : i32
42+
} else {
43+
fir.result %c1 : i32
44+
}
45+
return %result : i32
46+
}

flang/test/Fir/fir-ops.fir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,3 +1015,19 @@ func.func @test_box_total_elements(%arg0: !fir.class<!fir.type<sometype{i:i32}>>
10151015
%6 = arith.addi %2, %5 : index
10161016
return %6 : index
10171017
}
1018+
1019+
// CHECK-LABEL: func.func @test_if_weights(
1020+
// CHECK-SAME: %[[ARG0:.*]]: i1) {
1021+
func.func @test_if_weights(%cond: i1) {
1022+
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
1023+
// CHECK: }
1024+
fir.if %cond weights([99, 1]) {
1025+
}
1026+
// CHECK: fir.if %[[ARG0]] weights([99, 1]) {
1027+
// CHECK: } else {
1028+
// CHECK: }
1029+
fir.if %cond weights ([99,1]) {
1030+
} else {
1031+
}
1032+
return
1033+
}

flang/test/Fir/invalid.fir

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,3 +1393,31 @@ fir.local {type = local_init} @x.localizer : f32 init {
13931393
^bb0(%arg0: f32, %arg1: f32):
13941394
fir.yield(%arg0 : f32)
13951395
}
1396+
1397+
// -----
1398+
1399+
func.func @wrong_weights_number_in_if_then(%cond: i1) {
1400+
// expected-error @below {{expects number of region weights to match number of regions: 1 vs 2}}
1401+
fir.if %cond weights([50]) {
1402+
}
1403+
return
1404+
}
1405+
1406+
// -----
1407+
1408+
func.func @wrong_weights_number_in_if_then_else(%cond: i1) {
1409+
// expected-error @below {{expects number of region weights to match number of regions: 3 vs 2}}
1410+
fir.if %cond weights([50, 40, 10]) {
1411+
} else {
1412+
}
1413+
return
1414+
}
1415+
1416+
// -----
1417+
1418+
func.func @negative_weight_in_if_then(%cond: i1) {
1419+
// expected-error @below {{weight #0 must be non-negative}}
1420+
fir.if %cond weights([-1, 101]) {
1421+
}
1422+
return
1423+
}

mlir/include/mlir/Dialect/ControlFlow/IR/ControlFlowOps.td

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,10 +112,11 @@ def BranchOp : CF_Op<"br", [
112112
// CondBranchOp
113113
//===----------------------------------------------------------------------===//
114114

115-
def CondBranchOp : CF_Op<"cond_br",
116-
[AttrSizedOperandSegments,
117-
DeclareOpInterfaceMethods<BranchOpInterface, ["getSuccessorForOperands"]>,
118-
Pure, Terminator]> {
115+
def CondBranchOp
116+
: CF_Op<"cond_br", [AttrSizedOperandSegments,
117+
DeclareOpInterfaceMethods<
118+
BranchOpInterface, ["getSuccessorForOperands"]>,
119+
WeightedBranchOpInterface, Pure, Terminator]> {
119120
let summary = "Conditional branch operation";
120121
let description = [{
121122
The `cf.cond_br` terminator operation represents a conditional branch on a
@@ -144,20 +145,23 @@ def CondBranchOp : CF_Op<"cond_br",
144145
```
145146
}];
146147

147-
let arguments = (ins I1:$condition,
148-
Variadic<AnyType>:$trueDestOperands,
149-
Variadic<AnyType>:$falseDestOperands);
148+
let arguments = (ins I1:$condition, Variadic<AnyType>:$trueDestOperands,
149+
Variadic<AnyType>:$falseDestOperands,
150+
OptionalAttr<DenseI32ArrayAttr>:$branch_weights);
150151
let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest);
151152

152-
let builders = [
153-
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
154-
"ValueRange":$trueOperands, "Block *":$falseDest,
155-
"ValueRange":$falseOperands), [{
156-
build($_builder, $_state, condition, trueOperands, falseOperands, trueDest,
153+
let builders = [OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
154+
"ValueRange":$trueOperands,
155+
"Block *":$falseDest,
156+
"ValueRange":$falseOperands),
157+
[{
158+
build($_builder, $_state, condition, trueOperands, falseOperands, /*branch_weights=*/{}, trueDest,
157159
falseDest);
158160
}]>,
159-
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
160-
"Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{
161+
OpBuilder<(ins "Value":$condition, "Block *":$trueDest,
162+
"Block *":$falseDest,
163+
CArg<"ValueRange", "{}">:$falseOperands),
164+
[{
161165
build($_builder, $_state, condition, trueDest, ValueRange(), falseDest,
162166
falseOperands);
163167
}]>];
@@ -216,7 +220,7 @@ def CondBranchOp : CF_Op<"cond_br",
216220

217221
let hasCanonicalizer = 1;
218222
let assemblyFormat = [{
219-
$condition `,`
223+
$condition (`weights` `(` $branch_weights^ `)` )? `,`
220224
$trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,`
221225
$falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)?
222226
attr-dict

mlir/include/mlir/Dialect/LLVMIR/LLVMInterfaces.td

Lines changed: 0 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -168,42 +168,6 @@ def NonNegFlagInterface : OpInterface<"NonNegFlagInterface"> {
168168
];
169169
}
170170

171-
def BranchWeightOpInterface : OpInterface<"BranchWeightOpInterface"> {
172-
let description = [{
173-
An interface for operations that can carry branch weights metadata. It
174-
provides setters and getters for the operation's branch weights attribute.
175-
The default implementation of the interface methods expect the operation to
176-
have an attribute of type DenseI32ArrayAttr named branch_weights.
177-
}];
178-
179-
let cppNamespace = "::mlir::LLVM";
180-
181-
let methods = [
182-
InterfaceMethod<
183-
/*desc=*/ "Returns the branch weights attribute or nullptr",
184-
/*returnType=*/ "::mlir::DenseI32ArrayAttr",
185-
/*methodName=*/ "getBranchWeightsOrNull",
186-
/*args=*/ (ins),
187-
/*methodBody=*/ [{}],
188-
/*defaultImpl=*/ [{
189-
auto op = cast<ConcreteOp>(this->getOperation());
190-
return op.getBranchWeightsAttr();
191-
}]
192-
>,
193-
InterfaceMethod<
194-
/*desc=*/ "Sets the branch weights attribute",
195-
/*returnType=*/ "void",
196-
/*methodName=*/ "setBranchWeights",
197-
/*args=*/ (ins "::mlir::DenseI32ArrayAttr":$attr),
198-
/*methodBody=*/ [{}],
199-
/*defaultImpl=*/ [{
200-
auto op = cast<ConcreteOp>(this->getOperation());
201-
op.setBranchWeightsAttr(attr);
202-
}]
203-
>
204-
];
205-
}
206-
207171
def AccessGroupOpInterface : OpInterface<"AccessGroupOpInterface"> {
208172
let description = [{
209173
An interface for memory operations that can carry access groups metadata.

mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -660,12 +660,12 @@ def LLVM_FPTruncOp : LLVM_CastOp<"fptrunc", "FPTrunc",
660660
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
661661

662662
// Call-related operations.
663-
def LLVM_InvokeOp : LLVM_Op<"invoke", [
664-
AttrSizedOperandSegments,
665-
DeclareOpInterfaceMethods<BranchOpInterface>,
666-
DeclareOpInterfaceMethods<CallOpInterface>,
667-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
668-
Terminator]> {
663+
def LLVM_InvokeOp
664+
: LLVM_Op<"invoke", [AttrSizedOperandSegments,
665+
DeclareOpInterfaceMethods<BranchOpInterface>,
666+
DeclareOpInterfaceMethods<CallOpInterface>,
667+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
668+
Terminator]> {
669669
let arguments = (ins
670670
OptionalAttr<TypeAttrOf<LLVM_FunctionType>>:$var_callee_type,
671671
OptionalAttr<FlatSymbolRefAttr>:$callee,
@@ -734,12 +734,12 @@ def LLVM_VaArgOp : LLVM_Op<"va_arg"> {
734734
// CallOp
735735
//===----------------------------------------------------------------------===//
736736

737-
def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
738-
[AttrSizedOperandSegments,
739-
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
740-
DeclareOpInterfaceMethods<CallOpInterface>,
741-
DeclareOpInterfaceMethods<SymbolUserOpInterface>,
742-
DeclareOpInterfaceMethods<BranchWeightOpInterface>]> {
737+
def LLVM_CallOp
738+
: LLVM_MemAccessOpBase<
739+
"call", [AttrSizedOperandSegments,
740+
DeclareOpInterfaceMethods<FastmathFlagsInterface>,
741+
DeclareOpInterfaceMethods<CallOpInterface>,
742+
DeclareOpInterfaceMethods<SymbolUserOpInterface>]> {
743743
let summary = "Call to an LLVM function.";
744744
let description = [{
745745
In LLVM IR, functions may return either 0 or 1 value. LLVM IR dialect
@@ -788,21 +788,16 @@ def LLVM_CallOp : LLVM_MemAccessOpBase<"call",
788788
OptionalAttr<FlatSymbolRefAttr>:$callee,
789789
Variadic<LLVM_Type>:$callee_operands,
790790
DefaultValuedAttr<LLVM_FastmathFlagsAttr, "{}">:$fastmathFlags,
791-
OptionalAttr<DenseI32ArrayAttr>:$branch_weights,
792791
DefaultValuedAttr<CConv, "CConv::C">:$CConv,
793792
DefaultValuedAttr<TailCallKind, "TailCallKind::None">:$TailCallKind,
794793
OptionalAttr<LLVM_MemoryEffectsAttr>:$memory_effects,
795-
UnitAttr:$convergent,
796-
UnitAttr:$no_unwind,
797-
UnitAttr:$will_return,
794+
UnitAttr:$convergent, UnitAttr:$no_unwind, UnitAttr:$will_return,
798795
VariadicOfVariadic<LLVM_Type, "op_bundle_sizes">:$op_bundle_operands,
799796
DenseI32ArrayAttr:$op_bundle_sizes,
800797
OptionalAttr<ArrayAttr>:$op_bundle_tags,
801798
OptionalAttr<DictArrayAttr>:$arg_attrs,
802-
OptionalAttr<DictArrayAttr>:$res_attrs,
803-
UnitAttr:$no_inline,
804-
UnitAttr:$always_inline,
805-
UnitAttr:$inline_hint);
799+
OptionalAttr<DictArrayAttr>:$res_attrs, UnitAttr:$no_inline,
800+
UnitAttr:$always_inline, UnitAttr:$inline_hint);
806801
// Append the aliasing related attributes defined in LLVM_MemAccessOpBase.
807802
let arguments = !con(args, aliasAttrs);
808803
let results = (outs Optional<LLVM_Type>:$result);
@@ -1047,11 +1042,12 @@ def LLVM_BrOp : LLVM_TerminatorOp<"br",
10471042
LLVM_TerminatorPassthroughOpBuilder
10481043
];
10491044
}
1050-
def LLVM_CondBrOp : LLVM_TerminatorOp<"cond_br",
1051-
[AttrSizedOperandSegments,
1052-
DeclareOpInterfaceMethods<BranchOpInterface>,
1053-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
1054-
Pure]> {
1045+
def LLVM_CondBrOp
1046+
: LLVM_TerminatorOp<
1047+
"cond_br", [AttrSizedOperandSegments,
1048+
DeclareOpInterfaceMethods<BranchOpInterface>,
1049+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
1050+
Pure]> {
10551051
let arguments = (ins I1:$condition,
10561052
Variadic<LLVM_Type>:$trueDestOperands,
10571053
Variadic<LLVM_Type>:$falseDestOperands,
@@ -1136,11 +1132,12 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable"> {
11361132
}];
11371133
}
11381134

1139-
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
1140-
[AttrSizedOperandSegments,
1141-
DeclareOpInterfaceMethods<BranchOpInterface>,
1142-
DeclareOpInterfaceMethods<BranchWeightOpInterface>,
1143-
Pure]> {
1135+
def LLVM_SwitchOp
1136+
: LLVM_TerminatorOp<
1137+
"switch", [AttrSizedOperandSegments,
1138+
DeclareOpInterfaceMethods<BranchOpInterface>,
1139+
DeclareOpInterfaceMethods<WeightedBranchOpInterface>,
1140+
Pure]> {
11441141
let arguments = (ins
11451142
AnySignlessInteger:$value,
11461143
Variadic<AnyType>:$defaultOperands,

0 commit comments

Comments
 (0)