Skip to content

Commit 6f06316

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Standardize 2-space tabs in codegen (#4043)
Summary: Pull Request resolved: #4043 for consistently with the non-generated ET-VK code. The new perf codegen will build off this existing correctness codegen and I want the generated code to be readable. ghstack-source-id: 231368098 exported-using-ghexport bypass-github-export-checks bypass-github-pytorch-ci-checks bypass-github-executorch-ci-checks Reviewed By: copyrightly Differential Revision: D58954596 fbshipit-source-id: 11fab7c2c8a3d426bcae143a6c5e5a99aa34a2ab
1 parent 39e17e4 commit 6f06316

File tree

2 files changed

+72
-75
lines changed

2 files changed

+72
-75
lines changed

backends/vulkan/test/op_tests/utils/codegen.py

Lines changed: 60 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -612,36 +612,35 @@ def gen_op_check_fn(self) -> str:
612612

613613
test_fixture_template = """
614614
class GeneratedOpsTest_{op_name} : public ::testing::TestWithParam< ::std::tuple<at::ScalarType, api::StorageType, api::GPUMemoryLayout>> {{
615-
protected:
616-
ComputeGraph* graph;
617-
at::ScalarType test_dtype = at::kFloat;
618-
float rtol = {rtol};
619-
float atol = {atol};
620-
621-
void SetUp() override {{
622-
GraphConfig config;
623-
api::StorageType default_storage_type;
624-
api::GPUMemoryLayout default_memory_layout;
625-
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
626-
config.set_storage_type_override(default_storage_type);
627-
config.set_memory_layout_override(default_memory_layout);
628-
graph = new ComputeGraph(config);
629-
630-
if (test_dtype == at::kHalf) {{
631-
rtol = 1e-2;
632-
atol = 1e-2;
633-
}}
615+
protected:
616+
ComputeGraph* graph;
617+
at::ScalarType test_dtype = at::kFloat;
618+
float rtol = {rtol};
619+
float atol = {atol};
620+
621+
void SetUp() override {{
622+
GraphConfig config;
623+
api::StorageType default_storage_type;
624+
api::GPUMemoryLayout default_memory_layout;
625+
std::tie(test_dtype, default_storage_type, default_memory_layout) = GetParam();
626+
config.set_storage_type_override(default_storage_type);
627+
config.set_memory_layout_override(default_memory_layout);
628+
graph = new ComputeGraph(config);
629+
630+
if (test_dtype == at::kHalf) {{
631+
rtol = 1e-2;
632+
atol = 1e-2;
634633
}}
634+
}}
635635
636-
void TearDown() override {{
637-
delete graph;
638-
graph = nullptr;
639-
}}
640-
641-
{check_fn}
636+
void TearDown() override {{
637+
delete graph;
638+
graph = nullptr;
639+
}}
642640
643-
{prepacked_check_fn}
641+
{check_fn}
644642
643+
{prepacked_check_fn}
645644
}};
646645
"""
647646

@@ -676,13 +675,13 @@ def gen_parameterization(self) -> str:
676675
layouts = self.suite_def.layouts
677676

678677
return f"""
679-
INSTANTIATE_TEST_SUITE_P(
680-
Combos_{self.op_name},
681-
GeneratedOpsTest_{self.op_name},
682-
::testing::Combine(
683-
::testing::Values({', '.join(dtypes)}),
684-
::testing::Values({', '.join(storage_types)}),
685-
::testing::Values({', '.join(layouts)})));
678+
INSTANTIATE_TEST_SUITE_P(
679+
Combos_{self.op_name},
680+
GeneratedOpsTest_{self.op_name},
681+
::testing::Combine(
682+
::testing::Values({', '.join(dtypes)}),
683+
::testing::Values({', '.join(storage_types)}),
684+
::testing::Values({', '.join(layouts)})));
686685
"""
687686

688687

@@ -701,41 +700,41 @@ def gen_parameterization(self) -> str:
701700
using TensorOptions = at::TensorOptions;
702701
703702
api::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
704-
switch(at_scalartype) {
705-
case c10::kFloat:
706-
return api::kFloat;
707-
case c10::kHalf:
708-
return api::kHalf;
709-
case c10::kInt:
710-
return api::kInt;
711-
case c10::kLong:
712-
return api::kInt;
713-
case c10::kChar:
714-
return api::kChar;
715-
default:
716-
VK_THROW("Unsupported at::ScalarType!");
717-
}
703+
switch (at_scalartype) {
704+
case c10::kFloat:
705+
return api::kFloat;
706+
case c10::kHalf:
707+
return api::kHalf;
708+
case c10::kInt:
709+
return api::kInt;
710+
case c10::kLong:
711+
return api::kInt;
712+
case c10::kChar:
713+
return api::kChar;
714+
default:
715+
VK_THROW("Unsupported at::ScalarType!");
716+
}
718717
}
719718
720719
#ifdef USE_VULKAN_FP16_INFERENCE
721720
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-2, float atol=1e-2) {
722721
#else
723722
bool check_close(at::Tensor& t1, at::Tensor& t2, float rtol=1e-5, float atol=1e-5) {
724723
#endif
725-
// Skip checking index tensors
726-
if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {
727-
return true;
728-
}
729-
bool is_close = at::allclose(t1, t2, rtol, atol);
730-
if (!is_close && t1.numel() < 500) {
731-
std::cout << "reference: " << std::endl;
732-
print(t1, 150);
733-
std::cout << std::endl;
734-
std::cout << "vulkan: " << std::endl;
735-
print(t2, 150);
736-
std::cout << std::endl;
737-
}
738-
return is_close;
724+
// Skip checking index tensors
725+
if (t1.scalar_type() == at::kLong || t2.scalar_type() == at::kLong) {
726+
return true;
727+
}
728+
bool is_close = at::allclose(t1, t2, rtol, atol);
729+
if (!is_close && t1.numel() < 500) {
730+
std::cout << "reference: " << std::endl;
731+
print(t1, 150);
732+
std::cout << std::endl;
733+
std::cout << "vulkan: " << std::endl;
734+
print(t2, 150);
735+
std::cout << std::endl;
736+
}
737+
return is_close;
739738
}
740739
"""
741740

backends/vulkan/test/op_tests/utils/codegen_base.py

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def gen_create_ref_data(self, inputs: List[Any]) -> str:
243243
arg_data = get_or_return_default(arg, inputs, i)
244244
ref_code += self.create_input_data(arg, arg_data)
245245

246-
ref_code = re.sub(r"^", " ", ref_code, flags=re.M)
246+
ref_code = re.sub(r"^", " ", ref_code, flags=re.M)
247247
return ref_code
248248

249249
def gen_create_and_check_out(self, prepack=False) -> str:
@@ -254,7 +254,7 @@ def gen_create_and_check_out(self, prepack=False) -> str:
254254
arg = binding.argument
255255
test_str += f"{arg.name}, "
256256
test_str = test_str[:-2] + ");"
257-
test_str = re.sub(r"^", " ", test_str, flags=re.M)
257+
test_str = re.sub(r"^", " ", test_str, flags=re.M)
258258
return test_str
259259

260260
def gen_parameterization(self) -> str:
@@ -272,7 +272,7 @@ def generate_case_cpp(self, inputs, prepack=False) -> str:
272272
)
273273

274274
def generate_suite_cpp(self) -> str:
275-
suite_cpp = self.generate_fixture_cpp() + "\n"
275+
suite_cpp = self.generate_fixture_cpp()
276276
for inputs in self.suite_def.input_cases:
277277
if not self.suite_def.requires_prepack:
278278
suite_cpp += self.generate_case_cpp(inputs)
@@ -295,20 +295,19 @@ def generate_suite_cpp(self) -> str:
295295
{preamble}
296296
297297
at::Tensor make_rand_tensor(
298-
std::vector<int64_t> sizes,
299-
at::ScalarType dtype = at::kFloat,
300-
float low = 0.0,
301-
float high = 1.0) {{
302-
if (high == 1.0 && low == 0.0)
303-
return at::rand(sizes, at::device(at::kCPU).dtype(dtype));
298+
std::vector<int64_t> sizes,
299+
at::ScalarType dtype = at::kFloat,
300+
float low = 0.0,
301+
float high = 1.0) {{
302+
if (high == 1.0 && low == 0.0)
303+
return at::rand(sizes, at::device(at::kCPU).dtype(dtype));
304304
305-
if (dtype == at::kChar)
306-
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));
305+
if (dtype == at::kChar)
306+
return at::randint(high, sizes, at::device(at::kCPU).dtype(dtype));
307307
308-
return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low;
308+
return at::rand(sizes, at::device(at::kCPU).dtype(dtype)) * (high - low) + low;
309309
}}
310310
311-
312311
at::Tensor make_seq_tensor(
313312
std::vector<int64_t> sizes,
314313
at::ScalarType dtype = at::kFloat,
@@ -331,7 +330,6 @@ def generate_suite_cpp(self) -> str:
331330
return at::from_blob(values.data(), sizes, at::kFloat).toType(dtype).detach().clone();
332331
}}
333332
334-
335333
at::Tensor make_index_tensor(std::vector<int64_t> indices) {{
336334
at::ScalarType dtype = at::kInt;
337335
std::vector<int64_t> sizes = {{static_cast<int64_t>(indices.size())}};

0 commit comments

Comments
 (0)