@@ -645,31 +645,32 @@ def _test_qd8_per_token_weight_per_channel_group_int4(
645
645
bl_sizes = [32 , 32 , 32 , 64 ]
646
646
N_sizes = [2 , 17 , 92 , 128 ]
647
647
648
- for use_bias in [True , False ]:
649
- for M , K , bl , N in zip (M_sizes , K_sizes , bl_sizes , N_sizes ):
650
- lin_mod = BaseLinear (
651
- in_size = M ,
652
- input_channels = K ,
653
- output_channels = N ,
654
- dtype = dtype ,
655
- use_bias = use_bias ,
656
- )
648
+ for input_rank in range (2 , 4 ):
649
+ for use_bias in [True , False ]:
650
+ for M , K , bl , N in zip (M_sizes , K_sizes , bl_sizes , N_sizes ):
651
+ lin_mod = BaseLinear (
652
+ in_size = M ,
653
+ input_channels = K ,
654
+ output_channels = N ,
655
+ dtype = dtype ,
656
+ use_bias = use_bias ,
657
+ )
657
658
658
- inputs = lin_mod .get_inputs ()
659
- # Half requires slightly higher atol, but if you look at error it is not that bad:
660
- # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
661
- # -- Model vs. Reference --
662
- # Numel: 4, 4
663
- # Median: -0.05023193359375, -0.0516357421875
664
- # Mean: 0.2373046875, 0.237060546875
665
- # Max: 1.0078125, 1.0078125
666
- # Min: -0.08465576171875, -0.08441162109375
667
- atol = (
668
- 1e-2 if dtype == torch .half else 5e-3
669
- ) # TODO(T212995726): Investigate right atol for rand[n] inputs
670
- self ._test_groupwise_dq_linear (
671
- lin_mod , inputs , group_size = bl , use_bias = use_bias , atol = atol
672
- )
659
+ inputs = lin_mod .get_inputs (rank = input_rank )
660
+ # Half requires slightly higher atol, but if you look at error it is not that bad:
661
+ # Difference: max: 0.00140380859375, abs: 0.00140380859375, mean abs error: 0.00042724609375.
662
+ # -- Model vs. Reference --
663
+ # Numel: 4, 4
664
+ # Median: -0.05023193359375, -0.0516357421875
665
+ # Mean: 0.2373046875, 0.237060546875
666
+ # Max: 1.0078125, 1.0078125
667
+ # Min: -0.08465576171875, -0.08441162109375
668
+ atol = (
669
+ 1e-2 if dtype == torch .half else 5e-3
670
+ ) # TODO(T212995726): Investigate right atol for rand[n] inputs
671
+ self ._test_groupwise_dq_linear (
672
+ lin_mod , inputs , group_size = bl , use_bias = use_bias , atol = atol
673
+ )
673
674
674
675
def test_fp16_linear (self ):
675
676
for use_bias in (True , False ):
0 commit comments