@@ -47,24 +47,30 @@ def do_fp8_matmul(A, B, fp8_dtype, out_dtype):
47
47
A_fp8 = A .to (fp8_dtype )
48
48
B_fp8 = B .to (fp8_dtype ).t () # view
49
49
50
- A_pad = pad_tensor_for_matmul ( A_fp8 ) # mem copy
51
- B_pad = pad_tensor_for_matmul ( B_fp8 , both = True ). contiguous (). t () # mem copy
50
+ scale_a = torch . tensor ([ 1 ], device = "cuda" , dtype = torch . float32 )
51
+ scale_b = torch . tensor ([ 1 ], device = "cuda" , dtype = torch . float32 )
52
52
53
- return torch ._scaled_mm (A_pad , B_pad , out_dtype = out_dtype )[0 ][
53
+ A_pad = pad_tensor_for_matmul (A_fp8 , dims = 1 ) # mem copy
54
+ B_pad = pad_tensor_for_matmul (B_fp8 , dims = [0 , 1 ]).contiguous ().t () # mem copy
55
+
56
+ return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
54
57
: A .shape [0 ], : B .shape [1 ]
55
58
]
56
59
57
60
58
61
def do_fp8_pad_first_matmul (A , B , fp8_dtype , out_dtype ):
59
- A_pad = pad_tensor_for_matmul (A ) # mem copy
60
- B_pad = pad_tensor_for_matmul (B , both = True ) # mem copy
62
+ A_pad = pad_tensor_for_matmul (A , dims = 1 ) # mem copy
63
+ B_pad = pad_tensor_for_matmul (B , dims = [0 , 1 ]) # mem copy
64
+
65
+ scale_a = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
66
+ scale_b = torch .tensor ([1 ], device = "cuda" , dtype = torch .float32 )
61
67
62
68
A_pad = A_pad .to (fp8_dtype ) # mem copy
63
69
B_pad = B_pad .to (fp8_dtype ) # mem copy
64
70
65
71
B_pad = B_pad .t ().contiguous ().t () # mem copy
66
72
67
- return torch ._scaled_mm (A_pad , B_pad , out_dtype = out_dtype )[ 0 ] [
73
+ return torch ._scaled_mm (A_pad , B_pad , scale_a , scale_b , out_dtype = out_dtype )[
68
74
: A .shape [0 ], : B .shape [1 ]
69
75
]
70
76
@@ -86,8 +92,8 @@ def __iter__(self):
86
92
87
93
88
94
def gen_configs ():
89
- shapes = [(8192 , 2500 , 5000 ), (4096 , 10 , 4096 )]
90
- output_dtype = torch .float32
95
+ shapes = [(8192 , 2500 , 5000 ), (64 , 255 , 4096 )]
96
+ output_dtype = torch .bfloat16
91
97
fp8_dtype = torch .float8_e4m3fn
92
98
return [Experiment_config (* shape , output_dtype , fp8_dtype ) for shape in shapes ]
93
99
0 commit comments