Skip to content

Commit 38892ac

Browse files
dbortfacebook-github-bot
authored andcommitted
Clean up non-exec_aten references to tensor types (#5254)
Summary: Pull Request resolved: #5254 Most code should use the exec_aten:: types: only aten-mode-aware code should refer directly to e.g. torch::executor::Tensor. ArrayRef is a little different because it's also a core type, but in that case it should be called executorch::runtime::ArrayRef. Reviewed By: mergennachin Differential Revision: D62476845 fbshipit-source-id: 869fb15ce342873697823425271955ca7ed4c14d
1 parent bcd156b commit 38892ac

File tree

18 files changed

+75
-80
lines changed

18 files changed

+75
-80
lines changed

backends/qualcomm/runtime/QnnExecuTorch.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct CustomMemTensorInfo {
4444
size_t tensor_bytes;
4545
uint32_t* shape;
4646
uint32_t rank;
47-
torch::executor::ScalarType dtype;
47+
exec_aten::ScalarType dtype;
4848
};
4949

5050
/// Allocate specific tensors (usually graph inputs and outputs) on shared

backends/qualcomm/runtime/SharedBuffer.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ std::size_t std::hash<CustomMemTensorInfo>::operator()(
2525
hash_val ^= info.shape[i];
2626
}
2727
hash_val ^= std::hash<uint32_t>()(info.rank);
28-
hash_val ^= std::hash<torch::executor::ScalarType>()(info.dtype);
28+
hash_val ^= std::hash<exec_aten::ScalarType>()(info.dtype);
2929
return hash_val;
3030
}
3131

backends/vulkan/test/utils/test_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <executorch/backends/vulkan/test/utils/test_utils.h>
1010

11-
#include <executorch/runtime/core/portable_type/half.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212

1313
#include <executorch/backends/vulkan/runtime/graph/ops/impl/utils/TensorUtils.h>
1414

@@ -322,7 +322,7 @@ void record_reference_matmul(
322322
_(uint8_t, Byte) \
323323
_(int8_t, Char) \
324324
_(int32_t, Int) \
325-
_(torch::executor::Half, Half) \
325+
_(exec_aten::Half, Half) \
326326
_(float, Float) \
327327
_(int8_t, QInt8)
328328

backends/vulkan/test/vulkan_compute_api_test.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include <utility>
1212
#include <vector>
1313

14-
#include <executorch/runtime/core/portable_type/half.h>
14+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1515

1616
#include <executorch/backends/vulkan/runtime/api/api.h>
1717

@@ -485,7 +485,7 @@ TEST_F(VulkanComputeAPITest, test_buffer_float16) {
485485
if (!context()->adapter_ptr()->has_full_float16_buffers_support()) {
486486
GTEST_SKIP();
487487
}
488-
test_storage_buffer_type<torch::executor::Half, vkapi::kHalf>(16);
488+
test_storage_buffer_type<exec_aten::Half, vkapi::kHalf>(16);
489489
}
490490

491491
TEST_F(VulkanComputeAPITest, test_buffer_int8) {
@@ -567,7 +567,7 @@ TEST_F(VulkanComputeAPITest, buffer_tensor_sanity_check) {
567567
run_buffer_tensor_sanity_check<float>(a);
568568
break;
569569
case vkapi::kHalf:
570-
run_buffer_tensor_sanity_check<torch::executor::Half>(a);
570+
run_buffer_tensor_sanity_check<exec_aten::Half>(a);
571571
break;
572572
case vkapi::kChar:
573573
run_buffer_tensor_sanity_check<int8_t>(a);
@@ -2395,7 +2395,7 @@ TEST(VulkanToFromGPUShaderTest, round_trip_tests) {
23952395

23962396
for (auto& sizes : to_test) {
23972397
RUN_TESTS(float, vkapi::kFloat)
2398-
RUN_TESTS(torch::executor::Half, vkapi::kHalf)
2398+
RUN_TESTS(exec_aten::Half, vkapi::kHalf)
23992399
}
24002400

24012401
for (auto& sizes : to_test_int8) {

examples/models/flamingo/cross_attention/cross_attention_mask_test.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
#include <gtest/gtest.h>
1212

1313
using namespace ::testing;
14-
using torch::executor::ScalarType;
15-
using torch::executor::Tensor;
16-
using torch::executor::TensorImpl;
14+
using exec_aten::ScalarType;
15+
using exec_aten::Tensor;
16+
using exec_aten::TensorImpl;
1717

1818
TEST(CrossAttentxnMaskTest, TestCrossAttentionMask) {
1919
std::vector<int> tokens = {

examples/qualcomm/oss_scripts/llama2/runner/runner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ int32_t Runner::logitsToToken(const exec_aten::Tensor& logits_tensor) {
135135

136136
// Given an input token. Set up the inputs for the model and execute a single
137137
// step. Returning the logits tensor.
138-
Result<torch::executor::Tensor> Runner::run_model_step(
138+
Result<exec_aten::Tensor> Runner::run_model_step(
139139
int64_t input_token,
140140
TensorPtr& token,
141141
TensorPtr& start_pos,

examples/qualcomm/oss_scripts/llama2/runner/runner.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class IoMemMgr {
120120
std::vector<uint32_t> shape;
121121
uint32_t rank;
122122
size_t element_size;
123-
torch::executor::ScalarType dtype;
123+
exec_aten::ScalarType dtype;
124124
};
125125

126126
struct IoInfo {

extension/android/jni/jni_layer_constants.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
#include <unordered_map>
1010

11-
#include <executorch/runtime/core/portable_type/scalar_type.h>
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
1212

1313
namespace executorch::extension {
1414

@@ -37,7 +37,7 @@ constexpr static int kTensorDTypeBits4x2 = 20;
3737
constexpr static int kTensorDTypeBits8 = 21;
3838
constexpr static int kTensorDTypeBits16 = 22;
3939

40-
using torch::executor::ScalarType;
40+
using exec_aten::ScalarType;
4141

4242
const std::unordered_map<ScalarType, int> scalar_type_to_java_dtype = {
4343
{ScalarType::Byte, kTensorDTypeUInt8},

extension/android/src/main/java/org/pytorch/executorch/DType.java

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,51 +17,51 @@ public enum DType {
1717
// NOTE: "jniCode" must be kept in sync with scalar_type.h.
1818
// NOTE: Never serialize "jniCode", because it can change between releases.
1919

20-
/** Code for dtype torch::executor::Byte */
20+
/** Code for dtype ScalarType::Byte */
2121
UINT8(0),
22-
/** Code for dtype torch::executor::Char */
22+
/** Code for dtype ScalarType::Char */
2323
INT8(1),
24-
/** Code for dtype torch::executor::Short */
24+
/** Code for dtype ScalarType::Short */
2525
INT16(2),
26-
/** Code for dtype torch::executor::Int */
26+
/** Code for dtype ScalarType::Int */
2727
INT32(3),
28-
/** Code for dtype torch::executor::Long */
28+
/** Code for dtype ScalarType::Long */
2929
INT64(4),
30-
/** Code for dtype torch::executor::Half */
30+
/** Code for dtype ScalarType::Half */
3131
HALF(5),
32-
/** Code for dtype torch::executor::Float */
32+
/** Code for dtype ScalarType::Float */
3333
FLOAT(6),
34-
/** Code for dtype torch::executor::Double */
34+
/** Code for dtype ScalarType::Double */
3535
DOUBLE(7),
36-
/** Code for dtype torch::executor::ComplexHalf */
36+
/** Code for dtype ScalarType::ComplexHalf */
3737
COMPLEX_HALF(8),
38-
/** Code for dtype torch::executor::ComplexFloat */
38+
/** Code for dtype ScalarType::ComplexFloat */
3939
COMPLEX_FLOAT(9),
40-
/** Code for dtype torch::executor::ComplexDouble */
40+
/** Code for dtype ScalarType::ComplexDouble */
4141
COMPLEX_DOUBLE(10),
42-
/** Code for dtype torch::executor::Bool */
42+
/** Code for dtype ScalarType::Bool */
4343
BOOL(11),
44-
/** Code for dtype torch::executor::QInt8 */
44+
/** Code for dtype ScalarType::QInt8 */
4545
QINT8(12),
46-
/** Code for dtype torch::executor::QUInt8 */
46+
/** Code for dtype ScalarType::QUInt8 */
4747
QUINT8(13),
48-
/** Code for dtype torch::executor::QInt32 */
48+
/** Code for dtype ScalarType::QInt32 */
4949
QINT32(14),
50-
/** Code for dtype torch::executor::BFloat16 */
50+
/** Code for dtype ScalarType::BFloat16 */
5151
BFLOAT16(15),
52-
/** Code for dtype torch::executor::QUInt4x2 */
52+
/** Code for dtype ScalarType::QUInt4x2 */
5353
QINT4X2(16),
54-
/** Code for dtype torch::executor::QUInt2x4 */
54+
/** Code for dtype ScalarType::QUInt2x4 */
5555
QINT2X4(17),
56-
/** Code for dtype torch::executor::Bits1x8 */
56+
/** Code for dtype ScalarType::Bits1x8 */
5757
BITS1X8(18),
58-
/** Code for dtype torch::executor::Bits2x4 */
58+
/** Code for dtype ScalarType::Bits2x4 */
5959
BITS2X4(19),
60-
/** Code for dtype torch::executor::Bits4x2 */
60+
/** Code for dtype ScalarType::Bits4x2 */
6161
BITS4X2(20),
62-
/** Code for dtype torch::executor::Bits8 */
62+
/** Code for dtype ScalarType::Bits8 */
6363
BITS8(21),
64-
/** Code for dtype torch::executor::Bits16 */
64+
/** Code for dtype ScalarType::Bits16 */
6565
BITS16(22),
6666
;
6767

kernels/optimized/blas/CPUBlas.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ extern "C" void sgemm_(char *transa, char *transb, int *m, int *n, int *k, float
2424
namespace executorch {
2525
namespace cpublas {
2626

27-
// using Half = exec_aten::Half;
27+
using exec_aten::BFloat16;
28+
using exec_aten::Half;
2829

2930
#ifdef ET_BUILD_WITH_BLAS
3031
#ifdef ET_BUILD_FOR_APPLE

kernels/optimized/blas/CPUBlas.h

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717
namespace executorch {
1818
namespace cpublas {
1919

20-
using BFloat16 = torch::executor::BFloat16;
21-
using Half = torch::executor::Half;
22-
2320
enum class TransposeType {
2421
NoTranspose,
2522
Transpose,
@@ -100,20 +97,20 @@ void gemm(
10097
void gemm(
10198
TransposeType transa, TransposeType transb,
10299
int64_t m, int64_t n, int64_t k,
103-
const Half alpha,
104-
const Half *a, int64_t lda,
105-
const Half *b, int64_t ldb,
106-
const Half beta,
107-
Half *c, int64_t ldc);
100+
const exec_aten::Half alpha,
101+
const exec_aten::Half *a, int64_t lda,
102+
const exec_aten::Half *b, int64_t ldb,
103+
const exec_aten::Half beta,
104+
exec_aten::Half *c, int64_t ldc);
108105

109106
void gemm(
110107
TransposeType transa, TransposeType transb,
111108
int64_t m, int64_t n, int64_t k,
112-
const BFloat16 alpha,
113-
const BFloat16 *a, int64_t lda,
114-
const BFloat16 *b, int64_t ldb,
115-
const BFloat16 beta,
116-
BFloat16 *c, int64_t ldc);
109+
const exec_aten::BFloat16 alpha,
110+
const exec_aten::BFloat16 *a, int64_t lda,
111+
const exec_aten::BFloat16 *b, int64_t ldb,
112+
const exec_aten::BFloat16 beta,
113+
exec_aten::BFloat16 *c, int64_t ldc);
117114
// clang-format on
118115

119116
// clang-format off

kernels/optimized/cpu/op_exp.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@ template <
2727
typename CTYPE_OUT,
2828
typename std::enable_if<
2929
std::is_same<CTYPE_IN, CTYPE_OUT>::value &&
30-
!std::is_same<CTYPE_IN, torch::executor::Half>::value &&
31-
!std::is_same<CTYPE_OUT, torch::executor::Half>::value,
30+
!std::is_same<CTYPE_IN, exec_aten::Half>::value &&
31+
!std::is_same<CTYPE_OUT, exec_aten::Half>::value,
3232
int>::type = 0>
3333
void exp_data(
3434
const CTYPE_IN* in_data,
@@ -47,8 +47,8 @@ template <
4747
typename CTYPE_OUT,
4848
typename std::enable_if<
4949
!std::is_same<CTYPE_IN, CTYPE_OUT>::value ||
50-
std::is_same<CTYPE_IN, torch::executor::Half>::value ||
51-
std::is_same<CTYPE_OUT, torch::executor::Half>::value,
50+
std::is_same<CTYPE_IN, exec_aten::Half>::value ||
51+
std::is_same<CTYPE_OUT, exec_aten::Half>::value,
5252
int>::type = 0>
5353
void exp_data(
5454
const CTYPE_IN* in_data,

kernels/portable/cpu/op_empty.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ using exec_aten::Tensor;
2626
Tensor& empty_out(
2727
KernelRuntimeContext& context,
2828
IntArrayRef size,
29-
torch::executor::optional<torch::executor::MemoryFormat> memory_format,
29+
exec_aten::optional<exec_aten::MemoryFormat> memory_format,
3030
Tensor& out) {
3131
(void)context;
3232

kernels/portable/cpu/test/scalar_utils_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ struct promote_type_with_scalar_type_is_valid
1616
(std::is_same<T2, torch::executor::internal::B1>::value ||
1717
std::is_same<T2, torch::executor::internal::I8>::value ||
1818
std::is_same<T2, torch::executor::internal::F8>::value) &&
19-
!std::is_same<T1, torch::executor::BFloat16>::value &&
19+
!std::is_same<T1, exec_aten::BFloat16>::value &&
2020
!torch::executor::is_qint_type<T1>::value &&
2121
!torch::executor::is_bits_type<T1>::value> {};
2222

kernels/portable/cpu/util/math_util.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ INT_T max_override(INT_T a, INT_T b) {
9696

9797
template <
9898
typename T,
99-
typename std::enable_if<
100-
std::is_same<T, torch::executor::Half>::value,
101-
bool>::type = true>
99+
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
100+
type = true>
102101
T min_override(T a, T b) {
103102
const auto float_a = static_cast<float>(a);
104103
if (std::isnan(float_a)) {
@@ -117,9 +116,8 @@ T min_override(T a, T b) {
117116

118117
template <
119118
typename T,
120-
typename std::enable_if<
121-
std::is_same<T, torch::executor::Half>::value,
122-
bool>::type = true>
119+
typename std::enable_if<std::is_same<T, exec_aten::Half>::value, bool>::
120+
type = true>
123121
T max_override(T a, T b) {
124122
const auto float_a = static_cast<float>(a);
125123
if (std::isnan(float_a)) {

kernels/portable/cpu/util/test/broadcast_test.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
using namespace ::testing;
2121
using exec_aten::ScalarType;
2222
using exec_aten::Tensor;
23-
using torch::executor::ArrayRef;
24-
using torch::executor::testing::TensorFactory;
23+
using executorch::runtime::ArrayRef;
24+
using executorch::runtime::testing::TensorFactory;
2525

2626
TEST(BroadcastUtilTest, BroadcastTensor) {
2727
TensorFactory<ScalarType::Int> tf;
@@ -112,17 +112,17 @@ TEST(BroadcastUtilTest, GetBroadcastTargetSize) {
112112
Tensor a = tf.zeros({2, 1});
113113
Tensor b = tf.zeros({5, 1, 2});
114114

115-
get_broadcast_target_size(
115+
executorch::runtime::Error err = get_broadcast_target_size(
116116
a,
117117
b,
118118
expected_output_size,
119119
torch::executor::kTensorDimensionLimit,
120120
&expected_output_dim);
121+
EXPECT_EQ(err, torch::executor::Error::Ok);
121122

122123
EXPECT_TRUE(
123-
torch::executor::ArrayRef<Tensor::SizesType>(
124-
expected_output_size, expected_output_dim)
125-
.equals(torch::executor::ArrayRef<Tensor::SizesType>({5, 2, 2})));
124+
ArrayRef<Tensor::SizesType>(expected_output_size, expected_output_dim)
125+
.equals(ArrayRef<Tensor::SizesType>({5, 2, 2})));
126126
}
127127

128128
size_t linearize_indexes(size_t* indexes, size_t indexes_len, const Tensor& t) {

kernels/portable/test/op_mul_test.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,19 +36,18 @@ class OpMulOutKernelTest : public OperatorTest {
3636
TEST_F(OpMulOutKernelTest, UnhandledDtypeDies) {
3737
// mul_out() doesn't handle QInt8.
3838
// TensorFactory cannot be used with ScalarType::QInt8 since
39-
// torch::executor::qint8 does not have a default constructor. It must be
39+
// exec_aten::qint8 does not have a default constructor. It must be
4040
// initialized with an explicit value. So, we need to manually create the
4141
// underlying data without default construction and then the tensors from that
4242
// data via TensorImpl.
4343

4444
std::vector<SizesType> sizes = {2, 2};
4545

46-
std::vector<torch::executor::qint8> a_data{};
47-
std::generate_n(std::back_inserter(a_data), 4, []() {
48-
return torch::executor::qint8{0};
49-
});
50-
std::vector<torch::executor::qint8> b_data(a_data);
51-
std::vector<torch::executor::qint8> out_data(a_data);
46+
std::vector<exec_aten::qint8> a_data{};
47+
std::generate_n(
48+
std::back_inserter(a_data), 4, []() { return exec_aten::qint8{0}; });
49+
std::vector<exec_aten::qint8> b_data(a_data);
50+
std::vector<exec_aten::qint8> out_data(a_data);
5251

5352
auto a_impl = torch::executor::TensorImpl(
5453
ScalarType::QInt8, 2, sizes.data(), a_data.data());

kernels/prim_ops/register_prim_ops.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
#include <executorch/runtime/kernel/kernel_includes.h>
1313
#include <executorch/runtime/kernel/operator_registry.h>
1414

15-
using KernelArrayRef = ::torch::executor::ArrayRef<::torch::executor::Kernel>;
1615
using torch::executor::function::et_copy_index;
1716

1817
namespace torch {
@@ -294,13 +293,14 @@ static Kernel prim_ops[] = {
294293

295294
};
296295

297-
static KernelArrayRef kernel_array_ref(
296+
executorch::runtime::Span<const executorch::runtime::Kernel> kernel_span(
298297
prim_ops,
299298
prim_ops + sizeof(prim_ops) / sizeof(Kernel));
300299

301300
// Return value not used. Keep the static variable assignment to register
302301
// operators in static initialization time.
303-
static auto success_with_kernel_reg = register_kernels(kernel_array_ref);
302+
auto success_with_kernel_reg =
303+
executorch::runtime::register_kernels(kernel_span);
304304

305305
} // namespace
306306
} // namespace function

0 commit comments

Comments
 (0)