-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][Linalg] Fixes for Winograd decomposition and for tiling #123675
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The PR addresses issues with filers 1 x r and r x 1 and with tiling Signed-off-by: Dmitriy Smirnov <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Dmitriy Smirnov (d-smirnov) ChangesThe PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling. Patch is 35.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/123675.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
index e42fd5d2ce13c1..f8df828f74851b 100644
--- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
@@ -155,7 +155,7 @@ def Linalg_SoftmaxOp : Linalg_Op<"softmax",
}
def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
- [AllElementTypesMatch<["filter", "output"]>,
+ [AllElementTypesMatch<["filter", "output"]>, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
@@ -220,12 +220,13 @@ def Linalg_WinogradFilterTransformOp : Linalg_Op<"winograd_filter_transform",
int64_t getFilterCDim() {
return 3;
}
+ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
- [AllElementTypesMatch<["input", "output"]>,
+ [AllElementTypesMatch<["input", "output"]>, DestinationStyleOpInterface,
DeclareOpInterfaceMethods<TilingInterface,
["getIterationDomain",
"getLoopIteratorTypes",
@@ -308,6 +309,7 @@ def Linalg_WinogradInputTransformOp : Linalg_Op<"winograd_input_transform",
int64_t getOutputCDim() {
return 5;
}
+ MutableOperandRange getDpsInitsMutable() { return getOutputMutable(); }
}];
let hasVerifier = 1;
}
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index c13b663dbf05b1..0649fbc8e9549b 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -3063,8 +3063,11 @@ LogicalResult WinogradInputTransformOp::verify() {
int m = getM();
int r = getR();
int64_t tileSize = m + r - 1;
- bool leftTransform = inputH != 1;
- bool rightTransform = inputW != 1;
+
+ auto outputType = cast<ShapedType>(getOutput().getType());
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ bool leftTransform = outputShape[getOutputAlphaHDim()] != 1;
+ bool rightTransform = outputShape[getOutputAlphaWDim()] != 1;
SmallVector<int64_t> expectedOutputShape(6, inputH);
if (ShapedType::isDynamic(inputH)) {
@@ -3073,7 +3076,7 @@ LogicalResult WinogradInputTransformOp::verify() {
} else {
expectedOutputShape[getOutputAlphaHDim()] = leftTransform ? tileSize : 1;
expectedOutputShape[getOutputTileHDim()] =
- leftTransform ? (inputH - (r - 1)) / m : 1;
+ leftTransform ? (inputH - (r - 1)) / m : inputH;
}
if (ShapedType::isDynamic(inputW)) {
expectedOutputShape[getOutputAlphaWDim()] = tileSize;
@@ -3081,13 +3084,11 @@ LogicalResult WinogradInputTransformOp::verify() {
} else {
expectedOutputShape[getOutputAlphaWDim()] = rightTransform ? tileSize : 1;
expectedOutputShape[getOutputTileWDim()] =
- rightTransform ? (inputW - (r - 1)) / m : 1;
+ rightTransform ? (inputW - (r - 1)) / m : inputW;
}
expectedOutputShape[getOutputNDim()] = inputShape[getInputNDim()];
expectedOutputShape[getOutputCDim()] = inputShape[getInputCDim()];
- auto outputType = cast<ShapedType>(getOutput().getType());
- ArrayRef<int64_t> outputShape = outputType.getShape();
if (failed(verifyCompatibleShape(expectedOutputShape, outputShape))) {
return emitOpError("the output shape is not expected");
}
@@ -3124,15 +3125,17 @@ LogicalResult WinogradInputTransformOp::getResultTilePosition(
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
SmallVector<OpFoldResult> &resultSizes) {
IntegerAttr zeroAttr = builder.getI64IntegerAttr(0);
- ShapedType inputType = getInputOperandType();
- ArrayRef<int64_t> inputShape = inputType.getShape();
- int64_t inputH = inputShape[getInputHDim()];
- int64_t inputW = inputShape[getInputWDim()];
+ ShapedType outputType = getOutputOperandType();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t outputAlphaH = outputShape[getOutputAlphaHDim()];
+ int64_t outputAlphaW = outputShape[getOutputAlphaWDim()];
+
int64_t m = getM();
int64_t r = getR();
int64_t alpha = m + r - 1;
- int64_t alphaH = inputH != 1 ? alpha : 1;
- int64_t alphaW = inputW != 1 ? alpha : 1;
+ int64_t alphaH = outputAlphaH != 1 ? alpha : 1;
+ int64_t alphaW = outputAlphaW != 1 ? alpha : 1;
+
IntegerAttr alphaHAttr = builder.getI64IntegerAttr(alphaH);
IntegerAttr alphaWAttr = builder.getI64IntegerAttr(alphaW);
@@ -3165,6 +3168,11 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
int64_t m = getM();
int64_t r = getR();
+ ShapedType outputType = getOutputOperandType();
+ ArrayRef<int64_t> outputShape = outputType.getShape();
+ int64_t alphaH = outputShape[getOutputAlphaHDim()];
+ int64_t alphaW = outputShape[getOutputAlphaWDim()];
+
Location loc = getLoc();
MLIRContext *context = builder.getContext();
auto offsetAffineMap =
@@ -3190,9 +3198,9 @@ WinogradInputTransformOp::getTiledImplementation(OpBuilder &builder,
sliceOffsets.append(
{offsets[getOutputNDim()], offsetH, offsetW, offsets[getOutputCDim()]});
OpFoldResult sizeH =
- inputH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
+ alphaH != 1 ? OpFoldResult(mappedSizeH) : OpFoldResult(oneAttr);
OpFoldResult sizeW =
- inputW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
+ alphaW != 1 ? OpFoldResult(mappedSizeW) : OpFoldResult(oneAttr);
sliceSizes.append(
{sizes[getOutputNDim()], sizeH, sizeW, sizes[getOutputCDim()]});
int64_t inputRank = getInputOperandRank();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
index 79f77822116fd7..f1059ddf5da2cf 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp
@@ -514,12 +514,14 @@ Value inputTransform(RewriterBase &rewriter, Location loc, Value input,
Value CIter = ivs[3];
auto context = builder.getContext();
+
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
- Value widthOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+ Value heightOffset = builder.create<affine::AffineApplyOp>(
+ loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = builder.create<affine::AffineApplyOp>(
+ loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
// Extract (H, W) from (N, H, W, C).
auto extractInput =
@@ -753,12 +755,13 @@ Value outputTransform(RewriterBase &rewriter, Location loc, Value value,
Value zero = builder.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(elementType));
+ auto identityAffineMap = rewriter.getMultiDimIdentityMap(1);
auto affineMap =
AffineMap::get(1, 0, {builder.getAffineDimExpr(0) * m}, context);
- Value heightOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileHIter);
- Value widthOffset =
- builder.create<affine::AffineApplyOp>(loc, affineMap, tileWIter);
+ Value heightOffset = builder.create<affine::AffineApplyOp>(
+ loc, leftTransform ? affineMap : identityAffineMap, tileHIter);
+ Value widthOffset = builder.create<affine::AffineApplyOp>(
+ loc, rightTransform ? affineMap : identityAffineMap, tileWIter);
Value outInitVal =
extract2DDataFrom4D(builder, loc, args[0], NIter, FIter, heightOffset,
@@ -1075,16 +1078,17 @@ FailureOr<Operation *>
decomposeWinogradInputTransformHelper(RewriterBase &rewriter,
linalg::WinogradInputTransformOp op) {
Location loc = op.getLoc();
- Value input = op.getInput();
- auto inputType = cast<ShapedType>(input.getType());
- auto inputShape = inputType.getShape();
- int64_t inputH = inputShape[1];
- int64_t inputW = inputShape[2];
+ Value output = op.getOutput();
+ auto outputType = cast<ShapedType>(output.getType());
+ auto outputShape = outputType.getShape();
+
+ int64_t outputH = outputShape[0];
+ int64_t outputW = outputShape[1];
// For F(m x 1, r x 1), we only need to do left side transform.
- bool leftTransform = inputH != 1;
+ bool leftTransform = outputH != 1;
// For F(1 x m, 1 x r), we only need to do right side transform.
- bool rightTransform = inputW != 1;
+ bool rightTransform = outputW != 1;
Value transformedInput =
inputTransform(rewriter, loc, op.getInput(), op.getOutput(), op.getM(),
op.getR(), leftTransform, rightTransform);
diff --git a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
index 776dc5b748c846..a9af874ea40933 100644
--- a/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
+++ b/mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
@@ -61,13 +61,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S9]]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x2x2x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S3]])
+// CHECK: %[[S4:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG5]])
// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_7:.*]] = tensor.extract_slice %[[ARG6]][0, 0, %[[ARG3]], %[[ARG5]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S12:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[EXTRACTED_SLICE_7]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG9:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG10:.*]] = %[[ARG8]])
// CHECK: %[[EXTRACTED_SLICE_8:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG7]], 0, 0, %[[ARG9]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -195,13 +194,12 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[S9]] : tensor<6x6x5x2xf32>
// CHECK: %[[PADDED:.*]] = tensor.pad %[[ARG0]] low[0, 0, 0, 0] high[0, 3, 3, 0]
// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S3:.*]] = tensor.empty() : tensor<6x6x3x3x2x5xf32>
-// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S3]])
+// CHECK: %[[S4:.*]] = scf.for %[[ARG4:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG5:.*]] = %[[S2]])
// CHECK: %[[S9:.*]] = scf.for %[[ARG6:.*]] = %[[C0]] to %[[C3]] step %[[C1]] iter_args(%[[ARG7:.*]] = %[[ARG5]])
// CHECK: %[[S10:.*]] = affine.apply #[[$MAP0]](%[[ARG4]])
// CHECK: %[[S11:.*]] = affine.apply #[[$MAP0]](%[[ARG6]])
// CHECK: %[[EXTRACTED_SLICE_9:.*]] = tensor.extract_slice %[[PADDED]][0, %[[S10]], %[[S11]], 0] [2, 6, 6, 5] [1, 1, 1, 1]
-// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[S2]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_10:.*]] = tensor.extract_slice %[[ARG7]][0, 0, %[[ARG4]], %[[ARG6]], 0, 0] [6, 6, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
// CHECK: %[[S12:.*]] = scf.for %[[ARG8:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG9:.*]] = %[[EXTRACTED_SLICE_10]])
// CHECK: %[[S13:.*]] = scf.for %[[ARG10:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG11:.*]] = %[[ARG9]])
// CHECK: %[[EXTRACTED_SLICE_11:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE_9]][%[[ARG8]], 0, 0, %[[ARG10]]] [1, 6, 6, 1] [1, 1, 1, 1]
@@ -346,3 +344,106 @@ module attributes {transform.with_named_sequence} {
// CHECK: scf.yield %[[INSERTED_SLICE]]
// CHECK: scf.yield %[[S7]]
// CHECK: return %[[S6]]
+
+// -----
+
+func.func @conv2d_mx1_rx1_2(%arg0: tensor<2x6x2x5xf32>, %arg1: tensor<2x3x1x5xf32>, %arg2: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+ %cst = arith.constant 0.000000e+00 : f32
+ %0 = tensor.empty() : tensor<6x1x5x2xf32>
+ %1 = linalg.winograd_filter_transform m(4) r(3) ins(%arg1 : tensor<2x3x1x5xf32>) outs(%0 : tensor<6x1x5x2xf32>) -> tensor<6x1x5x2xf32>
+ %2 = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+ %3 = linalg.winograd_input_transform m(4) r(3) ins(%arg0 : tensor<2x6x2x5xf32>) outs(%2 : tensor<6x1x1x2x2x5xf32>) -> tensor<6x1x1x2x2x5xf32>
+ %collapsed = tensor.collapse_shape %1 [[0, 1], [2], [3]] : tensor<6x1x5x2xf32> into tensor<6x5x2xf32>
+ %collapsed_0 = tensor.collapse_shape %3 [[0, 1], [2, 3, 4], [5]] : tensor<6x1x1x2x2x5xf32> into tensor<6x4x5xf32>
+ %4 = tensor.empty() : tensor<6x4x2xf32>
+ %5 = linalg.fill ins(%cst : f32) outs(%4 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+ %6 = linalg.batch_matmul ins(%collapsed_0, %collapsed : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%5 : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+ %expanded = tensor.expand_shape %6 [[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2] : tensor<6x4x2xf32> into tensor<6x1x1x2x2x2xf32>
+ %7 = linalg.winograd_output_transform m(4) r(3) ins(%expanded : tensor<6x1x1x2x2x2xf32>) outs(%arg2 : tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32>
+ return %7 : tensor<2x4x2x2xf32>
+}
+
+module attributes {transform.with_named_sequence} {
+ transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
+ %0 = transform.structured.match ops{["linalg.winograd_filter_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %2 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %3, %loop3:2 = transform.structured.tile_using_for %2 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %4 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %arg1 : (!transform.any_op) -> !transform.any_op
+ %5, %loop5:2 = transform.structured.tile_using_for %4 tile_sizes [0, 0, 1, 1, 0, 0] : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+ %7 = transform.structured.decompose_winograd_op %0 : (!transform.any_op) -> (!transform.any_op)
+ %8 = transform.structured.match ops{["linalg.winograd_input_transform"]} in %3 : (!transform.any_op) -> !transform.any_op
+ %9 = transform.structured.decompose_winograd_op %8 : (!transform.any_op) -> (!transform.any_op)
+ %10 = transform.structured.match ops{["linalg.winograd_output_transform"]} in %5 : (!transform.any_op) -> !transform.any_op
+ %11 = transform.structured.decompose_winograd_op %10 : (!transform.any_op) -> (!transform.any_op)
+ transform.yield
+ }
+}
+
+// CHECK: #[[$MAP0:.+]] = affine_map<(d0) -> (d0 * 4)>
+// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1) -> ()>
+// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1) -> (d0, d1)>
+// CHECK-LABEL: func.func @conv2d_mx1_rx1
+// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x2x5xf32>, %[[ARG1:.*]]: tensor<2x3x1x5xf32>, %[[ARG2:.*]]: tensor<2x4x2x2xf32>) -> tensor<2x4x2x2xf32> {
+// CHECK: %[[CST:.*]] = arith.constant 3.200000e+01 : f32
+// CHECK: %[[CST_0:.*]] = arith.constant dense<{{.*}}> : tensor<4x6xf32>
+// CHECK: %[[CST_1:.*]] = arith.constant dense<{{.*}}> : tensor<6x6xf32>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<{{.*}}> : tensor<6x3xf32>
+// CHECK: %[[C1:.*]] = arith.constant 1 : index
+// CHECK: %[[C5:.*]] = arith.constant 5 : index
+// CHECK: %[[C2:.*]] = arith.constant 2 : index
+// CHECK: %[[C0:.*]] = arith.constant 0 : index
+// CHECK: %[[CST_3:.*]] = arith.constant 0.000000e+00 : f32
+// CHECK: %[[S0:.*]] = tensor.empty() : tensor<6x1x5x2xf32>
+// CHECK: %[[S1:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S0]])
+// CHECK: %[[S7:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[ARG4]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG1]][%[[ARG3]], 0, 0, %[[ARG5]]] [1, 3, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S8:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S9:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S8]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S10:.*]] = linalg.matmul ins(%[[CST_2]], %[[EXTRACTED_SLICE]] : tensor<6x3xf32>, tensor<3x1xf32>) outs(%[[S9]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S10]] into %[[ARG6]][0, 0, %[[ARG5]], %[[ARG3]]] [6, 1, 1, 1] [1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: scf.yield %[[S7]]
+// CHECK: %[[S2:.*]] = tensor.empty() : tensor<6x1x1x2x2x5xf32>
+// CHECK: %[[S3:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[S2]])
+// CHECK: %[[S8:.*]] = affine.apply #[[$MAP0]](%[[ARG3]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, 0, %8, 0] [2, 6, 1, 5] [1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[S9:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK: %[[S10:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C5]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][%[[ARG5]], 0, 0, %[[ARG7]]] [1, 6, 1, 1] [1, 1, 1, 1]
+// CHECK: %[[S11:.*]] = tensor.empty() : tensor<6x1xf32>
+// CHECK: %[[S12:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S11]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[S13:.*]] = linalg.matmul ins(%[[CST_1]], %[[EXTRACTED_SLICE_6]] : tensor<6x6xf32>, tensor<6x1xf32>) outs(%[[S12]] : tensor<6x1xf32>) -> tensor<6x1xf32>
+// CHECK: %[[INSERTED_SLICE_7:.*]] = tensor.insert_slice %[[S13]] into %[[ARG8]][0, 0, 0, 0, %[[ARG5]], %[[ARG7]]] [6, 1, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE_7]]
+// CHECK: scf.yield %[[S10]]
+// CHECK: %[[INSERTED_SLICE:.*]] = tensor.insert_slice %[[S9]] into %[[ARG4]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 5] [1, 1, 1, 1, 1, 1]
+// CHECK: scf.yield %[[INSERTED_SLICE]]
+// CHECK: %[[COLLAPSED:.*]] = tensor.collapse_shape %[[S1]] {{\[}}[0, 1], [2], [3]]
+// CHECK: %[[COLLAPSED_4:.*]] = tensor.collapse_shape %[[S3]] {{\[}}[0, 1], [2, 3, 4], [5]]
+// CHECK: %[[S4:.*]] = tensor.empty() : tensor<6x4x2xf32>
+// CHECK: %[[S5:.*]] = linalg.fill ins(%[[CST_3]] : f32) outs(%[[S4]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK: %[[S6:.*]] = linalg.batch_matmul ins(%[[COLLAPSED_4]], %[[COLLAPSED]] : tensor<6x4x5xf32>, tensor<6x5x2xf32>) outs(%[[S5]] : tensor<6x4x2xf32>) -> tensor<6x4x2xf32>
+// CHECK: %[[EXPANDED:.*]] = tensor.expand_shape %[[S6]] {{\[}}[0, 1], [2, 3, 4], [5]] output_shape [6, 1, 1, 2, 2, 2]
+// CHECK: %[[S7:.*]] = scf.for %[[ARG3:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG4:.*]] = %[[ARG2]])
+// CHECK: %[[EXTRACTED_SLICE:.*]] = tensor.extract_slice %[[EXPANDED]][0, 0, 0, %[[ARG3]], 0, 0] [6, 1, 1, 1, 2, 2] [1, 1, 1, 1, 1, 1]
+// CHECK: %[[EXTRACTED_SLICE_5:.*]] = tensor.extract_slice %[[ARG4]][0, 0, 0, 0] [2, 4, 1, 2] [1, 1, 1, 1]
+// CHECK: %[[S8:.*]] = scf.for %[[ARG5:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG6:.*]] = %[[EXTRACTED_SLICE_5]])
+// CHECK: %[[S9:.*]] = scf.for %[[ARG7:.*]] = %[[C0]] to %[[C2]] step %[[C1]] iter_args(%[[ARG8:.*]] = %[[ARG6]])
+// CHECK: %[[EXTRACTED_SLICE_6:.*]] = tensor.extract_slice %[[EXTRACTED_SLICE]][0, 0, 0, 0...
[truncated]
|
mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
Outdated
Show resolved
Hide resolved
mlir/test/Dialect/Linalg/transform-tile-and-winograd-rewrite.mlir
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
The PR addresses issues with the filters of 1 x r and of r x 1 and with the tiling.