|
| 1 | +from typing import Dict, List, Tuple |
| 2 | + |
1 | 3 | import torch
|
2 | 4 | import torch.nn as nn
|
3 |
| -from transformers import BertModel, BertTokenizer, BertConfig |
4 | 5 | import torch.nn.functional as F
|
5 |
| -from typing import Tuple, List, Dict |
| 6 | +from transformers import BertConfig, BertModel, BertTokenizer |
6 | 7 |
|
7 | 8 |
|
8 | 9 | # Sample Pool Model (for testing plugin serialization)
|
@@ -174,16 +175,19 @@ def BertModule():
|
174 | 175 | segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
|
175 | 176 | tokens_tensor = torch.tensor([indexed_tokens])
|
176 | 177 | segments_tensors = torch.tensor([segments_ids])
|
177 |
| - config = BertConfig( |
178 |
| - vocab_size_or_config_json_file=32000, |
179 |
| - hidden_size=768, |
180 |
| - num_hidden_layers=12, |
181 |
| - num_attention_heads=12, |
182 |
| - intermediate_size=3072, |
183 |
| - torchscript=True, |
184 |
| - ) |
185 |
| - model = BertModel(config) |
| 178 | + |
| 179 | + model_kwargs = { |
| 180 | + "vocab_size_or_config_json_file": 32000, |
| 181 | + "hidden_size": 768, |
| 182 | + "num_hidden_layers": 12, |
| 183 | + "num_attention_heads": 12, |
| 184 | + "intermediate_size": 3072, |
| 185 | + "use_cache": False, |
| 186 | + "output_attentions": False, |
| 187 | + "output_hidden_states": False, |
| 188 | + "torchscript": True, |
| 189 | + } |
| 190 | + model = BertModel.from_pretrained(model_name, **model_kwargs) |
186 | 191 | model.eval()
|
187 |
| - model = BertModel.from_pretrained(model_name, torchscript=True) |
188 | 192 | traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
|
189 | 193 | return traced_model
|
0 commit comments