Skip to content

Commit b446c93

Browse files
committed
fix: Torch nightly version 2
1 parent 45cbcd9 commit b446c93

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

tests/modules/custom_models.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
from typing import Dict, List, Tuple
2+
13
import torch
24
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, BertConfig
45
import torch.nn.functional as F
5-
from typing import Tuple, List, Dict
6+
from transformers import BertConfig, BertModel, BertTokenizer
67

78

89
# Sample Pool Model (for testing plugin serialization)
@@ -174,16 +175,19 @@ def BertModule():
174175
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
175176
tokens_tensor = torch.tensor([indexed_tokens])
176177
segments_tensors = torch.tensor([segments_ids])
177-
config = BertConfig(
178-
vocab_size_or_config_json_file=32000,
179-
hidden_size=768,
180-
num_hidden_layers=12,
181-
num_attention_heads=12,
182-
intermediate_size=3072,
183-
torchscript=True,
184-
)
185-
model = BertModel(config)
178+
179+
model_kwargs = {
180+
"vocab_size_or_config_json_file": 32000,
181+
"hidden_size": 768,
182+
"num_hidden_layers": 12,
183+
"num_attention_heads": 12,
184+
"intermediate_size": 3072,
185+
"use_cache": False,
186+
"output_attentions": False,
187+
"output_hidden_states": False,
188+
"torchscript": True,
189+
}
190+
model = BertModel.from_pretrained(model_name, **model_kwargs)
186191
model.eval()
187-
model = BertModel.from_pretrained(model_name, torchscript=True)
188192
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
189193
return traced_model

0 commit comments

Comments
 (0)