Skip to content

Commit fcaf6dd

Browse files
author
Mahesh Ravishankar
committed
[mlir][Transforms] CSE of ops with a single block.
Currently CSE does not support CSE of ops with regions. This patch extends the CSE support to ops with a single region. Differential Revision: https://reviews.llvm.org/D134306 Depends on D137857
1 parent 7d59b33 commit fcaf6dd

File tree

5 files changed

+232
-16
lines changed

5 files changed

+232
-16
lines changed

mlir/lib/IR/OperationSupport.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -721,16 +721,34 @@ bool OperationEquivalence::isEquivalentTo(
721721
ValueRange lhsOperands = lhs->getOperands(), rhsOperands = rhs->getOperands();
722722
SmallVector<Value> lhsOperandStorage, rhsOperandStorage;
723723
if (lhs->hasTrait<mlir::OpTrait::IsCommutative>()) {
724-
lhsOperandStorage.append(lhsOperands.begin(), lhsOperands.end());
725-
llvm::sort(lhsOperandStorage, [](Value a, Value b) -> bool {
726-
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
727-
});
728-
lhsOperands = lhsOperandStorage;
724+
auto sortValues = [](ValueRange values) {
725+
SmallVector<Value> sortedValues = llvm::to_vector(values);
726+
llvm::sort(sortedValues, [](Value a, Value b) {
727+
auto aArg = a.dyn_cast<BlockArgument>();
728+
auto bArg = b.dyn_cast<BlockArgument>();
729+
730+
// Case 1. Both `a` and `b` are `BlockArgument`s.
731+
if (aArg && bArg) {
732+
if (aArg.getParentBlock() == bArg.getParentBlock())
733+
return aArg.getArgNumber() < bArg.getArgNumber();
734+
return aArg.getParentBlock() < bArg.getParentBlock();
735+
}
729736

730-
rhsOperandStorage.append(rhsOperands.begin(), rhsOperands.end());
731-
llvm::sort(rhsOperandStorage, [](Value a, Value b) -> bool {
732-
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
733-
});
737+
// Case 2. One of then is a `BlockArgument` and other is not. Treat
738+
// `BlockArgument` as lesser.
739+
if (aArg && !bArg)
740+
return true;
741+
if (bArg && !aArg)
742+
return false;
743+
744+
// Case 3. Both are values.
745+
return a.getAsOpaquePointer() < b.getAsOpaquePointer();
746+
});
747+
return sortedValues;
748+
};
749+
lhsOperandStorage = sortValues(lhsOperands);
750+
lhsOperands = lhsOperandStorage;
751+
rhsOperandStorage = sortValues(rhsOperands);
734752
rhsOperands = rhsOperandStorage;
735753
}
736754
auto checkValueRangeMapping =

mlir/lib/Transforms/CSE.cpp

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,70 @@ struct SimpleOperationInfo : public llvm::DenseMapInfo<Operation *> {
4747
if (lhs == getTombstoneKey() || lhs == getEmptyKey() ||
4848
rhs == getTombstoneKey() || rhs == getEmptyKey())
4949
return false;
50+
51+
// If op has no regions, operation equivalence w.r.t operands alone is
52+
// enough.
53+
if (lhs->getNumRegions() == 0 && rhs->getNumRegions() == 0) {
54+
return OperationEquivalence::isEquivalentTo(
55+
const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
56+
OperationEquivalence::exactValueMatch,
57+
OperationEquivalence::ignoreValueEquivalence,
58+
OperationEquivalence::IgnoreLocations);
59+
}
60+
61+
// If lhs or rhs does not have a single region with a single block, they
62+
// aren't CSEed for now.
63+
if (lhs->getNumRegions() != 1 || rhs->getNumRegions() != 1 ||
64+
!llvm::hasSingleElement(lhs->getRegion(0)) ||
65+
!llvm::hasSingleElement(rhs->getRegion(0)))
66+
return false;
67+
68+
// Compare the two blocks.
69+
Block &lhsBlock = lhs->getRegion(0).front();
70+
Block &rhsBlock = rhs->getRegion(0).front();
71+
72+
// Don't CSE if number of arguments differ.
73+
if (lhsBlock.getNumArguments() != rhsBlock.getNumArguments())
74+
return false;
75+
76+
// Map to store `Value`s from `lhsBlock` that are equivalent to `Value`s in
77+
// `rhsBlock`. `Value`s from `lhsBlock` are the key.
78+
DenseMap<Value, Value> areEquivalentValues;
79+
for (auto bbArgs : llvm::zip(lhs->getRegion(0).getArguments(),
80+
rhs->getRegion(0).getArguments())) {
81+
areEquivalentValues[std::get<0>(bbArgs)] = std::get<1>(bbArgs);
82+
}
83+
84+
// Helper function to get the parent operation.
85+
auto getParent = [](Value v) -> Operation * {
86+
if (auto blockArg = v.dyn_cast<BlockArgument>())
87+
return blockArg.getParentBlock()->getParentOp();
88+
return v.getDefiningOp()->getParentOp();
89+
};
90+
91+
// Callback to compare if operands of ops in the region of `lhs` and `rhs`
92+
// are equivalent.
93+
auto mapOperands = [&](Value lhsValue, Value rhsValue) -> LogicalResult {
94+
if (lhsValue == rhsValue)
95+
return success();
96+
if (areEquivalentValues.lookup(lhsValue) == rhsValue)
97+
return success();
98+
return failure();
99+
};
100+
101+
// Callback to compare if results of ops in the region of `lhs` and `rhs`
102+
// are equivalent.
103+
auto mapResults = [&](Value lhsResult, Value rhsResult) -> LogicalResult {
104+
if (getParent(lhsResult) == lhs && getParent(rhsResult) == rhs) {
105+
auto insertion = areEquivalentValues.insert({lhsResult, rhsResult});
106+
return success(insertion.first->second == rhsResult);
107+
}
108+
return success();
109+
};
110+
50111
return OperationEquivalence::isEquivalentTo(
51112
const_cast<Operation *>(lhsC), const_cast<Operation *>(rhsC),
52-
/*mapOperands=*/OperationEquivalence::exactValueMatch,
53-
/*mapResults=*/OperationEquivalence::ignoreValueEquivalence,
54-
OperationEquivalence::IgnoreLocations);
113+
mapOperands, mapResults, OperationEquivalence::IgnoreLocations);
55114
}
56115
};
57116
} // namespace
@@ -204,7 +263,8 @@ LogicalResult CSE::simplifyOperation(ScopedMapTy &knownValues, Operation *op,
204263
// Don't simplify operations with nested blocks. We don't currently model
205264
// equality comparisons correctly among other things. It is also unclear
206265
// whether we would want to CSE such operations.
207-
if (op->getNumRegions() != 0)
266+
if (!(op->getNumRegions() == 0 ||
267+
(op->getNumRegions() == 1 && llvm::hasSingleElement(op->getRegion(0)))))
208268
return failure();
209269

210270
// Some simple use case of operation with memory side-effect are dealt with

mlir/test/Dialect/SparseTensor/codegen_buffer_initialization.mlir

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T4]] : memref<16xindex>)
1818
// CHECK: %[[T6:.*]] = memref.alloc() : memref<16xf64>
1919
// CHECK: %[[T7:.*]] = memref.cast %[[T6]] : memref<16xf64> to memref<?xf64>
20-
// CHECK: linalg.fill ins(%{{.*}} : f64) outs(%[[T6]] : memref<16xf64>)
2120
// CHECK: linalg.fill ins(%[[C0]] : index) outs(%[[T1]] : memref<3xindex>)
2221
// CHECK: memref.store %[[A]], %[[T0]][%[[C0]]] : memref<1xindex>
2322
// CHECK: %[[P0:.*]] = sparse_tensor.push_back %[[T1]], %[[T3]]

mlir/test/Transforms/cse.mlir

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,127 @@ func.func @check_cummutative_cse(%a : i32, %b : i32) -> i32 {
322322
%3 = arith.muli %1, %2 : i32
323323
return %3 : i32
324324
}
325+
326+
// Check that an operation with a single region can CSE.
327+
func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
328+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
329+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
330+
^bb0(%arg0 : f32):
331+
test.region_yield %arg0 : f32
332+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
333+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
334+
^bb0(%arg0 : f32):
335+
test.region_yield %arg0 : f32
336+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
337+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
338+
}
339+
// CHECK-LABEL: func @cse_single_block_ops
340+
// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
341+
// CHECK-NOT: test.cse_of_single_block_op
342+
// CHECK: return %[[OP]], %[[OP]]
343+
344+
// Operations with different number of bbArgs dont CSE.
345+
func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
346+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
347+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
348+
^bb0(%arg0 : f32, %arg1 : f32):
349+
test.region_yield %arg0 : f32
350+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
351+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
352+
^bb0(%arg0 : f32):
353+
test.region_yield %arg0 : f32
354+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
355+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
356+
}
357+
// CHECK-LABEL: func @no_cse_varied_bbargs
358+
// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
359+
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
360+
// CHECK: return %[[OP0]], %[[OP1]]
361+
362+
// Operations with different regions dont CSE
363+
func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>)
364+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
365+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
366+
^bb0(%arg0 : f32, %arg1 : f32):
367+
test.region_yield %arg0 : f32
368+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
369+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
370+
^bb0(%arg0 : f32, %arg1 : f32):
371+
test.region_yield %arg1 : f32
372+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
373+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
374+
}
375+
// CHECK-LABEL: func @no_cse_region_difference_simple
376+
// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
377+
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
378+
// CHECK: return %[[OP0]], %[[OP1]]
379+
380+
// Operation with identical region with multiple statements CSE.
381+
func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
382+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
383+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
384+
^bb0(%arg0 : f32, %arg1 : f32):
385+
%1 = arith.divf %arg0, %arg1 : f32
386+
%2 = arith.remf %arg0, %c : f32
387+
%3 = arith.select %d, %1, %2 : f32
388+
test.region_yield %3 : f32
389+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
390+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
391+
^bb0(%arg0 : f32, %arg1 : f32):
392+
%1 = arith.divf %arg0, %arg1 : f32
393+
%2 = arith.remf %arg0, %c : f32
394+
%3 = arith.select %d, %1, %2 : f32
395+
test.region_yield %3 : f32
396+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
397+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
398+
}
399+
// CHECK-LABEL: func @cse_single_block_ops_identical_bodies
400+
// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
401+
// CHECK-NOT: test.cse_of_single_block_op
402+
// CHECK: return %[[OP]], %[[OP]]
403+
404+
// Operation with non-identical regions dont CSE.
405+
func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1)
406+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
407+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
408+
^bb0(%arg0 : f32, %arg1 : f32):
409+
%1 = arith.divf %arg0, %arg1 : f32
410+
%2 = arith.remf %arg0, %c : f32
411+
%3 = arith.select %d, %1, %2 : f32
412+
test.region_yield %3 : f32
413+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
414+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
415+
^bb0(%arg0 : f32, %arg1 : f32):
416+
%1 = arith.divf %arg0, %arg1 : f32
417+
%2 = arith.remf %arg0, %c : f32
418+
%3 = arith.select %d, %2, %1 : f32
419+
test.region_yield %3 : f32
420+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
421+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
422+
}
423+
// CHECK-LABEL: func @no_cse_single_block_ops_different_bodies
424+
// CHECK: %[[OP0:.+]] = test.cse_of_single_block_op
425+
// CHECK: %[[OP1:.+]] = test.cse_of_single_block_op
426+
// CHECK: return %[[OP0]], %[[OP1]]
427+
428+
// Account for commutative ops within regions during CSE.
429+
func.func @cse_single_block_with_commutative_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32)
430+
-> (tensor<?x?xf32>, tensor<?x?xf32>) {
431+
%0 = test.cse_of_single_block_op inputs(%a, %b) {
432+
^bb0(%arg0 : f32, %arg1 : f32):
433+
%1 = arith.addf %arg0, %arg1 : f32
434+
%2 = arith.mulf %1, %c : f32
435+
test.region_yield %2 : f32
436+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
437+
%1 = test.cse_of_single_block_op inputs(%a, %b) {
438+
^bb0(%arg0 : f32, %arg1 : f32):
439+
%1 = arith.addf %arg1, %arg0 : f32
440+
%2 = arith.mulf %c, %1 : f32
441+
test.region_yield %2 : f32
442+
} : tensor<?x?xf32>, tensor<?x?xf32> -> tensor<?x?xf32>
443+
return %0, %1 : tensor<?x?xf32>, tensor<?x?xf32>
444+
}
445+
// CHECK-LABEL: func @cse_single_block_with_commutative_ops
446+
// CHECK: %[[OP:.+]] = test.cse_of_single_block_op
447+
// CHECK-NOT: test.cse_of_single_block_op
448+
// CHECK: return %[[OP]], %[[OP]]

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ def TestProducingBranchOp : TEST_Op<"producing_br",
670670

671671
// Produces an error value on the error path
672672
def TestInternalBranchOp : TEST_Op<"internal_br",
673-
[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
674-
AttrSizedOperandSegments]> {
673+
[DeclareOpInterfaceMethods<BranchOpInterface>, Terminator,
674+
AttrSizedOperandSegments]> {
675675

676676
let arguments = (ins Variadic<AnyType>:$successOperands,
677677
Variadic<AnyType>:$errorOperands);
@@ -3045,4 +3045,19 @@ def RecursivelySpeculatableOp : TEST_Op<"recursively_speculatable_op", [
30453045
let regions = (region SizedRegion<1>:$body);
30463046
}
30473047

3048+
//===---------------------------------------------------------------------===//
3049+
// Test CSE
3050+
//===---------------------------------------------------------------------===//
3051+
3052+
def TestCSEOfSingleBlockOp : TEST_Op<"cse_of_single_block_op",
3053+
[SingleBlockImplicitTerminator<"RegionYieldOp">, Pure]> {
3054+
let arguments = (ins Variadic<AnyType>:$inputs);
3055+
let results = (outs Variadic<AnyType>:$outputs);
3056+
let regions = (region SizedRegion<1>:$region);
3057+
let assemblyFormat = [{
3058+
attr-dict `inputs` `(` $inputs `)`
3059+
$region `:` type($inputs) `->` type($outputs)
3060+
}];
3061+
}
3062+
30483063
#endif // TEST_OPS

0 commit comments

Comments
 (0)