Skip to content

Commit 6dcc864

Browse files
author
Aviad Cohen
committed
[mlir][Linalg]: Optimize any structured linalg operation in transform::PromoteOp to avoid unnecessary copies
Before promotion, there is no need to copy outputs thats are not considered to init tensors.
1 parent ea1909f commit 6dcc864

File tree

4 files changed

+11
-9
lines changed

4 files changed

+11
-9
lines changed

mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,8 @@ LinalgOpInstancePromotionOptions::LinalgOpInstancePromotionOptions(
177177
Operation *op = opOperand.get().getDefiningOp();
178178
if (auto sv = dyn_cast_or_null<memref::SubViewOp>(op)) {
179179
subViews[operandNumber] = sv;
180-
// In case of linalg generic, copy in only if subview is used in linalg
181-
// payload.
182-
if (!isa<linalg::GenericOp>(linalgOp) ||
183-
linalgOp.payloadUsesValueFromOperand(&opOperand))
180+
// Copy in only if subview is being used by the linalg operation.
181+
if (linalgOp.isDpsInput(&opOperand) || !linalgOp.isInitTensor(&opOperand))
184182
operandsNumbersToCopyIn.insert(operandNumber);
185183
useFullTileBuffers[sv] = vUseFullTileBuffers[operandNumber];
186184
}

mlir/test/Dialect/Linalg/promote.mlir

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ func.func @matmul_f32(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
5454

5555
// CHECK: linalg.copy ins(%[[vA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
5656
// CHECK: linalg.copy ins(%[[vB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialB]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
57-
// CHECK: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
57+
// CHECK-NOT: linalg.copy ins(%[[vC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[partialC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
58+
5859
//
5960
// CHECK: linalg.matmul ins(%[[partialA]], %[[partialB]]{{.*}} outs(%[[partialC]]
6061
//
@@ -124,7 +125,8 @@ func.func @matmul_f64(%A: memref<?xi8>, %M: index, %N: index, %K: index) {
124125

125126
// CHECK: linalg.copy ins(%[[vA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialA_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
126127
// CHECK: linalg.copy ins(%[[vB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialB_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
127-
// CHECK: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
128+
// CHECK-NOT: linalg.copy ins(%[[vC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>) outs(%[[partialC_f64]] : memref<?x?xf64, strided<[?, 1], offset: ?>>)
129+
128130
//
129131
// CHECK: linalg.matmul ins(%[[partialA_f64]], %[[partialB_f64]]{{.*}} outs(%[[partialC_f64]]
130132
//
@@ -259,7 +261,8 @@ func.func @promote_rank_reducing_subviews(%arg0: memref<?x?x?x64xf32, strided<[
259261
// CHECK: %[[c_view:.+]] = memref.view
260262
// CHECK: %[[c_pro_subview:.+]] = memref.subview %[[c_view]]
261263

262-
// CHECK-COUNT-3: linalg.copy
264+
265+
// CHECK-COUNT-2: linalg.copy
263266
// CHECK: linalg.generic
264267
// CHECK-SAME: ins(%[[a_pro_subview]], %[[b_pro_subview]]
265268
// CHECK-SAME: outs(%[[c_pro_subview]]

mlir/test/Dialect/Linalg/promotion_options.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func.func @gemm(%a : memref<?x?xf32>, %b : memref<?x?xf32>, %c : memref<?x?xf32>
2828
// CHECK: %[[svCC:.+]] = memref.subview %[[VC]]
2929

3030
// CHECK: linalg.copy ins(%[[svA]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svAA]] : memref<?x?xf32, strided<[16, 1]>>)
31-
// CHECK: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
31+
// CHECK-NOT: linalg.copy ins(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>) outs(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>)
3232
// CHECK: linalg.matmul ins(%[[VA]], %[[svB]]{{.*}} outs(%[[VC]]
3333
// CHECK: linalg.copy ins(%[[svCC]] : memref<?x?xf32, strided<[16, 1]>>) outs(%[[svC]] : memref<?x?xf32, strided<[?, 1], offset: ?>>)
3434
// CHECK: memref.dealloc %[[tmpA]]

mlir/test/Dialect/Linalg/transform-promotion.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,10 @@ func.func @promote_subview_matmul(%arg0: memref<?x?xf32, strided<[?, 1], offset:
5151
// CHECK: %[[v2:.*]] = memref.view %[[a2]]{{.*}} : memref<24000000xi8> to memref<?x?xf32>
5252
// CHECK: %[[l2:.*]] = memref.subview %[[v2]][0, 0] [%{{.*}}, %{{.*}}] [1, 1]
5353
// CHECK-SAME: memref<?x?xf32> to memref<?x?xf32, strided<[?, 1], offset: ?>>
54+
5455
// CHECK: linalg.copy ins(%[[s0]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l0]] : memref<?x?xf32, strided{{.*}}>)
5556
// CHECK: linalg.copy ins(%[[s1]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l1]] : memref<?x?xf32, strided{{.*}}>)
56-
// CHECK: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
57+
// CHECK-NOT: linalg.copy ins(%[[s2]] : memref<?x?xf32, strided{{.*}}>) outs(%[[l2]] : memref<?x?xf32, strided{{.*}}>)
5758
// CHECK: linalg.matmul
5859
// CHECK-SAME: ins(%[[v0]], %[[v1]] : memref<?x?xf32>, memref<?x?xf32>)
5960
// CHECK-SAME: outs(%[[v2]] : memref<?x?xf32>)

0 commit comments

Comments
 (0)