Skip to content

Commit 29bc273

Browse files
Tensor constructor to create with a single scalar. (#9695)
Summary: #8366 Reviewed By: bsoyluoglu Differential Revision: D71930917 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent cdf9cc5 commit 29bc273

File tree

3 files changed

+57
-0
lines changed

3 files changed

+57
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental.")))
544544

545545
@end
546546

547+
@interface ExecuTorchTensor (Scalar)
548+
549+
/**
550+
* Initializes a tensor with a single scalar value and a specified data type.
551+
*
552+
* @param scalar An NSNumber representing the scalar value.
553+
* @param dataType An ExecuTorchDataType value specifying the element type.
554+
* @return An initialized ExecuTorchTensor instance representing the scalar.
555+
*/
556+
- (instancetype)initWithScalar:(NSNumber *)scalar
557+
dataType:(ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:));
558+
559+
/**
560+
* Initializes a tensor with a single scalar value, automatically deducing its data type.
561+
*
562+
* @param scalar An NSNumber representing the scalar value.
563+
* @return An initialized ExecuTorchTensor instance representing the scalar.
564+
*/
565+
- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:));
566+
567+
@end
568+
547569
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars {
455455
}
456456

457457
@end
458+
459+
@implementation ExecuTorchTensor (Scalar)
460+
461+
- (instancetype)initWithScalar:(NSNumber *)scalar
462+
dataType:(ExecuTorchDataType)dataType {
463+
return [self initWithScalars:@[scalar]
464+
shape:@[]
465+
strides:@[]
466+
dimensionOrder:@[]
467+
dataType:dataType
468+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
469+
}
470+
471+
- (instancetype)initWithScalar:(NSNumber *)scalar {
472+
return [self initWithScalars:@[scalar]
473+
shape:@[]
474+
strides:@[]
475+
dimensionOrder:@[]
476+
dataType:static_cast<ExecuTorchDataType>(utils::deduceType(scalar))
477+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
478+
}
479+
480+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,4 +392,16 @@ class TensorTest: XCTestCase {
392392
XCTAssertEqual(Array(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt.self), count: count)), data)
393393
}
394394
}
395+
396+
func testInitFloat() {
397+
let tensor = Tensor(Float(42.0) as NSNumber)
398+
XCTAssertEqual(tensor.dataType, .float)
399+
XCTAssertEqual(tensor.shape, [])
400+
XCTAssertEqual(tensor.strides, [])
401+
XCTAssertEqual(tensor.dimensionOrder, [])
402+
XCTAssertEqual(tensor.count, 1)
403+
tensor.bytes { pointer, count, dataType in
404+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count).first, 42.0)
405+
}
406+
}
395407
}

0 commit comments

Comments
 (0)