1
1
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2
- # All rights reserved.
3
2
#
4
3
# This source code is licensed under the BSD-style license found in the
5
4
# LICENSE file in the root directory of this source tree.
18
17
19
18
from executorch .exir .pass_base import ExportPass , PassResult
20
19
from torch .fx import GraphModule
20
+
21
21
from torch .library import impl , Library
22
22
23
23
lib = Library ("tosa" , "DEF" )
26
26
27
27
@impl (lib , "_table" )
28
28
def _table_impl (* args , ** kwargs ): # pyre-ignore
29
- return args [0 ]
29
+ in_dtype = args [0 ].dtype
30
+ if in_dtype == torch .int8 :
31
+ return args [0 ]
32
+ return args [0 ].to (dtype = torch .int32 )
30
33
31
34
32
35
class InsertTableOpsPass (ExportPass ):
@@ -59,29 +62,89 @@ def register_buffer(self, buffer_name: str, buffer: torch.Tensor) -> None:
59
62
"""
60
63
self .exported_program .state_dict [buffer_name ] = buffer
61
64
62
- def generate_table_values (
65
+ def generate_8bit_table_values (
63
66
self ,
64
67
torch_op : Callable [[torch .Tensor ], torch .Tensor ],
65
68
in_quantargs : QuantArgs ,
66
69
out_quantargs : QuantArgs ,
67
- ) -> torch .Tensor :
70
+ ) -> tuple [torch .Tensor , int ]:
71
+ """Compute LUT values for a INT8 TOSA.TABLE. Also returns 0 since no shifting is required after 8bit table.
72
+ The INT8 table is a simple 256 value 1-1 LUT.
73
+ """
74
+
68
75
def f (x : torch .Tensor ) -> torch .Tensor :
69
76
x = in_quantargs .dequantize_value (x )
70
77
x = torch_op (x )
71
78
return out_quantargs .quantize_value (x )
72
79
73
- input_dtype = in_quantargs .dtype
74
- steps = in_quantargs .qmax - in_quantargs .qmin + 1
75
- return f (
80
+ return (
81
+ f (
82
+ torch .linspace (
83
+ start = in_quantargs .qmin ,
84
+ end = in_quantargs .qmax ,
85
+ steps = 256 ,
86
+ # use torch.int64 to avoid overflow when dequantizing (subtracting zp).
87
+ # e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
88
+ dtype = torch .int64 ,
89
+ )
90
+ ).to (dtype = torch .int8 ),
91
+ 0 ,
92
+ )
93
+
94
+ def generate_16_bit_table_values (
95
+ self ,
96
+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
97
+ in_quantargs : QuantArgs ,
98
+ out_quantargs : QuantArgs ,
99
+ ) -> tuple [torch .Tensor , int ]:
100
+ """Compute LUT values for a INT16 TOSA.TABLE with 32 bit output (in practice 23 bit, see specification).
101
+ The output of the the table will have 7 fractional bits, which means the output will interpreted as
102
+ x128 times too large unless accounted for. Right shift the table values to fit
103
+ in 16 bits. Return a lshift of the right shift - 7 due to the fractional bits.
104
+ """
105
+
106
+ def f (x : torch .Tensor ) -> torch .Tensor :
107
+ # Dont use the 7 LSBs
108
+ x = in_quantargs .dequantize_value ((x & ~ 0x7F ))
109
+ x = torch_op (x )
110
+ return out_quantargs .quantize_value (x )
111
+
112
+ lut_values = f (
76
113
torch .linspace (
77
114
start = in_quantargs .qmin ,
78
- end = in_quantargs .qmax ,
79
- steps = steps ,
115
+ end = in_quantargs .qmax + 1 ,
116
+ steps = 513 ,
80
117
# use torch.int64 to avoid overflow when dequantizing (subtracting zp).
81
118
# e.g. torch.tensor(-50, dtype=torch.int8) - 100 == torch.tensor(106, dtype=torch.int8)
82
119
dtype = torch .int64 ,
83
120
)
84
- ).to (dtype = input_dtype )
121
+ )
122
+ # Calculate how much we need to shift table values to fit in 16 bits
123
+ # ceil(log2(max absolute table value)) + 1 bit for signedness - 16
124
+ # Note: for out_quantargs.dtype=torch.int16, rshift == 0.
125
+ rshift = int (torch .ceil (torch .log2 (lut_values .abs ().max ()))) + 1 - 16
126
+ lut_values = lut_values >> rshift
127
+ return lut_values .to (dtype = torch .int16 ), rshift - 7
128
+
129
+ def generate_table_values (
130
+ self ,
131
+ torch_op : Callable [[torch .Tensor ], torch .Tensor ],
132
+ in_quantargs : QuantArgs ,
133
+ out_quantargs : QuantArgs ,
134
+ ) -> tuple [torch .Tensor , int ]:
135
+ match out_quantargs .dtype :
136
+ case torch .int8 :
137
+ return self .generate_8bit_table_values (
138
+ torch_op , in_quantargs , out_quantargs
139
+ )
140
+ case torch .int16 | torch .int32 :
141
+ return self .generate_16_bit_table_values (
142
+ torch_op , in_quantargs , out_quantargs
143
+ )
144
+ case _:
145
+ raise ValueError (
146
+ f"Unsupported output dtype for table: { out_quantargs .dtype } "
147
+ )
85
148
86
149
def call (self , graph_module : GraphModule ) -> PassResult :
87
150
modified = False
@@ -100,10 +163,12 @@ def call(self, graph_module: GraphModule) -> PassResult:
100
163
op_target = torch .ops .tosa ._table .default ,
101
164
args = (node .args [0 ],),
102
165
)
166
+ output_node = table_node
103
167
assert len (input_qparams ) == 1
104
168
assert len (output_qparams ) == 1
105
- # Generate table buffer
106
- buffer = self .generate_table_values (
169
+
170
+ # Generate table buffer and how much to lshift the table output.
171
+ buffer , lshift = self .generate_table_values (
107
172
torch_op = self .table_ops [node .target ],
108
173
in_quantargs = input_qparams [0 ],
109
174
out_quantargs = output_qparams [0 ],
@@ -114,10 +179,20 @@ def call(self, graph_module: GraphModule) -> PassResult:
114
179
self .register_buffer (
115
180
buffer_name = table_node .name .replace ("_default" , "" ), buffer = buffer
116
181
)
117
- node .replace_all_uses_with (table_node )
182
+
183
+ if lshift != 0 :
184
+ scale = 2.0 ** lshift
185
+ rescale_node = create_node (
186
+ graph = graph_module .graph ,
187
+ op_target = torch .ops .tosa ._rescale .default ,
188
+ args = (table_node , output_qparams [0 ].dtype , scale , 0 , 0 ),
189
+ )
190
+ output_node = rescale_node
191
+
192
+ node .replace_all_uses_with (output_node )
118
193
graph_module .graph .erase_node (node )
119
- table_node .meta ["input_qparams" ] = input_qparams
120
- table_node .meta ["output_qparams" ] = output_qparams
194
+ output_node .meta ["input_qparams" ] = input_qparams
195
+ output_node .meta ["output_qparams" ] = output_qparams
121
196
modified = True
122
197
123
198
if modified :
0 commit comments