Skip to content

Commit 8299fe3

Browse files
Michael Gschwindfacebook-github-bot
authored andcommitted
Add mixed dtype linear (#2023)
Summary: Pull Request resolved: #2023 Add mixed dtype linear Reviewed By: mavlyutovr, manuelcandales Differential Revision: D53995591 fbshipit-source-id: f68e2fd1254cb3717f2276eef9375c944cb99d60
1 parent 57e192b commit 8299fe3

File tree

5 files changed

+192
-1
lines changed

5 files changed

+192
-1
lines changed

examples/models/llama2/ops/quantized.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,9 @@
33
kernels:
44
- arg_meta: null
55
kernel_name: torch::executor::quantized_embedding_byte_out
6+
7+
- func: quantized_decomposed::mixed_linear.out(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
8+
variants: function
9+
kernels:
10+
- arg_meta: null
11+
kernel_name: torch::executor::quantized_mixed_linear_out

kernels/portable/cpu/vec_ops.h

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,12 @@
1010

1111
#include <algorithm>
1212
#include <cmath>
13+
#include <cstdint>
1314
#include <cstring>
15+
#include <iostream>
1416
#include <numeric>
17+
#include <ostream>
1518
#include <type_traits>
16-
1719
/**
1820
* @file
1921
* This header defines common, low-level operations that can often be
@@ -103,6 +105,40 @@ inline void vec_quantized_matmul_int8(
103105
}
104106
}
105107

108+
static inline size_t bounds_min(size_t a, size_t b) {
109+
return (a < b) ? a : b;
110+
}
111+
112+
/// x: m * n, y: p * n, z: m * p, s: p * groups
113+
/// z[i][j] = sum(x[i][k] * y[j][k] * s[j][k/g])
114+
template <typename T, typename U = T, typename V = U>
115+
inline void vec_quantized_matmul_transb_int8(
116+
T* __restrict__ z,
117+
const U* __restrict__ x,
118+
const int8_t* __restrict__ y,
119+
const V* __restrict__ s,
120+
int64_t m,
121+
int64_t n,
122+
int64_t p,
123+
int64_t g) {
124+
int64_t n_over_g = (n + g - 1) / g;
125+
126+
for (size_t i = 0; i < m; ++i) {
127+
for (size_t j = 0; j < p; ++j) {
128+
T sum = 0;
129+
for (size_t k = 0; k < n; k += g) {
130+
T psum = 0;
131+
// the last group may have fewer than g elements
132+
for (size_t k2 = k; k2 < bounds_min(k + g, n); k2++) {
133+
psum += x[i * n + k2] * y[j * n + k2];
134+
}
135+
sum += psum * s[j * n_over_g + k / g];
136+
}
137+
z[i * p + j] = sum;
138+
}
139+
}
140+
}
141+
106142
// mat1 (m x n), mat2 (n x p), out (m, p), self (m x p)
107143
// z[i][j] = sum(x[i][k] * y[k][j]), for k in range(n)
108144
// T for tensor dtype, U for scalar type
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/vec_ops.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
namespace torch {
13+
namespace executor {
14+
namespace native {
15+
16+
using Tensor = exec_aten::Tensor;
17+
18+
bool check_quantized_mixed_linear_args(
19+
const Tensor& in,
20+
const Tensor& weight,
21+
const Tensor& weight_scales,
22+
const optional<Tensor>& opt_weight_zero_points,
23+
const optional<ScalarType> dtype,
24+
Tensor& out) {
25+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(in, 2));
26+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight, 2));
27+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(weight_scales, 1));
28+
ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(out, 2));
29+
30+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
31+
ET_LOG_AND_RETURN_IF_FALSE(
32+
tensors_have_same_size_at_dims(weight_scales, 0, weight, 0));
33+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_size_at_dims(in, 1, weight, 1));
34+
35+
ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, weight_scales));
36+
if (dtype.has_value()) {
37+
ET_LOG_AND_RETURN_IF_FALSE(out.scalar_type() == dtype.value());
38+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
39+
dtype.value() == ScalarType::Float || dtype.value() == ScalarType::Half,
40+
"dtype must be Float or Half");
41+
}
42+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
43+
weight.scalar_type() == ScalarType::Char, "weight dtype must be int8");
44+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
45+
in.scalar_type() == ScalarType::Float ||
46+
in.scalar_type() == ScalarType::Half,
47+
"input dtype must be Float or Half");
48+
49+
if (opt_weight_zero_points.has_value()) {
50+
ET_LOG_AND_RETURN_IF_FALSE(
51+
tensors_have_same_shape(opt_weight_zero_points.value(), weight_scales));
52+
ET_LOG_AND_RETURN_IF_FALSE(
53+
tensors_have_same_dtype(opt_weight_zero_points.value(), in));
54+
}
55+
56+
// Support for non-null zero points is not implemented yet.
57+
ET_LOG_MSG_AND_RETURN_IF_FALSE(
58+
!opt_weight_zero_points.has_value(), "zero points not supported yet.");
59+
return true;
60+
}
61+
62+
Tensor& quantized_mixed_linear_out(
63+
const Tensor& in,
64+
const Tensor& weight,
65+
const Tensor& weight_scales,
66+
const optional<Tensor>& opt_weight_zero_points,
67+
const optional<ScalarType> dtype,
68+
Tensor& out) {
69+
ET_KERNEL_CHECK(
70+
ctx,
71+
check_quantized_mixed_linear_args(
72+
in, weight, weight_scales, opt_weight_zero_points, dtype, out),
73+
InvalidArgument,
74+
out);
75+
76+
ScalarType out_dtype = dtype.has_value() ? dtype.value() : out.scalar_type();
77+
78+
size_t output_ndim = 2;
79+
exec_aten::SizesType output_sizes[kTensorDimensionLimit];
80+
output_sizes[0] = in.size(0);
81+
output_sizes[1] = weight.size(0);
82+
83+
ET_KERNEL_CHECK(
84+
ctx,
85+
resize_tensor(out, {output_sizes, output_ndim}) == Error::Ok,
86+
InvalidArgument,
87+
out);
88+
89+
constexpr auto name = "quantized_decomposed::mixed_linear.out";
90+
91+
ET_SWITCH_TWO_TYPES(Float, Half, in.scalar_type(), ctx, name, CTYPE, [&]() {
92+
ET_SWITCH_FLOAT_TYPES_AND(Half, out_dtype, ctx, name, CTYPE_OUT, [&]() {
93+
size_t m = in.size(0);
94+
size_t n = in.size(1);
95+
size_t p = weight.size(0);
96+
size_t g = n;
97+
98+
if (weight_scales.dim() == 2) {
99+
g = (n + weight_scales.size(1) - 1) / weight_scales.size(1);
100+
};
101+
102+
// FIXME: this currently ignores dtype
103+
vec_quantized_matmul_transb_int8<
104+
CTYPE_OUT, // T *z
105+
CTYPE>( // U *x, U *s
106+
out.mutable_data_ptr<CTYPE_OUT>(),
107+
in.const_data_ptr<CTYPE>(),
108+
weight.const_data_ptr<int8_t>(),
109+
weight_scales.const_data_ptr<CTYPE>(),
110+
m,
111+
n,
112+
p,
113+
g);
114+
});
115+
});
116+
117+
return out;
118+
}
119+
120+
Tensor& quantized_mixed_linear_out(
121+
RuntimeContext& ctx,
122+
const Tensor& in,
123+
const Tensor& weight,
124+
const Tensor& weight_scales,
125+
const optional<Tensor>& opt_weight_zero_points,
126+
const optional<ScalarType> dtype,
127+
Tensor& out) {
128+
// TODO(mcandales): Remove the need for this wrapper
129+
// TODO(mkg): add support for dtype
130+
(void)ctx;
131+
return quantized_mixed_linear_out(
132+
in, weight, weight_scales, opt_weight_zero_points, dtype, out);
133+
}
134+
135+
} // namespace native
136+
} // namespace executor
137+
} // namespace torch

kernels/quantized/cpu/targets.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ _QUANT_OPS = (
2929
"//executorch/kernels/portable/cpu:vec_ops",
3030
],
3131
),
32+
op_target(
33+
name = "op_mixed_linear",
34+
deps = [
35+
"//executorch/kernels/portable/cpu:vec_ops",
36+
],
37+
),
3238
op_target(
3339
name = "op_quantize",
3440
deps = [

kernels/quantized/quantized.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,12 @@
4646
- arg_meta: null
4747
kernel_name: torch::executor::quantized_mixed_mm_out
4848

49+
- func: quantized_decomposed::mixed_linear.out(Tensor input, Tensor weight, Tensor weight_scales, Tensor? weight_zero_points, ScalarType? dtype=None, *, Tensor(a!) out) -> Tensor(a!)
50+
variants: function
51+
kernels:
52+
- arg_meta: null
53+
kernel_name: torch::executor::quantized_mixed_linear_out
54+
4955
- func: quantized_decomposed::quantize_per_tensor.out(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype, *, Tensor(a!) out) -> Tensor(a!)
5056
variants: function
5157
kernels:

0 commit comments

Comments
 (0)