-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa-to-linalg] Add acc_type lowering Support #134267
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add support for lowering of convolution operations where the `acc_type` attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types. Add lit tests for the fp16 convolution with fp32 accumulator operators described above. Signed-off-by: Jack Frankland <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-linalg Author: Jack Frankland (FranklandJack) ChangesAdd support for lowering of convolution operations where the Add lit tests for the fp16 convolution with fp32 accumulator operators described above. Full diff: https://github.com/llvm/llvm-project/pull/134267.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fc1cad2423450..86f5e9baf4a94 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
}
// Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
- Location loc, Value source,
- Value result) {
+// If required, the element type is expanded using an arith.extsi or arith.extf
+// operation as appropriate.
+static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
+ Location loc, Value source,
+ Value result) {
ShapedType resultTy = cast<ShapedType>(result.getType());
const int64_t resultRank = resultTy.getRank();
// Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({source}), result, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
- biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+ .getResult()
+ : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+ .getResult();
}
builder.create<linalg::YieldOp>(loc, biasVal);
})
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
- Type resultETy = resultTy.getElementType();
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
+ Type accETy = op.getAccType();
+ Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
+
// Get and verify zero points.
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, filteredDims);
+ loc, resultTy.getShape(), accETy, filteredDims);
Value broadcastBias =
- linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+ linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value conv = rewriter
.create<LinalgConvOp>(
- loc, resultTy, ValueRange{input, weight},
+ loc, accTy, ValueRange{input, weight},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
+ // We may need to truncate back to the result type if the accumulator was
+ // wider than the result.
+ if (resultTy != accTy)
+ conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+
rewriter.replaceOp(op, conv);
return success();
}
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
+ Type accETy = op.getAccType();
+
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
- resultETy);
+ accETy);
- auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto resultZeroAttr = rewriter.getZeroAttr(accETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, linalgConvTy.getShape(), resultETy, filteredDims);
+ loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
+ // We may need to truncate back to the result type if the accumulator was
+ // wider than the result.
+ if (accETy != resultETy)
+ conv = rewriter.create<tosa::CastOp>(
+ loc,
+ RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
+ resultETy),
+ conv);
+
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 19c12ba3edbd4..242772fe5cdcf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// -----
+// CHECK-LABEL: @conv2d_f16_f32_acc
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
// -----
+// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
// -----
+// CHECK-LABEL: @conv3d_f16_f32_acc
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
|
@llvm/pr-subscribers-mlir-tosa Author: Jack Frankland (FranklandJack) ChangesAdd support for lowering of convolution operations where the Add lit tests for the fp16 convolution with fp32 accumulator operators described above. Full diff: https://github.com/llvm/llvm-project/pull/134267.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
index fc1cad2423450..86f5e9baf4a94 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp
@@ -119,10 +119,11 @@ static AffineMap getBroadcastingMap(PatternRewriter &rewriter, Value source,
}
// Broadcast the source value to all the outer dimensions of the result value.
-// If required, the element type is expanded using an arith.extsi operation.
-static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
- Location loc, Value source,
- Value result) {
+// If required, the element type is expanded using an arith.extsi or arith.extf
+// operation as appropriate.
+static mlir::Value linalgBroadcastAndMaybeExt(PatternRewriter &rewriter,
+ Location loc, Value source,
+ Value result) {
ShapedType resultTy = cast<ShapedType>(result.getType());
const int64_t resultRank = resultTy.getRank();
// Creating maps for the input and output of the broacast-like generic op.
@@ -135,11 +136,16 @@ static mlir::Value linalgBroadcastAndMaybeExtSI(PatternRewriter &rewriter,
.create<linalg::GenericOp>(
loc, resultTy, ValueRange({source}), result, indexingMaps,
getNParallelLoopsAttrs(resultTy.getRank()),
- [](OpBuilder &builder, Location loc, ValueRange args) {
+ [&resultTy](OpBuilder &builder, Location loc, ValueRange args) {
Value biasVal = args[0];
Type resType = args[1].getType();
if (resType != biasVal.getType()) {
- biasVal = builder.create<arith::ExtSIOp>(loc, resType, biasVal);
+ biasVal =
+ resultTy.getElementType().isFloat()
+ ? builder.create<arith::ExtFOp>(loc, resType, biasVal)
+ .getResult()
+ : builder.create<arith::ExtSIOp>(loc, resType, biasVal)
+ .getResult();
}
builder.create<linalg::YieldOp>(loc, biasVal);
})
@@ -253,12 +259,14 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
ShapedType resultTy = cast<ShapedType>(op->getResult(0).getType());
Type inputETy = inputTy.getElementType();
- Type resultETy = resultTy.getElementType();
DenseI64ArrayAttr padAttr = op.getPadAttr();
DenseI64ArrayAttr strideTosaAttr = op.getStrideAttr();
DenseI64ArrayAttr dilationTosaAttr = op.getDilationAttr();
+ Type accETy = op.getAccType();
+ Type accTy = RankedTensorType::get(resultTy.getShape(), accETy);
+
// Get and verify zero points.
FailureOr<int64_t> maybeIZp = op.getInputZeroPoint();
if (failed(maybeIZp))
@@ -385,10 +393,10 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, resultTy.getShape(), resultETy, filteredDims);
+ loc, resultTy.getShape(), accETy, filteredDims);
Value broadcastBias =
- linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
+ linalgBroadcastAndMaybeExt(rewriter, loc, bias, biasEmptyTensor);
if (hasZp) {
auto iZp = rewriter.getI32IntegerAttr(inputZpVal);
@@ -410,10 +418,15 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
Value conv = rewriter
.create<LinalgConvOp>(
- loc, resultTy, ValueRange{input, weight},
+ loc, accTy, ValueRange{input, weight},
ValueRange{broadcastBias}, strideAttr, dilationAttr)
->getResult(0);
+ // We may need to truncate back to the result type if the accumulator was
+ // wider than the result.
+ if (resultTy != accTy)
+ conv = rewriter.create<tosa::CastOp>(loc, resultTy, conv);
+
rewriter.replaceOp(op, conv);
return success();
}
@@ -444,6 +457,8 @@ class DepthwiseConvConverter
auto strideTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("stride"));
auto dilationTosaAttr = cast<DenseI64ArrayAttr>(op->getAttr("dilation"));
+ Type accETy = op.getAccType();
+
if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "tosa.depthwise_conv ops require static shapes");
@@ -516,11 +531,11 @@ class DepthwiseConvConverter
ShapedType linalgConvTy =
RankedTensorType::get({resultShape[0], resultShape[1], resultShape[2],
weightShape[2], weightShape[3]},
- resultETy);
+ accETy);
- auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
+ auto resultZeroAttr = rewriter.getZeroAttr(accETy);
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
- loc, linalgConvTy.getShape(), resultETy, filteredDims);
+ loc, linalgConvTy.getShape(), accETy, filteredDims);
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
Value zeroTensor = rewriter
.create<linalg::FillOp>(loc, ValueRange{zero},
@@ -543,6 +558,15 @@ class DepthwiseConvConverter
ValueRange{zeroTensor}, strideAttr, dilationAttr)
.getResult(0);
+ // We may need to truncate back to the result type if the accumulator was
+ // wider than the result.
+ if (accETy != resultETy)
+ conv = rewriter.create<tosa::CastOp>(
+ loc,
+ RankedTensorType::get(cast<ShapedType>(conv.getType()).getShape(),
+ resultETy),
+ conv);
+
SmallVector<ReassociationExprs, 4> reassociationMap;
createDepthwiseConvCollapseMap(resultRank, reassociationMap, rewriter);
Value convReshape = rewriter.create<tensor::CollapseShapeOp>(
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
index 19c12ba3edbd4..242772fe5cdcf 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir
@@ -658,6 +658,20 @@ func.func @conv2d_quant(%arg0 : tensor<1x12x12x1xi8>, %arg1 : tensor<1024x3x3x1x
// -----
+// CHECK-LABEL: @conv2d_f16_f32_acc
+func.func @conv2d_f16_f32_acc(%input: tensor<1x49x42x27xf16>, %weights: tensor<28x3x3x27xf16>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_2d_nhwc_fhwc {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>) outs(%{{.*}} : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf16>
+ %0 = tosa.conv2d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf16>, tensor<28x3x3x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x45x40x28xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP0:.*]] = affine_map<(d0, d1, d2, d3) -> (d3)>
// CHECK: #[[$MAP1:.*]] = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
@@ -848,6 +862,18 @@ func.func @depthwise_int_conv_zero_zp(%arg0 : tensor<1x7x5x3xi8>, %arg1 : tensor
// -----
+// CHECK-LABEL: @depthwise_conv2d_f16_f32_acc
+func.func @depthwise_conv2d_f16_f32_acc(%arg0 : tensor<1x7x5x3xf16>, %arg1 : tensor<3x1x3x11xf16>, %arg2 : tensor<33xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: %[[CONV:.*]] = linalg.depthwise_conv_2d_nhwc_hwcm {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>) outs(%{{.*}} : tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x5x5x3x11xf32>) -> tensor<1x5x5x3x11xf16>
+ %2 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1> } : (tensor<1x7x5x3xf16>, tensor<3x1x3x11xf16>, tensor<33xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x5x5x33xf16>
+ return
+}
+
+// -----
+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
@@ -918,6 +944,20 @@ func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5
// -----
+// CHECK-LABEL: @conv3d_f16_f32_acc
+func.func @conv3d_f16_f32_acc(%input: tensor<1x49x48x47x27xf16>, %weights: tensor<28x3x4x5x27xf16>, %bias: tensor<28xf16>) -> () {
+ %input_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ %weight_zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf16>}> : () -> tensor<1xf16>
+ // CHECK: linalg.generic {{{.*}}} ins(%{{.*}} : tensor<28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>)
+ // CHECK: arith.extf %{{.*}} : f16 to f32
+ // CHECK: %[[CONV:.*]] = linalg.conv_3d_ndhwc_dhwcf {{{.*}}} ins(%{{.*}}, %{{.*}} : tensor<1x49x48x47x27xf16>, tensor<3x4x5x27x28xf16>) outs(%{{.*}} : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
+ // CHECK: tosa.cast %[[CONV]] : (tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf16>
+ %0 = tosa.conv3d %input, %weights, %bias, %input_zp, %weight_zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf16>, tensor<28x3x4x5x27xf16>, tensor<28xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<1x47x45x43x28xf16>
+ return
+}
+
+// -----
+
// CHECK-LABEL: @test_transpose
// CHECK-SAME: (%[[ARG0:.+]]: tensor<1x2x3xi32>)
func.func @test_transpose(%arg0: tensor<1x2x3xi32>) -> () {
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - thanks! Probably needs someone more familiar with linalg to have a look as well
Add support for lowering of convolution operations where the
acc_type
attribute differs from the result type of the operation. The only case of this in for convolutions in the TOSA-v1.0 specification is an fp16 convolution which internally uses an fp32 accumulator; all other operations have accumulator types that match their output/result types.Add lit tests for the fp16 convolution with fp32 accumulator operators described above.