Skip to content

Commit 7bebf8e

Browse files
dbortfacebook-github-bot
authored andcommitted
Migrate backends/cadence away from deprecated namespaces (#5905)
Summary: Pull Request resolved: #5905 Stop using the `torch::` namespace where possible. For now, ops still live under `torch::executor::`. Reviewed By: Gasoonjia, zonglinpeng Differential Revision: D63924099 fbshipit-source-id: e1132f889bfdeccb56e55a5bad6be937cce366e3
1 parent 50f70f3 commit 7bebf8e

15 files changed

+47
-48
lines changed

backends/cadence/executor_runner.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737

3838
static uint8_t method_allocator_pool[18 * 1024U]; // 4 MB
3939

40-
using namespace torch::executor;
4140
#include <xtensa/config/core.h>
4241

4342
#define APP_MU MUB
@@ -48,8 +47,8 @@ using namespace torch::executor;
4847
/* How many message is used to test message sending */
4948
#define MSG_LENGTH 32U
5049

51-
using torch::executor::Error;
52-
using torch::executor::Result;
50+
using executorch::runtime::Error;
51+
using executorch::runtime::Result;
5352

5453
void LED_INIT();
5554
void LED_TOGGLE();
@@ -106,13 +105,13 @@ int main(int argc, char** argv) {
106105
BOARD_InitDebugConsole();
107106
ET_LOG(Info, "Booted up in DSP.");
108107

109-
torch::executor::runtime_init();
108+
executorch::runtime::runtime_init();
110109

111110
auto loader =
112-
torch::executor::util::BufferDataLoader(model_pte, sizeof(model_pte));
111+
executorch::extension::BufferDataLoader(model_pte, sizeof(model_pte));
113112

114-
Result<torch::executor::Program> program =
115-
torch::executor::Program::load(&loader);
113+
Result<executorch::runtime::Program> program =
114+
executorch::runtime::Program::load(&loader);
116115
if (!program.ok()) {
117116
ET_LOG(
118117
Error,
@@ -132,7 +131,7 @@ int main(int argc, char** argv) {
132131
}
133132
ET_LOG(Info, "ET: Running method %s", method_name);
134133

135-
Result<torch::executor::MethodMeta> method_meta =
134+
Result<executorch::runtime::MethodMeta> method_meta =
136135
program->method_meta(method_name);
137136
if (!method_meta.ok()) {
138137
ET_LOG(
@@ -142,12 +141,12 @@ int main(int argc, char** argv) {
142141
(unsigned int)method_meta.error());
143142
}
144143

145-
torch::executor::MemoryAllocator method_allocator{
146-
torch::executor::MemoryAllocator(
144+
executorch::runtime::MemoryAllocator method_allocator{
145+
executorch::runtime::MemoryAllocator(
147146
sizeof(method_allocator_pool), method_allocator_pool)};
148147

149148
std::vector<std::unique_ptr<uint8_t[]>> planned_buffers; // Owns the memory
150-
std::vector<torch::executor::Span<uint8_t>>
149+
std::vector<executorch::runtime::Span<uint8_t>>
151150
planned_spans; // Passed to the allocator
152151
size_t num_memory_planned_buffers = method_meta->num_memory_planned_buffers();
153152

@@ -161,13 +160,13 @@ int main(int argc, char** argv) {
161160
planned_spans.push_back({planned_buffers.back().get(), buffer_size});
162161
}
163162

164-
torch::executor::HierarchicalAllocator planned_memory(
163+
executorch::runtime::HierarchicalAllocator planned_memory(
165164
{planned_spans.data(), planned_spans.size()});
166165

167-
torch::executor::MemoryManager memory_manager(
166+
executorch::runtime::MemoryManager memory_manager(
168167
&method_allocator, &planned_memory);
169168

170-
Result<torch::executor::Method> method =
169+
Result<executorch::runtime::Method> method =
171170
program->load_method(method_name, &memory_manager);
172171
if (!method.ok()) {
173172
ET_LOG(
@@ -178,7 +177,7 @@ int main(int argc, char** argv) {
178177
}
179178

180179
ET_LOG(Info, "Method loaded.");
181-
torch::executor::util::prepare_input_tensors(*method);
180+
executorch::extension::prepare_input_tensors(*method);
182181
ET_LOG(Info, "Starting the model execution...");
183182

184183
Error status = method->execute();

backends/cadence/hifi/operators/dequantize_per_tensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ namespace impl {
1414
namespace HiFi {
1515
namespace native {
1616

17-
using Tensor = exec_aten::Tensor;
17+
using executorch::aten::ScalarType;
18+
using executorch::aten::Tensor;
1819
using executorch::runtime::KernelRuntimeContext;
19-
using ScalarType = exec_aten::ScalarType;
2020

2121
void dequantize_per_tensor_out(
2222
KernelRuntimeContext& context,

backends/cadence/hifi/operators/quantize_per_tensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,9 @@ namespace impl {
1414
namespace HiFi {
1515
namespace native {
1616

17-
using Tensor = exec_aten::Tensor;
17+
using executorch::aten::ScalarType;
18+
using executorch::aten::Tensor;
1819
using executorch::runtime::KernelRuntimeContext;
19-
using ScalarType = exec_aten::ScalarType;
2020

2121
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2222
// used in any computation.

backends/cadence/hifi/operators/quantized_layer_norm.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include <cmath>
1313
#include <tuple>
1414

15-
using Tensor = exec_aten::Tensor;
15+
using executorch::aten::Tensor;
1616
using executorch::runtime::KernelRuntimeContext;
1717

1818
namespace impl {
@@ -119,14 +119,14 @@ void quantized_layer_norm_out(
119119
const Tensor& input,
120120
const Tensor& in_scale,
121121
const Tensor& in_zero_point,
122-
const exec_aten::IntArrayRef normalized_shape,
122+
const executorch::aten::IntArrayRef normalized_shape,
123123
const Tensor& weight,
124124
const Tensor& bias,
125125
double eps,
126126
double output_scale,
127127
int64_t output_zero_point,
128128
Tensor& out) {
129-
if (input.scalar_type() == exec_aten::ScalarType::Byte) {
129+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
130130
quantized_layer_norm_<uint8_t>(
131131
input,
132132
in_scale,
@@ -137,7 +137,7 @@ void quantized_layer_norm_out(
137137
output_scale,
138138
output_zero_point,
139139
out);
140-
} else if (input.scalar_type() == exec_aten::ScalarType::Char) {
140+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
141141
quantized_layer_norm_<int8_t>(
142142
input,
143143
in_scale,

backends/cadence/hifi/operators/quantized_linear_out.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ namespace impl {
1515
namespace HiFi {
1616
namespace native {
1717

18-
using Tensor = exec_aten::Tensor;
18+
using executorch::aten::Tensor;
1919
using executorch::runtime::KernelRuntimeContext;
2020

2121
void quantized_linear_out(
@@ -28,7 +28,7 @@ void quantized_linear_out(
2828
const Tensor& out_multiplier,
2929
const Tensor& out_shift,
3030
int64_t out_zero_point,
31-
const exec_aten::optional<Tensor>& offset,
31+
const executorch::aten::optional<Tensor>& offset,
3232
Tensor& out) {
3333
// input comes in shape [leading_dims, in_dim]
3434
// weight comes in shape [out_dim, in_dim]

backends/cadence/reference/operators/dequantize_per_tensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ namespace impl {
1313
namespace reference {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
16+
using executorch::aten::ScalarType;
17+
using executorch::aten::Tensor;
1718
using executorch::runtime::KernelRuntimeContext;
18-
using ScalarType = exec_aten::ScalarType;
1919

2020
void dequantize_per_tensor_out(
2121
KernelRuntimeContext& context,

backends/cadence/reference/operators/op_embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace torch {
1212
namespace executor {
1313
namespace native {
1414

15-
using Tensor = exec_aten::Tensor;
15+
using executorch::aten::Tensor;
1616
using executorch::runtime::KernelRuntimeContext;
1717

1818
void embedding_out(

backends/cadence/reference/operators/op_full.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ namespace torch {
1313
namespace executor {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
17-
using ScalarType = exec_aten::ScalarType;
16+
using executorch::aten::ScalarType;
17+
using executorch::aten::Tensor;
1818

1919
Tensor& full_out(
2020
KernelRuntimeContext& ctx,

backends/cadence/reference/operators/op_view_copy.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace torch {
1212
namespace executor {
1313
namespace native {
1414

15-
using Tensor = exec_aten::Tensor;
15+
using executorch::aten::Tensor;
1616
using executorch::runtime::KernelRuntimeContext;
1717

1818
Tensor& view_copy_out(

backends/cadence/reference/operators/quantize_per_tensor.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,9 @@ namespace impl {
1313
namespace reference {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
16+
using executorch::aten::ScalarType;
17+
using executorch::aten::Tensor;
1718
using executorch::runtime::KernelRuntimeContext;
18-
using ScalarType = exec_aten::ScalarType;
1919

2020
// Quantize the input tensor (PT2 version). Note that quant_<min,max> are not
2121
// used in any computation.

backends/cadence/reference/operators/quantized_conv_out.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ namespace impl {
1414
namespace reference {
1515
namespace native {
1616

17-
using Tensor = exec_aten::Tensor;
17+
using executorch::aten::Tensor;
1818
using executorch::runtime::KernelRuntimeContext;
1919

2020
// This implements a generic 2d conv kernel that operates on raw pointers.
@@ -158,9 +158,9 @@ void quantized_conv_out(
158158
const Tensor& input,
159159
const Tensor& weight,
160160
const Tensor& bias,
161-
exec_aten::IntArrayRef stride,
162-
exec_aten::IntArrayRef padding,
163-
exec_aten::IntArrayRef dilation,
161+
executorch::aten::IntArrayRef stride,
162+
executorch::aten::IntArrayRef padding,
163+
executorch::aten::IntArrayRef dilation,
164164
int64_t groups,
165165
int64_t in_zero_point,
166166
const Tensor& weight_zero_point,

backends/cadence/reference/operators/quantized_layer_norm.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ void quantized_layer_norm_out(
115115
const Tensor& input,
116116
const Tensor& in_scale,
117117
const Tensor& in_zero_point,
118-
const exec_aten::IntArrayRef normalized_shape,
118+
const executorch::aten::IntArrayRef normalized_shape,
119119
const Tensor& weight,
120120
const Tensor& bias,
121121
double eps,
122122
double output_scale,
123123
int64_t output_zero_point,
124124
Tensor& out) {
125-
if (input.scalar_type() == exec_aten::ScalarType::Byte) {
125+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
126126
quantized_layer_norm_<uint8_t>(
127127
input,
128128
in_scale,
@@ -133,7 +133,7 @@ void quantized_layer_norm_out(
133133
output_scale,
134134
output_zero_point,
135135
out);
136-
} else if (input.scalar_type() == exec_aten::ScalarType::Char) {
136+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
137137
quantized_layer_norm_<int8_t>(
138138
input,
139139
in_scale,

backends/cadence/reference/operators/quantized_linear_out.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ void quantized_linear_out(
2727
const Tensor& out_multiplier,
2828
const Tensor& out_shift,
2929
int64_t out_zero_point,
30-
const exec_aten::optional<Tensor>& offset,
30+
const executorch::aten::optional<Tensor>& offset,
3131
Tensor& out) {
3232
// Assuming uint8_t for now, but needs to be updated for other quantization
3333
// types

backends/cadence/reference/operators/quantized_matmul_out.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void inline _typed_quantized_matmul(
6060
int64_t X_zero_point,
6161
const Tensor& Y,
6262
int64_t Y_zero_point,
63-
const exec_aten::optional<Tensor>& bias,
63+
const executorch::aten::optional<Tensor>& bias,
6464
int64_t out_multiplier,
6565
int64_t out_shift,
6666
int64_t out_zero_point,
@@ -114,13 +114,13 @@ void quantized_matmul_out(
114114
int64_t X_zero_point,
115115
const Tensor& Y,
116116
int64_t Y_zero_point,
117-
const exec_aten::optional<Tensor>& bias,
117+
const executorch::aten::optional<Tensor>& bias,
118118
int64_t out_multiplier,
119119
int64_t out_shift,
120120
int64_t out_zero_point,
121121
bool transposed,
122122
Tensor& out) {
123-
if (out.scalar_type() == exec_aten::ScalarType::Byte) {
123+
if (out.scalar_type() == executorch::aten::ScalarType::Byte) {
124124
_typed_quantized_matmul<uint8_t>(
125125
X,
126126
X_zero_point,
@@ -132,7 +132,7 @@ void quantized_matmul_out(
132132
out_zero_point,
133133
transposed,
134134
out);
135-
} else if (out.scalar_type() == exec_aten::ScalarType::Char) {
135+
} else if (out.scalar_type() == executorch::aten::ScalarType::Char) {
136136
_typed_quantized_matmul<int8_t>(
137137
X,
138138
X_zero_point,

backends/cadence/reference/operators/quantized_relu_out.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ namespace impl {
1313
namespace reference {
1414
namespace native {
1515

16-
using Tensor = exec_aten::Tensor;
16+
using executorch::aten::Tensor;
1717
using executorch::runtime::KernelRuntimeContext;
1818

1919
template <typename T>
@@ -51,15 +51,15 @@ void quantized_relu_out(
5151
const Tensor& out_multiplier,
5252
const Tensor& out_shift,
5353
Tensor& output) {
54-
if (input.scalar_type() == exec_aten::ScalarType::Byte) {
54+
if (input.scalar_type() == executorch::aten::ScalarType::Byte) {
5555
quantized_relu_<uint8_t>(
5656
input,
5757
in_zero_point,
5858
out_zero_point,
5959
out_multiplier,
6060
out_shift,
6161
output);
62-
} else if (input.scalar_type() == exec_aten::ScalarType::Char) {
62+
} else if (input.scalar_type() == executorch::aten::ScalarType::Char) {
6363
quantized_relu_<int8_t>(
6464
input,
6565
in_zero_point,

0 commit comments

Comments
 (0)