Skip to content

Commit bc9a3bf

Browse files
authored
Merge branch 'pytorch:main' into velapin
2 parents 4ba8996 + 5aa4cac commit bc9a3bf

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

43 files changed

+2905
-189
lines changed

backends/arm/operators/op_conv2d.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ class Conv2dVisitor(NodeVisitor):
2828
def __init__(self, *args):
2929
super().__init__(*args)
3030

31+
# torch.nn.Conv2d does not require the result of
32+
# `(input + 2 * pad - dilation * (weight - 1) - 1) / stride`
33+
# must be an integer, but tosa currently strictly require this property.
34+
# This function adjusts the pad value to meet the requirement.
35+
def adjust_pad_if_needed(self, input, weight, stride, pad, dilation):
36+
mod_remainder = (input + 2 * pad - dilation * (weight - 1) - 1) % stride
37+
38+
# No need to adjust
39+
if mod_remainder == 0:
40+
return pad
41+
42+
if mod_remainder > pad:
43+
raise RuntimeError(
44+
f"ignoring input element is not currently supported, got a large stride {stride}"
45+
)
46+
47+
return pad - mod_remainder
48+
3149
def define_node(
3250
self,
3351
node: torch.fx.Node,
@@ -52,6 +70,23 @@ def define_node(
5270
pad_attr = [val for val in pad.special for _ in (0, 1)]
5371
stride_attr = stride.special
5472
dilation_attr = dilation.special
73+
74+
# Adjust the pad value if needed to meet the strict convolution output shape calculation.
75+
pad_attr[1] = self.adjust_pad_if_needed(
76+
input.shape[2],
77+
weight.shape[2],
78+
stride_attr[0],
79+
pad_attr[1],
80+
dilation_attr[0],
81+
)
82+
pad_attr[3] = self.adjust_pad_if_needed(
83+
input.shape[3],
84+
weight.shape[3],
85+
stride_attr[1],
86+
pad_attr[3],
87+
dilation_attr[1],
88+
)
89+
5590
attr.ConvAttribute(
5691
pad=pad_attr,
5792
stride=stride_attr,

backends/arm/test/test_models.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023 Arm Limited and/or its affiliates.
1+
# Copyright 2023-2024 Arm Limited and/or its affiliates.
22
#
33
# This source code is licensed under the BSD-style license found in the
44
# LICENSE file in the root directory of this source tree.
@@ -214,6 +214,30 @@ def forward(self, x):
214214
x = self.conv2d(x)
215215
return x
216216

217+
# A test where `(input + 2 * pad - dilation * (weight - 1) - 1) / stride` is not an integer.
218+
@register_test
219+
class simple_conv2d_3x3_1x3x12x12_st2_pad1(torch.nn.Module):
220+
data = torch.ones(1, 3, 12, 12)
221+
inputs = {
222+
TosaProfile.BI: (data,),
223+
TosaProfile.MI: (data,),
224+
}
225+
226+
def __init__(self):
227+
super().__init__()
228+
self.conv2d = torch.nn.Conv2d(
229+
in_channels=3, out_channels=4, kernel_size=3, stride=2, padding=1
230+
)
231+
with torch.no_grad():
232+
self.conv2d.weight.copy_(
233+
rand_test_integers(low=1, high=4, size=(4, 3, 3, 3))
234+
)
235+
self.conv2d.bias.copy_(rand_test_integers(low=1, high=4, size=(4)))
236+
237+
def forward(self, x):
238+
x = self.conv2d(x)
239+
return x
240+
217241
@register_test
218242
class simple_conv2d_1x1_1x2x128x128_stride1(torch.nn.Module):
219243
data = torch.from_numpy(

backends/vulkan/runtime/VulkanBackend.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
10-
#include <executorch/backends/vulkan/runtime/graph/OperatorRegistry.h>
11-
129
#include <executorch/backends/vulkan/runtime/VulkanDelegateHeader.h>
1310
#include <executorch/backends/vulkan/schema_generated.h>
1411

12+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
13+
14+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
15+
1516
#include <executorch/runtime/backend/interface.h>
1617
#include <executorch/runtime/core/error.h>
1718
#include <executorch/runtime/core/evalue.h>

backends/vulkan/runtime/graph/Graph.cpp renamed to backends/vulkan/runtime/graph/ComputeGraph.cpp

Lines changed: 2 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -6,76 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/Graph.h>
9+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
1010

11-
#include <executorch/backends/vulkan/runtime/graph/ops/Staging.h>
11+
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Staging.h>
1212

1313
namespace at {
1414
namespace native {
1515
namespace vulkan {
1616

17-
//
18-
// SharedObject
19-
//
20-
21-
void SharedObject::add_user(ComputeGraph* const graph, const ValueRef idx) {
22-
vTensor& t = graph->get_val(idx).toTensor();
23-
24-
//
25-
// Aggregate Memory Requirements
26-
//
27-
28-
const VkMemoryRequirements mem_reqs = t.get_memory_requirements();
29-
aggregate_memory_requirements.size =
30-
std::max(mem_reqs.size, aggregate_memory_requirements.size);
31-
aggregate_memory_requirements.alignment =
32-
std::max(mem_reqs.alignment, aggregate_memory_requirements.alignment);
33-
aggregate_memory_requirements.memoryTypeBits |= mem_reqs.memoryTypeBits;
34-
35-
//
36-
// Aggregate Allocation Create Info
37-
//
38-
39-
const VmaAllocationCreateInfo create_info = t.get_allocation_create_info();
40-
// Clear out CREATE_STRATEGY bit flags in case of conflict
41-
VmaAllocationCreateFlags clear_mask = ~VMA_ALLOCATION_CREATE_STRATEGY_MASK;
42-
VmaAllocationCreateFlags create_flags = create_info.flags & clear_mask;
43-
// Use the default allocation strategy
44-
aggregate_create_info.flags = create_flags | api::DEFAULT_ALLOCATION_STRATEGY;
45-
46-
// Set the usage flag if it is currently not set
47-
if (aggregate_create_info.usage == VMA_MEMORY_USAGE_UNKNOWN) {
48-
aggregate_create_info.usage = create_info.usage;
49-
}
50-
// Otherwise check that there is no conflict regarding usage
51-
VK_CHECK_COND(aggregate_create_info.usage == create_info.usage);
52-
aggregate_create_info.requiredFlags |= create_info.requiredFlags;
53-
aggregate_create_info.preferredFlags |= create_info.preferredFlags;
54-
55-
users.emplace_back(idx);
56-
}
57-
58-
void SharedObject::allocate(ComputeGraph* const graph) {
59-
if (aggregate_memory_requirements.size == 0) {
60-
return;
61-
}
62-
allocation = graph->context()->adapter_ptr()->vma().create_allocation(
63-
aggregate_memory_requirements, aggregate_create_info);
64-
}
65-
66-
void SharedObject::bind_users(ComputeGraph* const graph) {
67-
if (users.empty()) {
68-
return;
69-
}
70-
for (const ValueRef idx : users) {
71-
graph->get_val(idx).toTensor().bind_allocation(allocation);
72-
}
73-
}
74-
75-
//
76-
// ComputeGraph
77-
//
78-
7917
ComputeGraph::ComputeGraph(GraphConfig config)
8018
: config_{config},
8119
context_{new api::Context(

backends/vulkan/runtime/graph/Graph.h renamed to backends/vulkan/runtime/graph/ComputeGraph.h

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -16,86 +16,18 @@
1616
#include <ATen/native/vulkan/api/Tensor.h>
1717
#include <ATen/native/vulkan/api/Types.h>
1818

19-
#include <executorch/backends/vulkan/runtime/graph/Config.h>
20-
#include <executorch/backends/vulkan/runtime/graph/Value.h>
19+
#include <executorch/backends/vulkan/runtime/graph/GraphConfig.h>
20+
21+
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
22+
#include <executorch/backends/vulkan/runtime/graph/containers/Value.h>
23+
24+
#include <executorch/backends/vulkan/runtime/graph/ops/ExecuteNode.h>
25+
#include <executorch/backends/vulkan/runtime/graph/ops/PrepackNode.h>
2126

2227
namespace at {
2328
namespace native {
2429
namespace vulkan {
2530

26-
using ValueRef = int32_t;
27-
28-
struct IOValueRef {
29-
ValueRef value;
30-
ValueRef staging;
31-
};
32-
33-
class ComputeGraph;
34-
35-
/*
36-
* Represents a single prepacking op in a ML model. In graph mode, ops will be
37-
* implemented in a derived class that implements encode, which will implement
38-
* encoding of shaders transferring necessary data (such as weights and biases)
39-
* to the GPU.
40-
*/
41-
class PrepackNode {
42-
friend class ComputeGraph;
43-
44-
public:
45-
PrepackNode(ValueRef tref, ValueRef packed) : tref_{tref}, packed_{packed} {}
46-
47-
virtual ~PrepackNode() = default;
48-
49-
protected:
50-
ValueRef tref_;
51-
ValueRef packed_;
52-
53-
public:
54-
virtual void encode(ComputeGraph* graph) const = 0;
55-
};
56-
57-
/*
58-
* Represents a single execution op in a ML model. In graph mode, ops will be
59-
* implemented in a derived class that implements encode, which will implement
60-
* encoding of the shader corresponding to the op into the command buffer of a
61-
* ComputeGraph.
62-
*/
63-
class ExecuteNode {
64-
friend class ComputeGraph;
65-
66-
public:
67-
ExecuteNode(ValueRef input, ValueRef output)
68-
: inputs_{input}, outputs_{output} {}
69-
ExecuteNode(
70-
const std::vector<ValueRef>& inputs,
71-
const std::vector<ValueRef>& outputs)
72-
: inputs_(inputs), outputs_(outputs) {}
73-
74-
virtual ~ExecuteNode() = default;
75-
76-
protected:
77-
std::vector<ValueRef> inputs_;
78-
std::vector<ValueRef> outputs_;
79-
80-
public:
81-
virtual void encode(ComputeGraph* graph) const = 0;
82-
};
83-
84-
struct SharedObject {
85-
friend class ComputeGraph;
86-
87-
explicit SharedObject() = default;
88-
89-
VkMemoryRequirements aggregate_memory_requirements;
90-
VmaAllocationCreateInfo aggregate_create_info;
91-
std::vector<ValueRef> users;
92-
api::MemoryAllocation allocation;
93-
94-
void add_user(ComputeGraph* const graph, const ValueRef idx);
95-
void allocate(ComputeGraph* const graph);
96-
void bind_users(ComputeGraph* const graph);
97-
};
98-
9931
/*
10032
* This is the core data structure used to execute Vulkan models in graph mode.
10133
* As opposed to ATen/eager mode where a command buffer is encoded every

backends/vulkan/runtime/graph/Constant.cpp renamed to backends/vulkan/runtime/graph/containers/Constant.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/backends/vulkan/runtime/graph/Constant.h>
9+
#include <executorch/backends/vulkan/runtime/graph/containers/Constant.h>
1010

1111
namespace at {
1212
namespace native {
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/containers/SharedObject.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
12+
13+
namespace at {
14+
namespace native {
15+
namespace vulkan {
16+
17+
void SharedObject::add_user(ComputeGraph* const graph, const ValueRef idx) {
18+
vTensor& t = graph->get_val(idx).toTensor();
19+
20+
//
21+
// Aggregate Memory Requirements
22+
//
23+
24+
const VkMemoryRequirements mem_reqs = t.get_memory_requirements();
25+
aggregate_memory_requirements.size =
26+
std::max(mem_reqs.size, aggregate_memory_requirements.size);
27+
aggregate_memory_requirements.alignment =
28+
std::max(mem_reqs.alignment, aggregate_memory_requirements.alignment);
29+
aggregate_memory_requirements.memoryTypeBits |= mem_reqs.memoryTypeBits;
30+
31+
//
32+
// Aggregate Allocation Create Info
33+
//
34+
35+
const VmaAllocationCreateInfo create_info = t.get_allocation_create_info();
36+
// Clear out CREATE_STRATEGY bit flags in case of conflict
37+
VmaAllocationCreateFlags clear_mask = ~VMA_ALLOCATION_CREATE_STRATEGY_MASK;
38+
VmaAllocationCreateFlags create_flags = create_info.flags & clear_mask;
39+
// Use the default allocation strategy
40+
aggregate_create_info.flags = create_flags | api::DEFAULT_ALLOCATION_STRATEGY;
41+
42+
// Set the usage flag if it is currently not set
43+
if (aggregate_create_info.usage == VMA_MEMORY_USAGE_UNKNOWN) {
44+
aggregate_create_info.usage = create_info.usage;
45+
}
46+
// Otherwise check that there is no conflict regarding usage
47+
VK_CHECK_COND(aggregate_create_info.usage == create_info.usage);
48+
aggregate_create_info.requiredFlags |= create_info.requiredFlags;
49+
aggregate_create_info.preferredFlags |= create_info.preferredFlags;
50+
51+
users.emplace_back(idx);
52+
}
53+
54+
void SharedObject::allocate(ComputeGraph* const graph) {
55+
if (aggregate_memory_requirements.size == 0) {
56+
return;
57+
}
58+
allocation = graph->context()->adapter_ptr()->vma().create_allocation(
59+
aggregate_memory_requirements, aggregate_create_info);
60+
}
61+
62+
void SharedObject::bind_users(ComputeGraph* const graph) {
63+
if (users.empty()) {
64+
return;
65+
}
66+
for (const ValueRef idx : users) {
67+
graph->get_val(idx).toTensor().bind_allocation(allocation);
68+
}
69+
}
70+
71+
} // namespace vulkan
72+
} // namespace native
73+
} // namespace at

0 commit comments

Comments
 (0)