@@ -232,58 +232,18 @@ def _test_linear_impl(
232
232
@pytest .mark .parametrize (
233
233
"scaling_type_dL_dY" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
234
234
)
235
+ @pytest .mark .parametrize ("linear_dtype" , [torch .bfloat16 , torch .float32 ])
236
+ @pytest .mark .parametrize ("linear_bias" , [False , True ])
235
237
@unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
236
- def test_linear_nobias (
238
+ def test_linear (
237
239
self ,
238
240
x_shape ,
239
241
emulate : bool ,
240
242
scaling_type_x : TensorScalingType ,
241
243
scaling_type_w : TensorScalingType ,
242
244
scaling_type_dL_dY : TensorScalingType ,
243
- ):
244
- if not emulate :
245
- if not torch .cuda .is_available ():
246
- warnings .warn ("CUDA not available" )
247
- pytest .skip ()
248
- elif torch .cuda .get_device_capability () < (9 , 0 ):
249
- warnings .warn (
250
- f"CUDA capability { torch .cuda .get_device_capability ()} < (9.0)"
251
- )
252
- pytest .skip ()
253
- x = torch .randn (* x_shape , device = "cuda" )
254
- m_ref = nn .Linear (16 , 32 , bias = False , device = "cuda" )
255
- self ._test_linear_impl (
256
- x ,
257
- m_ref ,
258
- emulate ,
259
- scaling_type_x ,
260
- scaling_type_w ,
261
- scaling_type_dL_dY ,
262
- )
263
-
264
- @pytest .mark .parametrize ("emulate" , [True , False ] if is_H100 else [True ])
265
- @pytest .mark .parametrize ("x_shape" , [(16 , 16 ), (2 , 16 , 16 ), (3 , 2 , 16 , 16 )])
266
- @pytest .mark .parametrize (
267
- "scaling_type_x" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
268
- )
269
- @pytest .mark .parametrize (
270
- "scaling_type_w" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
271
- )
272
- @pytest .mark .parametrize (
273
- "scaling_type_dL_dY" , [TensorScalingType .DELAYED , TensorScalingType .DYNAMIC ]
274
- )
275
- @pytest .mark .parametrize (
276
- "linear_dtype" , [torch .float16 , torch .bfloat16 , torch .float32 ]
277
- )
278
- @unittest .skipIf (not torch .cuda .is_available (), "CUDA not available" )
279
- def test_linear_bias (
280
- self ,
281
- x_shape ,
282
- scaling_type_x : TensorScalingType ,
283
- scaling_type_w : TensorScalingType ,
284
- scaling_type_dL_dY : TensorScalingType ,
285
- emulate : bool ,
286
245
linear_dtype : torch .dtype ,
246
+ linear_bias : bool ,
287
247
):
288
248
if not emulate :
289
249
if not torch .cuda .is_available ():
@@ -295,7 +255,7 @@ def test_linear_bias(
295
255
)
296
256
pytest .skip ()
297
257
x = torch .randn (* x_shape , device = "cuda" , dtype = linear_dtype )
298
- m_ref = nn .Linear (16 , 32 , bias = True , device = "cuda" , dtype = linear_dtype )
258
+ m_ref = nn .Linear (16 , 32 , bias = linear_bias , device = "cuda" , dtype = linear_dtype )
299
259
self ._test_linear_impl (
300
260
x ,
301
261
m_ref ,
0 commit comments