Skip to content

Commit 96166da

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Replace c10::ScalarType with native equivalent
Summary: X-link: pytorch/pytorch#117181 ## Context This change is part of a set of changes that removes all references to the `c10` library in the `api/`, `graph/`, and `impl/` folders of the PyTorch Vulkan codebase. This is to ensure that these components can be built as a standalone library such that they can be used as the foundations of a Android GPU delegate for ExecuTorch. ## Notes for Reviewers This changeset introduces `api::ScalarType` in `api/Types.h`, which is intended to function the same as `c10::ScalarType`; thus `api/Types.h` is the primary file of interest. The rest of the changes are straightforward replacements of `c10::ScalarType` with `api::ScalarType`. ghstack-source-id: 211899066 exported-using-ghexport Reviewed By: yipjustin, liuk22 Differential Revision: D52662237 fbshipit-source-id: 824b59c4595731562a64de798f4f36c9ac6065c1
1 parent 7a6abb4 commit 96166da

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

backends/vulkan/VulkanBackend.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ class VulkanBackend final : public PyTorchBackendInterface {
5151
}
5252
}
5353

54-
c10::ScalarType get_scalar_type(
54+
at::native::vulkan::api::ScalarType get_scalar_type(
5555
const at::vulkan::delegate::VkDatatype& vk_datatype) const {
5656
switch (vk_datatype) {
5757
case (at::vulkan::delegate::VkDatatype::vk_datatype_fp32): {
58-
return c10::kFloat;
58+
return at::native::vulkan::api::kFloat;
5959
}
6060
}
6161
}
@@ -87,7 +87,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
8787
"Only constant buffers are supported when adding tensors to compute graph (indicated by constant_buffer_idx == 0), but got constant_buffer_idx of %d",
8888
vk_tensor->constant_buffer_idx());
8989

90-
const c10::ScalarType& tensor_dtype =
90+
const at::native::vulkan::api::ScalarType& tensor_dtype =
9191
get_scalar_type(vk_tensor->datatype());
9292

9393
const flatbuffers_fbsource::Vector<uint32_t>* tensor_dims_fb =
@@ -177,7 +177,7 @@ class VulkanBackend final : public PyTorchBackendInterface {
177177
input_id,
178178
input_vk_tensor->constant_buffer_idx());
179179

180-
const c10::ScalarType& input_dtype =
180+
const at::native::vulkan::api::ScalarType& input_dtype =
181181
get_scalar_type(input_vk_tensor->datatype());
182182

183183
const flatbuffers_fbsource::Vector<uint32_t>* input_dims_fb =

0 commit comments

Comments
 (0)