Skip to content

[mlir][tosa]Fix Rescale shift attr data type #71084

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 10, 2024

Conversation

Tai78641
Copy link
Contributor

@Tai78641 Tai78641 commented Nov 2, 2023

Change Rescale shift attribute to be DenseI8ArrayAttr to match spec (instead of DenseI32ArrayAttr)

This replaces https://reviews.llvm.org/D157439

Change Rescale shift attribute to be DenseI8ArrayAttr to match spec
(instead of DenseI32ArrayAttr)

Signed-off-by: Tai Ly <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Nov 2, 2023

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tosa

Author: Tai Ly (Tai78641)

Changes

Change Rescale shift attribute to be DenseI8ArrayAttr to match spec (instead of DenseI32ArrayAttr)

This replaces https://reviews.llvm.org/D157439


Full diff: https://github.com/llvm/llvm-project/pull/71084.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+14-14)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+1-1)
  • (modified) mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp (+3-2)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index 81b9e93c2095f57..128a2273f78506a 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -1840,7 +1840,7 @@ def Tosa_RescaleOp: Tosa_Op<"rescale", [Pure,
     I32Attr:$input_zp,
     I32Attr:$output_zp,
     DenseI32ArrayAttr:$multiplier,
-    DenseI32ArrayAttr:$shift,
+    DenseI8ArrayAttr:$shift,
     BoolAttr:$scale32,
     BoolAttr:$double_round,
     BoolAttr:$per_channel
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index aa53b366f6da684..ac14829755cadea 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -93,7 +93,7 @@ func.func @test_add_0d(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
   // CHECK:   linalg.yield [[ADDF]] : f32
   // CHECK: } -> tensor<f32>
   %0 = tosa.add %arg0, %arg1 : (tensor<f32>, tensor<f32>) -> tensor<f32>
-  
+
   // CHECK: return [[RESULT]] : tensor<f32>
   return %0 : tensor<f32>
 }
@@ -223,7 +223,7 @@ func.func @test_add_1d_broadcast_static_to_static(%arg0: tensor<1xf32>, %arg1: t
   // CHECK:   linalg.yield %[[VAL_4]] : f32
   // CHECK: } -> tensor<3xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<1xf32>, tensor<3xf32>) -> tensor<3xf32>
-  
+
   // CHECK: return %[[RESULT]] : tensor<3xf32>
   return %0 : tensor<3xf32>
 }
@@ -352,7 +352,7 @@ func.func @test_add_2d_different_ranks(%arg0: tensor<3x4xf32>, %arg1: tensor<2x3
   // CHECK:   linalg.yield %[[VAL_4]] : f32
   // CHECK: } -> tensor<2x3x4xf32>
   %0 = tosa.add %arg0, %arg1 : (tensor<3x4xf32>, tensor<2x3x4xf32>) -> tensor<2x3x4xf32>
-  
+
   // CHECK: return %[[RESULT]] : tensor<2x3x4xf32>
   return %0 : tensor<2x3x4xf32>
 }
@@ -1119,7 +1119,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
 
   // CHECK: [[C0:%.+]] = arith.constant 19689
   // CHECK: [[C1:%.+]] = arith.constant 15
@@ -1141,7 +1141,7 @@ func.func @rescale_i8(%arg0 : tensor<2xi8>) -> () {
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: [[CAST:%.+]] = builtin.unrealized_conversion_cast [[TRUNC]] : i8 to ui8
   // CHECK: linalg.yield [[CAST]]
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xi8>) -> tensor<2xui8>
 
   // CHECK: return
   return
@@ -1158,13 +1158,13 @@ func.func @rescale_i8_dyn_batch(%arg0 : tensor<?x2xi8>) -> () {
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xi8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xi8>)
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xi8>
 
   // CHECK: %[[C0:.+]] = arith.constant 0
   // CHECK: %[[BATCH:.+]] = tensor.dim %[[ARG0]], %[[C0]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[BATCH]]) : tensor<?x2xui8>
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP0]], #[[$MAP0]]], iterator_types = ["parallel", "parallel"]} ins(%[[ARG0]] : tensor<?x2xi8>) outs(%[[INIT]] : tensor<?x2xui8>)
-  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
+  %1 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<?x2xi8>) -> tensor<?x2xui8>
 
   return
 }
@@ -1182,7 +1182,7 @@ func.func @rescale_dyn(%arg0 : tensor<1x?x?x32xi32>) -> () {
   // CHECK: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]]
   // CHECK: %[[INIT:.+]] = tensor.empty(%[[DIM1]], %[[DIM2]])
   // CHECK: [[GENERIC:%.+]] = linalg.generic {indexing_maps = [#[[$MAP1]], #[[$MAP1]]], iterator_types = ["parallel", "parallel", "parallel", "parallel"]} ins(%[[ARG0]] : tensor<1x?x?x32xi32>) outs(%[[INIT]] : tensor<1x?x?x32xi8>)
-  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i32: 38>} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
+  %0 = tosa.rescale %arg0 {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 1376784203>, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array<i8: 38>} : (tensor<1x?x?x32xi32>) -> tensor<1x?x?x32xi8>
   return
 }
 
@@ -1213,7 +1213,7 @@ func.func @rescale_ui8(%arg0 : tensor<2xui8>) -> () {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 17 : i32, output_zp = 22 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<2xui8>) -> tensor<2xi8>
 
   return
 }
@@ -1245,7 +1245,7 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
   // CHECK-DAG: [[BOUNDED:%.+]] = arith.select [[MAXLT]], [[CMAX]], [[LOWER]]
   // CHECK-DAG: [[TRUNC:%.+]] = arith.trunci [[BOUNDED]]
   // CHECK-DAG: linalg.yield [[TRUNC]]
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i32: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43, 44>, shift = array<i8: 14, 15, 64>, scale32 = false, double_round = false, per_channel = false} : (tensor<3xi8>) -> tensor<3xi8>
 
   // CHECK: return [[GENERIC]]
   return %0 : tensor<3xi8>
@@ -1256,18 +1256,18 @@ func.func @rescale_per_channel(%arg0 : tensor<3xi8>) -> (tensor<3xi8>) {
 // CHECK-LABEL: @rescaleDoubleRound
 func.func @rescaleDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
-  // CHECK: tosa.apply_scale 
+  // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = true}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i32: 33>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 33>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
 // CHECK-LABEL: @rescaleUnnecessaryDoubleRound
 func.func @rescaleUnnecessaryDoubleRound(%arg0 : tensor<2xi8>) -> (tensor<2xi8>) {
   // CHECK: linalg.generic
-  // CHECK: tosa.apply_scale 
+  // CHECK: tosa.apply_scale
   // CHECK-SAME:  {double_round = false}
-  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i32: 15>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
+  %0 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 19689>, shift = array<i8: 15>, scale32 = true, double_round = true, per_channel = false} : (tensor<2xi8>) -> tensor<2xi8>
   return %0 : tensor<2xi8>
 }
 
diff --git a/mlir/test/Dialect/Tosa/ops.mlir b/mlir/test/Dialect/Tosa/ops.mlir
index a3e2b66e0305281..64c9386c3d26f5f 100644
--- a/mlir/test/Dialect/Tosa/ops.mlir
+++ b/mlir/test/Dialect/Tosa/ops.mlir
@@ -64,7 +64,7 @@ func.func @test_conv2d_q8xi4(%arg0: tensor<1x11x11x3xi8>) -> tensor<1x1x1x3xi8>
   %0 = "tosa.const"() {value = dense<0> : tensor<3x11x11x3xi4>} : () -> tensor<3x11x11x3xi4>
   %1 = "tosa.const"() {value = dense<[12, 23, 55]> : tensor<3xi32>} : () -> tensor<3xi32>
   %2 = "tosa.conv2d"(%arg0, %0, %1) {dilation = array<i64: 1, 1>, pad = array<i64: 0, 0, 0, 0>, quantization_info = #tosa.conv_quant<input_zp = 0, weight_zp = 0>, stride = array<i64: 1, 1>} : (tensor<1x11x11x3xi8>, tensor<3x11x11x3xi4>, tensor<3xi32>) -> tensor<1x1x1x3xi32>
-  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i32: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
+  %3 = "tosa.rescale"(%2) {double_round = true, input_zp = 0 : i32, multiplier = array<i32: 2026291432, 1079222024, 1693132724>, output_zp = 27 : i32, per_channel = true, scale32 = true, shift = array<i8: 37, 36, 37>} : (tensor<1x1x1x3xi32>) -> tensor<1x1x1x3xi8>
   return %3 : tensor<1x1x1x3xi8>
 }
 
@@ -562,7 +562,7 @@ func.func @test_cast3(%arg0: tensor<13x21x3xi32>) -> tensor<13x21x3x!quant.unifo
 // -----
 // CHECK-LABEL: rescale
 func.func @test_rescale(%arg0: tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>> {
-    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i32: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
+    %0 = tosa.rescale %arg0 {double_round = false, input_zp = 127 : i32, multiplier = array<i32: 1073741824>, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array<i8: 30>} : (tensor<13x21x3x!quant.uniform<u8:f32, 0.015655439347028732:127>>) -> tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
     return %0 : tensor<13x21x3x!quant.uniform<i8:f32, 0.015655439347028732:-1>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
index 7af66ae1dbc90f0..ca30c8127ef0cf1 100644
--- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir
@@ -94,7 +94,7 @@ func.func @test_unary_i32(%arg0 : tensor<4xi32>) -> () {
   %5 = tosa.reverse %arg0 { axis = 0 : i32 } : (tensor<4xi32>) -> tensor<?xi32>
 
   // CHECK: tosa.rescale %arg0 {{.+}} : (tensor<4xi32>) -> tensor<4xi16>
-  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i32: 14, 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>) -> tensor<*xi16>
+  %6 = tosa.rescale %arg0 {input_zp = 243 : i32, output_zp = 252 : i32, multiplier = array<i32: 42, 43>, shift = array<i8: 14, 15>, scale32 = false, double_round = false, per_channel = false} : (tensor<4xi32>) -> tensor<*xi16>
 
   // CHECK: tosa.identity %arg0 : (tensor<4xi32>) -> tensor<4xi32>
   %7 = tosa.identity %arg0 : (tensor<4xi32>) -> tensor<?xi32>
diff --git a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
index 9642301e8111c1c..e5a3e2b6fccaa32 100644
--- a/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
+++ b/mlir/test/lib/Dialect/Tosa/TosaTestPasses.cpp
@@ -169,8 +169,9 @@ ConvertTosaConv2DOp::matchAndRewrite(Operation *op,
       op->getLoc(), outputType, newTosaConv2DOp.getResult(),
       rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(outputZp),
       rewriter.getDenseI32ArrayAttr({multiplier}),
-      rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(true),
-      rewriter.getBoolAttr(true), rewriter.getBoolAttr(false));
+      rewriter.getDenseI8ArrayAttr({static_cast<int8_t>(shift)}),
+      rewriter.getBoolAttr(true), rewriter.getBoolAttr(true),
+      rewriter.getBoolAttr(false));
 
   rewriter.replaceOp(op, {newTosaRescaleOp.getResult()});
   return success();

Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jpienaar
Copy link
Member

jpienaar commented Jan 2, 2024

Note https://github.com/tensorflow/tensorflow/pull/61501 modulo one missing cast to int8_t should be ready to go if you give headsup.

@GeorgeARM GeorgeARM merged commit af78e5d into llvm:main Jan 10, 2024
hanhanW added a commit to iree-org/llvm-project that referenced this pull request Jan 22, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 23, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 24, 2024
MaheshRavishankar pushed a commit to iree-org/llvm-project that referenced this pull request Jan 26, 2024
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
Change Rescale shift attribute to be DenseI8ArrayAttr to match spec
(instead of DenseI32ArrayAttr)

This replaces https://reviews.llvm.org/D157439

Signed-off-by: Tai Ly <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants