Skip to content

Commit 6b3de99

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Stylize check_fn in codegen (#4044)
Summary: Pull Request resolved: #4044 Separate change because this one is real messy. ## Before ``` void check_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { if (test_dtype == at::kHalf) { if (!graph->context()->adapter_ptr()->has_full_float16_buffers_support()) { GTEST_SKIP();} } at::Tensor out = at::add(self, other, alpha); IOValueRef self_ref = graph->add_input_tensor(self.sizes().vec(), from_at_scalartype(self.scalar_type())); IOValueRef other_ref = graph->add_input_tensor(other.sizes().vec(), from_at_scalartype(other.scalar_type())); ValueRef alpha_ref = graph->add_scalar<double>(alpha.toDouble()); ValueRef out_ref = graph->add_tensor(out.sizes().vec(), from_at_scalartype(out.scalar_type())); VK_GET_OP_FN("aten.add.Tensor")(*graph, {self_ref.value, other_ref.value, alpha_ref, out_ref}); ValueRef out_ref_staging = graph->set_output_tensor(out_ref); graph->prepare(); graph->encode_prepack(); graph->prepack(); graph->encode_execute(); { graph->get_tensor(self_ref.value)->virtual_resize(self.sizes().vec()); graph->copy_into_staging(self_ref.staging, self.const_data_ptr(), self.numel()); graph->get_tensor(other_ref.value)->virtual_resize(other.sizes().vec()); graph->copy_into_staging(other_ref.staging, other.const_data_ptr(), other.numel()); graph->propagate_resize(); graph->execute(); at::Tensor vk_out_ref = at::empty_like(out).contiguous(); graph->copy_from_staging(out_ref_staging, vk_out_ref.mutable_data_ptr(), vk_out_ref.numel()); EXPECT_TRUE(check_close(out, vk_out_ref, rtol, atol)); } } ``` ## After ``` void check_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha=1) { if (test_dtype == at::kHalf) { if (!graph->context()->adapter_ptr()->has_full_float16_buffers_support()) { GTEST_SKIP(); } } at::Tensor out = at::add(self, other, alpha); IOValueRef self_ref = graph->add_input_tensor(self.sizes().vec(), from_at_scalartype(self.scalar_type())); IOValueRef other_ref = graph->add_input_tensor(other.sizes().vec(), from_at_scalartype(other.scalar_type())); ValueRef alpha_ref = graph->add_scalar<double>(alpha.toDouble()); ValueRef out_ref = graph->add_tensor(out.sizes().vec(), from_at_scalartype(out.scalar_type())); VK_GET_OP_FN("aten.add.Tensor")(*graph, {self_ref.value, other_ref.value, alpha_ref, out_ref}); ValueRef out_ref_staging = graph->set_output_tensor(out_ref); graph->prepare(); graph->encode_prepack(); graph->prepack(); graph->encode_execute(); { graph->get_tensor(self_ref.value)->virtual_resize(self.sizes().vec()); graph->copy_into_staging(self_ref.staging, self.const_data_ptr(), self.numel()); graph->get_tensor(other_ref.value)->virtual_resize(other.sizes().vec()); graph->copy_into_staging(other_ref.staging, other.const_data_ptr(), other.numel()); graph->propagate_resize(); graph->execute(); at::Tensor vk_out_ref = at::empty_like(out).contiguous(); graph->copy_from_staging(out_ref_staging, vk_out_ref.mutable_data_ptr(), vk_out_ref.numel()); EXPECT_TRUE(check_close(out, vk_out_ref, rtol, atol)); } } ``` bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: copyrightly Differential Revision: D58954595 fbshipit-source-id: 4112697eb25bdb65b77cbaa0b5d48142d782444d
1 parent 6f06316 commit 6b3de99

File tree

1 file changed

+30
-24
lines changed

1 file changed

+30
-24
lines changed

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 30 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -310,13 +310,15 @@ def create_value_for(self, ref: ValueRefList) -> str: # noqa: C901
310310
ret_str = f"std::vector<IOValueRef> {ref.name}_io_value_refs;\n"
311311
ret_str += f"std::vector<ValueRef> {ref.name}_value_refs;\n"
312312
ret_str += f"for (int i=0; i < {ref.src_cpp_name}.size(); i++) {{\n"
313-
ret_str += f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n"
314-
ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n"
315313
ret_str += (
316-
f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n"
314+
f" {cpp_type} io_value_ref = {self.graph}{self.dot}add_input_tensor(\n"
317315
)
318-
ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n"
319-
ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n"
316+
ret_str += f" {ref.src_cpp_name}[i].sizes().vec(),\n"
317+
ret_str += (
318+
f" from_at_scalartype({ref.src_cpp_name}[i].scalar_type())); \n"
319+
)
320+
ret_str += f" {ref.name}_value_refs.emplace_back(io_value_ref.value);\n"
321+
ret_str += f" {ref.name}_io_value_refs.emplace_back(io_value_ref);\n"
320322
ret_str += "}\n"
321323
ret_str += f"ValueRef {ref.name} = {self.graph}{self.dot}add_value_list(std::move({ref.name}_value_refs));\n"
322324
return ret_str
@@ -428,7 +430,9 @@ def virtual_resize(self, ref: ValueRefList) -> str:
428430
elif ref.src_cpp_type == AT_TENSOR_LIST:
429431
ret_str = ""
430432
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
431-
ret_str += f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)"
433+
ret_str += (
434+
f" {self.graph}{self.dot}get_tensor({ref.name}_io_value_refs[i].value)"
435+
)
432436
ret_str += f"->virtual_resize({ref.src_cpp_name}[i].sizes().vec());\n"
433437
ret_str += "}\n"
434438
else:
@@ -451,7 +455,7 @@ def copy_into_staging(self, ref: ValueRefList) -> str:
451455
elif ref.src_cpp_type == AT_TENSOR_LIST:
452456
ret_str = ""
453457
ret_str += f"for (int i=0; i < {ref.name}_io_value_refs.size(); i++) {{\n"
454-
ret_str += f" {self.graph}{self.dot}copy_into_staging("
458+
ret_str += f" {self.graph}{self.dot}copy_into_staging("
455459
ret_str += f"{ref.name}_io_value_refs[i].staging, "
456460
ret_str += f"{ref.src_cpp_name}[i].const_data_ptr(), "
457461
ret_str += f"{ref.src_cpp_name}[i].numel());\n"
@@ -522,7 +526,9 @@ def check_graph_out(self, ref: ValueRefList) -> str:
522526
"""
523527
return ret_str
524528

525-
return f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));\n"
529+
return (
530+
f"EXPECT_TRUE(check_close({ref.src_cpp_name}, vk_{ref.name}, rtol, atol));"
531+
)
526532

527533
## Top level code generation
528534

@@ -541,13 +547,11 @@ def gen_graph_build_code(self) -> str:
541547
graph_build += f"{self.graph}{self.dot}prepack();\n"
542548
graph_build += f"{self.graph}{self.dot}encode_execute();\n"
543549

550+
graph_build += "\n"
544551
return graph_build
545552

546-
def gen_graph_exec_code(self, loop_range: int = 1) -> str:
553+
def gen_graph_exec_code(self) -> str:
547554
graph_exec = ""
548-
if loop_range > 1:
549-
graph_exec += f"for (int i = 0; i < {loop_range} ; ++i) "
550-
graph_exec += "{\n"
551555
for aten_arg in self.args:
552556
ref = self.refs[aten_arg.name]
553557
if ref.is_in:
@@ -560,17 +564,20 @@ def gen_graph_exec_code(self, loop_range: int = 1) -> str:
560564
graph_exec += self.declare_vk_out_for(self.refs["out"])
561565
graph_exec += self.copy_from_staging(self.refs["out"])
562566
graph_exec += self.check_graph_out(self.refs["out"])
563-
graph_exec += "}\n"
567+
568+
graph_exec = re.sub(r"^", " ", graph_exec, flags=re.M)
569+
graph_exec = "{\n" + graph_exec + "\n}"
564570

565571
return graph_exec
566572

567573
def gen_conditional_skips(self) -> str:
568574
fp16_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_float16_buffers_support()) {{\n"
569-
fp16_skip += " GTEST_SKIP();"
570-
fp16_skip += "}\n"
575+
fp16_skip += " GTEST_SKIP();\n"
576+
fp16_skip += "}"
577+
fp16_skip = re.sub(r"^", " ", fp16_skip, flags=re.M) + "\n"
571578

572579
int8_skip = f"if (!{self.graph}{self.dot}context()->adapter_ptr()->has_full_int8_buffers_support()) {{\n"
573-
int8_skip += " GTEST_SKIP();"
580+
int8_skip += " GTEST_SKIP();\n"
574581
int8_skip += "}\n"
575582

576583
skips = ""
@@ -584,24 +591,24 @@ def gen_conditional_skips(self) -> str:
584591
skips += int8_skip
585592
continue
586593

594+
skips += "\n"
587595
return skips
588596

589597
def gen_op_check_fn(self) -> str:
590598
op_name = self.f.func.name.unambiguous_name()
591599
op_check_fn = self.gen_decl(f"check_{op_name}") + " {\n"
592600
if self.should_prepack:
593-
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {"
601+
op_check_fn = self.gen_decl(f"prepacked_check_{op_name}") + " {\n"
594602

595603
op_check_fn_body = ""
596604
op_check_fn_body += self.gen_conditional_skips()
597605
op_check_fn_body += self.gen_graph_build_code()
598606
op_check_fn_body += self.gen_graph_exec_code()
599607

600-
# Add two level of indent for readability
601-
op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M)
608+
op_check_fn_body = re.sub(r"^", " ", op_check_fn_body, flags=re.M)
602609

603-
op_check_fn += op_check_fn_body + "\n"
604-
op_check_fn += " }\n"
610+
op_check_fn += op_check_fn_body
611+
op_check_fn += "\n }"
605612

606613
return op_check_fn
607614

@@ -639,8 +646,6 @@ class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple
639646
}}
640647
641648
{check_fn}
642-
643-
{prepacked_check_fn}
644649
}};
645650
"""
646651

@@ -660,11 +665,12 @@ def generate_fixture_cpp(self) -> str:
660665
if self.suite_def.supports_prepack():
661666
self.generator.should_prepack = True
662667
prepacked_check_fn = self.generator.gen_op_check_fn()
668+
check_fn += "\n\n "
669+
check_fn += prepacked_check_fn
663670

664671
return test_fixture_template.format(
665672
op_name=self.op_name,
666673
check_fn=check_fn,
667-
prepacked_check_fn=prepacked_check_fn,
668674
rtol=self.suite_def.rtol,
669675
atol=self.suite_def.atol,
670676
)

0 commit comments

Comments
 (0)