Skip to content

Commit ceebb4b

Browse files
Merge branch 'main' into support_oss_models_2
2 parents f1caefb + 879eee0 commit ceebb4b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

65 files changed

+1220
-1048
lines changed

.github/scripts/label_utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@
2222

2323
LABEL_ERR_MSG_TITLE = "This PR needs a `release notes:` label"
2424
LABEL_ERR_MSG = f"""# {LABEL_ERR_MSG_TITLE}
25-
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`.
26-
27-
If not, please add the `release notes: none` label.
25+
If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with `release notes:`. This helps us keep track and include your important work in the next release notes.
2826
2927
To add a label, you can comment to pytorchbot, for example
3028
`@pytorchbot label "release notes: none"`

.github/scripts/trymerge.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,7 @@
5959
patterns_to_regex,
6060
retries_decorator,
6161
)
62-
from label_utils import (
63-
gh_add_labels,
64-
gh_remove_label,
65-
has_required_labels,
66-
LABEL_ERR_MSG,
67-
)
62+
from label_utils import gh_add_labels, gh_remove_label
6863
from trymerge_explainer import get_revert_message, TryMergeExplainer
6964

7065
# labels
@@ -2116,9 +2111,6 @@ def merge(
21162111
# Check for approvals
21172112
find_matching_merge_rule(pr, repo, skip_mandatory_checks=True)
21182113

2119-
if not has_required_labels(pr):
2120-
raise RuntimeError(LABEL_ERR_MSG.lstrip(" #"))
2121-
21222114
if ignore_current:
21232115
checks = pr.get_checkrun_conclusions()
21242116
_, failing, _ = categorize_checks(

.github/workflows/check-labels.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ jobs:
5151
PR_NUM: ${{ github.event.number || github.event.inputs.pr_number }}
5252
run: |
5353
set -ex
54-
python3 .github/scripts/check_labels.py --exit-non-zero "${PR_NUM}"
54+
python3 .github/scripts/check_labels.py "${PR_NUM}"

backends/cadence/aot/replace_ops.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2065,11 +2065,10 @@ def call_operator(
20652065
return super().call_operator(op, args, kwargs, meta)
20662066

20672067

2068-
@register_cadence_pass(CadencePassAttribute(opt_level=2))
2069-
class ReplaceGeluWithApproximateGeluPass(ExportPass):
2068+
@register_cadence_pass(CadencePassAttribute(opt_level=0))
2069+
class ReplaceAtenApproxGeluWithApproxGeluPass(ExportPass):
20702070
"""
2071-
Replace the gelu op with an approximate gelu op. The approximate gelu op
2072-
is more efficient on DSP backends.
2071+
Replace the aten gelu op with an approximate arg with an approximate gelu op.
20732072
"""
20742073

20752074
def call_operator(
@@ -2079,6 +2078,9 @@ def call_operator(
20792078
kwargs: Dict[str, Argument],
20802079
meta: NodeMetadata,
20812080
) -> ProxyValue:
2081+
if "approximate" not in kwargs:
2082+
return super().call_operator(op, args, kwargs, meta)
2083+
20822084
if op not in {
20832085
exir_ops.edge.aten.gelu.default,
20842086
}:
@@ -2414,7 +2416,7 @@ class CadenceReplaceOpsInGraph:
24142416
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
24152417
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
24162418
ReplaceWhereWithFullArgsWithWhereScalar,
2417-
ReplaceGeluWithApproximateGeluPass,
2419+
ReplaceAtenApproxGeluWithApproxGeluPass,
24182420
ReplaceSplitWithSlicePass,
24192421
ReplacePowWithMulPass,
24202422
]

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@
2626
ForceChannelLastForConvPass,
2727
MakeSliceAndCatDimOutermostPass,
2828
ReplaceAddMMWithLinearPass,
29+
ReplaceAtenApproxGeluWithApproxGeluPass,
2930
ReplaceAtenConvolutionWithJarvisConvolutionPass,
3031
ReplaceConstantPadNdWithSlicePass,
3132
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,
3233
ReplaceConvWithIm2RowAndLinear,
3334
ReplaceEmptyTensorsWithFullPass,
3435
ReplaceFunctionallyEquivalentOpTargets,
35-
ReplaceGeluWithApproximateGeluPass,
3636
ReplaceIm2RowWithViewPass,
3737
ReplaceLinearWithFullyConnectedOpPass,
3838
ReplaceMatmulWithTransposedMatmulPass,
@@ -1287,17 +1287,41 @@ def forward(self, cond: torch.Tensor):
12871287
1,
12881288
)
12891289

1290-
def test_replace_aten_gelu_with_approximate_gelu(self):
1291-
class Gelu(torch.nn.Module):
1292-
def forward(self, input):
1293-
return torch.nn.functional.gelu(input)
1290+
def test_no_replace_aten_gelu_with_approximate_gelu(self):
1291+
inputs = torch.randn(2, 1, 64)
1292+
1293+
gm = single_op_builder(
1294+
placeholders=(inputs,),
1295+
op=exir_ops.edge.aten.gelu.default,
1296+
args=(inputs,),
1297+
)
1298+
gm = ExportPass().call(gm).graph_module
1299+
1300+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1301+
graph_after_passes = p.call(gm).graph_module
12941302

1303+
# Assert that aten.gelu op was not decomposed, since it didn't have an approximate argument
1304+
self.assertEqual(
1305+
count_node(
1306+
graph_after_passes,
1307+
exir_ops.edge.aten.gelu.default,
1308+
),
1309+
1,
1310+
)
1311+
1312+
def test_replace_aten_approximate_gelu_with_approximate_gelu(self):
12951313
inputs = torch.randn(2, 1, 64)
12961314

1297-
graph_module = export_to_edge(Gelu(), (inputs,)).exported_program().graph_module
1315+
gm = single_op_builder(
1316+
placeholders=(inputs,),
1317+
op=exir_ops.edge.aten.gelu.default,
1318+
args=(inputs,),
1319+
kwargs={"approximate": "tanh"},
1320+
)
1321+
gm = ExportPass().call(gm).graph_module
12981322

1299-
p = ReplaceGeluWithApproximateGeluPass()
1300-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
1323+
p = ReplaceAtenApproxGeluWithApproxGeluPass()
1324+
graph_after_passes = p.call(gm).graph_module
13011325

13021326
# Assert that aten.gelu op was decomposed
13031327
self.assertEqual(

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,8 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
499499
compute_graph->encode_prepack();
500500
compute_graph->prepack();
501501

502+
// TODO(ssjia): remove this once we can batch compile compute pipelines
503+
// during prepare().
502504
compute_graph->encode_execute();
503505

504506
return Error::Ok;
@@ -567,9 +569,14 @@ class VulkanBackend final : public ::executorch::runtime::BackendInterface {
567569
}
568570
}
569571

572+
// propagate_resize() will re-encode the command buffer so that push
573+
// constants are updated and DynamicDispatchNode can update the compute
574+
// shader, global workgroup size, and local workgroup size to perform the
575+
// model inference.
570576
if (should_propagate_resize) {
571577
compute_graph->propagate_resize();
572578
}
579+
573580
compute_graph->execute();
574581

575582
for (size_t i = 0; i < compute_graph->outputs().size(); i++) {

backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,14 +492,24 @@ vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
492492
const ValueRef idx) {
493493
if (values_.at(idx).isInt()) {
494494
const int32_t val = extract_scalar<int32_t>(idx);
495-
create_params_buffer(val);
495+
return create_params_buffer(val);
496496
} else if (values_.at(idx).isSymInt()) {
497497
SymIntPtr symint = get_symint(idx);
498498
return vkapi::BufferBindInfo(symint->gpu_buffer.buffer());
499499
}
500500
VK_THROW("Cannot create a int param buffer for the given value");
501501
}
502502

503+
vkapi::BufferBindInfo ComputeGraph::get_or_create_int_param_buffer(
504+
const ValueRef idx,
505+
const int32_t default_val) {
506+
if (values_.at(idx).isNone()) {
507+
return create_params_buffer(default_val);
508+
} else {
509+
return get_or_create_int_param_buffer(idx);
510+
}
511+
}
512+
503513
void ComputeGraph::set_symint(const ValueRef idx, const int32_t val) {
504514
get_symint(idx)->set(val);
505515
}
@@ -678,11 +688,12 @@ void ComputeGraph::encode_execute() {
678688
}
679689
}
680690

681-
void ComputeGraph::execute() const {
691+
void ComputeGraph::execute() {
682692
vkapi::VulkanFence fence = context_->fences().get_fence();
683693
context_->submit_cmd_to_gpu(fence.get_submit_handle());
684694
fence.wait();
685695
context_->fences().return_fence(fence);
696+
execute_count_++;
686697
}
687698

688699
void ComputeGraph::resize_input(
@@ -692,10 +703,17 @@ void ComputeGraph::resize_input(
692703
get_tensor(io_val.value)->virtual_resize(new_sizes);
693704
}
694705

706+
void ComputeGraph::virtual_resize(
707+
const ValueRef idx,
708+
const std::vector<int64_t>& new_sizes) {
709+
get_tensor(idx)->virtual_resize(new_sizes);
710+
}
711+
695712
void ComputeGraph::propagate_resize() {
696713
for (std::unique_ptr<ExecuteNode>& node : execute_nodes_) {
697714
node->trigger_resize(this);
698715
}
716+
encode_execute();
699717
}
700718

701719
} // namespace vkcompute

backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ class ComputeGraph final {
187187

188188
protected:
189189
size_t values_in_use_ = 0;
190+
size_t execute_count_ = 0;
190191

191192
public:
192193
//
@@ -397,6 +398,19 @@ class ComputeGraph final {
397398
std::optional<T> extract_optional_scalar(const ValueRef idx) {
398399
if (val_is_none(idx)) {
399400
return ::std::nullopt;
401+
} else if (val_is_symint(idx)) {
402+
return utils::safe_downcast<T>(read_symint(idx));
403+
} else {
404+
return extract_scalar<T>(idx);
405+
}
406+
}
407+
408+
template <typename T>
409+
T extract_optional_scalar(const ValueRef idx, const T default_val) {
410+
if (val_is_none(idx)) {
411+
return default_val;
412+
} else if (val_is_symint(idx)) {
413+
return utils::safe_downcast<T>(read_symint(idx));
400414
} else {
401415
return extract_scalar<T>(idx);
402416
}
@@ -608,6 +622,10 @@ class ComputeGraph final {
608622
*/
609623
vkapi::BufferBindInfo get_or_create_int_param_buffer(const ValueRef idx);
610624

625+
vkapi::BufferBindInfo get_or_create_int_param_buffer(
626+
const ValueRef idx,
627+
const int32_t default_value);
628+
611629
void set_symint(const ValueRef idx, const int32_t val);
612630

613631
int32_t read_symint(const ValueRef idx);
@@ -745,13 +763,16 @@ class ComputeGraph final {
745763
//
746764

747765
void encode_execute();
748-
void execute() const;
766+
void execute();
749767

750768
//
751769
// Dynamic Shape support
752770
//
753771

754772
void resize_input(const int64_t idx, const std::vector<int64_t>& new_sizes);
773+
void virtual_resize(
774+
const ValueRef idx,
775+
const std::vector<int64_t>& new_sizes);
755776
void propagate_resize();
756777

757778
//
@@ -762,6 +783,10 @@ class ComputeGraph final {
762783
return context_->adapter_ptr()->supports_int16_shader_types();
763784
}
764785

786+
inline size_t execute_count() const {
787+
return execute_count_;
788+
}
789+
765790
/*
766791
* Check whether the GPU supports 8 bit buffers.
767792
*/

backends/vulkan/runtime/graph/ops/DispatchNode.cpp

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,15 +46,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
4646

4747
std::unique_lock<std::mutex> cmd_lock = context->dispatch_lock();
4848

49-
std::array<uint8_t, kMaxPushConstantSize> push_constants_data;
50-
uint32_t push_constants_offset = 0;
51-
52-
for (const auto& push_constant : push_constants_) {
53-
push_constants_offset += push_constant.write(
54-
push_constants_data.data(),
55-
push_constants_offset,
56-
kMaxPushConstantSize);
57-
}
49+
write_push_constant_data();
5850

5951
context->report_shader_dispatch_start(
6052
shader_.kernel_name,
@@ -63,7 +55,7 @@ void DispatchNode::encode(ComputeGraph* graph) {
6355
node_id_);
6456

6557
vkapi::DescriptorSet descriptor_set = context->get_descriptor_set(
66-
shader_, local_workgroup_size_, spec_vars_, push_constants_offset);
58+
shader_, local_workgroup_size_, spec_vars_, push_constants_offset_);
6759

6860
uint32_t idx = 0;
6961
idx = bind_values_to_descriptor_set(
@@ -76,10 +68,20 @@ void DispatchNode::encode(ComputeGraph* graph) {
7668
pipeline_barrier,
7769
shader_,
7870
global_workgroup_size_,
79-
push_constants_data.data(),
80-
push_constants_offset);
71+
push_constants_data_.data(),
72+
push_constants_offset_);
8173

8274
context->report_shader_dispatch_end();
8375
}
8476

77+
void DispatchNode::write_push_constant_data() {
78+
push_constants_offset_ = 0;
79+
for (const auto& push_constant : push_constants_) {
80+
push_constants_offset_ += push_constant.write(
81+
push_constants_data_.data(),
82+
push_constants_offset_,
83+
kMaxPushConstantSize);
84+
}
85+
}
86+
8587
} // namespace vkcompute

backends/vulkan/runtime/graph/ops/DispatchNode.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,12 @@ class DispatchNode : public ExecuteNode {
5050
const vkapi::SpecVarList spec_vars_;
5151
const std::vector<PushConstantDataInfo> push_constants_;
5252

53+
// For push constants
54+
std::array<uint8_t, kMaxPushConstantSize> push_constants_data_{};
55+
uint32_t push_constants_offset_ = 0;
56+
57+
void write_push_constant_data();
58+
5359
public:
5460
operator bool() const {
5561
return shader_;

0 commit comments

Comments
 (0)