@@ -20,15 +20,15 @@ int main(int argc, const char* argv[]) {
20
20
21
21
torch::jit::script::Module mod;
22
22
try {
23
- // Deserialize the ScriptModule from a file using torch::jit::load().
23
+ // / Deserialize the ScriptModule from a file using torch::jit::load().
24
24
mod = torch::jit::load (argv[1 ]);
25
25
}
26
26
catch (const c10::Error& e) {
27
27
std::cerr << " error loading the model\n " ;
28
28
return -1 ;
29
29
}
30
30
31
- // Create the calibration dataset
31
+ // / Create the calibration dataset
32
32
const std::string data_dir = std::string (argv[2 ]);
33
33
auto calibration_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
34
34
.use_subset (320 )
@@ -42,24 +42,23 @@ int main(int argc, const char* argv[]) {
42
42
std::string calibration_cache_file = " /tmp/vgg16_TRT_ptq_calibration.cache" ;
43
43
44
44
auto calibrator = trtorch::ptq::make_int8_calibrator (std::move (calibration_dataloader), calibration_cache_file, true );
45
- // auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file);
46
45
47
46
48
47
std::vector<std::vector<int64_t >> input_shape = {{32 , 3 , 32 , 32 }};
49
- // Configure settings for compilation
48
+ // / Configure settings for compilation
50
49
auto extra_info = trtorch::ExtraInfo ({input_shape});
51
- // Set operating precision to INT8
50
+ // / Set operating precision to INT8
52
51
extra_info.op_precision = torch::kI8 ;
53
- // Use the TensorRT Entropy Calibrator
52
+ // / Use the TensorRT Entropy Calibrator
54
53
extra_info.ptq_calibrator = calibrator;
55
- // Set max batch size for the engine
54
+ // / Set max batch size for the engine
56
55
extra_info.max_batch_size = 32 ;
57
- // Set a larger workspace
56
+ // / Set a larger workspace
58
57
extra_info.workspace_size = 1 << 28 ;
59
58
60
59
mod.eval ();
61
60
62
- // Dataloader moved into calibrator so need another for inference
61
+ // / Dataloader moved into calibrator so need another for inference
63
62
auto eval_dataset = datasets::CIFAR10 (data_dir, datasets::CIFAR10::Mode::kTest )
64
63
.map (torch::data::transforms::Normalize<>({0.4914 , 0.4822 , 0.4465 },
65
64
{0.2023 , 0.1994 , 0.2010 }))
@@ -68,7 +67,7 @@ int main(int argc, const char* argv[]) {
68
67
.batch_size (32 )
69
68
.workers (2 ));
70
69
71
- // Check the FP32 accuracy in JIT
70
+ // / Check the FP32 accuracy in JIT
72
71
float correct = 0.0 , total = 0.0 ;
73
72
for (auto batch : *eval_dataloader) {
74
73
auto images = batch.data .to (torch::kCUDA );
@@ -82,19 +81,19 @@ int main(int argc, const char* argv[]) {
82
81
}
83
82
std::cout << " Accuracy of JIT model on test set: " << 100 * (correct / total) << " %" << std::endl;
84
83
85
- // Compile Graph
84
+ // / Compile Graph
86
85
std::cout << " Compiling and quantizing module" << std::endl;
87
86
auto trt_mod = trtorch::CompileGraph (mod, extra_info);
88
87
89
- // Check the INT8 accuracy in TRT
88
+ // / Check the INT8 accuracy in TRT
90
89
correct = 0.0 ;
91
90
total = 0.0 ;
92
91
for (auto batch : *eval_dataloader) {
93
92
auto images = batch.data .to (torch::kCUDA );
94
93
auto targets = batch.target .to (torch::kCUDA );
95
94
96
95
if (images.sizes ()[0 ] < 32 ) {
97
- // To handle smaller batches util Optimization profiles work with Int8
96
+ // / To handle smaller batches util Optimization profiles work with Int8
98
97
auto diff = 32 - images.sizes ()[0 ];
99
98
auto img_padding = torch::zeros ({diff, 3 , 32 , 32 }, {torch::kCUDA });
100
99
auto target_padding = torch::zeros ({diff}, {torch::kCUDA });
@@ -107,7 +106,7 @@ int main(int argc, const char* argv[]) {
107
106
predictions = predictions.reshape (predictions.sizes ()[0 ]);
108
107
109
108
if (predictions.sizes ()[0 ] != targets.sizes ()[0 ]) {
110
- // To handle smaller batches util Optimization profiles work with Int8
109
+ // / To handle smaller batches util Optimization profiles work with Int8
111
110
predictions = predictions.slice (0 , 0 , targets.sizes ()[0 ]);
112
111
}
113
112
@@ -116,7 +115,7 @@ int main(int argc, const char* argv[]) {
116
115
}
117
116
std::cout << " Accuracy of quantized model on test set: " << 100 * (correct / total) << " %" << std::endl;
118
117
119
- // Time execution in INT8
118
+ // / Time execution in JIT-FP32 and TRT- INT8
120
119
auto execution_timer = timers::PreciseCPUTimer ();
121
120
auto images = (*(*eval_dataloader).begin ()).data .to (torch::kCUDA );
122
121
0 commit comments