Skip to content

Commit 9d3e53f

Browse files
committed
fix: Torch nightly version 2
1 parent 45cbcd9 commit 9d3e53f

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

tests/modules/custom_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from typing import Dict, List, Tuple
2+
13
import torch
24
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, BertConfig
45
import torch.nn.functional as F
5-
from typing import Tuple, List, Dict
6+
from transformers import BertConfig, BertModel, BertTokenizer
67

78

89
# Sample Pool Model (for testing plugin serialization)
@@ -180,10 +181,12 @@ def BertModule():
180181
num_hidden_layers=12,
181182
num_attention_heads=12,
182183
intermediate_size=3072,
184+
use_cache=False,
185+
output_attentions=False,
186+
output_hidden_states=False,
183187
torchscript=True,
184188
)
185-
model = BertModel(config)
189+
model = BertModel.from_pretrained(model_name, config=config)
186190
model.eval()
187-
model = BertModel.from_pretrained(model_name, torchscript=True)
188191
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
189192
return traced_model

tests/modules/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
timm==v0.9.2
2-
transformers==4.30.0
2+
transformers==4.33.2
33
torchvision

0 commit comments

Comments
 (0)