Skip to content

Commit b0f0125

Browse files
authored
Ensure transformers is up to date (#1479)
We fix tests to adjust for deprecated methods, we added a line to ensure we get the latest transformers package is latest.
1 parent af1ee04 commit b0f0125

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

kaggle_requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ torchao
130130
torchinfo
131131
torchmetrics
132132
torchtune
133+
transformers>=4.51.0
133134
triton
134135
tsfresh
135136
vtk

tests/test_transformers.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import unittest
22

33
import torch
4-
from transformers import AdamW
4+
import torch.optim as optim
55
import transformers.pipelines # verify this import works
66

77

@@ -10,13 +10,12 @@ def assertListAlmostEqual(self, list1, list2, tol):
1010
self.assertEqual(len(list1), len(list2))
1111
for a, b in zip(list1, list2):
1212
self.assertAlmostEqual(a, b, delta=tol)
13-
1413
def test_adam_w(self):
1514
w = torch.tensor([0.1, -0.2, -0.1], requires_grad=True)
1615
target = torch.tensor([0.4, 0.2, -0.5])
1716
criterion = torch.nn.MSELoss()
1817
# No warmup, constant schedule, no gradient clipping
19-
optimizer = AdamW(params=[w], lr=2e-1, weight_decay=0.0)
18+
optimizer = optim.AdamW(params=[w], lr=2e-1, weight_decay=0.0)
2019
for _ in range(100):
2120
loss = criterion(w, target)
2221
loss.backward()

0 commit comments

Comments
 (0)