Skip to content

Commit 8139da9

Browse files
committed
chore: Add BERT to the model set
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e63908b commit 8139da9

File tree

3 files changed

+32
-2
lines changed

3 files changed

+32
-2
lines changed

noxfile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def install_deps(session):
1818

1919
def download_models(session, use_host_env=False):
2020
print("Downloading test models")
21-
session.install('timm')
21+
session.install("-r", os.path.join(TOP_DIR, "tests", "modules", "requirements.txt"))
2222
print(TOP_DIR)
2323
session.chdir(os.path.join(TOP_DIR, "tests", "modules"))
2424
if use_host_env:

tests/modules/hub.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import torch.nn.functional as F
44
import torchvision.models as models
55
import timm
6+
from transformers import BertModel, BertTokenizer, BertConfig
67

78
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
89

@@ -189,3 +190,31 @@ def forward(self, x):
189190
conditional_model = FallbackIf().eval().cuda()
190191
conditional_script_model = torch.jit.script(conditional_model)
191192
torch.jit.save(conditional_script_model, "conditional_scripted.jit.pt")
193+
194+
195+
enc = BertTokenizer.from_pretrained("bert-base-uncased")
196+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
197+
tokenized_text = enc.tokenize(text)
198+
masked_index = 8
199+
tokenized_text[masked_index] = "[MASK]"
200+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
201+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
202+
tokens_tensor = torch.tensor([indexed_tokens])
203+
segments_tensors = torch.tensor([segments_ids])
204+
dummy_input = [tokens_tensor, segments_tensors]
205+
206+
config = BertConfig(
207+
vocab_size_or_config_json_file=32000,
208+
hidden_size=768,
209+
num_hidden_layers=12,
210+
num_attention_heads=12,
211+
intermediate_size=3072,
212+
torchscript=True,
213+
)
214+
215+
model = BertModel(config)
216+
model.eval()
217+
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
218+
219+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
220+
torch.jit.save(traced_model, "bert_base_uncased_traced.jit..pt")

tests/modules/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
-f https://download.pytorch.org/whl/torch_stable.html
2+
#torch==1.10.0+cu113
23
timm==v0.4.12
3-
torch==1.10.0+cu113
4+
transformers==4.17.0

0 commit comments

Comments
 (0)