3
3
import torch .nn .functional as F
4
4
import torchvision .models as models
5
5
import timm
6
+ from transformers import BertModel , BertTokenizer , BertConfig
6
7
7
8
torch .hub ._validate_not_a_forked_repo = lambda a , b , c : True
8
9
@@ -189,3 +190,31 @@ def forward(self, x):
189
190
conditional_model = FallbackIf ().eval ().cuda ()
190
191
conditional_script_model = torch .jit .script (conditional_model )
191
192
torch .jit .save (conditional_script_model , "conditional_scripted.jit.pt" )
193
+
194
+
195
+ enc = BertTokenizer .from_pretrained ("bert-base-uncased" )
196
+ text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
197
+ tokenized_text = enc .tokenize (text )
198
+ masked_index = 8
199
+ tokenized_text [masked_index ] = "[MASK]"
200
+ indexed_tokens = enc .convert_tokens_to_ids (tokenized_text )
201
+ segments_ids = [0 , 0 , 0 , 0 , 0 , 0 , 0 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ]
202
+ tokens_tensor = torch .tensor ([indexed_tokens ])
203
+ segments_tensors = torch .tensor ([segments_ids ])
204
+ dummy_input = [tokens_tensor , segments_tensors ]
205
+
206
+ config = BertConfig (
207
+ vocab_size_or_config_json_file = 32000 ,
208
+ hidden_size = 768 ,
209
+ num_hidden_layers = 12 ,
210
+ num_attention_heads = 12 ,
211
+ intermediate_size = 3072 ,
212
+ torchscript = True ,
213
+ )
214
+
215
+ model = BertModel (config )
216
+ model .eval ()
217
+ model = BertModel .from_pretrained ("bert-base-uncased" , torchscript = True )
218
+
219
+ traced_model = torch .jit .trace (model , [tokens_tensor , segments_tensors ])
220
+ torch .jit .save (traced_model , "bert_base_uncased_traced.jit..pt" )
0 commit comments