Skip to content

Commit 08ce3ed

Browse files
Overloads for scalar constructor. (#9696)
Summary: #8366 Reviewed By: bsoyluoglu Differential Revision: D71931436 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 29bc273 commit 08ce3ed

File tree

3 files changed

+222
-1
lines changed

3 files changed

+222
-1
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,110 @@ __attribute__((deprecated("This API is experimental.")))
564564
*/
565565
- (instancetype)initWithScalar:(NSNumber *)scalar NS_SWIFT_NAME(init(_:));
566566

567+
/**
568+
* Initializes a tensor with a byte scalar value.
569+
*
570+
* @param scalar A uint8_t value.
571+
* @return An initialized ExecuTorchTensor instance.
572+
*/
573+
- (instancetype)initWithByte:(uint8_t)scalar NS_SWIFT_NAME(init(_:));
574+
575+
/**
576+
* Initializes a tensor with a char scalar value.
577+
*
578+
* @param scalar An int8_t value.
579+
* @return An initialized ExecuTorchTensor instance.
580+
*/
581+
- (instancetype)initWithChar:(int8_t)scalar NS_SWIFT_NAME(init(_:));
582+
583+
/**
584+
* Initializes a tensor with a short scalar value.
585+
*
586+
* @param scalar An int16_t value.
587+
* @return An initialized ExecuTorchTensor instance.
588+
*/
589+
- (instancetype)initWithShort:(int16_t)scalar NS_SWIFT_NAME(init(_:));
590+
591+
/**
592+
* Initializes a tensor with an int scalar value.
593+
*
594+
* @param scalar An int32_t value.
595+
* @return An initialized ExecuTorchTensor instance.
596+
*/
597+
- (instancetype)initWithInt:(int32_t)scalar NS_SWIFT_NAME(init(_:));
598+
599+
/**
600+
* Initializes a tensor with a long scalar value.
601+
*
602+
* @param scalar An int64_t value.
603+
* @return An initialized ExecuTorchTensor instance.
604+
*/
605+
- (instancetype)initWithLong:(int64_t)scalar NS_SWIFT_NAME(init(_:));
606+
607+
/**
608+
* Initializes a tensor with a float scalar value.
609+
*
610+
* @param scalar A float value.
611+
* @return An initialized ExecuTorchTensor instance.
612+
*/
613+
- (instancetype)initWithFloat:(float)scalar NS_SWIFT_NAME(init(_:));
614+
615+
/**
616+
* Initializes a tensor with a double scalar value.
617+
*
618+
* @param scalar A double value.
619+
* @return An initialized ExecuTorchTensor instance.
620+
*/
621+
- (instancetype)initWithDouble:(double)scalar NS_SWIFT_NAME(init(_:));
622+
623+
/**
624+
* Initializes a tensor with a boolean scalar value.
625+
*
626+
* @param scalar A BOOL value.
627+
* @return An initialized ExecuTorchTensor instance.
628+
*/
629+
- (instancetype)initWithBool:(BOOL)scalar NS_SWIFT_NAME(init(_:));
630+
631+
/**
632+
* Initializes a tensor with a uint16 scalar value.
633+
*
634+
* @param scalar A uint16_t value.
635+
* @return An initialized ExecuTorchTensor instance.
636+
*/
637+
- (instancetype)initWithUInt16:(uint16_t)scalar NS_SWIFT_NAME(init(_:));
638+
639+
/**
640+
* Initializes a tensor with a uint32 scalar value.
641+
*
642+
* @param scalar A uint32_t value.
643+
* @return An initialized ExecuTorchTensor instance.
644+
*/
645+
- (instancetype)initWithUInt32:(uint32_t)scalar NS_SWIFT_NAME(init(_:));
646+
647+
/**
648+
* Initializes a tensor with a uint64 scalar value.
649+
*
650+
* @param scalar A uint64_t value.
651+
* @return An initialized ExecuTorchTensor instance.
652+
*/
653+
- (instancetype)initWithUInt64:(uint64_t)scalar NS_SWIFT_NAME(init(_:));
654+
655+
/**
656+
* Initializes a tensor with an NSInteger scalar value.
657+
*
658+
* @param scalar An NSInteger value.
659+
* @return An initialized ExecuTorchTensor instance.
660+
*/
661+
- (instancetype)initWithInteger:(NSInteger)scalar NS_SWIFT_NAME(init(_:));
662+
663+
/**
664+
* Initializes a tensor with an NSUInteger scalar value.
665+
*
666+
* @param scalar An NSUInteger value.
667+
* @return An initialized ExecuTorchTensor instance.
668+
*/
669+
- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar NS_SWIFT_NAME(init(_:));
670+
567671
@end
568672

569673
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,4 +477,121 @@ - (instancetype)initWithScalar:(NSNumber *)scalar {
477477
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
478478
}
479479

480+
- (instancetype)initWithByte:(uint8_t)scalar {
481+
return [self initWithBytes:&scalar
482+
shape:@[]
483+
strides:@[]
484+
dimensionOrder:@[]
485+
dataType:ExecuTorchDataTypeByte
486+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
487+
}
488+
489+
- (instancetype)initWithChar:(int8_t)scalar {
490+
return [self initWithBytes:&scalar
491+
shape:@[]
492+
strides:@[]
493+
dimensionOrder:@[]
494+
dataType:ExecuTorchDataTypeChar
495+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
496+
}
497+
498+
- (instancetype)initWithShort:(int16_t)scalar {
499+
return [self initWithBytes:&scalar
500+
shape:@[]
501+
strides:@[]
502+
dimensionOrder:@[]
503+
dataType:ExecuTorchDataTypeShort
504+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
505+
}
506+
507+
- (instancetype)initWithInt:(int32_t)scalar {
508+
return [self initWithBytes:&scalar
509+
shape:@[]
510+
strides:@[]
511+
dimensionOrder:@[]
512+
dataType:ExecuTorchDataTypeInt
513+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
514+
}
515+
516+
- (instancetype)initWithLong:(int64_t)scalar {
517+
return [self initWithBytes:&scalar
518+
shape:@[]
519+
strides:@[]
520+
dimensionOrder:@[]
521+
dataType:ExecuTorchDataTypeLong
522+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
523+
}
524+
525+
- (instancetype)initWithFloat:(float)scalar {
526+
return [self initWithBytes:&scalar
527+
shape:@[]
528+
strides:@[]
529+
dimensionOrder:@[]
530+
dataType:ExecuTorchDataTypeFloat
531+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
532+
}
533+
534+
- (instancetype)initWithDouble:(double)scalar {
535+
return [self initWithBytes:&scalar
536+
shape:@[]
537+
strides:@[]
538+
dimensionOrder:@[]
539+
dataType:ExecuTorchDataTypeDouble
540+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
541+
}
542+
543+
- (instancetype)initWithBool:(BOOL)scalar {
544+
return [self initWithBytes:&scalar
545+
shape:@[]
546+
strides:@[]
547+
dimensionOrder:@[]
548+
dataType:ExecuTorchDataTypeBool
549+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
550+
}
551+
552+
- (instancetype)initWithUInt16:(uint16_t)scalar {
553+
return [self initWithBytes:&scalar
554+
shape:@[]
555+
strides:@[]
556+
dimensionOrder:@[]
557+
dataType:ExecuTorchDataTypeUInt16
558+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
559+
}
560+
561+
- (instancetype)initWithUInt32:(uint32_t)scalar {
562+
return [self initWithBytes:&scalar
563+
shape:@[]
564+
strides:@[]
565+
dimensionOrder:@[]
566+
dataType:ExecuTorchDataTypeUInt32
567+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
568+
}
569+
570+
- (instancetype)initWithUInt64:(uint64_t)scalar {
571+
return [self initWithBytes:&scalar
572+
shape:@[]
573+
strides:@[]
574+
dimensionOrder:@[]
575+
dataType:ExecuTorchDataTypeUInt64
576+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
577+
}
578+
579+
- (instancetype)initWithInteger:(NSInteger)scalar {
580+
return [self initWithBytes:&scalar
581+
shape:@[]
582+
strides:@[]
583+
dimensionOrder:@[]
584+
dataType:(sizeof(scalar) == 8 ? ExecuTorchDataTypeLong : ExecuTorchDataTypeInt)
585+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
586+
}
587+
588+
- (instancetype)initWithUnsignedInteger:(NSUInteger)scalar {
589+
return [self initWithBytes:&scalar
590+
shape:@[]
591+
strides:@[]
592+
dimensionOrder:@[]
593+
dataType:(sizeof(scalar) == 8 ? ExecuTorchDataTypeUInt64 : ExecuTorchDataTypeUInt32)
594+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
595+
}
596+
480597
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,7 @@ class TensorTest: XCTestCase {
394394
}
395395

396396
func testInitFloat() {
397-
let tensor = Tensor(Float(42.0) as NSNumber)
397+
let tensor = Tensor(Float(42.0))
398398
XCTAssertEqual(tensor.dataType, .float)
399399
XCTAssertEqual(tensor.shape, [])
400400
XCTAssertEqual(tensor.strides, [])

0 commit comments

Comments
 (0)