Skip to content

Commit e45ab89

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
MethodMeta
Summary: There is a growing class of usecases that want to be able to inspect meta data about methods without paying the full init cost. These classes provide a safe and cheap way to view this information Differential Revision: D48039273 fbshipit-source-id: 855c59b28f508edd0b45c1e26fba905809b8ef04
1 parent 0fa2da2 commit e45ab89

File tree

7 files changed

+523
-0
lines changed

7 files changed

+523
-0
lines changed

runtime/executor/method_meta.cpp

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/runtime/core/error.h>
10+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
11+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/core/span.h>
14+
#include <executorch/runtime/core/tag.h>
15+
#include <executorch/runtime/executor/method_meta.h>
16+
#include <executorch/schema/program_generated.h>
17+
18+
namespace torch {
19+
namespace executor {
20+
21+
TensorInfo::TensorInfo(
22+
const Span<const int32_t>& sizes,
23+
const Span<const uint8_t>& dim_order,
24+
exec_aten::ScalarType scalar_type) noexcept
25+
: sizes_(sizes), dim_order_(dim_order), scalar_type_(scalar_type) {}
26+
27+
const Span<const int32_t>& TensorInfo::sizes() const noexcept {
28+
return sizes_;
29+
}
30+
31+
const Span<const uint8_t>& TensorInfo::dim_order() const noexcept {
32+
return dim_order_;
33+
}
34+
35+
exec_aten::ScalarType TensorInfo::scalar_type() const noexcept {
36+
return scalar_type_;
37+
}
38+
39+
size_t TensorInfo::nbytes() const noexcept {
40+
ssize_t n = 1;
41+
for (ssize_t i = 0; i < sizes_.size(); i++) {
42+
n *= sizes_[i];
43+
}
44+
return n * sizeof_scalar_type(scalar_type_);
45+
}
46+
47+
MethodMeta::MethodMeta(
48+
const executorch_flatbuffer::ExecutionPlan* s_plan) noexcept
49+
: s_plan_(s_plan) {}
50+
51+
const char* MethodMeta::name() const noexcept {
52+
return s_plan_->name()->c_str();
53+
}
54+
55+
namespace {
56+
Result<Tag> get_tag(
57+
flatbuffers::Vector<flatbuffers::Offset<executorch_flatbuffer::EValue>>::
58+
return_type serialization_value,
59+
size_t index) {
60+
switch (serialization_value->val_type()) {
61+
case executorch_flatbuffer::KernelTypes::Int: {
62+
return Tag::Int;
63+
} break;
64+
case executorch_flatbuffer::KernelTypes::Double: {
65+
return Tag::Double;
66+
} break;
67+
case executorch_flatbuffer::KernelTypes::Bool: {
68+
return Tag::Bool;
69+
} break;
70+
case executorch_flatbuffer::KernelTypes::String: {
71+
return Tag::String;
72+
} break;
73+
case executorch_flatbuffer::KernelTypes::Tensor: {
74+
return Tag::Tensor;
75+
} break;
76+
default:
77+
ET_CHECK_OR_RETURN_ERROR(
78+
false,
79+
Internal,
80+
"Invalid tag: %zu input: %zu",
81+
(size_t)serialization_value->val_type(),
82+
index);
83+
}
84+
}
85+
} // namespace
86+
87+
size_t MethodMeta::num_inputs() const noexcept {
88+
return s_plan_->inputs()->size();
89+
}
90+
91+
Result<Tag> MethodMeta::input_tag(size_t index) const noexcept {
92+
auto num_inputs = this->num_inputs();
93+
ET_CHECK_OR_RETURN_ERROR(
94+
index >= 0 && index < num_inputs,
95+
InvalidArgument,
96+
"index %zu out of range. num_inputs: %zu",
97+
index,
98+
num_inputs);
99+
auto input_index = s_plan_->inputs()->Get(index);
100+
auto serialization_value = s_plan_->values()->Get(input_index);
101+
return get_tag(serialization_value, index);
102+
}
103+
104+
Result<TensorInfo> MethodMeta::input_tensor_meta(size_t index) const noexcept {
105+
auto tag = this->input_tag(index);
106+
if (!tag.ok()) {
107+
return tag.error();
108+
}
109+
ET_CHECK_OR_RETURN_ERROR(
110+
tag.get() == Tag::Tensor,
111+
InvalidArgument,
112+
"Tag: %zu input: %zu is not Tensor",
113+
(size_t)tag.get(),
114+
index);
115+
auto input_index = s_plan_->inputs()->Get(index);
116+
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
117+
return TensorInfo(
118+
Span<const int32_t>(
119+
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
120+
Span<const uint8_t>(
121+
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
122+
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()));
123+
}
124+
125+
size_t MethodMeta::num_outputs() const noexcept {
126+
return s_plan_->outputs()->size();
127+
}
128+
129+
Result<Tag> MethodMeta::output_tag(size_t index) const noexcept {
130+
auto num_outputs = this->num_outputs();
131+
ET_CHECK_OR_RETURN_ERROR(
132+
index >= 0 && index < num_outputs,
133+
InvalidArgument,
134+
"index %zu out of range. num_outputs: %zu",
135+
index,
136+
num_outputs);
137+
auto input_index = s_plan_->outputs()->Get(index);
138+
auto serialization_value = s_plan_->values()->Get(input_index);
139+
return get_tag(serialization_value, index);
140+
}
141+
142+
Result<TensorInfo> MethodMeta::output_tensor_meta(size_t index) const noexcept {
143+
auto tag = this->output_tag(index);
144+
if (!tag.ok()) {
145+
return tag.error();
146+
}
147+
ET_CHECK_OR_RETURN_ERROR(
148+
tag.get() == Tag::Tensor,
149+
InvalidArgument,
150+
"Tag: %zu output: %zu is not Tensor",
151+
(size_t)tag.get(),
152+
index);
153+
auto input_index = s_plan_->outputs()->Get(index);
154+
auto tensor_value = s_plan_->values()->Get(input_index)->val_as_Tensor();
155+
return TensorInfo(
156+
Span<const int32_t>(
157+
tensor_value->sizes()->data(), tensor_value->sizes()->size()),
158+
Span<const uint8_t>(
159+
tensor_value->dim_order()->data(), tensor_value->dim_order()->size()),
160+
static_cast<exec_aten::ScalarType>(tensor_value->scalar_type()));
161+
}
162+
163+
size_t MethodMeta::num_non_const_buffers() const noexcept {
164+
return s_plan_->non_const_buffer_sizes()->size();
165+
}
166+
167+
Result<int64_t> MethodMeta::non_const_buffer_size(size_t index) const noexcept {
168+
auto num_buffers = this->num_non_const_buffers();
169+
ET_CHECK_OR_RETURN_ERROR(
170+
index >= 0 && index < num_buffers,
171+
InvalidArgument,
172+
"index %zu out of range. num_buffers: %zu",
173+
index,
174+
num_buffers);
175+
return s_plan_->non_const_buffer_sizes()->Get(index);
176+
}
177+
178+
} // namespace executor
179+
} // namespace torch

runtime/executor/method_meta.h

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
12+
#include <executorch/runtime/core/result.h>
13+
#include <executorch/runtime/core/span.h>
14+
#include <executorch/runtime/core/tag.h>
15+
16+
// Forward declare flatbuffer types. This is a public header and must not
17+
// include the generated flatbuffer header.
18+
namespace executorch_flatbuffer {
19+
struct ExecutionPlan;
20+
} // namespace executorch_flatbuffer
21+
22+
namespace torch {
23+
namespace executor {
24+
25+
/// Metadata about a specific Tensor input/output of an Executorch Program
26+
/// The program used to create the MethodMeta object that created this
27+
/// TensorInfo must outlive this TensorInfo.
28+
class TensorInfo final {
29+
public:
30+
TensorInfo() noexcept = delete;
31+
TensorInfo(const TensorInfo&) noexcept = default;
32+
TensorInfo(TensorInfo&&) noexcept = default;
33+
TensorInfo& operator=(const TensorInfo&) noexcept = default;
34+
TensorInfo& operator=(TensorInfo&& other) noexcept = default;
35+
~TensorInfo() noexcept = default;
36+
37+
/**
38+
* Get the sizes of the input/output.
39+
*
40+
* @returns The sizes of the tensor.
41+
*/
42+
const Span<const int32_t>& sizes() const noexcept;
43+
44+
/**
45+
* Get the dim order of the input/output.
46+
*
47+
* @returns The dim order of the tensor.
48+
*/
49+
const Span<const uint8_t>& dim_order() const noexcept;
50+
51+
/**
52+
* Get the scalar type of the input/output.
53+
*
54+
* @returns The scalar type of the tensor.
55+
*/
56+
exec_aten::ScalarType scalar_type() const noexcept;
57+
58+
/**
59+
* The size of the input/output in bytes.
60+
*
61+
* @returns The size of the tensor in bytes.
62+
*/
63+
size_t nbytes() const noexcept;
64+
65+
private:
66+
// Let MethodMeta create IOMeta.
67+
friend class MethodMeta;
68+
69+
TensorInfo(
70+
const Span<const int32_t>& sizes,
71+
const Span<const uint8_t>& dim_order,
72+
exec_aten::ScalarType scalar_type) noexcept;
73+
74+
/// Sizes of the tensor.
75+
Span<const int32_t> sizes_;
76+
77+
/// Dim order of the tensor.
78+
Span<const uint8_t> dim_order_;
79+
80+
/// Scalar type of the tensor.
81+
exec_aten::ScalarType scalar_type_;
82+
};
83+
84+
/**
85+
* Manages metadata about a method in an Executorch program.
86+
87+
* The program used to create a MethodMeta object must outlive the MethodMeta.
88+
* Separate from Method so that this information can be accessed without paying
89+
* the initialization cost of loading the full Method.
90+
*/
91+
class MethodMeta final {
92+
public:
93+
MethodMeta() noexcept = delete;
94+
MethodMeta(const MethodMeta&) noexcept = default;
95+
MethodMeta(MethodMeta&&) noexcept = default;
96+
MethodMeta& operator=(const MethodMeta&) noexcept = default;
97+
MethodMeta& operator=(MethodMeta&& other) noexcept = default;
98+
~MethodMeta() noexcept = default;
99+
100+
/**
101+
* Get the name of this method.
102+
*
103+
* @returns The method name.
104+
*/
105+
const char* name() const noexcept;
106+
107+
/**
108+
* Get the number of inputs to this method.
109+
*
110+
* @returns The number of inputs.
111+
*/
112+
size_t num_inputs() const noexcept;
113+
114+
/**
115+
* Get the tag of the specified input.
116+
*
117+
* @param[in] index The index of the input to look up.
118+
* @returns The tag of input, can only be [Tensor, Int, Bool, Double, String].
119+
*/
120+
Result<Tag> input_tag(size_t index) const noexcept;
121+
122+
/**
123+
* Get metadata about the specified input.
124+
*
125+
* @param[in] index The index of the input to look up.
126+
* @returns The metadata on success, or an error on failure. Only valid for
127+
* tag::Tensor
128+
*/
129+
Result<TensorInfo> input_tensor_meta(size_t index) const noexcept;
130+
131+
/**
132+
* Get the number of outputs to this method.
133+
*
134+
* @returns The number of outputs.
135+
*/
136+
size_t num_outputs() const noexcept;
137+
138+
/**
139+
* Get the tag of the specified output.
140+
*
141+
* @param[in] index The index of the output to look up.
142+
* @returns The tag of output, can only be [Tensor, Int, Bool, Double,
143+
* String].
144+
*/
145+
Result<Tag> output_tag(size_t index) const noexcept;
146+
147+
/**
148+
* Get metadata about the specified output.
149+
*
150+
* @param[in] index The index of the output to look up.
151+
* @returns The metadata on success, or an error on failure. Only valid for
152+
* tag::Tensor
153+
*/
154+
Result<TensorInfo> output_tensor_meta(size_t index) const noexcept;
155+
156+
/**
157+
* Get the number of non-constant buffers this method requires.
158+
*
159+
* @returns The number of non-constant buffers.
160+
*/
161+
size_t num_non_const_buffers() const noexcept;
162+
163+
/**
164+
* Get the size of the specified non-constant buffer.
165+
*
166+
* @param[in] index The index of the buffer to look up.
167+
* @returns The size in bytes on success, or an error on failure.
168+
*/
169+
Result<int64_t> non_const_buffer_size(size_t index) const noexcept;
170+
171+
private:
172+
// Let Program create MethodMeta.
173+
friend class Program;
174+
175+
explicit MethodMeta(
176+
const executorch_flatbuffer::ExecutionPlan* s_plan) noexcept;
177+
178+
/// Source of truth for method information
179+
const executorch_flatbuffer::ExecutionPlan* s_plan_;
180+
};
181+
182+
} // namespace executor
183+
} // namespace torch

runtime/executor/program.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,18 @@ Result<Method> Program::load_method(
206206
return Error::InvalidArgument;
207207
}
208208

209+
Result<MethodMeta> Program::method_meta(const char* method_name) const {
210+
EXECUTORCH_SCOPE_PROF("Program::method_meta");
211+
auto execution_plans = internal_program_->execution_plan();
212+
for (size_t i = 0; i < execution_plans->size(); i++) {
213+
auto serialization_plan = execution_plans->GetMutableObject(i);
214+
if (std::strcmp(serialization_plan->name()->c_str(), method_name) == 0) {
215+
return MethodMeta(serialization_plan);
216+
}
217+
}
218+
return Error::InvalidArgument;
219+
}
220+
209221
const void* Program::get_constant_buffer_data(size_t buffer_idx) const {
210222
ET_CHECK(is_valid());
211223
auto internal_program =

0 commit comments

Comments
 (0)