Skip to content

Commit d305e33

Browse files
committed
correcting ptq cases
1 parent 489e22d commit d305e33

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

tests/py/ts/ptq/test_ptq_dataloader_calibrator.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torchvision.transforms as transforms
99
from torch.nn import functional as F
1010
from torch_tensorrt.ts.logging import *
11+
import torch_tensorrt.ts.ptq as PTQ
1112

1213

1314
def find_repo_root(max_depth=10):
@@ -76,11 +77,11 @@ def test_compile_script(self):
7677
self.testing_dataloader = torch.utils.data.DataLoader(
7778
self.testing_dataset, batch_size=1, shuffle=False, num_workers=1
7879
)
79-
self.calibrator = torchtrt.ptq.DataLoaderCalibrator(
80+
self.calibrator = PTQ.DataLoaderCalibrator(
8081
self.testing_dataloader,
8182
cache_file="./calibration.cache",
8283
use_cache=False,
83-
algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
84+
algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2,
8485
device=torch.device("cuda:0"),
8586
)
8687

tests/py/ts/ptq/test_ptq_to_backend.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torchvision.transforms as transforms
99
from torch.nn import functional as F
1010
from torch_tensorrt.ts.logging import *
11+
import torch_tensorrt.ts.ptq as PTQ
1112

1213

1314
def find_repo_root(max_depth=10):
@@ -76,11 +77,11 @@ def test_compile_script(self):
7677
self.testing_dataloader = torch.utils.data.DataLoader(
7778
self.testing_dataset, batch_size=1, shuffle=False, num_workers=1
7879
)
79-
self.calibrator = torchtrt.ptq.DataLoaderCalibrator(
80+
self.calibrator = PTQ.DataLoaderCalibrator(
8081
self.testing_dataloader,
8182
cache_file="./calibration.cache",
8283
use_cache=False,
83-
algo_type=torchtrt.ptq.CalibrationAlgo.ENTROPY_CALIBRATION_2,
84+
algo_type=PTQ.CalibrationAlgo.ENTROPY_CALIBRATION_2,
8485
device=torch.device("cuda:0"),
8586
)
8687

0 commit comments

Comments
 (0)