Skip to content

Commit 7b007c0

Browse files
[mlir][tosa-to-linalg] Add acc_type lowering Support (#134267)
Add support for lowering of convolution operations where the `acc_type` attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types. Add lit tests for the fp16 convolution with fp32 accumulator operators described above. Signed-off-by: Jack Frankland <[email protected]>
1 parent c9157d4 commit 7b007c0

File tree

2 files changed

+77
-13
lines changed

2 files changed

+77
-13
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
119119
}
120120

121121
// Broadcast the source value to all the outer dimensions of the result value.
122-
// If required, the element type is expanded using an arith.extsi operation.
123-
static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
124-
Location loc, Value source,
125-
Value result) {
122+
// If required, the element type is expanded using an arith.extsi or arith.extf
123+
// operation as appropriate.
124+
static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
125+
Location loc, Value source,
126+
Value result) {
126127
ShapedType resultTy = cast<ShapedType>(result.getType());
127128
const int64_t resultRank = resultTy.getRank();
128129
// Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
135136
.create<linalg::GenericOp>(
136137
loc, resultTy, ValueRange({source}), result, indexingMaps,
137138
getNParallelLoopsAttrs(resultTy.getRank()),
138-
[](OpBuilder &builder, Location loc, ValueRange args) {
139+
[&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
139140
Value biasVal = args[0];
140141
Type resType = args[1].getType();
141142
if (resType != biasVal.getType()) {
142-
biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
143+
biasVal =
144+
resultTy.getElementType().isFloat()
145+
? builder.create<arith::ExtFOp>(loc, resType, biasVal)
146+
.getResult()
147+
: builder.create<arith::ExtSIOp>(loc, resType, biasVal)
148+
.getResult();
143149
}
144150
builder.create<linalg::YieldOp>(loc, biasVal);
145151
})
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
253259
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
254260

255261
Type inputETy = inputTy.getElementType();
256-
Type resultETy = resultTy.getElementType();
257262

258263
DenseI64ArrayAttr padAttr = op.getPadAttr();
259264
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
260265
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
261266

267+
Type accETy = op.getAccType();
268+
Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
269+
262270
// Get and verify zero points.
263271
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
264272
if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
385393
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
386394

387395
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
388-
loc, resultTy.getShape(), resultETy, filteredDims);
396+
loc, resultTy.getShape(), accETy, filteredDims);
389397

390398
Value broadcastBias =
391-
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
399+
linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
392400

393401
if (hasZp) {
394402
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
410418

411419
Value conv = rewriter
412420
.create<LinalgConvOp>(
413-
loc, resultTy, ValueRange{input, weight},
421+
loc, accTy, ValueRange{input, weight},
414422
ValueRange{broadcastBias}, strideAttr, dilationAttr)
415423
->getResult(0);
416424

425+
// We may need to truncate back to the result type if the accumulator was
426+
// wider than the result.
427+
if (resultTy != accTy)
428+
conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
429+
417430
rewriter.replaceOp(op, conv);
418431
return success();
419432
}
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
444457
auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
445458
auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
446459

460+
Type accETy = op.getAccType();
461+
447462
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
448463
return rewriter.notifyMatchFailure(
449464
op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
516531
ShapedType linalgConvTy =
517532
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
518533
weightShape[2], weightShape[3]},
519-
resultETy);
534+
accETy);
520535

521-
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
536+
auto resultZeroAttr = rewriter.getZeroAttr(accETy);
522537
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
523-
loc, linalgConvTy.getShape(), resultETy, filteredDims);
538+
loc, linalgConvTy.getShape(), accETy, filteredDims);
524539
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
525540
Value zeroTensor = rewriter
526541
.create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
543558
ValueRange{zeroTensor}, strideAttr, dilationAttr)
544559
.getResult(0);
545560

561+
// We may need to truncate back to the result type if the accumulator was
562+
// wider than the result.
563+
if (accETy != resultETy)
564+
conv = rewriter.create<tosa::CastOp>(
565+
loc,
566+
RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
567+
resultETy),
568+
conv);
569+
546570
SmallVector<ReassociationExprs, 4> reassociationMap;
547571
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
548572
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
658658

659659
// -----
660660

661+
// CHECK-LABEL: @conv2d_f16_f32_acc
662+
func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
663+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
664+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
665+
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
666+
// CHECK: arith.extf %{{.*}} : f16 to f32
667+
// CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
668+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
669+
%0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
670+
return
671+
}
672+
673+
// -----
674+
661675
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
662676
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
663677

@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
848862

849863
// -----
850864

865+
// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
866+
func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
867+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
868+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
869+
// CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
870+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
871+
%2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
872+
return
873+
}
874+
875+
// -----
876+
851877
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
852878
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
853879

@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
918944

919945
// -----
920946

947+
// CHECK-LABEL: @conv3d_f16_f32_acc
948+
func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
949+
%input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
950+
%weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
951+
// CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
952+
// CHECK: arith.extf %{{.*}} : f16 to f32
953+
// CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
954+
// CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
955+
%0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
956+
return
957+
}
958+
959+
// -----
960+
921961
// CHECK-LABEL: @test_transpose
922962
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
923963
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {

0 commit comments

Comments
 (0)