Skip to content

[mlir][tosa-to-linalg] Add acc_type lowering Support #134267

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 37 additions & 13 deletions mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
}

// Broadcast the source value to all the outer dimensions of the result value.
// If required, the element type is expanded using an arith.extsi operation.
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
Location loc, Value source,
Value result) {
// If required, the element type is expanded using an arith.extsi or arith.extf
// operation as appropriate.
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
Location loc, Value source,
Value result) {
ShapedType resultTy = cast<ShapedType>(result.getType());
const int64_t resultRank = resultTy.getRank();
// Creating maps for the input and output of the broacast-like generic op.
Expand All @@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({source}), result, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
[](OpBuilder &builder, Location loc, ValueRange args) {
[&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
biasVal =
resultTy.getElementType().isFloat()
? builder.create<arith::ExtFOp>(loc, resType, biasVal)
.getResult()
: builder.create<arith::ExtSIOp>(loc, resType, biasVal)
.getResult();
}
builder.create<linalg::YieldOp>(loc, biasVal);
})
Expand Down Expand Up @@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());

Type inputETy = inputTy.getElementType();
Type resultETy = resultTy.getElementType();

DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();

Type accETy = op.getAccType();
Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);

// Get and verify zero points.
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
Expand Down Expand Up @@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto dilationAttr = rewriter.getI64TensorAttr(dilation);

Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
loc, resultTy.getShape(), resultETy, filteredDims);
loc, resultTy.getShape(), accETy, filteredDims);

Value broadcastBias =
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);

if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
Expand All @@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {

Value conv = rewriter
.create<LinalgConvOp>(
loc, resultTy, ValueRange{input, weight},
loc, accTy, ValueRange{input, weight},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);

// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (resultTy != accTy)
conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);

rewriter.replaceOp(op, conv);
return success();
}
Expand Down Expand Up @@ -444,6 +457,8 @@ class DepthwiseConvConverter
auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));

Type accETy = op.getAccType();

if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv ops require static shapes");
Expand Down Expand Up @@ -516,11 +531,11 @@ class DepthwiseConvConverter
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
resultETy);
accETy);

auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
auto resultZeroAttr = rewriter.getZeroAttr(accETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, linalgConvTy.getShape(), resultETy, filteredDims);
loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
Expand All @@ -543,6 +558,15 @@ class DepthwiseConvConverter
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);

// We may need to truncate back to the result type if the accumulator was
// wider than the result.
if (accETy != resultETy)
conv = rewriter.create<tosa::CastOp>(
loc,
RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
resultETy),
conv);

SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
Expand Down
40 changes: 40 additions & 0 deletions mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x

// -----

// CHECK-LABEL: @conv2d_f16_f32_acc
func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
// CHECK: arith.extf %{{.*}} : f16 to f32
// CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
// CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
return
}

// -----

// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>

Expand Down Expand Up @@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor

// -----

// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
// CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
return
}

// -----

// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

Expand Down Expand Up @@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5

// -----

// CHECK-LABEL: @conv3d_f16_f32_acc
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
// CHECK: arith.extf %{{.*}} : f16 to f32
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
return
}

// -----

// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
Expand Down