@@ -310,13 +310,15 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
310
310
ret_str = f"std::vector<IOValueRef> { ref .name } _io_value_refs;\n "
311
311
ret_str += f"std::vector<ValueRef> { ref .name } _value_refs;\n "
312
312
ret_str += f"for (int i=0; i < { ref .src_cpp_name } .size(); i++) {{\n "
313
- ret_str += f" { cpp_type } io_value_ref = { self .graph } { self .dot } add_input_tensor(\n "
314
- ret_str += f" { ref .src_cpp_name } [i].sizes().vec(),\n "
315
313
ret_str += (
316
- f" from_at_scalartype( { ref . src_cpp_name } [i].scalar_type())); \n "
314
+ f" { cpp_type } io_value_ref = { self . graph } { self . dot } add_input_tensor( \n "
317
315
)
318
- ret_str += f" { ref .name } _value_refs.emplace_back(io_value_ref.value);\n "
319
- ret_str += f" { ref .name } _io_value_refs.emplace_back(io_value_ref);\n "
316
+ ret_str += f" { ref .src_cpp_name } [i].sizes().vec(),\n "
317
+ ret_str += (
318
+ f" from_at_scalartype({ ref .src_cpp_name } [i].scalar_type())); \n "
319
+ )
320
+ ret_str += f" { ref .name } _value_refs.emplace_back(io_value_ref.value);\n "
321
+ ret_str += f" { ref .name } _io_value_refs.emplace_back(io_value_ref);\n "
320
322
ret_str += "}\n "
321
323
ret_str += f"ValueRef { ref .name } = { self .graph } { self .dot } add_value_list(std::move({ ref .name } _value_refs));\n "
322
324
return ret_str
@@ -428,7 +430,9 @@ def virtual_resize(self, ref: ValueRefList) -> str:
428
430
elif ref .src_cpp_type == AT_TENSOR_LIST :
429
431
ret_str = ""
430
432
ret_str += f"for (int i=0; i < { ref .name } _io_value_refs.size(); i++) {{\n "
431
- ret_str += f" { self .graph } { self .dot } get_tensor({ ref .name } _io_value_refs[i].value)"
433
+ ret_str += (
434
+ f" { self .graph } { self .dot } get_tensor({ ref .name } _io_value_refs[i].value)"
435
+ )
432
436
ret_str += f"->virtual_resize({ ref .src_cpp_name } [i].sizes().vec());\n "
433
437
ret_str += "}\n "
434
438
else :
@@ -451,7 +455,7 @@ def copy_into_staging(self, ref: ValueRefList) -> str:
451
455
elif ref .src_cpp_type == AT_TENSOR_LIST :
452
456
ret_str = ""
453
457
ret_str += f"for (int i=0; i < { ref .name } _io_value_refs.size(); i++) {{\n "
454
- ret_str += f" { self .graph } { self .dot } copy_into_staging("
458
+ ret_str += f" { self .graph } { self .dot } copy_into_staging("
455
459
ret_str += f"{ ref .name } _io_value_refs[i].staging, "
456
460
ret_str += f"{ ref .src_cpp_name } [i].const_data_ptr(), "
457
461
ret_str += f"{ ref .src_cpp_name } [i].numel());\n "
@@ -522,7 +526,9 @@ def check_graph_out(self, ref: ValueRefList) -> str:
522
526
"""
523
527
return ret_str
524
528
525
- return f"EXPECT_TRUE(check_close({ ref .src_cpp_name } , vk_{ ref .name } , rtol, atol));\n "
529
+ return (
530
+ f"EXPECT_TRUE(check_close({ ref .src_cpp_name } , vk_{ ref .name } , rtol, atol));"
531
+ )
526
532
527
533
## Top level code generation
528
534
@@ -541,13 +547,11 @@ def gen_graph_build_code(self) -> str:
541
547
graph_build += f"{ self .graph } { self .dot } prepack();\n "
542
548
graph_build += f"{ self .graph } { self .dot } encode_execute();\n "
543
549
550
+ graph_build += "\n "
544
551
return graph_build
545
552
546
- def gen_graph_exec_code (self , loop_range : int = 1 ) -> str :
553
+ def gen_graph_exec_code (self ) -> str :
547
554
graph_exec = ""
548
- if loop_range > 1 :
549
- graph_exec += f"for (int i = 0; i < { loop_range } ; ++i) "
550
- graph_exec += "{\n "
551
555
for aten_arg in self .args :
552
556
ref = self .refs [aten_arg .name ]
553
557
if ref .is_in :
@@ -560,17 +564,20 @@ def gen_graph_exec_code(self, loop_range: int = 1) -> str:
560
564
graph_exec += self .declare_vk_out_for (self .refs ["out" ])
561
565
graph_exec += self .copy_from_staging (self .refs ["out" ])
562
566
graph_exec += self .check_graph_out (self .refs ["out" ])
563
- graph_exec += "}\n "
567
+
568
+ graph_exec = re .sub (r"^" , " " , graph_exec , flags = re .M )
569
+ graph_exec = "{\n " + graph_exec + "\n }"
564
570
565
571
return graph_exec
566
572
567
573
def gen_conditional_skips (self ) -> str :
568
574
fp16_skip = f"if (!{ self .graph } { self .dot } context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n "
569
- fp16_skip += " GTEST_SKIP();"
570
- fp16_skip += "}\n "
575
+ fp16_skip += " GTEST_SKIP();\n "
576
+ fp16_skip += "}"
577
+ fp16_skip = re .sub (r"^" , " " , fp16_skip , flags = re .M ) + "\n "
571
578
572
579
int8_skip = f"if (!{ self .graph } { self .dot } context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n "
573
- int8_skip += " GTEST_SKIP();"
580
+ int8_skip += " GTEST_SKIP();\n "
574
581
int8_skip += "}\n "
575
582
576
583
skips = ""
@@ -584,24 +591,24 @@ def gen_conditional_skips(self) -> str:
584
591
skips += int8_skip
585
592
continue
586
593
594
+ skips += "\n "
587
595
return skips
588
596
589
597
def gen_op_check_fn (self ) -> str :
590
598
op_name = self .f .func .name .unambiguous_name ()
591
599
op_check_fn = self .gen_decl (f"check_{ op_name } " ) + " {\n "
592
600
if self .should_prepack :
593
- op_check_fn = self .gen_decl (f"prepacked_check_{ op_name } " ) + " {"
601
+ op_check_fn = self .gen_decl (f"prepacked_check_{ op_name } " ) + " {\n "
594
602
595
603
op_check_fn_body = ""
596
604
op_check_fn_body += self .gen_conditional_skips ()
597
605
op_check_fn_body += self .gen_graph_build_code ()
598
606
op_check_fn_body += self .gen_graph_exec_code ()
599
607
600
- # Add two level of indent for readability
601
- op_check_fn_body = re .sub (r"^" , " " , op_check_fn_body , flags = re .M )
608
+ op_check_fn_body = re .sub (r"^" , " " , op_check_fn_body , flags = re .M )
602
609
603
- op_check_fn += op_check_fn_body + " \n "
604
- op_check_fn += " } \n "
610
+ op_check_fn += op_check_fn_body
611
+ op_check_fn += "\n } "
605
612
606
613
return op_check_fn
607
614
@@ -639,8 +646,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
639
646
}}
640
647
641
648
{check_fn}
642
-
643
- {prepacked_check_fn}
644
649
}};
645
650
"""
646
651
@@ -660,11 +665,12 @@ def generate_fixture_cpp(self) -> str:
660
665
if self .suite_def .supports_prepack ():
661
666
self .generator .should_prepack = True
662
667
prepacked_check_fn = self .generator .gen_op_check_fn ()
668
+ check_fn += "\n \n "
669
+ check_fn += prepacked_check_fn
663
670
664
671
return test_fixture_template .format (
665
672
op_name = self .op_name ,
666
673
check_fn = check_fn ,
667
- prepacked_check_fn = prepacked_check_fn ,
668
674
rtol = self .suite_def .rtol ,
669
675
atol = self .suite_def .atol ,
670
676
)
0 commit comments