Skip to content

Commit 293c21d

Browse files
authored
[mlir][tosa] Improve lowering of tosa.conv2d (#74143)
The existing lowering of tosa.conv2d emits a separate linalg.generic operator to add the bias after computing the computation. This change eliminates that additional step by using the generated linalg.conv_2d_* operator by using the bias value as the input to the linalg.conv_2d operation. Rather than: %init = tensor.empty() %conv = linalg.conv_2d ins(%A, %B) %outs(%init) %init = tensor.empty() %bias = linalg.generic ins(%conv, %bias) outs(%init2) { // perform add operation } The lowering now produces: %init = tensor.empty() %bias_expanded = linalg.broadcast ins(%bias) outs(%init) %conv = linalg.conv_2d ins(%A, %B) %outs(%bias) This is the same strategy as #73049 applied to convolutions.
1 parent 7081585 commit 293c21d

File tree

2 files changed

+68
-103
lines changed

2 files changed

+68
-103
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 9 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -344,15 +344,6 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
344344
weightPermValue);
345345
}
346346

347-
auto resultZeroAttr = rewriter.getZeroAttr(resultETy);
348-
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
349-
loc, resultTy.getShape(), resultETy, filteredDims);
350-
Value zero = rewriter.create<arith::ConstantOp>(loc, resultZeroAttr);
351-
Value zeroTensor = rewriter
352-
.create<linalg::FillOp>(loc, ValueRange{zero},
353-
ValueRange{emptyTensor})
354-
.result();
355-
356347
// Extract the attributes for convolution.
357348
ArrayRef<int64_t> stride = strideTosaAttr;
358349
ArrayRef<int64_t> dilation = dilationTosaAttr;
@@ -361,57 +352,38 @@ class ConvConverter : public OpConversionPattern<TosaConvOp> {
361352
auto strideAttr = rewriter.getI64TensorAttr(stride);
362353
auto dilationAttr = rewriter.getI64TensorAttr(dilation);
363354

364-
// Create maps for the bias broadcasting
365-
SmallVector<AffineMap, 4> indexingMaps;
366-
indexingMaps.push_back(AffineMap::get(
367-
/*dimCount=*/resultTy.getRank(), /*symbolCount=*/0,
368-
{rewriter.getAffineDimExpr(resultTy.getRank() - 1)},
369-
rewriter.getContext()));
370-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
371-
indexingMaps.push_back(rewriter.getMultiDimIdentityMap(resultTy.getRank()));
372-
373355
Value biasEmptyTensor = rewriter.create<tensor::EmptyOp>(
374356
loc, resultTy.getShape(), resultETy, filteredDims);
375357

358+
Value broadcastBias =
359+
linalgBroadcastAndMaybeExtSI(rewriter, loc, bias, biasEmptyTensor);
360+
376361
if (isQuantized) {
377362
auto quantizationInfo = *op.getQuantizationInfo();
378363
auto iZp = rewriter.getI32IntegerAttr(quantizationInfo.getInputZp());
379364
auto kZp = rewriter.getI32IntegerAttr(quantizationInfo.getWeightZp());
380365

381366
auto iZpVal = rewriter.create<arith::ConstantOp>(loc, iZp);
382367
auto kZpVal = rewriter.create<arith::ConstantOp>(loc, kZp);
368+
383369
Value conv =
384370
rewriter
385371
.create<LinalgConvQOp>(
386372
loc, resultTy, ValueRange{input, weight, iZpVal, kZpVal},
387-
ValueRange{zeroTensor}, strideAttr, dilationAttr)
373+
ValueRange{broadcastBias}, strideAttr, dilationAttr)
388374
->getResult(0);
389-
Value result = linalgIntBroadcastExtSIAdd(rewriter, loc, bias, conv,
390-
biasEmptyTensor, indexingMaps);
391-
rewriter.replaceOp(op, result);
375+
376+
rewriter.replaceOp(op, conv);
392377
return success();
393378
}
394379

395380
Value conv = rewriter
396381
.create<LinalgConvOp>(
397382
loc, resultTy, ValueRange{input, weight},
398-
ValueRange{zeroTensor}, strideAttr, dilationAttr)
383+
ValueRange{broadcastBias}, strideAttr, dilationAttr)
399384
->getResult(0);
400385

401-
Value result =
402-
rewriter
403-
.create<linalg::GenericOp>(
404-
loc, resultTy, ValueRange({bias, conv}), biasEmptyTensor,
405-
indexingMaps, getNParallelLoopsAttrs(resultTy.getRank()),
406-
[&](OpBuilder &nestedBuilder, Location nestedLoc,
407-
ValueRange args) {
408-
Value added = nestedBuilder.create<arith::AddFOp>(
409-
loc, args[0], args[1]);
410-
nestedBuilder.create<linalg::YieldOp>(nestedLoc, added);
411-
})
412-
.getResult(0);
413-
414-
rewriter.replaceOp(op, result);
386+
rewriter.replaceOp(op, conv);
415387
return success();
416388
}
417389
};

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-named.mlir

Lines changed: 59 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -378,16 +378,14 @@ func.func @avg_pool_dyn(%arg0: tensor<?x6x34x62xf32>) -> (tensor<?x5x33x62xf32>)
378378
func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi8>, %bias: tensor<28xi8>) -> () {
379379
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
380380
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x1x1x27xi8>, tensor<4xi64>) -> tensor<1x1x27x28xi8>
381-
// CHECK: %[[M_IN:.+]] = tensor.empty()
382-
// CHECK: %[[CST:.+]] = arith.constant 0
383-
// CHECK: %[[FILL:.+]] = linalg.fill
384-
// CHECK: %[[B_IN:.+]] = tensor.empty()
385-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[FILL]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
386-
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32_0, %c0_i32_1 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
387-
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xi8>, tensor<1x45x40x28xi32>) outs(%[[B_IN]] : tensor<1x45x40x28xi32>)
381+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xi32>
382+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xi8>) outs(%[[INIT]] : tensor<1x45x40x28xi32>) {
388383
// CHECK: arith.extsi
389-
// CHECK: arith.addi
390384
// CHECK: linalg.yield
385+
// CHECK: } -> tensor<1x45x40x28xi32>
386+
// CHECK: linalg.conv_2d_nhwc_fhwc_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1, %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, i32, i32) outs(%[[BROADCAST]] : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
387+
// HWCF: linalg.conv_2d_nhwc_hwcf_q {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]], %c0_i32, %c0_i32_0 : tensor<1x49x42x27xi8>, tensor<1x1x27x28xi8>, i32, i32) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xi32>) -> tensor<1x45x40x28xi32>
388+
391389
%0 = tosa.conv2d %input, %weights, %bias {dilation = array<i64: 2, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x49x42x27xi8>, tensor<28x1x1x27xi8>, tensor<28xi8>) -> tensor<1x45x40x28xi32>
392390
return
393391
}
@@ -401,15 +399,14 @@ func.func @conv2d_i8(%input: tensor<1x49x42x27xi8>, %weights: tensor<28x1x1x27xi
401399
func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
402400
// HWCF: %[[TRANSPOSE_DIMS:.+]] = arith.constant dense<[1, 2, 3, 0]> : tensor<4xi64>
403401
// HWCF: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[TRANSPOSE_DIMS]] : (tensor<28x3x3x27xf32>, tensor<4xi64>) -> tensor<3x3x27x28xf32>
404-
// CHECK: %[[M_IN:.+]] = tensor.empty()
405-
// CHECK: %[[CST:.+]] = arith.constant 0
406-
// CHECK: %[[FILL:.+]] = linalg.fill
407-
// CHECK: %[[B_IN:.+]] = tensor.empty()
408-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x45x40x28xf32>)
409-
// HWCF: %[[CONV:.+]] = linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
410-
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x45x40x28xf32>) outs(%[[B_IN]] : tensor<1x45x40x28xf32>)
411-
// CHECK: arith.addf
402+
403+
// CHECK: %[[INIT:.+]] = tensor.empty() : tensor<1x45x40x28xf32>
404+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x45x40x28xf32>) {
412405
// CHECK: linalg.yield
406+
// CHECK: } -> tensor<1x45x40x28xf32>
407+
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%1 : tensor<1x45x40x28xf32>) -> tensor<1x45x40x28xf32>
408+
409+
// HWCF: linalg.conv_2d_nhwc_hwcf {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x42x27xf32>, tensor<3x3x27x28xf32>) outs(%{{[a-zA-Z0-9_]*}} : tensor<1x45x40x28xf32>
413410
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x45x40x28xf32>
414411
return
415412
}
@@ -421,16 +418,14 @@ func.func @conv2d_f32(%input: tensor<1x49x42x27xf32>, %weights: tensor<28x3x3x27
421418

422419
// CHECK-LABEL: @conv2d_dyn
423420
func.func @conv2d_dyn(%input: tensor<?x49x42x27xf32>, %weights: tensor<28x3x3x27xf32>, %bias: tensor<28xf32>) -> () {
424-
// CHECK: %[[C0:.+]] = arith.constant 0
425-
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]]
426-
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[BATCH]])
427-
// CHECK: %[[CST:.+]] = arith.constant 0
428-
// CHECK: %[[FILL:.+]] = linalg.fill
429-
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[BATCH]])
430-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<?x45x40x28xf32>)
431-
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<?x45x40x28xf32>) outs(%[[B_IN]] : tensor<?x45x40x28xf32>)
432-
// CHECK: %[[ADD:.+]] = arith.addf
433-
// CHECK: linalg.yield %[[ADD]] : f32
421+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
422+
// CHECK: %[[BATCH:.+]] = tensor.dim %arg0, %[[C0]] : tensor<?x49x42x27xf32>
423+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x45x40x28xf32>
424+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<?x45x40x28xf32>) {
425+
// CHECK: ^bb0(%[[IN:.+]]: f32, %{{.+}}: f32):
426+
// CHECK: linalg.yield %[[IN]] : f32
427+
// CHECK: } -> tensor<?x45x40x28xf32>
428+
// CHECK: %2 = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>) outs(%[[BROADCAST]] : tensor<?x45x40x28xf32>) -> tensor<?x45x40x28xf32>
434429
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<?x49x42x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<?x45x40x28xf32>
435430
return
436431
}
@@ -481,14 +476,12 @@ func.func @conv2d_dyn_w_h(%input: tensor<1x?x?x27xf32>, %weights: tensor<28x3x3x
481476
// CHECK: %[[W_OUT:.+]] = arith.addi %[[DIVIDED_0]], %[[ONE_0]] : index
482477

483478
// Running convolution
484-
// CHECK: %[[M_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
485-
// CHECK: %[[CST:.+]] = arith.constant 0
486-
// CHECK: %[[FILL:.+]] = linalg.fill
487-
// CHECK: %[[B_IN:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]])
488-
// CHECK: %[[CONV:.+]] = linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%[[FILL]] : tensor<1x?x?x28xf32>)
489-
// CHECK: %[[B:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2, %[[CONV]] : tensor<28xf32>, tensor<1x?x?x28xf32>) outs(%[[B_IN]] : tensor<1x?x?x28xf32>)
490-
// CHECK: %[[ADD:.+]] = arith.addf
491-
// CHECK: linalg.yield %[[ADD]] : f32
479+
// CHECK: %[[INIT:.+]] = tensor.empty(%[[H_OUT]], %[[W_OUT]]) : tensor<1x?x?x28xf32>
480+
// CHECK: %[[BROADCAST:.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%arg2 : tensor<28xf32>) outs(%[[INIT]] : tensor<1x?x?x28xf32>) {
481+
// CHECK: linalg.yield
482+
// CHECK: } -> tensor<1x?x?x28xf32>
483+
// CHECK: linalg.conv_2d_nhwc_fhwc {dilations = dense<[2, 1]> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>) outs(%17 : tensor<1x?x?x28xf32>) -> tensor<1x?x?x28xf32>
484+
492485
%0 = tosa.conv2d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 2, 1>} : (tensor<1x?x?x27xf32>, tensor<28x3x3x27xf32>, tensor<28xf32>) -> tensor<1x?x?x28xf32>
493486
return
494487
}
@@ -678,52 +671,52 @@ func.func @depthwise_conv2d_dyn_w_h(%arg0: tensor<2x?x?x3xf32>, %arg1: tensor<3x
678671

679672
// -----
680673

674+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
675+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
676+
681677
// CHECK-LABEL: @conv3d_f32
682678
func.func @conv3d_f32(%input: tensor<1x49x48x47x27xf32>, %weights: tensor<28x3x4x5x27xf32>, %bias: tensor<28xf32>) -> () {
683-
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
684-
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
685-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
686-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
687-
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>)
688-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
689-
// CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf
679+
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
680+
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
681+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xf32>
682+
// CHECK: %[[BROADCAST:.+]] = linalg.generic
683+
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
684+
// CHECK-SAME: ins(%arg2 : tensor<28xf32>) outs(%1 : tensor<1x47x45x43x28xf32>) {
685+
// CHECK: ^bb0(%[[IN:.+]]: f32, %[[OUT:.+]]: f32):
686+
// CHECK: linalg.yield %[[IN]] : f32
687+
// CHECK: } -> tensor<1x47x45x43x28xf32>
688+
// CHECK: linalg.conv_3d_ndhwc_dhwcf
690689
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
691690
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]] : tensor<1x49x48x47x27xf32>, tensor<3x4x5x27x28xf32>)
692-
// CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
693-
// CHECK: %[[GENERIC:.+]] = linalg.generic
694-
// CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
695-
// CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xf32>, tensor<1x47x45x43x28xf32>)
696-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xf32>) {
697-
// CHECK: ^bb0(%[[A1:.+]]: f32, %[[A2:.+]]: f32, %{{.+}}: f32):
698-
// CHECK: %[[ADD:.+]] = arith.addf %[[A1]], %[[A2]] : f32
699-
// CHECK: linalg.yield %[[ADD]]
691+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xf32>) -> tensor<1x47x45x43x28xf32>
700692
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xf32>, tensor<28x3x4x5x27xf32>, tensor<28xf32>) -> tensor<1x47x45x43x28xf32>
701693
return
702694
}
703695

704696
// -----
705697

698+
// CHECK: #[[$MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d4)>
699+
// CHECK: #[[$MAP2:.+]] = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
700+
706701
// CHECK-LABEL: @conv3d_i8
707702
func.func @conv3d_i8(%input: tensor<1x49x48x47x27xi8>, %weights: tensor<28x3x4x5x27xi8>, %bias: tensor<28xi32>) -> () {
708-
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
709-
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
710-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
711-
// CHECK-DAG: %[[ZERO:.+]] = arith.constant 0
712-
// CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : i32) outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>)
713-
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty()
714-
// CHECK-DAG: %[[IZP:.+]] = arith.constant -128 : i32
715-
// CHECK-DAG: %[[FZP:.+]] = arith.constant 42 : i32
716-
// CHECK-DAG: %[[CONV3D:.+]] = linalg.conv_3d_ndhwc_dhwcf_q
703+
// CHECK-DAG: %[[PERMS:.+]] = arith.constant dense<[1, 2, 3, 4, 0]>
704+
// CHECK-DAG: %[[TRANSPOSE:.+]] = tosa.transpose %arg1, %[[PERMS]]
705+
// CHECK-DAG: %[[INIT:.+]] = tensor.empty() : tensor<1x47x45x43x28xi32>
706+
// CHECK: %[[BROADCAST:.+]] = linalg.generic
707+
// CHECK-SAME: {indexing_maps = [#[[$MAP1]], #[[$MAP2]]], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
708+
// CHECK-SAME: ins(%arg2 : tensor<28xi32>)
709+
// CHECK-SAME: outs(%[[INIT]] : tensor<1x47x45x43x28xi32>) {
710+
// CHECK: ^bb0(%[[IN:.+]]: i32, %[[OUT:.+]]: i32):
711+
// CHECK: linalg.yield %[[IN]] : i32
712+
// CHECK: } -> tensor<1x47x45x43x28xi32>
713+
// CHECK: %[[IZP:.+]] = arith.constant -128 : i32
714+
// CHECK: %[[FZP:.+]] = arith.constant 42 : i32
715+
// CHECK: linalg.conv_3d_ndhwc_dhwcf_q
717716
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}
718717
// CHECK-SAME: ins(%arg0, %[[TRANSPOSE]], %[[IZP]], %[[FZP]] : tensor<1x49x48x47x27xi8>, tensor<3x4x5x27x28xi8>, i32, i32)
719-
// CHECK-SAME: outs(%[[FILL]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
720-
// CHECK: %[[GENERIC:.+]] = linalg.generic
721-
// CHECK-SAME: {indexing_maps = [#map, #map1, #map1], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
722-
// CHECK-SAME: ins(%arg2, %[[CONV3D]] : tensor<28xi32>, tensor<1x47x45x43x28xi32>)
723-
// CHECK-SAME: outs(%[[EMPTY]] : tensor<1x47x45x43x28xi32>) {
724-
// CHECK: ^bb0(%[[A1:.+]]: i32, %[[A2:.+]]: i32, %{{.+}}: i32):
725-
// CHECK: %[[ADD:.+]] = arith.addi %[[A1]], %[[A2]] : i32
726-
// CHECK: linalg.yield %[[ADD]]
718+
// CHECK-SAME: outs(%[[BROADCAST]] : tensor<1x47x45x43x28xi32>) -> tensor<1x47x45x43x28xi32>
719+
727720
%0 = tosa.conv3d %input, %weights, %bias {pad = array<i64: 0, 0, 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = -128, weight_zp = 42>, stride = array<i64: 1, 1, 1>, dilation = array<i64: 1, 1, 1>} : (tensor<1x49x48x47x27xi8>, tensor<28x3x4x5x27xi8>, tensor<28xi32>) -> tensor<1x47x45x43x28xi32>
728721
return
729722
}

0 commit comments

Comments
 (0)