Skip to content

Commit c33ad16

Browse files
authored
Merge pull request #516 from NVIDIA/calib_data
fix: Transfer calibration data to gpu when the batch is not a list
2 parents a032c3a + 23739cb commit c33ad16

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

py/trtorch/ptq.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def get_batch(self, names):
3434
# Treat the first element as input and others as targets.
3535
if isinstance(batch, list):
3636
batch = batch[0].to(self.device)
37+
else:
38+
batch = batch.to(self.device)
3739
return [batch.data_ptr()]
3840

3941

0 commit comments

Comments
 (0)