Skip to content

Commit ec51f1d

Browse files
authored
fix: Torch nightly version constraint (#2546)
1 parent 20264a3 commit ec51f1d

File tree

4 files changed

+12
-9
lines changed

4 files changed

+12
-9
lines changed

py/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ numpy
22
packaging
33
pybind11==2.6.2
44
--extra-index-url https://download.pytorch.org/whl/nightly/cu121
5-
torch>=2.2.0.dev,<2.3.0
5+
torch>=2.2.0.dev,<=2.3.0
66
torchvision>=0.17.0.dev,<0.18.0
77
--extra-index-url https://pypi.ngc.nvidia.com
88
tensorrt==8.6.1

pyproject.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ requires = [
99
"typing-extensions>=4.7.0",
1010
"future>=0.18.3",
1111
"tensorrt>=8.6,<8.7",
12-
"torch >=2.2.0.dev,<2.3.0",
12+
"torch >=2.2.0.dev,<=2.3.0",
1313
#"torch==2.1.0.dev20230731",
1414
"pybind11==2.6.2",
1515
"numpy",
@@ -42,7 +42,7 @@ readme = {file = "py/README.md", content-type = "text/markdown"}
4242
requires-python = ">=3.8"
4343
keywords = ["pytorch", "torch", "tensorrt", "trt", "ai", "artificial intelligence", "ml", "machine learning", "dl", "deep learning", "compiler", "dynamo", "torchscript", "inference"]
4444
dependencies = [
45-
"torch >=2.2.0.dev,<2.3.0",
45+
"torch >=2.2.0.dev,<=2.3.0",
4646
#"torch==2.1.0.dev20230731",
4747
"tensorrt>=8.6,<8.7",
4848
"packaging>=23",

tests/modules/custom_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from typing import Dict, List, Tuple
2+
13
import torch
24
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, BertConfig
45
import torch.nn.functional as F
5-
from typing import Tuple, List, Dict
6+
from transformers import BertConfig, BertModel, BertTokenizer
67

78

89
# Sample Pool Model (for testing plugin serialization)
@@ -180,10 +181,12 @@ def BertModule():
180181
num_hidden_layers=12,
181182
num_attention_heads=12,
182183
intermediate_size=3072,
184+
use_cache=False,
185+
output_attentions=False,
186+
output_hidden_states=False,
183187
torchscript=True,
184188
)
185-
model = BertModel(config)
189+
model = BertModel.from_pretrained(model_name, config=config)
186190
model.eval()
187-
model = BertModel.from_pretrained(model_name, torchscript=True)
188191
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
189192
return traced_model

tests/modules/requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
timm==v0.9.2
2-
transformers==4.30.0
1+
timm
2+
transformers
33
torchvision

0 commit comments

Comments
 (0)