Skip to content

Commit 3d24313

Browse files
fadara01pytorchmergebot
authored andcommitted
Pass ideep:lowp_kind to matmul_forward::compute on cache misses (pytorch#135058)
Optimized dynamic quantization for aarch64 was enabled by pytorch#126687 and pytorch#134897 This PR fixes an issue for aarch64 where on a [cache miss](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp#L592) (e.g. if input dimensions change) [ideep::matmul_forward::compute ](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L160) (wrongly) runs with the [default lowp_kind (u8s8)](https://github.com/intel/ideep/blob/pytorch-rls-v3.5.3-2/include/ideep/operators/matmul.hpp#L174) which is not supported by oneDNN+ACL (Arm Compute Library), causing the workload to fall back to a much slower oneDNN gemm:jit kernel Example: ```python import torch DIM = 4096 INPUT_SIZE1 = 32 INPUT_SIZE2 = 16 class LinearNet(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(DIM, DIM, bias=False) def forward(self, x): x = self.fc1(x) return x input1 = torch.randn(size=(INPUT_SIZE1, DIM)) input2 = torch.randn(size=(INPUT_SIZE2, DIM)) with torch.no_grad(): model = LinearNet() model = torch.ao.quantization.quantize_dynamic(model,{torch.nn.Linear}) model(input1) # this goes to ACL lowp_gemm print("="*50) model(input2) # this goes to gemm:jit without this PR, and to ACL with this PR ``` In the code snippet above: - The matmul from `model(input1)` goes to oneDNN+ACL (in both cases, with and without the PR) - The matmul from `model(input2)`: **Without this PR**: there's a cache miss (different input shapes) and matmul_forward::compute is run with the default lowp_kind (u8s8). Hence the matmul falls back to gemm:jit in oneDNN. However, **With this PR** the matmul goes to oneDNN+ACL which is around 10x faster than oneDNN+jit. Pull Request resolved: pytorch#135058 Approved by: https://github.com/jondea, https://github.com/malfet
1 parent cd472bb commit 3d24313

File tree

1 file changed

+15
-4
lines changed

1 file changed

+15
-4
lines changed

aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,10 +590,21 @@ at::Tensor PackedLinearWeightsOnednn::apply_dynamic_impl(
590590
LinearParams& params = get_cache().get_param();
591591
ideep::matmul_forward::compute(params, x, w, b, y, src_scales, src_zero_point);
592592
} else {
593-
ideep::matmul_forward::compute(x, w, b, y,
594-
src_scales, weights_scales, ideep::scale_t(),
595-
src_zero_point, ideep::zero_point_t(),
596-
1.0f, 1.0f, op_attr);
593+
ideep::matmul_forward::compute(
594+
x,
595+
w,
596+
b,
597+
y,
598+
src_scales,
599+
weights_scales,
600+
ideep::scale_t(),
601+
src_zero_point,
602+
ideep::zero_point_t(),
603+
1.0f,
604+
1.0f,
605+
op_attr,
606+
ideep::tensor::data_type::undef,
607+
std::is_signed_v<input_qtype> ? ideep::s8s8 : ideep::u8s8);
597608
}
598609
auto out_sizes = input.sizes().vec();
599610
out_sizes.back() = w.get_dim(1);

0 commit comments

Comments
 (0)