File tree Expand file tree Collapse file tree 3 files changed +57
-0
lines changed
extension/apple/ExecuTorch Expand file tree Collapse file tree 3 files changed +57
-0
lines changed Original file line number Diff line number Diff line change @@ -544,4 +544,26 @@ __attribute__((deprecated("This API is experimental.")))
544
544
545
545
@end
546
546
547
+ @interface ExecuTorchTensor (Scalar)
548
+
549
+ /* *
550
+ * Initializes a tensor with a single scalar value and a specified data type.
551
+ *
552
+ * @param scalar An NSNumber representing the scalar value.
553
+ * @param dataType An ExecuTorchDataType value specifying the element type.
554
+ * @return An initialized ExecuTorchTensor instance representing the scalar.
555
+ */
556
+ - (instancetype )initWithScalar : (NSNumber *)scalar
557
+ dataType : (ExecuTorchDataType)dataType NS_SWIFT_NAME(init(_:dataType:));
558
+
559
+ /* *
560
+ * Initializes a tensor with a single scalar value, automatically deducing its data type.
561
+ *
562
+ * @param scalar An NSNumber representing the scalar value.
563
+ * @return An initialized ExecuTorchTensor instance representing the scalar.
564
+ */
565
+ - (instancetype )initWithScalar : (NSNumber *)scalar NS_SWIFT_NAME(init(_:));
566
+
567
+ @end
568
+
547
569
NS_ASSUME_NONNULL_END
Original file line number Diff line number Diff line change @@ -455,3 +455,26 @@ - (instancetype)initWithScalars:(NSArray<NSNumber *> *)scalars {
455
455
}
456
456
457
457
@end
458
+
459
+ @implementation ExecuTorchTensor (Scalar)
460
+
461
+ - (instancetype )initWithScalar : (NSNumber *)scalar
462
+ dataType : (ExecuTorchDataType)dataType {
463
+ return [self initWithScalars: @[scalar]
464
+ shape: @[]
465
+ strides: @[]
466
+ dimensionOrder: @[]
467
+ dataType: dataType
468
+ shapeDynamism: ExecuTorchShapeDynamismDynamicBound];
469
+ }
470
+
471
+ - (instancetype )initWithScalar : (NSNumber *)scalar {
472
+ return [self initWithScalars: @[scalar]
473
+ shape: @[]
474
+ strides: @[]
475
+ dimensionOrder: @[]
476
+ dataType: static_cast <ExecuTorchDataType>(utils: :deduceType (scalar))
477
+ shapeDynamism: ExecuTorchShapeDynamismDynamicBound];
478
+ }
479
+
480
+ @end
Original file line number Diff line number Diff line change @@ -392,4 +392,16 @@ class TensorTest: XCTestCase {
392
392
XCTAssertEqual ( Array ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: UInt . self) , count: count) ) , data)
393
393
}
394
394
}
395
+
396
+ func testInitFloat( ) {
397
+ let tensor = Tensor ( Float ( 42.0 ) as NSNumber )
398
+ XCTAssertEqual ( tensor. dataType, . float)
399
+ XCTAssertEqual ( tensor. shape, [ ] )
400
+ XCTAssertEqual ( tensor. strides, [ ] )
401
+ XCTAssertEqual ( tensor. dimensionOrder, [ ] )
402
+ XCTAssertEqual ( tensor. count, 1 )
403
+ tensor. bytes { pointer, count, dataType in
404
+ XCTAssertEqual ( UnsafeBufferPointer ( start: pointer. assumingMemoryBound ( to: Float . self) , count: count) . first, 42.0 )
405
+ }
406
+ }
395
407
}
You can’t perform that action at this time.
0 commit comments