Skip to content
This repository was archived by the owner on Aug 7, 2024. It is now read-only.

Commit 13d3198

Browse files
vkuzofacebook-github-bot
authored andcommitted
unify linear test cases (#307)
Summary: Pull Request resolved: #307 cleaning up some light tech debt Reviewed By: drisspg Differential Revision: D59521200 fbshipit-source-id: 768e26db0b0ac461d9112484fb1858c9d0a8853a
1 parent 13f2c26 commit 13d3198

File tree

1 file changed

+5
-45
lines changed

1 file changed

+5
-45
lines changed

test/test_base.py

Lines changed: 5 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -232,58 +232,18 @@ def _test_linear_impl(
232232
@pytest.mark.parametrize(
233233
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
234234
)
235+
@pytest.mark.parametrize("linear_dtype", [torch.bfloat16, torch.float32])
236+
@pytest.mark.parametrize("linear_bias", [False, True])
235237
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
236-
def test_linear_nobias(
238+
def test_linear(
237239
self,
238240
x_shape,
239241
emulate: bool,
240242
scaling_type_x: TensorScalingType,
241243
scaling_type_w: TensorScalingType,
242244
scaling_type_dL_dY: TensorScalingType,
243-
):
244-
if not emulate:
245-
if not torch.cuda.is_available():
246-
warnings.warn("CUDA not available")
247-
pytest.skip()
248-
elif torch.cuda.get_device_capability() < (9, 0):
249-
warnings.warn(
250-
f"CUDA capability {torch.cuda.get_device_capability()} < (9.0)"
251-
)
252-
pytest.skip()
253-
x = torch.randn(*x_shape, device="cuda")
254-
m_ref = nn.Linear(16, 32, bias=False, device="cuda")
255-
self._test_linear_impl(
256-
x,
257-
m_ref,
258-
emulate,
259-
scaling_type_x,
260-
scaling_type_w,
261-
scaling_type_dL_dY,
262-
)
263-
264-
@pytest.mark.parametrize("emulate", [True, False] if is_H100 else [True])
265-
@pytest.mark.parametrize("x_shape", [(16, 16), (2, 16, 16), (3, 2, 16, 16)])
266-
@pytest.mark.parametrize(
267-
"scaling_type_x", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
268-
)
269-
@pytest.mark.parametrize(
270-
"scaling_type_w", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
271-
)
272-
@pytest.mark.parametrize(
273-
"scaling_type_dL_dY", [TensorScalingType.DELAYED, TensorScalingType.DYNAMIC]
274-
)
275-
@pytest.mark.parametrize(
276-
"linear_dtype", [torch.float16, torch.bfloat16, torch.float32]
277-
)
278-
@unittest.skipIf(not torch.cuda.is_available(), "CUDA not available")
279-
def test_linear_bias(
280-
self,
281-
x_shape,
282-
scaling_type_x: TensorScalingType,
283-
scaling_type_w: TensorScalingType,
284-
scaling_type_dL_dY: TensorScalingType,
285-
emulate: bool,
286245
linear_dtype: torch.dtype,
246+
linear_bias: bool,
287247
):
288248
if not emulate:
289249
if not torch.cuda.is_available():
@@ -295,7 +255,7 @@ def test_linear_bias(
295255
)
296256
pytest.skip()
297257
x = torch.randn(*x_shape, device="cuda", dtype=linear_dtype)
298-
m_ref = nn.Linear(16, 32, bias=True, device="cuda", dtype=linear_dtype)
258+
m_ref = nn.Linear(16, 32, bias=linear_bias, device="cuda", dtype=linear_dtype)
299259
self._test_linear_impl(
300260
x,
301261
m_ref,

0 commit comments

Comments
 (0)