Skip to content

Commit 2bf8dba

Browse files
authored
enable int8 LSTM on latest cpu-device (#692)
* int8 lstm graph rewrite * integrate oneDNN int8 lstm * use scale and zp of input to be those of output for lstm * add UT for int8 lstm * rename var to maybe_quantized_lstm * add assertion for input scalar type * only get input scalar type once * add doxygen spec for pack_qlstm_weight * add doxygen spec for quantized_lstm * use inline utils function to get scale and zero point of input and weight
1 parent c28e621 commit 2bf8dba

File tree

12 files changed

+597
-67
lines changed

12 files changed

+597
-67
lines changed

intel_extension_for_pytorch/csrc/aten/cpu/RNN.cpp

Lines changed: 280 additions & 46 deletions
Large diffs are not rendered by default.

intel_extension_for_pytorch/csrc/aten/cpu/RNN.h

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ class IPEXLSTMOp : public torch::autograd::Function<IPEXLSTMOp> {
4040
bool has_biases,
4141
bool bidirectional,
4242
bool batch_first,
43-
bool train);
43+
bool train,
44+
double scale,
45+
int64_t zp,
46+
int64_t dtype);
4447
static std::vector<at::Tensor> forward(
4548
torch::autograd::AutogradContext* ctx,
4649
const at::Tensor& input,
@@ -58,7 +61,10 @@ class IPEXLSTMOp : public torch::autograd::Function<IPEXLSTMOp> {
5861
bool has_biases,
5962
bool bidirectional,
6063
bool batch_first,
61-
bool train);
64+
bool train,
65+
double scale,
66+
int64_t zp,
67+
int64_t dtype);
6268

6369
static torch::autograd::tensor_list backward(
6470
torch::autograd::AutogradContext* ctx,
@@ -81,7 +87,10 @@ std::vector<at::Tensor> ipex_lstm_layer(
8187
bool has_biases,
8288
bool bidirectional,
8389
bool batch_first,
84-
bool train);
90+
bool train,
91+
double scale,
92+
int64_t zp,
93+
int64_t dtype);
8594

8695
std::vector<at::Tensor> ipex_lstm_layer_backward(
8796
const at::Tensor& input,
@@ -124,18 +133,9 @@ std::vector<at::Tensor> ipex_lstm_layer_forward(
124133
bool has_biases,
125134
bool bidirectional,
126135
bool batch_first,
127-
bool train);
128-
129-
static std::tuple<at::Tensor, at::Tensor, at::Tensor> ipex_lstm(
130-
const at::Tensor& input,
131-
std::vector<at::Tensor> hx,
132-
std::vector<at::Tensor> params,
133-
bool has_biases,
134-
int64_t num_layers,
135-
double dropout_p,
136136
bool train,
137-
bool bidirectional,
138-
bool batch_first);
139-
137+
double scale,
138+
int64_t zp,
139+
int64_t dtype);
140140
} // namespace cpu
141141
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/cpu/ideep/IDeepConversions.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,10 @@ ideep::tensor itensor_view_from_dense(
9494
"itensor_view_from_dense expects dense tensor input");
9595
TORCH_CHECK(
9696
tensor.scalar_type() == at::ScalarType::Float ||
97-
tensor.scalar_type() == at::ScalarType::BFloat16,
98-
"itensor_view_from_dense expects float or bfloat16 tensor input");
97+
tensor.scalar_type() == at::ScalarType::BFloat16 ||
98+
tensor.scalar_type() == at::ScalarType::QInt8 ||
99+
tensor.scalar_type() == at::ScalarType::QUInt8,
100+
"itensor_view_from_dense expects float, bfloat16 or int8 tensor input");
99101
return {desc, tensor.data_ptr()};
100102
}
101103

intel_extension_for_pytorch/csrc/cpu/ideep/ideep/operators/lstm.hpp

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
1818
tensor& dst_iter_c,
1919
const bool reverse = false,
2020
const prop_kind aprop = prop_kind::forward_inference,
21+
const float scale = -1.,
22+
const int32_t zp = -1,
23+
const int weights_scale_mask = -1,
24+
const std::vector<float>& weights_scales = scale_t(),
2125
const engine& aengine = engine::cpu_engine()) {
2226
auto direction = reverse ? rnn_direction::unidirectional_right2left
2327
: rnn_direction::unidirectional_left2right;
@@ -30,13 +34,21 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
3034
auto weights_layer_desc = weights_layer.get_desc().to_format_any();
3135
auto weights_iter_desc = weights_iter.get_desc().to_format_any();
3236

37+
attr_t op_attr;
38+
if (src_layer.get_data_type() == data_type::u8) {
39+
weights_layer_desc = weights_layer_desc.to_type(data_type::s8);
40+
weights_iter_desc = weights_iter_desc.to_type(data_type::s8);
41+
42+
op_attr.set_rnn_data_qparams(scale, zp);
43+
op_attr.set_rnn_weights_qparams(weights_scale_mask, weights_scales);
44+
}
45+
3346
auto bias_desc = bias.get_desc();
3447
auto dst_layer_desc = dst_layer.get_desc();
3548
auto dst_iter_desc = dst_iter.get_desc();
3649
auto dst_iter_c_desc = dst_iter_c.get_desc();
3750

3851
// Use user mode scratchpad
39-
auto op_attr = dnnl::primitive_attr();
4052
op_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
4153

4254
auto pd = primitive_desc(
@@ -55,9 +67,9 @@ struct lstm_forward_inference : public dnnl::lstm_forward {
5567
aengine);
5668

5769
auto expected_weights_layer =
58-
weights_layer.reorder_if_differ_in(pd.weights_layer_desc());
70+
weights_layer.reorder_if_differ_in(pd.weights_layer_desc(), op_attr);
5971
auto expected_weights_iter =
60-
weights_iter.reorder_if_differ_in(pd.weights_iter_desc());
72+
weights_iter.reorder_if_differ_in(pd.weights_iter_desc(), op_attr);
6173
tensor scratchpad(pd.scratchpad_desc());
6274

6375
super(pd).execute(

intel_extension_for_pytorch/csrc/jit/codegen/onednn/quantization_patterns.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include <torch/csrc/jit/ir/ir.h>
22
#include <torch/csrc/jit/ir/subgraph_matcher.h>
3+
#include <torch/csrc/jit/jit_log.h>
34
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
45
#include <string>
56
#include "csrc/jit/cpu/passes/graph_rewrite.h"
@@ -70,8 +71,11 @@ void IpexQuantFusion(std::shared_ptr<Graph>& graph) {
7071
rewriter.RegisterRewritePattern(info.pattern, info.replacement);
7172
rewriter.runOnGraph(graph, info.filters);
7273
}
74+
GRAPH_DUMP("Before IpexQuantFusion", graph);
7375
graph_rewrite::replaceEmbeddingBagWithQEmbeddingBag(graph);
7476
graph_rewrite::replaceInteractionWithQInteraction(graph);
77+
graph_rewrite::replaceLstmWithQLstm(graph);
78+
GRAPH_DUMP("After IpexQuantFusion", graph);
7579
}
7680

7781
} // namespace jit
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
#pragma once
2+
3+
#include <ATen/Tensor.h>
4+
5+
#include <c10/core/Scalar.h>
6+
#include <torch/csrc/jit/runtime/custom_operator.h>
7+
8+
#include "csrc/cpu/ideep/ideep.hpp"
9+
10+
namespace torch_ipex {
11+
namespace cpu {
12+
13+
//! function: quantized_lstm
14+
/*!
15+
*
16+
* Compute a quantized LSTM for INT8 input, INT8 weight and FP32 initial hidden
17+
and cell states which
18+
* returns INT8 ouput along with FP32 final hidden and cell states.
19+
* \param input: INT8 tensor of shape :math:`(L, N, H_{in})` when
20+
``batch_first=False`` or
21+
* :math:`(N, L, H_{in})` when ``batch_first=True`` containing the
22+
features of
23+
* the input sequence.
24+
* \param hx: list of FP32 initial hidden state and cell state:
25+
* hx[0]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
26+
containing the initial hidden
27+
* state for the input sequence batch .
28+
* hx[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
29+
containing the initial cell
30+
* state for the input sequence batch .
31+
* \param weights: List of INT8 weights and FP32 biases.
32+
* \param has_biases: If ``False``, then the layer does not use bias weights
33+
`b_ih` and `b_hh`.
34+
* \param num_layers: the number of layers of LSTM.
35+
* \param dropout_p: If non-zero, introduces a `Dropout` layer on the outputs of
36+
each RNN layer except the last layer, with dropout probability equal to
37+
:attr:`dropout` when the model is in training state.
38+
* \param train: whether the model is in training state.
39+
* \param bidirectional: If ``True``, becomes a bidirectional LSTM.
40+
* \param batch_first: If ``True``, then the input and output tensors are
41+
provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. Note
42+
that this does not apply to hidden or cell states.
43+
* \param scale: the calibration scale of the output in double.
44+
* \param zp: the calibration zero point of the output in int64_t.
45+
* \param dtype: the calibration data type of the output.
46+
* \return: tuple of output tensors:
47+
* output[0]: INT8 tensor of shape :math:`(L, N, D * H_{out})` when
48+
``batch_first=False`` or :math:`(N, L, D * H_{out})` when ``batch_first=True``
49+
containing the output features
50+
`(h_t)` from the last layer of the RNN, for each `t`.
51+
* output[1]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
52+
containing the final hidden state for each element in the batch.
53+
* output[2]: FP32 tensor of shape :math:`(D * \text{num\_layers}, N, H_{out})`
54+
containing the final cell state for each element in the batch.
55+
where:
56+
57+
.. math::
58+
\begin{aligned}
59+
N ={} & \text{batch size} \\
60+
L ={} & \text{sequence length} \\
61+
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
62+
H_{in} ={} & \text{input\_size} \\
63+
H_{out} ={} & \text{hidden\_size}
64+
\end{aligned}
65+
*/
66+
std::tuple<at::Tensor, at::Tensor, at::Tensor> quantized_lstm(
67+
const at::Tensor& input,
68+
c10::List<at::Tensor> hx,
69+
c10::List<at::Tensor> weights,
70+
bool has_biases,
71+
int64_t num_layers,
72+
double dropout_p,
73+
bool train,
74+
bool bidirectional,
75+
bool batch_first,
76+
double scale,
77+
int64_t zp,
78+
int64_t dtype);
79+
80+
} // namespace cpu
81+
} // namespace torch_ipex

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,78 @@ void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph) {
453453
}
454454
}
455455

456+
void replaceLstmWithQLstm(std::shared_ptr<Graph>& graph) {
457+
std::vector<std::string> patterns;
458+
std::vector<std::string> replacements;
459+
460+
for (auto* n : graph->block()->nodes()) {
461+
if (n->kind() == aten::lstm) {
462+
std::string weight_pattern = "";
463+
std::vector<std::string> ListConstruct;
464+
std::vector<std::string> header;
465+
466+
size_t id = 0;
467+
auto weights_ListConstructNode = n->input(2)->node();
468+
469+
bool maybe_quantized_lstm = std::any_of(
470+
weights_ListConstructNode->inputs().begin(),
471+
weights_ListConstructNode->inputs().end(),
472+
[](auto& v) {
473+
return v->node()->kind() == Symbol::aten("dequantize");
474+
});
475+
476+
if (!maybe_quantized_lstm)
477+
return;
478+
479+
for (auto input : weights_ListConstructNode->inputs()) {
480+
if (input->node()->kind() == Symbol::aten("dequantize")) {
481+
std::string dequant = "%dq_out_" + std::to_string(id) +
482+
" : Tensor = aten::dequantize(" + "%dq_in_" + std::to_string(id) +
483+
")";
484+
weight_pattern.append(dequant);
485+
486+
header.push_back("%dq_in_" + std::to_string(id));
487+
ListConstruct.push_back("%dq_out_" + std::to_string(id));
488+
} else {
489+
header.push_back("%bias_in_" + std::to_string(id));
490+
ListConstruct.push_back("%bias_in_" + std::to_string(id));
491+
}
492+
++id;
493+
}
494+
495+
std::string complete_header =
496+
"graph(%quantized_input, %h, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist, %scale, %zp, %dtype," +
497+
c10::Join(", ", header) + R"(
498+
): )";
499+
std::string complete_LC = "%weights = prim::ListConstruct(" +
500+
c10::Join(", ", ListConstruct) + ")";
501+
502+
std::string QLstmPattern = complete_header + R"(
503+
%input : Tensor = aten::dequantize(%quantized_input) )" +
504+
weight_pattern + complete_LC + R"(
505+
%output, %hy, %cy = aten::lstm(%input, %h, %weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist)
506+
%quantized_output = aten::quantize_per_tensor(%output, %scale, %zp, %dtype)
507+
return (%quantized_output, %hy, %cy) )";
508+
509+
std::string QLstmReplacement = complete_header + R"(
510+
%quantized_weights : Tensor[] = prim::ListConstruct( )" +
511+
c10::Join(", ", header) + R"(
512+
)
513+
%quantized_output, %hy, %cy = ipex::quantized_lstm(%quantized_input, %h, %quantized_weights, %has_biases, %num_layers, %dropout_p, %train, %bidirectional, %batch_fist, %scale, %zp, %dtype)
514+
return (%quantized_output, %hy, %cy) )";
515+
516+
patterns.push_back(QLstmPattern);
517+
replacements.push_back(QLstmReplacement);
518+
}
519+
}
520+
521+
SubgraphRewriter rewriter;
522+
for (size_t i = 0; i < patterns.size(); i++) {
523+
rewriter.RegisterRewritePattern(patterns[i], replacements[i]);
524+
rewriter.runOnGraph(graph);
525+
}
526+
}
527+
456528
void fuseBmmAdd(std::shared_ptr<Graph>& graph) {
457529
std::array<std::string, 2> add_operators = {"add", "add_"};
458530

intel_extension_for_pytorch/csrc/jit/cpu/passes/graph_rewrite.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ void replaceAtenBatchNormWithIpexBatchNorm(std::shared_ptr<Graph>& graph);
3636
void replaceAtenLayerNormWithIpexLayerNorm(std::shared_ptr<Graph>& graph);
3737
void replaceEmbeddingBagWithQEmbeddingBag(std::shared_ptr<Graph>& graph);
3838
void replaceInteractionWithQInteraction(std::shared_ptr<Graph>& graph);
39+
void replaceLstmWithQLstm(std::shared_ptr<Graph>& graph);
3940

4041
void replaceFrozenIPEXConvWithAtenConv(std::shared_ptr<Graph>& graph);
4142
void insertPrePackedConvOp(std::shared_ptr<Graph>& graph);

intel_extension_for_pytorch/csrc/jit/cpu/passes/register_dnnl_jit_ops.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
#include "csrc/jit/cpu/kernels/MaxPool2D.h"
1717
#include "csrc/jit/cpu/kernels/Mha.h"
1818
#include "csrc/jit/cpu/kernels/OpContext.h"
19+
#include "csrc/jit/cpu/kernels/RNN.h"
1920
#include "csrc/jit/cpu/kernels/Shuffle.h"
2021
#include "csrc/jit/cpu/kernels/Softmax.h"
2122

@@ -767,6 +768,33 @@ RegisterOperators op({
767768
},
768769
aliasAnalysisFromSchema()),
769770

771+
Operator(
772+
"ipex::quantized_lstm(Tensor quantized_input, Tensor[] hx, Tensor [] quantized_weights, bool has_biases, int num_layers, float dropout_p, bool train, bool bidirectional, bool batch_first, float scale, int zp, int dtype) -> (Tensor, Tensor, Tensor)",
773+
[](const Node* node) -> Operation {
774+
return [](Stack* stack) {
775+
auto result = quantized_lstm(
776+
(std::move(peek(stack, 0, 12))).toTensor(),
777+
(std::move(peek(stack, 1, 12))).toTensorList(),
778+
(std::move(peek(stack, 2, 12))).toTensorList(),
779+
(std::move(peek(stack, 3, 12))).toBool(),
780+
(std::move(peek(stack, 4, 12))).toInt(),
781+
(std::move(peek(stack, 5, 12))).toDouble(),
782+
(std::move(peek(stack, 6, 12))).toBool(),
783+
(std::move(peek(stack, 7, 12))).toBool(),
784+
(std::move(peek(stack, 8, 12))).toBool(),
785+
(std::move(peek(stack, 9, 12))).toDouble(),
786+
(std::move(peek(stack, 10, 12))).toInt(),
787+
(std::move(peek(stack, 11, 12))).toInt());
788+
drop(stack, 12);
789+
790+
pack(stack, std::move(std::get<0>(result)));
791+
pack(stack, std::move(std::get<1>(result)));
792+
pack(stack, std::move(std::get<2>(result)));
793+
return 0;
794+
};
795+
},
796+
aliasAnalysisFromSchema()),
797+
770798
Operator(
771799
"ipex::shuffle_2d("
772800
" Tensor input,"

intel_extension_for_pytorch/csrc/quantization/AutoCast.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,8 @@ std::tuple<Tensor, Tensor, Tensor> lstm(
474474
auto w = at::cat({_params[i], _params[i + 1]}, 1);
475475
weights.push_back(w);
476476
}
477-
calibrate({input}, weights, {output}, "lstm", op_id, OP_TYPE_DEFAULT);
477+
// oneDNN LSTM: input and output share the same scale and zero_point
478+
calibrate({input}, weights, {input}, "lstm", op_id, OP_TYPE_DEFAULT);
478479
return std::make_tuple(output, hy, cy);
479480
}
480481
params p = get_params(op_id);
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#pragma once
2+
3+
#include <ATen/ATen.h>
4+
5+
namespace torch_ipex {
6+
namespace int8 {
7+
namespace utils {
8+
9+
inline std::tuple<double, int64_t> get_mkldnn_input_scale_zp(
10+
const at::Tensor& input) {
11+
TORCH_CHECK(
12+
input.qscheme() == c10::QScheme::PER_TENSOR_AFFINE,
13+
"should use per_tensor_affine quantization for input of LSTM");
14+
15+
double scale = input.q_scale();
16+
17+
// PyTorch scale: (max - min) / (qmax - qmin)
18+
// oneDNN scale: (qmax - qmin) / (max - min)
19+
double mkldnn_scale = 1. / scale;
20+
21+
int64_t zp = input.q_zero_point();
22+
return std::make_tuple(mkldnn_scale, zp);
23+
}
24+
25+
inline at::Tensor get_weight_scale_tensor(const at::Tensor& weight) {
26+
TORCH_CHECK(
27+
weight.qscheme() == c10::QScheme::PER_CHANNEL_AFFINE,
28+
"should use per_channel_affine quantization for weight of LSTM");
29+
at::Tensor weight_scales_tensor = weight.q_per_channel_scales();
30+
TORCH_CHECK(
31+
weight_scales_tensor.dim() == 1,
32+
"expect weight_scales tensor to be 1d, got dim = ",
33+
weight_scales_tensor.dim());
34+
return weight_scales_tensor;
35+
}
36+
37+
} // namespace utils
38+
} // namespace int8
39+
} // namespace torch_ipex

0 commit comments

Comments
 (0)