Skip to content

Commit 4142932

Browse files
author
MaheshRavishankar
committed
[mlir][Linalg] Move named op conversions out of canonicalizations.
These conversions are better suited to be applied at whole tensor level. Applying these as canonicalizations end up triggering such canonicalizations at all levels of the stack which might be undesirable. For example some of the resulting code patterns wont bufferize in-place and need additional stack buffers. Best is to be more deliberate in when these canonicalizations apply. Differential Revision: https://reviews.llvm.org/D115912
1 parent bee5bc9 commit 4142932

File tree

9 files changed

+202
-139
lines changed

9 files changed

+202
-139
lines changed

mlir/include/mlir/Dialect/Linalg/Passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ std::unique_ptr<Pass> createLinalgFoldUnitExtentDimsPass();
2626
std::unique_ptr<Pass> createLinalgElementwiseOpFusionPass();
2727
std::unique_ptr<Pass> createFoldReshapeOpsByLinearizationPass();
2828

29+
std::unique_ptr<Pass> createLinalgNamedOpConversionPass();
30+
2931
std::unique_ptr<OperationPass<FuncOp>> createLinalgTilingPass(
3032
ArrayRef<int64_t> tileSizes = {},
3133
linalg::LinalgTilingLoopType loopType = linalg::LinalgTilingLoopType::Loops,

mlir/include/mlir/Dialect/Linalg/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,12 @@ def LinalgFoldReshapeOpsByLinearization :
100100
let dependentDialects = ["AffineDialect", "memref::MemRefDialect"];
101101
}
102102

103+
def LinalgNamedOpConversion: Pass<"linalg-named-op-conversion"> {
104+
let summary = "Convert from one named linalg op to another.";
105+
let constructor = "mlir::createLinalgNamedOpConversionPass()";
106+
let dependentDialects = ["linalg::LinalgDialect", "tensor::TensorDialect"];
107+
}
108+
103109
def LinalgLowerTiledLoopsToSCF
104110
: FunctionPass<"convert-linalg-tiled-loops-to-scf"> {
105111
let summary = "Lower linalg tiled loops to SCF loops and parallel loops";

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ void populateFoldReshapeOpsByLinearizationPatterns(RewritePatternSet &patterns);
8686
void populateFoldUnitDimsReshapeOpsByLinearizationPatterns(
8787
RewritePatternSet &patterns);
8888

89+
/// Patterns to convert from one named op to another. These can be seen as
90+
/// canonicalizations of named ops into another named op.
91+
void populateLinalgNamedOpConversionPatterns(RewritePatternSet &patterns);
92+
8993
/// Populates the given list with patterns to bufferize linalg ops.
9094
void populateLinalgBufferizePatterns(
9195
bufferization::BufferizeTypeConverter &converter,

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -2665,118 +2665,6 @@ struct FoldTensorCastOp : public OpInterfaceRewritePattern<LinalgOp> {
26652665
}
26662666
};
26672667

2668-
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
2669-
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
2670-
}
2671-
2672-
LogicalResult matchAndReplaceDepthwiseConv(Operation *operation, Value input,
2673-
Value kernel, Value iZp, Value kZp,
2674-
Value init, Attribute stride,
2675-
Attribute dilation,
2676-
PatternRewriter &rewriter) {
2677-
Location loc = operation->getLoc();
2678-
auto linalgOp = dyn_cast<LinalgOp>(operation);
2679-
// Exit out on the memref version of this operation.
2680-
if (!linalgOp || !linalgOp.hasTensorSemantics())
2681-
return failure();
2682-
2683-
auto result = operation->getResult(0);
2684-
2685-
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
2686-
auto initTy = init.getType().dyn_cast<RankedTensorType>();
2687-
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
2688-
if (!kernelTy || !initTy || !resultTy)
2689-
return failure();
2690-
2691-
if (kernelTy.getDimSize(3) != 1)
2692-
return failure();
2693-
2694-
// Collapse kernel dims.
2695-
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
2696-
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
2697-
auto newKernelTy = RankedTensorType::get(
2698-
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
2699-
kernelTy.getElementType());
2700-
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
2701-
loc, newKernelTy, kernel, collapsedKernelDims);
2702-
2703-
// Collapse init dims.
2704-
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
2705-
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
2706-
getIndicesVector(3, 5)};
2707-
auto newInitTy =
2708-
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
2709-
initTy.getDimSize(2), initTy.getDimSize(3)},
2710-
initTy.getElementType());
2711-
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
2712-
loc, newInitTy, init, collapsedInitDims);
2713-
2714-
Value newConv;
2715-
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
2716-
newConv = rewriter
2717-
.create<DepthwiseConv2DNhwcHwcOp>(
2718-
loc, newInitTy, ValueRange{input, collapsedKernel},
2719-
ValueRange{collapsedInit}, stride, dilation)
2720-
.getResult(0);
2721-
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
2722-
newConv =
2723-
rewriter
2724-
.create<DepthwiseConv2DNhwcHwcQOp>(
2725-
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
2726-
ValueRange{collapsedInit}, stride, dilation)
2727-
.getResult(0);
2728-
}
2729-
2730-
if (!newConv)
2731-
return failure();
2732-
2733-
// Expand dimensions back out to
2734-
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
2735-
operation, resultTy, newConv, collapsedInitDims);
2736-
return success();
2737-
}
2738-
2739-
struct SimplifyDepthwiseConvOp
2740-
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
2741-
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
2742-
2743-
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
2744-
PatternRewriter &rewriter) const override {
2745-
Operation *operation = op.getOperation();
2746-
Value input = op.getInputOperand(0)->get();
2747-
Value kernel = op.getInputOperand(1)->get();
2748-
Value init = op.getOutputOperand(0)->get();
2749-
2750-
auto stride = op.strides();
2751-
auto dilation = op.dilations();
2752-
2753-
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
2754-
nullptr, init, stride, dilation,
2755-
rewriter);
2756-
}
2757-
};
2758-
2759-
struct SimplifyDepthwiseConvQOp
2760-
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
2761-
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
2762-
2763-
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
2764-
PatternRewriter &rewriter) const override {
2765-
Operation *operation = op.getOperation();
2766-
Value input = op.getInputOperand(0)->get();
2767-
Value kernel = op.getInputOperand(1)->get();
2768-
Value iZp = op.getInputOperand(2)->get();
2769-
Value kZp = op.getInputOperand(3)->get();
2770-
Value init = op.getOutputOperand(0)->get();
2771-
2772-
auto stride = op.strides();
2773-
auto dilation = op.dilations();
2774-
2775-
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
2776-
init, stride, dilation, rewriter);
2777-
}
2778-
};
2779-
27802668
} // namespace
27812669

27822670
#define LINALGOP_FOLDERS(XXX) \
@@ -2798,8 +2686,7 @@ LINALGOP_FOLDERS(GenericOp)
27982686

27992687
void LinalgDialect::getCanonicalizationPatterns(
28002688
RewritePatternSet &results) const {
2801-
results.add<EraseDeadLinalgOp, FoldTensorCastOp, SimplifyDepthwiseConvOp,
2802-
SimplifyDepthwiseConvQOp>(getContext());
2689+
results.add<EraseDeadLinalgOp, FoldTensorCastOp>(getContext());
28032690
}
28042691

28052692
Operation *LinalgDialect::materializeConstant(OpBuilder &builder,

mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
1616
Interchange.cpp
1717
Loops.cpp
1818
LinalgStrategyPasses.cpp
19+
NamedOpConversions.cpp
1920
Promotion.cpp
2021
Tiling.cpp
2122
Transforms.cpp
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
//===- NamedOpConversions.cpp - Implements conversions between named ops --===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements conversions between named ops that can be seens as
10+
// canonicalizations of named ops.
11+
//
12+
//===----------------------------------------------------------------------===//
13+
#include "PassDetail.h"
14+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
15+
#include "mlir/Dialect/Linalg/Passes.h"
16+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
17+
#include "mlir/IR/PatternMatch.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
#include "llvm/ADT/SmallVector.h"
21+
22+
using namespace mlir;
23+
using namespace mlir::linalg;
24+
25+
static llvm::SmallVector<int64_t> getIndicesVector(int start, int end) {
26+
return llvm::to_vector<2>(llvm::seq<int64_t>(start, end));
27+
}
28+
29+
static LogicalResult
30+
matchAndReplaceDepthwiseConv(Operation *operation, Value input, Value kernel,
31+
Value iZp, Value kZp, Value init, Attribute stride,
32+
Attribute dilation, PatternRewriter &rewriter) {
33+
Location loc = operation->getLoc();
34+
auto linalgOp = dyn_cast<LinalgOp>(operation);
35+
// Exit out on the memref version of this operation.
36+
if (!linalgOp || !linalgOp.hasTensorSemantics())
37+
return failure();
38+
39+
auto result = operation->getResult(0);
40+
41+
auto kernelTy = kernel.getType().dyn_cast<RankedTensorType>();
42+
auto initTy = init.getType().dyn_cast<RankedTensorType>();
43+
auto resultTy = result.getType().template dyn_cast<RankedTensorType>();
44+
if (!kernelTy || !initTy || !resultTy)
45+
return failure();
46+
47+
if (kernelTy.getDimSize(3) != 1)
48+
return failure();
49+
50+
// Collapse kernel dims.
51+
SmallVector<ReassociationIndices, 4> collapsedKernelDims = {
52+
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 4)};
53+
auto newKernelTy = RankedTensorType::get(
54+
{kernelTy.getDimSize(0), kernelTy.getDimSize(1), kernelTy.getDimSize(2)},
55+
kernelTy.getElementType());
56+
auto collapsedKernel = rewriter.create<tensor::CollapseShapeOp>(
57+
loc, newKernelTy, kernel, collapsedKernelDims);
58+
59+
// Collapse init dims.
60+
SmallVector<ReassociationIndices, 4> collapsedInitDims = {
61+
getIndicesVector(0, 1), getIndicesVector(1, 2), getIndicesVector(2, 3),
62+
getIndicesVector(3, 5)};
63+
auto newInitTy =
64+
RankedTensorType::get({initTy.getDimSize(0), initTy.getDimSize(1),
65+
initTy.getDimSize(2), initTy.getDimSize(3)},
66+
initTy.getElementType());
67+
auto collapsedInit = rewriter.create<tensor::CollapseShapeOp>(
68+
loc, newInitTy, init, collapsedInitDims);
69+
70+
Value newConv;
71+
if (isa<DepthwiseConv2DNhwcHwcmOp>(operation)) {
72+
newConv = rewriter
73+
.create<DepthwiseConv2DNhwcHwcOp>(
74+
loc, newInitTy, ValueRange{input, collapsedKernel},
75+
ValueRange{collapsedInit}, stride, dilation)
76+
.getResult(0);
77+
} else if (isa<DepthwiseConv2DNhwcHwcmQOp>(operation)) {
78+
newConv =
79+
rewriter
80+
.create<DepthwiseConv2DNhwcHwcQOp>(
81+
loc, newInitTy, ValueRange{input, collapsedKernel, iZp, kZp},
82+
ValueRange{collapsedInit}, stride, dilation)
83+
.getResult(0);
84+
}
85+
86+
if (!newConv)
87+
return failure();
88+
89+
// Expand dimensions back out to
90+
rewriter.replaceOpWithNewOp<tensor::ExpandShapeOp>(
91+
operation, resultTy, newConv, collapsedInitDims);
92+
return success();
93+
}
94+
95+
namespace {
96+
struct SimplifyDepthwiseConvOp
97+
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmOp> {
98+
using OpRewritePattern<DepthwiseConv2DNhwcHwcmOp>::OpRewritePattern;
99+
100+
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmOp op,
101+
PatternRewriter &rewriter) const override {
102+
Operation *operation = op.getOperation();
103+
Value input = op.getInputOperand(0)->get();
104+
Value kernel = op.getInputOperand(1)->get();
105+
Value init = op.getOutputOperand(0)->get();
106+
107+
auto stride = op.strides();
108+
auto dilation = op.dilations();
109+
110+
return matchAndReplaceDepthwiseConv(operation, input, kernel, nullptr,
111+
nullptr, init, stride, dilation,
112+
rewriter);
113+
}
114+
};
115+
116+
struct SimplifyDepthwiseConvQOp
117+
: public OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp> {
118+
using OpRewritePattern<DepthwiseConv2DNhwcHwcmQOp>::OpRewritePattern;
119+
120+
LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcmQOp op,
121+
PatternRewriter &rewriter) const override {
122+
Operation *operation = op.getOperation();
123+
Value input = op.getInputOperand(0)->get();
124+
Value kernel = op.getInputOperand(1)->get();
125+
Value iZp = op.getInputOperand(2)->get();
126+
Value kZp = op.getInputOperand(3)->get();
127+
Value init = op.getOutputOperand(0)->get();
128+
129+
auto stride = op.strides();
130+
auto dilation = op.dilations();
131+
132+
return matchAndReplaceDepthwiseConv(operation, input, kernel, iZp, kZp,
133+
init, stride, dilation, rewriter);
134+
}
135+
};
136+
137+
struct LinalgNamedOpConversionPass
138+
: public LinalgNamedOpConversionBase<LinalgNamedOpConversionPass> {
139+
LinalgNamedOpConversionPass() = default;
140+
LinalgNamedOpConversionPass(const LinalgNamedOpConversionPass &) {}
141+
142+
void runOnOperation() override {
143+
Operation *op = getOperation();
144+
RewritePatternSet patterns(op->getContext());
145+
populateLinalgNamedOpConversionPatterns(patterns);
146+
if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns))))
147+
return signalPassFailure();
148+
}
149+
};
150+
} // namespace
151+
152+
void mlir::linalg::populateLinalgNamedOpConversionPatterns(
153+
RewritePatternSet &patterns) {
154+
patterns.add<SimplifyDepthwiseConvOp, SimplifyDepthwiseConvQOp>(
155+
patterns.getContext());
156+
}
157+
158+
std::unique_ptr<Pass> mlir::createLinalgNamedOpConversionPass() {
159+
return std::make_unique<LinalgNamedOpConversionPass>();
160+
}

mlir/lib/Dialect/Linalg/Transforms/PassDetail.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ namespace memref {
3838
class MemRefDialect;
3939
} // namespace memref
4040

41+
namespace tensor {
42+
class TensorDialect;
43+
} // namespace tensor
44+
4145
namespace vector {
4246
class VectorDialect;
4347
} // namespace vector

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -758,28 +758,3 @@ func @dim_of_tiled_loop_result_no_canonicalize(%arg0: tensor<?x?xf32>, %arg1: te
758758
%r2 = tensor.dim %r, %c0 : tensor<?x?xf32>
759759
return %r2 : index
760760
}
761-
762-
// -----
763-
764-
// CHECK-LABEL: @depthwise_conv
765-
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
766-
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
767-
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
768-
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
769-
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
770-
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
771-
return %0 : tensor<?x?x?x?x1xf32>
772-
}
773-
774-
775-
// -----
776-
777-
// CHECK-LABEL: @depthwise_conv_q
778-
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
779-
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
780-
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
781-
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
782-
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
783-
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
784-
return %0 : tensor<?x?x?x?x1xi32>
785-
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// RUN: mlir-opt %s -linalg-named-op-conversion -split-input-file | FileCheck %s
2+
3+
// CHECK-LABEL: @depthwise_conv
4+
func @depthwise_conv(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x1xf32>, %arg2: tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32> {
5+
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
6+
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
7+
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]] : tensor<?x?x?x?xf32>, tensor<?x?x?xf32>) outs(%[[INIT]] : tensor<?x?x?x?xf32>)
8+
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
9+
%0 = linalg.depthwise_conv_2d_nhwc_hwcm {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x?x?x?xf32>, tensor<?x?x?x1xf32>) outs(%arg2 : tensor<?x?x?x?x1xf32>) -> tensor<?x?x?x?x1xf32>
10+
return %0 : tensor<?x?x?x?x1xf32>
11+
}
12+
13+
14+
// -----
15+
16+
// CHECK-LABEL: @depthwise_conv_q
17+
func @depthwise_conv_q(%arg0: tensor<?x?x?x?xi8>, %arg1: tensor<?x?x?x1xi8>, %arg2: tensor<?x?x?x?x1xi32>, %arg3 : i32, %arg4 : i32) -> tensor<?x?x?x?x1xi32> {
18+
// CHECK-DAG: %[[KERNEL:.+]] = tensor.collapse_shape %arg1 {{\[\[}}0], [1], [2, 3]]
19+
// CHECK-DAG: %[[INIT:.+]] = tensor.collapse_shape %arg2 {{\[\[}}0], [1], [2], [3, 4]]
20+
// CHECK-DAG: %[[CONV:.+]] = linalg.depthwise_conv_2d_nhwc_hwc_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %[[KERNEL]], %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?xi8>, i32, i32) outs(%[[INIT]] : tensor<?x?x?x?xi32>)
21+
// CHECK: %[[OUT:.+]] = tensor.expand_shape %[[CONV]] {{\[\[}}0], [1], [2], [3, 4]]
22+
%0 = linalg.depthwise_conv_2d_nhwc_hwcm_q {dilations = dense<1> : tensor<2xi64>, strides = dense<2> : tensor<2xi64>} ins(%arg0, %arg1, %arg3, %arg4 : tensor<?x?x?x?xi8>, tensor<?x?x?x1xi8>, i32, i32) outs(%arg2 : tensor<?x?x?x?x1xi32>) -> tensor<?x?x?x?x1xi32>
23+
return %0 : tensor<?x?x?x?x1xi32>
24+
}

0 commit comments

Comments
 (0)