Skip to content

Commit 122429f

Browse files
committed
chore: Apply linting
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 70a7bb3 commit 122429f

File tree

5 files changed

+17
-12
lines changed

5 files changed

+17
-12
lines changed

cpp/src/compile_spec.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,9 @@ CompileSpec::Input::Input(at::Tensor tensor) {
294294
this->shape = tensor.sizes().vec();
295295
this->dtype = tensor.scalar_type();
296296
this->explicit_set_dtype = true;
297-
TRTORCH_ASSERT(tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous), "Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
297+
TRTORCH_ASSERT(
298+
tensor.is_contiguous(at::MemoryFormat::ChannelsLast) || tensor.is_contiguous(at::MemoryFormat::Contiguous),
299+
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last");
298300
at::MemoryFormat frmt;
299301
if (tensor.is_contiguous(at::MemoryFormat::Contiguous)) {
300302
frmt = at::MemoryFormat::Contiguous;

py/trtorch/Input.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,13 @@ def _parse_format(format: Any) -> _types.TensorFormat:
199199

200200
@classmethod
201201
def _from_tensor(cls, t: torch.Tensor):
202-
if not any([t.is_contiguous(memory_format=torch.contiguous_format), t.is_contiguous(memory_format=torch.channels_last)]):
203-
raise ValueError("Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last")
204-
frmt = torch.contiguous_format if t.is_contiguous(memory_format=torch.contiguous_format) else torch.channels_last
205-
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
202+
if not any([
203+
t.is_contiguous(memory_format=torch.contiguous_format),
204+
t.is_contiguous(memory_format=torch.channels_last)
205+
]):
206+
raise ValueError(
207+
"Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last"
208+
)
209+
frmt = torch.contiguous_format if t.is_contiguous(
210+
memory_format=torch.contiguous_format) else torch.channels_last
211+
return cls(shape=t.shape, dtype=t.dtype, format=frmt)

py/trtorch/_compile_spec.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,8 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
175175

176176
if "inputs" in compile_spec:
177177
if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]):
178-
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format([typeof(i) for i in compile_spec["inputs"]]))
178+
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format(
179+
[typeof(i) for i in compile_spec["inputs"]]))
179180

180181
inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
181182
info.inputs = [i._to_internal() for i in inputs]

tests/cpp/test_example_tensors.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,3 @@ INSTANTIATE_TEST_SUITE_P(
2121
CompiledModuleForwardIsCloseSuite,
2222
CppAPITests,
2323
testing::Values(PathAndInSize({"tests/modules/resnet18_traced.jit.pt", {{1, 3, 224, 224}}, 2e-5})));
24-

tests/py/test_api.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,13 @@ def test_from_torch_tensor(self):
8888
self.assertTrue(same < 2e-2)
8989

9090
def test_device(self):
91-
compile_spec = {
92-
"inputs": [self.input],
93-
"device": trtorch.Device("gpu:0"),
94-
"enabled_precisions": {torch.float}
95-
}
91+
compile_spec = {"inputs": [self.input], "device": trtorch.Device("gpu:0"), "enabled_precisions": {torch.float}}
9692

9793
trt_mod = trtorch.compile(self.scripted_model, compile_spec)
9894
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
9995
self.assertTrue(same < 2e-2)
10096

97+
10198
class TestCompileHalf(ModelTestCase):
10299

103100
def setUp(self):

0 commit comments

Comments
 (0)