Skip to content

Commit 54a24b3

Browse files
committed
fix(//cpp/ptq): Tracing model in eval mode wrecks accuracy in Libtorch
HACK: WYA tracing without being in eval mode and ignoring the warning, will follow up with the PyTorch Team and test after script mode support lands Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent cd24f26 commit 54a24b3

File tree

2 files changed

+131
-7
lines changed

2 files changed

+131
-7
lines changed

cpp/ptq/README.md

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,125 @@
11
# ptq
22

3+
## How to create your own PTQ application
4+
5+
Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain and track the activations in FP32 to calibrate a mapping to INT8 that minimizes the information loss between FP32 inference and INT8 inference.
6+
7+
Users writing TensorRT applications are required to setup a calibrator class which will provide sample data to the TensorRT calibrator. With TRTorch we look to leverage existing infrastructure in PyTorch to make implementing calibrators easier.
8+
9+
LibTorch provides a `Dataloader` and `Dataset` API which steamlines preprocessing and batching input data. TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly implement a `torch::Dataset` for your target domain, place it in a Dataloader and create a INT8 Calibrator from it which you can provide to TRTorch to run INT8 Calibration during compliation of your module.
10+
11+
### Code
12+
13+
Here is an example interface of a `torch::Dataset` class for CIFAR10:
14+
15+
```C++
16+
//cpp/ptq/datasets/cifar10.h
17+
#pragma once
18+
19+
#include "torch/data/datasets/base.h"
20+
#include "torch/data/example.h"
21+
#include "torch/types.h"
22+
23+
#include <cstddef>
24+
#include <string>
25+
26+
namespace datasets {
27+
// The CIFAR10 Dataset
28+
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
29+
public:
30+
// The mode in which the dataset is loaded
31+
enum class Mode { kTrain, kTest };
32+
33+
// Loads CIFAR10 from un-tarred file
34+
// Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz
35+
// Root path should be the directory that contains the content of tarball
36+
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);
37+
38+
// Returns the pair at index in the dataset
39+
torch::data::Example<> get(size_t index) override;
40+
41+
// The size of the dataset
42+
c10::optional<size_t> size() const override;
43+
44+
// The mode the dataset is in
45+
bool is_train() const noexcept;
46+
47+
// Returns all images stacked into a single tensor
48+
const torch::Tensor& images() const;
49+
50+
// Returns all targets stacked into a single tensor
51+
const torch::Tensor& targets() const;
52+
53+
// Trims the dataset to the first n pairs
54+
CIFAR10&& use_subset(int64_t new_size);
55+
56+
57+
private:
58+
Mode mode_;
59+
torch::Tensor images_, targets_;
60+
};
61+
} // namespace datasets
62+
```
63+
64+
This class's implementation reads from the binary distribution of the CIFAR10 dataset and builds two tensors which hold the images and labels.
65+
66+
Then we select a subset of the dataset to use for calibration, since we don't need the the full dataset for calibration and calibration does take time, then define the preprocessing to apply to the images in the dataset and create a Dataloader from the dataset which will batch the data:
67+
68+
```C++
69+
auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest)
70+
.use_subset(320)
71+
.map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465},
72+
{0.2023, 0.1994, 0.2010}))
73+
.map(torch::data::transforms::Stack<>());
74+
auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset),
75+
torch::data::DataLoaderOptions().batch_size(32)
76+
.workers(2));
77+
```
78+
79+
Next we create a calibrator from the `calibration_dataloader` using the calibrator factory:
80+
81+
```C++
82+
auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true);
83+
84+
```
85+
86+
Here we also define a location to write a calibration cache file to which we can use to reuse the calibration data without needing the dataset and whether or not we should use the cache file if it exists. There also exists a `trtorch::ptq::make_int8_cache_calibrator` factory which creates a calibrator that uses the cache only for cases where you may do engine building on a machine that has limited storage (i.e. no space for a dataset) or to have a simpiler deployment application.
87+
88+
The calibrator factories create a calibrator that inherits from a `nvinfer1::IInt8Calibrator` virtual class (`nvinfer1::IInt8EntropyCalibrator2` by default) which defines the calibration algorithm used when calibrating. You can explicitly make the selection of calibration algorithm like this:
89+
90+
```C++
91+
// MinMax Calibrator is geared more towards NLP tasks
92+
auto calibrator = trtorch::ptq::make_int8_calibrator<nvinfer1::IInt8MinMaxCalibrator>(std::move(calibration_dataloader), calibration_cache_file, true);
93+
```
94+
95+
Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::ExtraInfo` struct and compiling the module:
96+
97+
```C++
98+
std::vector<std::vector<int64_t>> input_shape = {{32, 3, 32, 32}};
99+
/// Configure settings for compilation
100+
auto extra_info = trtorch::ExtraInfo({input_shape});
101+
/// Set operating precision to INT8
102+
extra_info.op_precision = torch::kI8;
103+
/// Use the TensorRT Entropy Calibrator
104+
extra_info.ptq_calibrator = calibrator;
105+
/// Set a larger workspace (you may get better performace from doing so)
106+
extra_info.workspace_size = 1 << 28;
107+
108+
auto trt_mod = trtorch::CompileGraph(mod, extra_info);
109+
```
110+
111+
If you have an existing Calibrator implementation for TensorRT you may directly set the `ptq_calibrator` field with a pointer to your calibrator and it will work as well.
112+
113+
From here not much changes in terms of how to execution works. You are still able to fully use Libtorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`.
114+
115+
116+
## Running the Example Application
117+
3118
This is a short example application that shows how to use TRTorch to perform post-training quantization for a module.
4119
5120
## Prerequisites
6121
7-
1. Download CIFAR10 Dataset Binary version
122+
1. Download CIFAR10 Dataset Binary version ([https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz))
8123
2. Train a network on CIFAR10 (see `training/` for a VGG16 recipie)
9124
3. Export model to torchscript
10125
@@ -26,6 +141,16 @@ bazel build //cpp/ptq --compilation_mode=dbg
26141
ptq <path-to-module> <path-to-cifar10>
27142
```
28143

144+
## Example Output
145+
146+
```
147+
Accuracy of JIT model on test set: 92.1%
148+
Compiling and quantizing module
149+
Accuracy of quantized model on test set: 91.0044%
150+
Latency of JIT model FP32 (Batch Size 32): 1.73497ms
151+
Latency of quantized model (Batch Size 32): 0.365737ms
152+
```
153+
29154
## Citations
30155

31156
```

cpp/ptq/training/vgg16/export_ckpt.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def test(model, dataloader, crit):
1818
loss = 0.0
1919
class_probs = []
2020
class_preds = []
21-
model.eval()
2221

2322
with torch.no_grad():
2423
for data, labels in dataloader:
@@ -54,9 +53,12 @@ def test(model, dataloader, crit):
5453
weights = new_state_dict
5554

5655
model.load_state_dict(weights)
57-
model.eval()
56+
57+
# Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team
58+
#model.eval()
5859

5960
jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda"))
61+
jit_model.eval()
6062

6163
testing_dataset = datasets.CIFAR10(root='./data', train=False, download=True,
6264
transform=transforms.Compose([
@@ -68,10 +70,7 @@ def test(model, dataloader, crit):
6870
shuffle=False, num_workers=2)
6971

7072
crit = torch.nn.CrossEntropyLoss()
71-
test_loss, test_acc = test(model, testing_dataloader, crit)
72-
print("[PTH] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
7373

74-
torch.jit.save(jit_model, "trained_vgg16.jit.pt")
75-
jit_model = torch.jit.load("trained_vgg16.jit.pt")
7674
test_loss, test_acc = test(jit_model, testing_dataloader, crit)
7775
print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc))
76+
torch.jit.save(jit_model, "trained_vgg16.jit.pt")

0 commit comments

Comments
 (0)