Skip to content

Commit f480e97

Browse files
committed
[mlir][linalg] Fix for bias handling for Winograd
Patch adds handing of bias to Winograd output transform op decompositon Signed-off-by: Dmitriy Smirnov <[email protected]>
1 parent 1911a50 commit f480e97

File tree

3 files changed

+32
-5
lines changed

3 files changed

+32
-5
lines changed

mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,9 +837,24 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
837837
Value widthOffset =
838838
builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
839839

840+
Value outInitVal =
841+
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
842+
widthOffset, retRows, retCols,
843+
/*loopNorFIdx=*/0,
844+
/*loopCorFIdx=*/3, /*heightIdx=*/1,
845+
/*widthIdx=*/2);
846+
Value outVal =
847+
builder
848+
.create<linalg::AddOp>(
849+
loc, outInitVal.getType(), ValueRange{matmulRetValue, outInitVal},
850+
ValueRange{builder.create<tensor::EmptyOp>(
851+
loc, llvm::cast<ShapedType>(outInitVal.getType()).getShape(),
852+
elementType)})
853+
.getResult(0);
854+
840855
// Insert (H, W) to (N, H, W, F).
841856
Value combinedVal =
842-
insert2DDataTo4D(builder, loc, matmulRetValue, args[0], NIter, FIter,
857+
insert2DDataTo4D(builder, loc, outVal, args[0], NIter, FIter,
843858
heightOffset, widthOffset, retRows, retCols,
844859
/*loopNorFIdx=*/0,
845860
/*loopCorFIdx=*/3, /*heightIdx=*/1,

mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,10 @@ module attributes {transform.with_named_sequence} {
109109
// CHECK: linalg.yield %[[IN]] : f32
110110
// CHECK: } -> tensor<4x4xf32>
111111
// CHECK: %[[S24:.*]] = linalg.mul ins(%[[S23]], %[[S21]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S22]] : tensor<4x4xf32>) -> tensor<4x4xf32>
112-
// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
112+
// CHECK: %[[S25:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
113+
// CHECK: %[[S26:.*]] = tensor.empty() : tensor<4x4xf32>
114+
// CHECK: %[[S27:.*]] = linalg.add ins(%[[S24]], %[[S25]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S26]] : tensor<4x4xf32>) -> tensor<4x4xf32>
115+
// CHECK: %[[INSERTED_SLICE_9:.*]] = tensor.insert_slice %[[S27]] into %[[ARG10]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1]
113116
// CHECK: scf.yield %[[INSERTED_SLICE_9]]
114117
// CHECK: scf.yield %[[S15]]
115118
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
@@ -243,7 +246,10 @@ module attributes {transform.with_named_sequence} {
243246
// CHECK: linalg.yield %[[IN]] : f32
244247
// CHECK: } -> tensor<4x4xf32>
245248
// CHECK: %[[S25:.*]] = linalg.mul ins(%[[S24]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
246-
// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S25]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
249+
// CHECK: %[[S26:.*]] = tensor.extract_slice %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
250+
// CHECK: %[[S27:.*]] = tensor.empty() : tensor<4x4xf32>
251+
// CHECK: %[[S28:.*]] = linalg.add ins(%[[S25]], %[[S26]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S27]] : tensor<4x4xf32>) -> tensor<4x4xf32>
252+
// CHECK: %[[INSERTED_SLICE_12:.*]] = tensor.insert_slice %[[S28]] into %[[ARG11]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 4, 4, 1] [1, 1, 1, 1]
247253
// CHECK: scf.yield %[[INSERTED_SLICE_12]]
248254
// CHECK: scf.yield %[[S15]] : tensor<2x4x4x2xf32>
249255
// CHECK: %[[S13:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
@@ -339,7 +345,10 @@ module attributes {transform.with_named_sequence} {
339345
// CHECK: linalg.yield %[[IN]] : f32
340346
// CHECK: } -> tensor<4x1xf32>
341347
// CHECK: %[[S14:.*]] = linalg.mul ins(%[[S13]], %[[S11]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S12]] : tensor<4x1xf32>) -> tensor<4x1xf32>
342-
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S14]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
348+
// CHECK: %[[S15:.*]] = tensor.extract_slice %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
349+
// CHECK: %[[S16:.*]] = tensor.empty() : tensor<4x1xf32>
350+
// CHECK: %[[S17:.*]] = linalg.add ins(%[[S14]], %[[S15]] : tensor<4x1xf32>, tensor<4x1xf32>) outs(%[[S16]] : tensor<4x1xf32>) -> tensor<4x1xf32>
351+
// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S17]] into %[[ARG6]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 4, 1, 1] [1, 1, 1, 1]
343352
// CHECK: scf.yield %[[INSERTED_SLICE]]
344353
// CHECK: scf.yield %[[S7]]
345354
// CHECK: return %[[S6]]

mlir/test/Dialect/Linalg/winograd-conv2d-rewrite.mlir

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,10 @@ func.func @conv2d(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg
114114
// CHECK-NEXT: %[[S19:.*]] = linalg.mul ins(%[[S18]], %[[S16]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S17]] : tensor<4x4xf32>) -> tensor<4x4xf32>
115115
// CHECK-NEXT: %[[S20:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
116116
// CHECK-NEXT: %[[S21:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
117-
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S19]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
117+
// CHECK-NEXT: %[[S22:.*]] = tensor.extract_slice %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<2x12x12x2xf32> to tensor<4x4xf32>
118+
// CHECK-NEXT: %[[S23:.*]] = tensor.empty() : tensor<4x4xf32>
119+
// CHECK-NEXT: %[[S24:.*]] = linalg.add ins(%[[S19]], %[[S22]] : tensor<4x4xf32>, tensor<4x4xf32>) outs(%[[S23]] : tensor<4x4xf32>) -> tensor<4x4xf32>
120+
// CHECK-NEXT: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S24]] into %[[ARG10]][%[[ARG7]], %[[S20]], %[[S21]], %[[ARG9]]] [1, 4, 4, 1] [1, 1, 1, 1] : tensor<4x4xf32> into tensor<2x12x12x2xf32>
118121
// CHECK-NEXT: scf.yield %[[INSERTED_SLICE]] : tensor<2x12x12x2xf32>
119122
// CHECK-NEXT: }
120123
// CHECK-NEXT: scf.yield %[[S9]] : tensor<2x12x12x2xf32>

0 commit comments

Comments
 (0)