Skip to content

Commit 8542680

Browse files
author
morelos
committed
[ET-VK][Ops] choose_qparams ops skeleton test framework
Skeleton framework that is needed to build out the choose_qparams.tensor and choose_qparams_per_token_asymmetric.default operators based on cpu implementation Differential Revision: [D76436870](https://our.internmc.facebook.com/intern/diff/D76436870/) ghstack-source-id: 289707204 Pull Request resolved: #11554
1 parent 9fb73e4 commit 8542680

File tree

2 files changed

+242
-0
lines changed

2 files changed

+242
-0
lines changed
Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <executorch/extension/aten_util/make_aten_functor_from_et_functor.h>
18+
#include <executorch/extension/kernel_util/make_boxed_from_unboxed_functor.h>
19+
20+
#include <cassert>
21+
#include <iostream>
22+
23+
namespace torch {
24+
namespace executor {
25+
namespace native {
26+
27+
// Forward declarations of the functions we're testing
28+
std::tuple<Tensor&, Tensor&> choose_qparams_tensor_out(
29+
const Tensor& input,
30+
int64_t quant_min,
31+
int64_t quant_max,
32+
ET_UNUSED double eps,
33+
ScalarType dtype,
34+
Tensor& scale_out,
35+
Tensor& zero_point_out);
36+
37+
std::tuple<Tensor&, Tensor&> choose_qparams_per_token_asymmetric_out(
38+
const Tensor& input,
39+
ScalarType dtype,
40+
Tensor& scale_out,
41+
Tensor& zero_point_out);
42+
43+
// Wrapper function for choose_qparams_tensor_out without context
44+
Tensor& choose_qparams_tensor_out_no_context(
45+
const Tensor& input,
46+
int64_t quant_min,
47+
int64_t quant_max,
48+
ET_UNUSED double eps,
49+
ScalarType dtype,
50+
Tensor& scale_out,
51+
Tensor& zero_point_out) {
52+
torch::executor::native::choose_qparams_tensor_out(
53+
input, quant_min, quant_max, eps, dtype, scale_out, zero_point_out);
54+
return scale_out;
55+
}
56+
57+
// Wrapper function for choose_qparams_per_token_asymmetric_out without context
58+
Tensor& choose_qparams_per_token_asymmetric_out_no_context(
59+
const Tensor& input,
60+
ScalarType dtype,
61+
Tensor& scale_out,
62+
Tensor& zero_point_out) {
63+
torch::executor::native::choose_qparams_per_token_asymmetric_out(
64+
input, dtype, scale_out, zero_point_out);
65+
return scale_out;
66+
}
67+
68+
// ATen wrapper for choose_qparams_tensor
69+
std::tuple<at::Tensor, at::Tensor> choose_qparams_tensor_aten(
70+
const at::Tensor& input,
71+
int64_t quant_min,
72+
int64_t quant_max,
73+
at::ScalarType dtype) {
74+
auto scale_out = at::empty({}, at::device(at::kCPU).dtype(at::kDouble));
75+
auto zero_point_out = at::empty({}, at::device(at::kCPU).dtype(at::kLong));
76+
double eps = 1e-7;
77+
78+
// Convert at::ScalarType to executorch::ScalarType
79+
ScalarType et_dtype;
80+
switch (dtype) {
81+
case at::kByte:
82+
et_dtype = ScalarType::Byte;
83+
break;
84+
case at::kChar:
85+
et_dtype = ScalarType::Char;
86+
break;
87+
case at::kShort:
88+
et_dtype = ScalarType::Short;
89+
break;
90+
case at::kInt:
91+
et_dtype = ScalarType::Int;
92+
break;
93+
case at::kLong:
94+
et_dtype = ScalarType::Long;
95+
break;
96+
case at::kFloat:
97+
et_dtype = ScalarType::Float;
98+
break;
99+
case at::kDouble:
100+
et_dtype = ScalarType::Double;
101+
break;
102+
default:
103+
throw std::runtime_error("Unsupported dtype");
104+
}
105+
106+
// Use WRAP_TO_ATEN with the wrapper function
107+
WRAP_TO_ATEN(choose_qparams_tensor_out_no_context, 5)
108+
(input, quant_min, quant_max, eps, et_dtype, scale_out, zero_point_out);
109+
110+
return {scale_out, zero_point_out};
111+
}
112+
113+
// ATen wrapper for choose_qparams_per_token_asymmetric
114+
std::tuple<at::Tensor, at::Tensor> choose_qparams_per_token_asymmetric_aten(
115+
const at::Tensor& input,
116+
at::ScalarType dtype) {
117+
// Calculate output sizes for scale and zero_point tensors
118+
std::vector<int64_t> output_sizes;
119+
for (int64_t i = 0; i < input.dim() - 1; i++) {
120+
output_sizes.push_back(input.size(i));
121+
}
122+
output_sizes.push_back(1);
123+
124+
auto scale_out =
125+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kDouble));
126+
auto zero_point_out =
127+
at::empty(output_sizes, at::device(at::kCPU).dtype(at::kLong));
128+
129+
// Convert at::ScalarType to executorch::ScalarType
130+
ScalarType et_dtype;
131+
switch (dtype) {
132+
case at::kByte:
133+
et_dtype = ScalarType::Byte;
134+
break;
135+
case at::kChar:
136+
et_dtype = ScalarType::Char;
137+
break;
138+
case at::kShort:
139+
et_dtype = ScalarType::Short;
140+
break;
141+
case at::kInt:
142+
et_dtype = ScalarType::Int;
143+
break;
144+
case at::kLong:
145+
et_dtype = ScalarType::Long;
146+
break;
147+
case at::kFloat:
148+
et_dtype = ScalarType::Float;
149+
break;
150+
case at::kDouble:
151+
et_dtype = ScalarType::Double;
152+
break;
153+
default:
154+
throw std::runtime_error("Unsupported dtype");
155+
}
156+
157+
// Use WRAP_TO_ATEN with the wrapper function
158+
WRAP_TO_ATEN(choose_qparams_per_token_asymmetric_out_no_context, 2)
159+
(input, et_dtype, scale_out, zero_point_out);
160+
161+
return {scale_out, zero_point_out};
162+
}
163+
164+
} // namespace native
165+
} // namespace executor
166+
} // namespace torch
167+
168+
//
169+
// Test functions
170+
//
171+
172+
// Helper function to get the name of a ScalarType for better error messages
173+
std::string scalar_type_name(c10::ScalarType dtype) {
174+
switch (dtype) {
175+
case c10::kLong:
176+
return "c10::kLong";
177+
case c10::kShort:
178+
return "c10::kShort";
179+
case c10::kComplexHalf:
180+
return "c10::kComplexHalf";
181+
case c10::kComplexFloat:
182+
return "c10::kComplexFloat";
183+
case c10::kComplexDouble:
184+
return "c10::kComplexDouble";
185+
case c10::kBool:
186+
return "c10::kBool";
187+
case c10::kQInt8:
188+
return "c10::kQInt8";
189+
case c10::kQUInt8:
190+
return "c10::kQUInt8";
191+
case c10::kQInt32:
192+
return "c10::kQInt32";
193+
case c10::kBFloat16:
194+
return "c10::kBFloat16";
195+
case c10::kQUInt4x2:
196+
return "c10::kQUInt4x2";
197+
case c10::kQUInt2x4:
198+
return "c10::kQUInt2x4";
199+
default:
200+
return "Unknown(" + std::to_string(static_cast<int>(dtype)) + ")";
201+
}
202+
}
203+
204+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
205+
using namespace vkcompute;
206+
switch (at_scalartype) {
207+
case c10::kFloat:
208+
return vkapi::kFloat;
209+
case c10::kHalf:
210+
return vkapi::kHalf;
211+
case c10::kInt:
212+
return vkapi::kInt;
213+
case c10::kLong:
214+
// We don't have inherent vkapi::kLong, use kInt instead
215+
return vkapi::kInt;
216+
case c10::kChar:
217+
return vkapi::kChar;
218+
case c10::kByte:
219+
return vkapi::kByte;
220+
case c10::kDouble:
221+
return vkapi::kDouble;
222+
case c10::kShort:
223+
return vkapi::kShort;
224+
case c10::kUInt16:
225+
return vkapi::kUInt16;
226+
default:
227+
VK_THROW(
228+
"Unsupported at::ScalarType: ",
229+
scalar_type_name(at_scalartype),
230+
" (",
231+
static_cast<int>(at_scalartype),
232+
")");
233+
}
234+
}

backends/vulkan/test/op_tests/targets.bzl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,5 +164,13 @@ def define_common_targets(is_fbcode = False):
164164
"//executorch/extension/aten_util:aten_bridge",
165165
]
166166
)
167+
define_test_targets(
168+
"choose_qparams_test",
169+
extra_deps = [
170+
"//executorch/kernels/quantized/cpu:op_choose_qparams",
171+
"//executorch/extension/tensor:tensor",
172+
"//executorch/extension/aten_util:aten_bridge",
173+
]
174+
)
167175
define_test_targets("linear_weight_int4_test")
168176
define_test_targets("rotary_embedding_test")

0 commit comments

Comments
 (0)