Skip to content

Commit 62a13c1

Browse files
SS-JIAfacebook-github-bot
authored andcommitted
Handle scalar tensor and mutable buffer inputs in Vulkan delegate runtime (#5930)
Summary: Pull Request resolved: #5930 ## Context * Handle scalar tensor inputs by adding them to the graph as symbolic ints * Add support for symint inputs in the Vulkan delegate * Add type checking for Vulkan delegate inputs and outputs This is needed for Transformer models, which receive a an `input_pos` integer scalar tensor as an input. `input_pos` is used in KV cache updates and determines the sizes of the cache slices. Additionally, mutable buffer inputs/outputs, which appear as `TensorRef` to the Vulkan graph, are handled as well by ignoring them when copying outputs. More details in the comments. ### Why are scalar tensors added as symint? Adding scalar tensors as symints makes more sense than adding them as real tensors, since symints are commonly used to inform tensor shapes. Adding scalar tensors as symints allow them to be easily accessible by the CPU at graph encoding and resizing time, as well as easily accesible by the GPU within compute shaders. ghstack-source-id: 246752221 Reviewed By: jorgep31415 Differential Revision: D63979312 fbshipit-source-id: ce76993d65c9b5af8de98e4f131c5a6f475900ab
1 parent a4b88a3 commit 62a13c1

File tree

2 files changed

+85
-19
lines changed

2 files changed

+85
-19
lines changed

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 77 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,16 @@ class GraphBuilder {
312312
add_value_to_graph(fb_id, value);
313313
}
314314

315-
// Parse the inputs
315+
// Parse the inputs, which will be tensors most of the time but can also be
316+
// symints and tensorrefs (which will be the case if the original graph had)
317+
// mutable buffers.
316318
for (const uint32_t fb_id : *flatbuffer_->input_ids()) {
317319
const ValueRef ref = get_fb_id_valueref(fb_id);
318-
compute_graph_->set_input_tensor(ref);
320+
if (compute_graph_->val_is_tensor(ref)) {
321+
compute_graph_->set_input_tensor(ref);
322+
} else {
323+
compute_graph_->set_val_as_input(ref);
324+
}
319325
}
320326

321327
// Parse the operators
@@ -354,10 +360,15 @@ class GraphBuilder {
354360
}
355361
}
356362

357-
// Parse the outputs
363+
// Parse the outputs, which will be mostly tensors. For some reason,
364+
// mutable buffers are shown to be returned in the fx.Graph but do not get
365+
// returned by the delegate; this may be an implementation detail of how the
366+
// executorch emitter handles mutable buffers.
358367
for (const uint32_t fb_id : *flatbuffer_->output_ids()) {
359368
const ValueRef ref = get_fb_id_valueref(fb_id);
360-
compute_graph_->set_output_tensor(ref);
369+
if (compute_graph_->val_is_tensor(ref)) {
370+
compute_graph_->set_output_tensor(ref);
371+
}
361372
}
362373
}
363374
};
@@ -401,6 +412,26 @@ bool maybe_resize_input(
401412
return should_resize;
402413
}
403414

415+
bool maybe_update_scalar_tensor(
416+
ComputeGraph* graph,
417+
const ValueRef ref,
418+
executorch::aten::Tensor& scalar_tensor_src) {
419+
const int32_t cur_val = graph->read_symint(ref);
420+
int32_t scalar_tensor_val = 0;
421+
exec_aten::ScalarType dtype = scalar_tensor_src.scalar_type();
422+
if (dtype == exec_aten::ScalarType::Int) {
423+
scalar_tensor_val = *scalar_tensor_src.const_data_ptr<int32_t>();
424+
} else if (dtype == exec_aten::ScalarType::Long) {
425+
scalar_tensor_val = int32_t(*scalar_tensor_src.const_data_ptr<int64_t>());
426+
}
427+
bool was_updated = false;
428+
if (scalar_tensor_val != cur_val) {
429+
graph->set_symint(ref, scalar_tensor_val);
430+
was_updated = true;
431+
}
432+
return was_updated;
433+
}
434+
404435
void maybe_resize_output(
405436
ComputeGraph* graph,
406437
const size_t output_i,
@@ -487,7 +518,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
487518

488519
Error err = compileModel(processed->data(), compute_graph);
489520

490-
// This backend does not need its processed data after compiling the model.
521+
// This backend does not need its processed data after compiling the
522+
// model.
491523
processed->Free();
492524

493525
if (err != Error::Ok) {
@@ -508,13 +540,31 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
508540
const size_t num_inputs = compute_graph->inputs().size();
509541
bool should_propagate_resize = false;
510542
for (size_t i = 0; i < num_inputs; i++) {
511-
bool was_resized =
512-
maybe_resize_input(compute_graph, i, args[i]->toTensor());
513-
should_propagate_resize = should_propagate_resize || was_resized;
514-
compute_graph->copy_into_staging(
515-
compute_graph->inputs()[i].staging,
516-
args[i]->toTensor().const_data_ptr(),
517-
args[i]->toTensor().numel());
543+
const ValueRef iref = compute_graph->inputs()[i].value;
544+
if (compute_graph->val_is_tensor(iref)) {
545+
VK_CHECK_COND(args[i]->isTensor());
546+
bool was_resized =
547+
maybe_resize_input(compute_graph, i, args[i]->toTensor());
548+
should_propagate_resize = should_propagate_resize || was_resized;
549+
compute_graph->copy_into_staging(
550+
compute_graph->inputs()[i].staging,
551+
args[i]->toTensor().const_data_ptr(),
552+
args[i]->toTensor().numel());
553+
} else if (compute_graph->val_is_symint(iref)) {
554+
VK_CHECK_COND(
555+
args[i]->isTensor(),
556+
"Cannot handle symint arg to graph that is not derived from a "
557+
"scalar tensor at the moment.");
558+
bool was_updated = maybe_update_scalar_tensor(
559+
compute_graph, iref, args[i]->toTensor());
560+
// Since symint inputs may impact tensor's sizes, trigger a resize if
561+
// any symbolic integer shapes are updated.
562+
should_propagate_resize = should_propagate_resize || was_updated;
563+
} else {
564+
VK_THROW(
565+
"Could not handle input with type ",
566+
compute_graph->get_val_type(iref));
567+
}
518568
}
519569

520570
if (should_propagate_resize) {
@@ -523,13 +573,21 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
523573
compute_graph->execute();
524574

525575
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {
526-
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
527-
// args holds inputs directly followed by outputs, so the i'th output
528-
// for compute_graph corresponds to the (i + num_inputs)'th arg
529-
compute_graph->copy_from_staging(
530-
compute_graph->outputs()[i].staging,
531-
args[num_inputs + i]->toTensor().mutable_data_ptr(),
532-
args[num_inputs + i]->toTensor().numel());
576+
const ValueRef oref = compute_graph->outputs()[i].value;
577+
if (compute_graph->val_is_tensor(oref)) {
578+
VK_CHECK_COND(args[i]->isTensor());
579+
maybe_resize_output(compute_graph, i, args[num_inputs + i]->toTensor());
580+
// args holds inputs directly followed by outputs, so the i'th output
581+
// for compute_graph corresponds to the (i + num_inputs)'th arg
582+
compute_graph->copy_from_staging(
583+
compute_graph->outputs()[i].staging,
584+
args[num_inputs + i]->toTensor().mutable_data_ptr(),
585+
args[num_inputs + i]->toTensor().numel());
586+
} else {
587+
VK_THROW(
588+
"Could not handle output with type ",
589+
compute_graph->get_val_type(oref));
590+
}
533591
}
534592

535593
#ifdef ET_EVENT_TRACER_ENABLED

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,14 @@ class ComputeGraph final {
555555

556556
int32_t read_symint(const ValueRef idx);
557557

558+
inline void set_val_as_input(const ValueRef idx) {
559+
inputs_.push_back({idx, kDummyValueRef});
560+
}
561+
562+
inline void set_val_as_output(const ValueRef idx) {
563+
outputs_.push_back({idx, kDummyValueRef});
564+
}
565+
558566
/*
559567
* Convenience function to add an input tensor along with its staging buffer
560568
*/

0 commit comments

Comments
 (0)