Skip to content

Commit 9a85b06

Browse files
authored
Helpers to create full tensors.
Differential Revision: D74020935 Pull Request resolved: #10641
1 parent ff0df1c commit 9a85b06

File tree

3 files changed

+192
-0
lines changed

3 files changed

+192
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,4 +748,95 @@ __attribute__((deprecated("This API is experimental.")))
748748

749749
@end
750750

751+
#pragma mark - Full Category
752+
753+
@interface ExecuTorchTensor (Full)
754+
755+
/**
756+
* Creates a tensor filled with the specified scalar value, with full specification of shape, strides, data type, and shape dynamism.
757+
*
758+
* @param shape An NSArray of NSNumber objects representing the desired shape.
759+
* @param scalar An NSNumber representing the value to fill the tensor.
760+
* @param strides An NSArray of NSNumber objects representing the desired strides.
761+
* @param dataType An ExecuTorchDataType value specifying the element type.
762+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
763+
* @return A new ExecuTorchTensor instance filled with the scalar value.
764+
*/
765+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
766+
scalar:(NSNumber *)scalar
767+
strides:(NSArray<NSNumber *> *)strides
768+
dataType:(ExecuTorchDataType)dataType
769+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
770+
NS_SWIFT_NAME(full(shape:scalar:strides:dataType:shapeDynamism:));
771+
772+
/**
773+
* Creates a tensor filled with the specified scalar value, with the given shape, data type, and shape dynamism.
774+
*
775+
* @param shape An NSArray of NSNumber objects representing the desired shape.
776+
* @param scalar An NSNumber representing the value to fill the tensor.
777+
* @param dataType An ExecuTorchDataType value specifying the element type.
778+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
779+
* @return A new ExecuTorchTensor instance filled with the scalar value.
780+
*/
781+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
782+
scalar:(NSNumber *)scalar
783+
dataType:(ExecuTorchDataType)dataType
784+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
785+
NS_SWIFT_NAME(full(shape:scalar:dataType:shapeDynamism:));
786+
787+
/**
788+
* Creates a tensor filled with the specified scalar value, with the given shape and data type,
789+
* using dynamic bound shape for strides and dimension order.
790+
*
791+
* @param shape An NSArray of NSNumber objects representing the desired shape.
792+
* @param scalar An NSNumber representing the value to fill the tensor.
793+
* @param dataType An ExecuTorchDataType value specifying the element type.
794+
* @return A new ExecuTorchTensor instance filled with the scalar value.
795+
*/
796+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
797+
scalar:(NSNumber *)scalar
798+
dataType:(ExecuTorchDataType)dataType
799+
NS_SWIFT_NAME(full(shape:scalar:dataType:));
800+
801+
/**
802+
* Creates a tensor filled with the specified scalar value, similar to an existing tensor, with the given data type and shape dynamism.
803+
*
804+
* @param tensr An existing ExecuTorchTensor instance whose shape and strides are used.
805+
* @param scalar An NSNumber representing the value to fill the tensor.
806+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
807+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
808+
* @return A new ExecuTorchTensor instance filled with the scalar value.
809+
*/
810+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensr
811+
scalar:(NSNumber *)scalar
812+
dataType:(ExecuTorchDataType)dataType
813+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
814+
NS_SWIFT_NAME(full(like:scalar:dataType:shapeDynamism:));
815+
816+
/**
817+
* Creates a tensor filled with the specified scalar value, similar to an existing tensor, with the given data type.
818+
*
819+
* @param tensr An existing ExecuTorchTensor instance whose shape and strides are used.
820+
* @param scalar An NSNumber representing the value to fill the tensor.
821+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
822+
* @return A new ExecuTorchTensor instance filled with the scalar value.
823+
*/
824+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensr
825+
scalar:(NSNumber *)scalar
826+
dataType:(ExecuTorchDataType)dataType
827+
NS_SWIFT_NAME(full(like:scalar:dataType:));
828+
829+
/**
830+
* Creates a tensor filled with the specified scalar value, similar to an existing tensor.
831+
*
832+
* @param tensr An existing ExecuTorchTensor instance.
833+
* @param scalar An NSNumber representing the value to fill the tensor.
834+
* @return A new ExecuTorchTensor instance filled with the scalar value.
835+
*/
836+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensr
837+
scalar:(NSNumber *)scalar
838+
NS_SWIFT_NAME(full(like:scalar:));
839+
840+
@end
841+
751842
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,3 +649,79 @@ + (instancetype)emptyTensorLikeTensor:(ExecuTorchTensor *)tensor {
649649
}
650650

651651
@end
652+
653+
@implementation ExecuTorchTensor (Full)
654+
655+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
656+
scalar:(NSNumber *)scalar
657+
strides:(NSArray<NSNumber *> *)strides
658+
dataType:(ExecuTorchDataType)dataType
659+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
660+
Scalar fillValue;
661+
ET_SWITCH_REALHBBF16_AND_UINT_TYPES(
662+
static_cast<ScalarType>(dataType), nil, "fullTensor", CTYPE, [&] {
663+
fillValue = utils::extractValue<CTYPE>(scalar);
664+
}
665+
);
666+
auto tensor = full_strided(
667+
utils::toVector<SizesType>(shape),
668+
utils::toVector<StridesType>(strides),
669+
fillValue,
670+
static_cast<ScalarType>(dataType),
671+
static_cast<TensorShapeDynamism>(shapeDynamism)
672+
);
673+
return [[self alloc] initWithNativeInstance:&tensor];
674+
}
675+
676+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
677+
scalar:(NSNumber *)scalar
678+
dataType:(ExecuTorchDataType)dataType
679+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
680+
return [self fullTensorWithShape:shape
681+
scalar:scalar
682+
strides:@[]
683+
dataType:dataType
684+
shapeDynamism:shapeDynamism];
685+
}
686+
687+
+ (instancetype)fullTensorWithShape:(NSArray<NSNumber *> *)shape
688+
scalar:(NSNumber *)scalar
689+
dataType:(ExecuTorchDataType)dataType {
690+
return [self fullTensorWithShape:shape
691+
scalar:scalar
692+
strides:@[]
693+
dataType:dataType
694+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
695+
}
696+
697+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensor
698+
scalar:(NSNumber *)scalar
699+
dataType:(ExecuTorchDataType)dataType
700+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
701+
return [self fullTensorWithShape:tensor.shape
702+
scalar:scalar
703+
strides:tensor.strides
704+
dataType:dataType
705+
shapeDynamism:shapeDynamism];
706+
}
707+
708+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensor
709+
scalar:(NSNumber *)scalar
710+
dataType:(ExecuTorchDataType)dataType {
711+
return [self fullTensorWithShape:tensor.shape
712+
scalar:scalar
713+
strides:tensor.strides
714+
dataType:dataType
715+
shapeDynamism:tensor.shapeDynamism];
716+
}
717+
718+
+ (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensor
719+
scalar:(NSNumber *)scalar {
720+
return [self fullTensorWithShape:tensor.shape
721+
scalar:scalar
722+
strides:tensor.strides
723+
dataType:tensor.dataType
724+
shapeDynamism:tensor.shapeDynamism];
725+
}
726+
727+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,4 +568,29 @@ class TensorTest: XCTestCase {
568568
XCTAssertEqual(tensor.dimensionOrder, other.dimensionOrder)
569569
XCTAssertEqual(tensor.dataType, other.dataType)
570570
}
571+
572+
func testFull() {
573+
let tensor = Tensor.full(shape: [2, 2], scalar: 7, dataType: .int)
574+
XCTAssertEqual(tensor.shape, [2, 2])
575+
XCTAssertEqual(tensor.count, 4)
576+
tensor.bytes { pointer, count, dataType in
577+
XCTAssertEqual(dataType, .int)
578+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
579+
for value in buffer {
580+
XCTAssertEqual(value, 7)
581+
}
582+
}
583+
}
584+
585+
func testFullLike() {
586+
let other = Tensor.empty(shape: [2, 2], dataType: .int)
587+
let tensor = Tensor.full(like: other, scalar: 42, dataType: .float)
588+
XCTAssertEqual(tensor.shape, other.shape)
589+
tensor.bytes { pointer, count, dataType in
590+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)
591+
for value in buffer {
592+
XCTAssertEqual(value, 42.0)
593+
}
594+
}
595+
}
571596
}

0 commit comments

Comments
 (0)