@@ -1176,7 +1176,7 @@ def func(q, k, v, score_mod, block_mask):
1176
1176
def test_aot_eager_gradcheck (self , score_mod ):
1177
1177
make_tensor = functools .partial (
1178
1178
torch .randn ,
1179
- (2 , 2 , 8 , 4 ),
1179
+ (2 , 2 , 128 , 4 ),
1180
1180
device = "cuda" ,
1181
1181
dtype = torch .float64 ,
1182
1182
requires_grad = True ,
@@ -1199,7 +1199,7 @@ def test_captured_score_mod_aot_eager_gradcheck(
1199
1199
):
1200
1200
make_tensor = functools .partial (
1201
1201
torch .randn ,
1202
- (2 , 2 , 8 , 4 ),
1202
+ (2 , 2 , 128 , 4 ),
1203
1203
device = "cuda" ,
1204
1204
dtype = torch .float64 ,
1205
1205
requires_grad = True ,
@@ -1336,7 +1336,7 @@ def test_fw_bw_graph_correctness(self):
1336
1336
cnt = CompileCounterWithBackend ("aot_eager" )
1337
1337
make_tensor = functools .partial (
1338
1338
torch .randn ,
1339
- (2 , 2 , 8 , 4 ),
1339
+ (2 , 2 , 128 , 4 ),
1340
1340
device = "cuda" ,
1341
1341
dtype = torch .float64 ,
1342
1342
requires_grad = True ,
@@ -1355,7 +1355,7 @@ def test_fw_bw_graph_correctness(self):
1355
1355
norm_graph ,
1356
1356
"""\
1357
1357
class GraphModule(torch.nn.Module):
1358
- def forward(self, L_args_0_: "f64[2, 2, 8 , 4]", L_args_1_: "f64[2, 2, 8 , 4]", L_args_2_: "f64[2, 2, 8 , 4]"):
1358
+ def forward(self, L_args_0_: "f64[2, 2, 128 , 4]", L_args_1_: "f64[2, 2, 128 , 4]", L_args_2_: "f64[2, 2, 128 , 4]"):
1359
1359
l_args_0_ = L_args_0_
1360
1360
l_args_1_ = L_args_1_
1361
1361
l_args_2_ = L_args_2_
@@ -1374,8 +1374,8 @@ def forward(self, L_args_0_: "f64[2, 2, 8, 4]", L_args_1_: "f64[2, 2, 8, 4]", L_
1374
1374
child_3: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1375
1375
child_4: "i32[]" = l_args_0_.new_empty([], dtype = torch.int32)
1376
1376
flex_attention_0 = self.flex_attention_0
1377
- flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 8, 8 ), 0.5); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1378
- out: "f64[2, 2, 8 , 4]" = flex_attention[0]; flex_attention = None
1377
+ flex_attention = torch.ops.higher_order.flex_attention(l_args_0_, l_args_1_, l_args_2_, flex_attention_0, (ones, zeros, ones_1, zeros_1, 128, 128 ), 0.5); l_args_0_ = l_args_1_ = l_args_2_ = flex_attention_0 = ones = zeros = ones_1 = zeros_1 = None
1378
+ out: "f64[2, 2, 128 , 4]" = flex_attention[0]; flex_attention = None
1379
1379
return (out,)
1380
1380
1381
1381
class GraphModule(torch.nn.Module):
@@ -1405,13 +1405,13 @@ def debug_compile_fx_inner(graph, example_inputs, *args, **kwargs):
1405
1405
joint_graph ,
1406
1406
"""\
1407
1407
class GraphModule(torch.nn.Module):
1408
- def forward(self, primals_1: "f64[2, 2, 8 , 4]", primals_2: "f64[2, 2, 8 , 4]", primals_3: "f64[2, 2, 8 , 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 8 , 4]", getitem_1: "f32[2, 2, 8 ]", tangents_1: "f64[2, 2, 8 , 4]"):
1408
+ def forward(self, primals_1: "f64[2, 2, 128 , 4]", primals_2: "f64[2, 2, 128 , 4]", primals_3: "f64[2, 2, 128 , 4]", full_default: "i32[1, 1, 1]", full_default_1: "i32[1, 1, 1, 1]", getitem: "f64[2, 2, 128 , 4]", getitem_1: "f32[2, 2, 128 ]", tangents_1: "f64[2, 2, 128 , 4]"):
1409
1409
fw_graph = self.fw_graph
1410
1410
joint_graph = self.joint_graph
1411
- flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 8, 8 ), 0.5); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1412
- getitem_2: "f64[2, 2, 8 , 4]" = flex_attention_backward[0]
1413
- getitem_3: "f64[2, 2, 8 , 4]" = flex_attention_backward[1]
1414
- getitem_4: "f64[2, 2, 8 , 4]" = flex_attention_backward[2]; flex_attention_backward = None
1411
+ flex_attention_backward = torch.ops.higher_order.flex_attention_backward(primals_1, primals_2, primals_3, getitem, getitem_1, tangents_1, fw_graph, joint_graph, (full_default, full_default_1, full_default, full_default_1, 128, 128 ), 0.5); primals_1 = primals_2 = primals_3 = getitem = getitem_1 = tangents_1 = fw_graph = joint_graph = full_default = full_default_1 = None
1412
+ getitem_2: "f64[2, 2, 128 , 4]" = flex_attention_backward[0]
1413
+ getitem_3: "f64[2, 2, 128 , 4]" = flex_attention_backward[1]
1414
+ getitem_4: "f64[2, 2, 128 , 4]" = flex_attention_backward[2]; flex_attention_backward = None
1415
1415
return [getitem_2, getitem_3, getitem_4]
1416
1416
1417
1417
class <lambda>(torch.nn.Module):
@@ -1429,6 +1429,29 @@ def forward(self, arg0_1: "f64[]", arg1_1: "i32[]", arg2_1: "i32[]", arg3_1: "i3
1429
1429
""" , # noqa: B950
1430
1430
)
1431
1431
1432
+ @supported_platform
1433
+ def test_nyi_for_non_divisible_seq_lens (self ):
1434
+ with self .assertRaisesRegex (
1435
+ NotImplementedError , "NYI: L must be a multiple of 128"
1436
+ ):
1437
+ flex_attention (
1438
+ torch .randn ((2 , 3 , 4 )),
1439
+ torch .randn ((2 , 10 , 5 )),
1440
+ torch .randn ((2 , 10 , 5 )),
1441
+ score_mod = _identity ,
1442
+ )
1443
+
1444
+ with self .assertRaisesRegex (
1445
+ NotImplementedError , "NYI: L must be a multiple of 128"
1446
+ ):
1447
+ compiled_flex = torch .compile (flex_attention )
1448
+ compiled_flex (
1449
+ torch .randn ((2 , 3 , 4 )),
1450
+ torch .randn ((2 , 10 , 5 )),
1451
+ torch .randn ((2 , 10 , 5 )),
1452
+ score_mod = _identity ,
1453
+ )
1454
+
1432
1455
1433
1456
common_utils .instantiate_parametrized_tests (TestFlexAttention )
1434
1457
0 commit comments