Skip to content

Commit 059da56

Browse files
committed
[MLIR][Lianlg] Add memory effects for softmax
Reviewed By: qcolombet Differential Revision: https://reviews.llvm.org/D157629
1 parent 6725a6b commit 059da56

File tree

3 files changed

+19
-0
lines changed

3 files changed

+19
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
9595
PredOpTrait<"input and output have same element type", TCopVTEtIsSameAs<0, 1>>,
9696
DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
9797
DeclareOpInterfaceMethods<AggregatedOpInterface, ["decomposeOperation"]>,
98+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
9899
DeclareOpInterfaceMethods<TilingInterface,
99100
["getIterationDomain",
100101
"getLoopIteratorTypes",

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2345,6 +2345,13 @@ SoftmaxOp::reifyResultShapes(OpBuilder &b,
23452345
.reifyResultShapes(b, reifiedReturnShapes);
23462346
}
23472347

2348+
void SoftmaxOp::getEffects(
2349+
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2350+
&effects) {
2351+
getGenericEffectsImpl(effects, getOperation()->getResults(),
2352+
getDpsInputOperands(), getDpsInitOperands());
2353+
}
2354+
23482355
// Helper functions for softmax decomposition.
23492356
// @{
23502357

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -897,3 +897,14 @@ func.func @cast_producer_mixed(%arg0 : tensor<5xf32>, %arg1: memref<?xf32>) {
897897
// CHECK-SAME: iterator_types = ["parallel"]
898898
// CHECK-SAME: } ins(%[[ARG1]] : tensor<5xf32>)
899899
// CHECK-SAME: outs(%[[ARG2]] : memref<?xf32>) {
900+
901+
// -----
902+
903+
// CHECK-LABEL: dead_softmax
904+
func.func @dead_softmax(%arg0: tensor<16x64x256xf32>) -> tensor<16x64x256xf32> {
905+
%0 = tensor.empty() : tensor<16x64x256xf32>
906+
// CHECK-NOT: linalg.softmax
907+
%1 = linalg.softmax dimension(1)
908+
ins(%arg0 : tensor<16x64x256xf32>) outs(%0 : tensor<16x64x256xf32>) -> tensor<16x64x256xf32>
909+
return %arg0 : tensor<16x64x256xf32>
910+
}

0 commit comments

Comments
 (0)