Skip to content

Commit 04e034f

Browse files
authored
Merge branch 'main' into export-D75228037
2 parents 6b3810f + df5e7df commit 04e034f

File tree

9 files changed

+533
-66
lines changed

9 files changed

+533
-66
lines changed

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 109 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,6 @@
1313
import executorch.backends.cadence.aot.ops_registrations # noqa
1414
import torch
1515
from executorch.backends.cadence.aot import compiler
16-
from executorch.backends.cadence.aot.compiler import (
17-
export_to_edge,
18-
quantize_and_export_to_edge,
19-
)
2016
from executorch.backends.cadence.aot.fuse_ops import (
2117
FuseFullThenReshapePass,
2218
FuseMulScalarIntoDequantPass,
@@ -336,94 +332,144 @@ def test_replace_quant_view_dequant_with_requantize(self):
336332
)
337333

338334
def test_replace_dequant_quant_with_requantize(self):
339-
class M(torch.nn.Module):
340-
def __init__(self):
341-
super().__init__()
342-
343-
def forward(self, x):
344-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
345-
x, 1.2, 3, 0, 127, torch.int8
346-
)
347-
x = torch.permute(x, [2, 0, 1, 3])
348-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
349-
x, 4.5, 6, 0, 127, torch.int8
350-
)
351-
return x
352-
353-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
354-
model = M()
355-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
356-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
335+
builder = GraphBuilder()
336+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
337+
dequant = builder.call_operator(
338+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
339+
args=(x, 1.2, 3, 0, 127, torch.int8),
340+
)
341+
quant = builder.call_operator(
342+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
343+
args=(dequant, 4.5, 6, 0, 127, torch.int8),
344+
)
345+
builder.output(quant)
346+
graph_module = FuseQuantDequantToRequantizePass()(
347+
builder.get_graph_module()
348+
).graph_module
357349

358350
self.check_op_counts(
359351
graph_module,
360352
expected_op_counts={
361-
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
353+
# Verify that dequant -> quant was replaced with requantize.
362354
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
363355
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
364356
exir_ops.edge.cadence.requantize.default: 1,
365357
},
366358
)
367359

368360
def test_replace_dequant_permute_quant_with_requantize(self):
369-
class M(torch.nn.Module):
370-
def __init__(self):
371-
super().__init__()
372-
373-
def forward(self, x):
374-
x = torch.ops.quantized_decomposed.dequantize_per_tensor(
375-
x, 1.2, 3, 0, 127, torch.int8
376-
)
377-
x = torch.permute(x, [2, 0, 1, 3])
378-
x = torch.ops.quantized_decomposed.quantize_per_tensor(
379-
x, 4.5, 6, 0, 127, torch.int8
380-
)
381-
return x
382-
383-
inputs = torch.randn(2, 12, 1, 6).to(torch.int8)
384-
model = M()
385-
graph_module = export_to_edge(model, (inputs,)).exported_program().graph_module
386-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
361+
builder = GraphBuilder()
362+
x = builder.placeholder("x", torch.randn(2, 12, 1, 6, dtype=torch.float32))
363+
dequant = builder.call_operator(
364+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
365+
args=(x, 1.2, 3, 0, 127, torch.int8),
366+
)
367+
permute = builder.call_operator(
368+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant, [2, 0, 1, 3])
369+
)
370+
quant = builder.call_operator(
371+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
372+
args=(permute, 4.5, 6, 0, 127, torch.int8),
373+
)
374+
builder.output(quant)
375+
graph_module = FuseQuantDequantToRequantizePass()(
376+
builder.get_graph_module()
377+
).graph_module
387378

388379
self.check_op_counts(
389380
graph_module,
390381
expected_op_counts={
391382
# Verify that dequant -> permute -> quant was replaced with permute -> requantize.
392383
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 0,
393384
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 0,
385+
exir_ops.edge.aten.permute_copy.default: 1,
394386
exir_ops.edge.cadence.requantize.default: 1,
395387
},
396388
)
397389

398390
def test_remove_nop_dequant_quant(self):
399-
class M(torch.nn.Module):
400-
def __init__(self):
401-
super(M, self).__init__()
402-
self.lin1 = torch.nn.Linear(6, 12, bias=False)
403-
self.lin2 = torch.nn.Linear(12, 24, bias=False)
391+
LEADING_DIMS: Final[int] = 12
392+
IN_DIM: Final[int] = 6
393+
OUT_DIM: Final[int] = 12
404394

405-
def forward(self, x):
406-
x = self.lin1(x)
407-
# redundant dequant+quant will be created around this permute
408-
x = torch.permute(x, [0, 2, 1, 3])
409-
x = self.lin2(x)
410-
return x
411-
412-
inputs = torch.randn(2, 12, 1, 6)
413-
model = M()
414-
graph_module = (
415-
quantize_and_export_to_edge(model, (inputs,))
416-
.exported_program()
417-
.graph_module
395+
builder = GraphBuilder()
396+
x = builder.placeholder(
397+
"x", torch.randn(LEADING_DIMS, IN_DIM, dtype=torch.float32)
398+
)
399+
quant1 = builder.call_operator(
400+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
401+
args=(x, 4.5, 6, 0, 127, torch.int8),
402+
)
403+
weights = builder.call_operator(
404+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM, IN_DIM], 1)
405+
)
406+
bias = builder.call_operator(
407+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
408+
)
409+
weight_zero_point = builder.call_operator(
410+
op=exir_ops.edge.aten.full.default, args=([IN_DIM], 0)
411+
)
412+
out_multiplier = builder.call_operator(
413+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 1)
414+
)
415+
out_shift = builder.call_operator(
416+
op=exir_ops.edge.aten.full.default, args=([OUT_DIM], 0)
418417
)
419-
graph_module = FuseQuantDequantToRequantizePass()(graph_module).graph_module
418+
linear1 = builder.call_operator(
419+
op=exir_ops.edge.cadence.quantized_linear.default,
420+
args=(
421+
quant1,
422+
weights,
423+
bias,
424+
0, # src_zero_point
425+
weight_zero_point,
426+
out_multiplier,
427+
out_shift,
428+
0, # out_zero_point
429+
None,
430+
),
431+
)
432+
dequant1 = builder.call_operator(
433+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
434+
args=(linear1, 1.2, 3, 0, 127, torch.int8),
435+
)
436+
permute = builder.call_operator(
437+
op=exir_ops.edge.aten.permute_copy.default, args=(dequant1, [1, 0])
438+
)
439+
quant2 = builder.call_operator(
440+
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
441+
args=(permute, 4.5, 6, 0, 127, torch.int8),
442+
)
443+
linear2 = builder.call_operator(
444+
op=exir_ops.edge.cadence.quantized_linear.default,
445+
args=(
446+
quant2,
447+
weights,
448+
bias,
449+
0, # src_zero_point
450+
weight_zero_point,
451+
out_multiplier,
452+
out_shift,
453+
0, # out_zero_point
454+
None,
455+
),
456+
)
457+
dequant2 = builder.call_operator(
458+
op=exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
459+
args=(linear2, 1.2, 3, 0, 127, torch.int8),
460+
)
461+
builder.output(dequant2)
462+
graph_module = FuseQuantDequantToRequantizePass()(
463+
builder.get_graph_module()
464+
).graph_module
420465
self.check_op_counts(
421466
graph_module,
422467
expected_op_counts={
423-
# Verify that one dequant/quant pair was removed
424-
# Expect 1 quantize ops: 1 input
468+
# Verify that one dequant/quant pair was removed from chain:
469+
# quant->linear->dequant->permute->quant->linear->dequant
470+
# gets converted to:
471+
# quant->linear->permute->linear->dequant
425472
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default: 1,
426-
# Expect 1 dequant op at the end (output of second linear)
427473
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: 1,
428474
},
429475
)

examples/qualcomm/executor_runner/qnn_executor_runner.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -481,6 +481,12 @@ int main(int argc, char** argv) {
481481

482482
++inference_index;
483483
}
484+
ET_LOG(
485+
Info,
486+
"%d inference took %f ms, avg %f ms",
487+
inference_index,
488+
elapsed_time,
489+
elapsed_time / inference_index);
484490
} else {
485491
// if no input is provided, fill the inputs with default values
486492
auto inputs = prepare_input_tensors(*method);
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
@_exported import ExecuTorch
10+
11+
/// A protocol that types conform to in order to be used as tensor element types.
12+
/// Provides the mapping from the Swift type to the underlying `DataType`.
13+
@available(*, deprecated, message: "This API is experimental.")
14+
protocol Scalar {
15+
/// The `DataType` corresponding to this scalar type.
16+
static var dataType: DataType { get }
17+
}
18+
19+
@available(*, deprecated, message: "This API is experimental.")
20+
extension UInt8: Scalar { static var dataType: DataType { .byte } }
21+
@available(*, deprecated, message: "This API is experimental.")
22+
extension Int8: Scalar { static var dataType: DataType { .char } }
23+
@available(*, deprecated, message: "This API is experimental.")
24+
extension Int16: Scalar { static var dataType: DataType { .short } }
25+
@available(*, deprecated, message: "This API is experimental.")
26+
extension Int32: Scalar { static var dataType: DataType { .int } }
27+
@available(*, deprecated, message: "This API is experimental.")
28+
extension Int64: Scalar { static var dataType: DataType { .long } }
29+
@available(*, deprecated, message: "This API is experimental.")
30+
extension Int: Scalar { static var dataType: DataType { .long } }
31+
@available(macOS 11.0, *)
32+
@available(*, deprecated, message: "This API is experimental.")
33+
extension Float16: Scalar { static var dataType: DataType { .half } }
34+
@available(*, deprecated, message: "This API is experimental.")
35+
extension Float: Scalar { static var dataType: DataType { .float } }
36+
@available(*, deprecated, message: "This API is experimental.")
37+
extension Double: Scalar { static var dataType: DataType { .double } }
38+
@available(*, deprecated, message: "This API is experimental.")
39+
extension Bool: Scalar { static var dataType: DataType { .bool } }
40+
@available(*, deprecated, message: "This API is experimental.")
41+
extension UInt16: Scalar { static var dataType: DataType { .uInt16 } }
42+
@available(*, deprecated, message: "This API is experimental.")
43+
extension UInt32: Scalar { static var dataType: DataType { .uInt32 } }
44+
@available(*, deprecated, message: "This API is experimental.")
45+
extension UInt64: Scalar { static var dataType: DataType { .uInt64 } }
46+
@available(*, deprecated, message: "This API is experimental.")
47+
extension UInt: Scalar { static var dataType: DataType { .uInt64 } }
48+
49+
@available(*, deprecated, message: "This API is experimental.")
50+
extension Tensor {
51+
/// Calls the closure with a typed, immutable buffer pointer over the tensor’s elements.
52+
///
53+
/// - Parameter body: A closure that receives an `UnsafeBufferPointer<T>` bound to the tensor’s data.
54+
/// - Returns: The value returned by `body`.
55+
/// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`,
56+
/// or any error thrown by `body`.
57+
func withUnsafeBytes<T: Scalar, R>(_ body: (UnsafeBufferPointer<T>) throws -> R) throws -> R {
58+
guard dataType == T.dataType else { throw Error(code: .invalidArgument) }
59+
var result: Result<R, Error>?
60+
bytes { pointer, count, _ in
61+
result = Result { try body(
62+
UnsafeBufferPointer(
63+
start: pointer.assumingMemoryBound(to: T.self),
64+
count: count
65+
)
66+
) }
67+
}
68+
return try result!.get()
69+
}
70+
71+
/// Calls the closure with a typed, mutable buffer pointer over the tensor’s elements.
72+
///
73+
/// - Parameter body: A closure that receives an `UnsafeMutableBufferPointer<T>` bound to the tensor’s data.
74+
/// - Returns: The value returned by `body`.
75+
/// - Throws: `Error(code: .invalidArgument)` if `T.dataType` doesn’t match the tensor’s `dataType`,
76+
/// or any error thrown by `body`.
77+
func withUnsafeMutableBytes<T: Scalar, R>(_ body: (UnsafeMutableBufferPointer<T>) throws -> R) throws -> R {
78+
guard dataType == T.dataType else { throw Error(code: .invalidArgument) }
79+
var result: Result<R, Error>?
80+
mutableBytes { pointer, count, _ in
81+
result = Result { try body(
82+
UnsafeMutableBufferPointer(
83+
start: pointer.assumingMemoryBound(to: T.self),
84+
count: count
85+
)
86+
) }
87+
}
88+
return try result!.get()
89+
}
90+
}

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,54 @@ class TensorTest: XCTestCase {
148148
}
149149
}
150150

151+
func testWithUnsafeBytes() throws {
152+
var data: [Float] = [1, 2, 3, 4, 5, 6]
153+
let tensor = data.withUnsafeMutableBytes {
154+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [2, 3], dataType: .float)
155+
}
156+
let array: [Float] = try tensor.withUnsafeBytes { Array($0) }
157+
XCTAssertEqual(array, data)
158+
}
159+
160+
func testWithUnsafeMutableBytes() throws {
161+
var data = [1, 2, 3, 4]
162+
let tensor = data.withUnsafeMutableBytes {
163+
Tensor(bytes: $0.baseAddress!, shape: [4], dataType: .long)
164+
}
165+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Int>) in
166+
for i in buffer.indices {
167+
buffer[i] *= 2
168+
}
169+
}
170+
try tensor.withUnsafeBytes { buffer in
171+
XCTAssertEqual(Array(buffer), [2, 4, 6, 8])
172+
}
173+
}
174+
175+
func testWithUnsafeBytesFloat16() throws {
176+
var data: [Float16] = [1, 2, 3, 4, 5, 6]
177+
let tensor = data.withUnsafeMutableBytes {
178+
Tensor(bytesNoCopy: $0.baseAddress!, shape: [6], dataType: .half)
179+
}
180+
let array: [Float16] = try tensor.withUnsafeBytes { Array($0) }
181+
XCTAssertEqual(array, data)
182+
}
183+
184+
func testWithUnsafeMutableBytesFloat16() throws {
185+
var data: [Float16] = [1, 2, 3, 4]
186+
let tensor = data.withUnsafeMutableBytes { buffer in
187+
Tensor(bytes: buffer.baseAddress!, shape: [4], dataType: .half)
188+
}
189+
try tensor.withUnsafeMutableBytes { (buffer: UnsafeMutableBufferPointer<Float16>) in
190+
for i in buffer.indices {
191+
buffer[i] *= 2
192+
}
193+
}
194+
try tensor.withUnsafeBytes { buffer in
195+
XCTAssertEqual(Array(buffer), data.map { $0 * 2 })
196+
}
197+
}
198+
151199
func testInitWithTensor() {
152200
var data: [Int] = [10, 20, 30, 40]
153201
let tensor1 = data.withUnsafeMutableBytes {
@@ -618,7 +666,7 @@ class TensorTest: XCTestCase {
618666
}
619667
}
620668
}
621-
669+
622670
func testZeros() {
623671
let tensor = Tensor.zeros(shape: [2, 3], dataType: .double)
624672
XCTAssertEqual(tensor.shape, [2, 3])

kernels/aten/functions.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@
201201

202202
- op: index_put.out
203203

204+
- op: index_put_
205+
204206
- op: index_select.out
205207

206208
- op: index.Tensor_out

0 commit comments

Comments
 (0)