Skip to content

Commit b205f2f

Browse files
authored
Merge branch 'main' into export-D75228037
2 parents e7c9290 + a6e2961 commit b205f2f

File tree

6 files changed

+60
-52
lines changed

6 files changed

+60
-52
lines changed

backends/qualcomm/_passes/remove_redundancy.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self):
2222
exir_ops.edge.aten.clone.default: self._default_condition,
2323
torch.ops.aten.alias.default: self._default_condition,
2424
exir_ops.edge.aten.alias.default: self._default_condition,
25+
exir_ops.edge.aten.alias_copy.default: self._default_condition,
2526
exir_ops.edge.aten.lift_fresh_copy.default: self._default_condition,
2627
# remove this target if '_skip_dim_order' is set to False
2728
exir_ops.edge.dim_order_ops._to_dim_order_copy.default: self._dim_order_op_condition,

backends/qualcomm/tests/models.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,17 @@
1010
# module with related operator only
1111

1212

13+
# Ensure alias_copy is removed in remove_redundancy pass
14+
class Alias(torch.nn.Module):
15+
def __init__(self):
16+
super().__init__()
17+
self.relu = torch.nn.ReLU()
18+
19+
def forward(self, x):
20+
alias_x = torch.ops.aten.alias.default(x)
21+
return self.relu(alias_x)
22+
23+
1324
class And(torch.nn.Module):
1425
def __init__(self, pos, neg):
1526
super().__init__()

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,11 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
124124
sample_input = (torch.randn(1, 512, 7, 7),)
125125
self.lower_module_and_test_output(module, sample_input)
126126

127+
def test_qnn_backend_alias(self):
128+
module = Alias() # noqa: F405
129+
sample_input = (torch.randn(1, 10),)
130+
self.lower_module_and_test_output(module, sample_input)
131+
127132
def test_qnn_backend_amax(self):
128133
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
129134
sample_input = (torch.randn(4, 4),)
@@ -1162,6 +1167,12 @@ def test_qnn_backend_adaptive_avg_pool2d(self):
11621167
module = self.get_qdq_module(module, sample_input)
11631168
self.lower_module_and_test_output(module, sample_input)
11641169

1170+
def test_qnn_backend_alias(self):
1171+
module = Alias() # noqa: F405
1172+
sample_input = (torch.randn(1, 10),)
1173+
module = self.get_qdq_module(module, sample_input)
1174+
self.lower_module_and_test_output(module, sample_input)
1175+
11651176
def test_qnn_backend_amax(self):
11661177
modules = [AMax(dim=1, keepdim=False), AMax(dim=1, keepdim=True)] # noqa: F405
11671178
sample_input = (torch.randn(4, 4),)

extension/apple/ExecuTorch/Exported/ExecuTorch+Tensor.swift

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,43 +11,40 @@
1111
/// A protocol that types conform to in order to be used as tensor element types.
1212
/// Provides the mapping from the Swift type to the underlying `DataType`.
1313
@available(*, deprecated, message: "This API is experimental.")
14-
protocol Scalar {
14+
public protocol Scalar {
1515
/// The `DataType` corresponding to this scalar type.
1616
static var dataType: DataType { get }
1717
}
1818

1919
@available(*, deprecated, message: "This API is experimental.")
20-
extension UInt8: Scalar { static var dataType: DataType { .byte } }
20+
extension UInt8: Scalar { public static var dataType: DataType { .byte } }
2121
@available(*, deprecated, message: "This API is experimental.")
22-
extension Int8: Scalar { static var dataType: DataType { .char } }
22+
extension Int8: Scalar { public static var dataType: DataType { .char } }
2323
@available(*, deprecated, message: "This API is experimental.")
24-
extension Int16: Scalar { static var dataType: DataType { .short } }
24+
extension Int16: Scalar { public static var dataType: DataType { .short } }
2525
@available(*, deprecated, message: "This API is experimental.")
26-
extension Int32: Scalar { static var dataType: DataType { .int } }
26+
extension Int32: Scalar { public static var dataType: DataType { .int } }
2727
@available(*, deprecated, message: "This API is experimental.")
28-
extension Int64: Scalar { static var dataType: DataType { .long } }
28+
extension Int64: Scalar { public static var dataType: DataType { .long } }
2929
@available(*, deprecated, message: "This API is experimental.")
30-
extension Int: Scalar { static var dataType: DataType { .long } }
31-
@available(macOS 11.0, *)
30+
extension Int: Scalar { public static var dataType: DataType { .long } }
3231
@available(*, deprecated, message: "This API is experimental.")
33-
extension Float16: Scalar { static var dataType: DataType { .half } }
32+
extension Float: Scalar { public static var dataType: DataType { .float } }
3433
@available(*, deprecated, message: "This API is experimental.")
35-
extension Float: Scalar { static var dataType: DataType { .float } }
34+
extension Double: Scalar { public static var dataType: DataType { .double } }
3635
@available(*, deprecated, message: "This API is experimental.")
37-
extension Double: Scalar { static var dataType: DataType { .double } }
36+
extension Bool: Scalar { public static var dataType: DataType { .bool } }
3837
@available(*, deprecated, message: "This API is experimental.")
39-
extension Bool: Scalar { static var dataType: DataType { .bool } }
38+
extension UInt16: Scalar { public static var dataType: DataType { .uInt16 } }
4039
@available(*, deprecated, message: "This API is experimental.")
41-
extension UInt16: Scalar { static var dataType: DataType { .uInt16 } }
40+
extension UInt32: Scalar { public static var dataType: DataType { .uInt32 } }
4241
@available(*, deprecated, message: "This API is experimental.")
43-
extension UInt32: Scalar { static var dataType: DataType { .uInt32 } }
42+
extension UInt64: Scalar { public static var dataType: DataType { .uInt64 } }
4443
@available(*, deprecated, message: "This API is experimental.")
45-
extension UInt64: Scalar { static var dataType: DataType { .uInt64 } }
46-
@available(*, deprecated, message: "This API is experimental.")
47-
extension UInt: Scalar { static var dataType: DataType { .uInt64 } }
44+
extension UInt: Scalar { public static var dataType: DataType { .uInt64 } }
4845

4946
@available(*, deprecated, message: "This API is experimental.")
50-
extension Tensor {
47+
public extension Tensor {
5148
/// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements.
5249
///
5350
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ class TensorTest: XCTestCase {
153153
let tensor = data.withUnsafeMutableBytes {
154154
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
155155
}
156-
let array: [Float] = try tensor.withUnsafeBytes { Array($0) }
156+
let array = try tensor.withUnsafeBytes([Float].init)
157157
XCTAssertEqual(array, data)
158158
}
159159

@@ -172,30 +172,6 @@ class TensorTest: XCTestCase {
172172
}
173173
}
174174

175-
func testWithUnsafeBytesFloat16() throws {
176-
var data: [Float16] = [1, 2, 3, 4, 5, 6]
177-
let tensor = data.withUnsafeMutableBytes {
178-
Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half)
179-
}
180-
let array: [Float16] = try tensor.withUnsafeBytes { Array($0) }
181-
XCTAssertEqual(array, data)
182-
}
183-
184-
func testWithUnsafeMutableBytesFloat16() throws {
185-
var data: [Float16] = [1, 2, 3, 4]
186-
let tensor = data.withUnsafeMutableBytes { buffer in
187-
Tensor(bytes: buffer.baseAddress!, shape: [4], dataType: .half)
188-
}
189-
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Float16>) in
190-
for i in buffer.indices {
191-
buffer[i] *= 2
192-
}
193-
}
194-
try tensor.withUnsafeBytes { buffer in
195-
XCTAssertEqual(Array(buffer), data.map { $0 * 2 })
196-
}
197-
}
198-
199175
func testInitWithTensor() {
200176
var data: [Int] = [10, 20, 30, 40]
201177
let tensor1 = data.withUnsafeMutableBytes {

scripts/create_frameworks.sh

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,15 +138,27 @@ create_xcframework() {
138138
echo "No .swiftmodule file found in ${module_source_dir}"
139139
exit 1
140140
fi
141-
142-
local dir_suffix
143-
dir_suffix=$(echo "$dir" | cut -d'/' -f1 | tr '[:upper:]' '[:lower:]' | sed 's/[\/\.~]/_/g')
144-
for slice_path in "${xcframework}/${dir_suffix}-"*; do
145-
if [ -d "${slice_path}/Headers" ]; then
146-
echo " - Copying ${swiftmodule_file##*/} to ${slice_path}/Headers/${swift_module}.swiftmodule"
147-
cp "${swiftmodule_file}" "${slice_path}/Headers/${swift_module}.swiftmodule"
148-
fi
149-
done
141+
local base=$(basename "$swiftmodule_file" .swiftmodule)
142+
local arch="${base%%-*}"
143+
local rest="${base#*-apple-}"
144+
local platform_tag
145+
local variant
146+
if [[ "$rest" == *-simulator ]]; then
147+
platform_tag="${rest%-simulator}"
148+
variant="-simulator"
149+
else
150+
platform_tag="$rest"
151+
variant=""
152+
fi
153+
local slice_name="${platform_tag}-${arch}${variant}"
154+
local slice_path="${xcframework}/${slice_name}"
155+
if [ ! -d "$slice_path" ]; then
156+
echo "Warning: slice '${slice_name}' not found in ${xcframework}, skipping"
157+
continue
158+
fi
159+
echo " - Copying ${swift_module}.swiftmodule into slice ${slice_name}"
160+
cp "$swiftmodule_file" "${slice_path}/${swift_module}.swiftmodule"
161+
ln -sf "../${swift_module}.swiftmodule" "${slice_path}/Headers/${swift_module}.swiftmodule"
150162
done
151163
fi
152164

0 commit comments

Comments
 (0)