Skip to content

Commit 8ffdea1

Browse files
authored
Helpers to create ones tensor.
Differential Revision: D74020938 Pull Request resolved: #10642
1 parent 9a85b06 commit 8ffdea1

File tree

3 files changed

+138
-0
lines changed

3 files changed

+138
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -839,4 +839,67 @@ __attribute__((deprecated("This API is experimental.")))
839839

840840
@end
841841

842+
#pragma mark - Ones Category
843+
844+
@interface ExecuTorchTensor (Ones)
845+
846+
/**
847+
* Creates a tensor filled with ones, with the specified shape, data type, and shape dynamism.
848+
*
849+
* @param shape An NSArray of NSNumber objects representing the desired shape.
850+
* @param dataType An ExecuTorchDataType value specifying the element type.
851+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
852+
* @return A new ExecuTorchTensor instance filled with ones.
853+
*/
854+
+ (instancetype)onesTensorWithShape:(NSArray<NSNumber *> *)shape
855+
dataType:(ExecuTorchDataType)dataType
856+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
857+
NS_SWIFT_NAME(ones(shape:dataType:shapeDynamism:));
858+
859+
/**
860+
* Creates a tensor filled with ones, with the specified shape and data type.
861+
*
862+
* @param shape An NSArray of NSNumber objects representing the desired shape.
863+
* @param dataType An ExecuTorchDataType value specifying the element type.
864+
* @return A new ExecuTorchTensor instance filled with ones.
865+
*/
866+
+ (instancetype)onesTensorWithShape:(NSArray<NSNumber *> *)shape
867+
dataType:(ExecuTorchDataType)dataType
868+
NS_SWIFT_NAME(ones(shape:dataType:));
869+
870+
/**
871+
* Creates a tensor filled with ones similar to an existing tensor, with the specified data type and shape dynamism.
872+
*
873+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
874+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
875+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
876+
* @return A new ExecuTorchTensor instance filled with ones.
877+
*/
878+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor
879+
dataType:(ExecuTorchDataType)dataType
880+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
881+
NS_SWIFT_NAME(ones(like:dataType:shapeDynamism:));
882+
883+
/**
884+
* Creates a tensor filled with ones similar to an existing tensor, with the specified data type.
885+
*
886+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
887+
* @param dataType An ExecuTorchDataType value specifying the desired element type.
888+
* @return A new ExecuTorchTensor instance filled with ones.
889+
*/
890+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor
891+
dataType:(ExecuTorchDataType)dataType
892+
NS_SWIFT_NAME(ones(like:dataType:));
893+
894+
/**
895+
* Creates a tensor filled with ones similar to an existing tensor.
896+
*
897+
* @param tensor An existing ExecuTorchTensor instance.
898+
* @return A new ExecuTorchTensor instance filled with ones.
899+
*/
900+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor
901+
NS_SWIFT_NAME(ones(like:));
902+
903+
@end
904+
842905
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -725,3 +725,53 @@ + (instancetype)fullTensorLikeTensor:(ExecuTorchTensor *)tensor
725725
}
726726

727727
@end
728+
729+
@implementation ExecuTorchTensor (Ones)
730+
731+
+ (instancetype)onesTensorWithShape:(NSArray<NSNumber *> *)shape
732+
dataType:(ExecuTorchDataType)dataType
733+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
734+
return [self fullTensorWithShape:shape
735+
scalar:@(1)
736+
strides:@[]
737+
dataType:dataType
738+
shapeDynamism:shapeDynamism];
739+
}
740+
741+
+ (instancetype)onesTensorWithShape:(NSArray<NSNumber *> *)shape
742+
dataType:(ExecuTorchDataType)dataType {
743+
return [self fullTensorWithShape:shape
744+
scalar:@(1)
745+
strides:@[]
746+
dataType:dataType
747+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
748+
}
749+
750+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor
751+
dataType:(ExecuTorchDataType)dataType
752+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
753+
return [self fullTensorWithShape:tensor.shape
754+
scalar:@(1)
755+
strides:tensor.strides
756+
dataType:dataType
757+
shapeDynamism:shapeDynamism];
758+
}
759+
760+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor
761+
dataType:(ExecuTorchDataType)dataType {
762+
return [self fullTensorWithShape:tensor.shape
763+
scalar:@(1)
764+
strides:tensor.strides
765+
dataType:dataType
766+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
767+
}
768+
769+
+ (instancetype)onesTensorLikeTensor:(ExecuTorchTensor *)tensor {
770+
return [self fullTensorWithShape:tensor.shape
771+
scalar:@(1)
772+
strides:tensor.strides
773+
dataType:tensor.dataType
774+
shapeDynamism:tensor.shapeDynamism];
775+
}
776+
777+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -593,4 +593,29 @@ class TensorTest: XCTestCase {
593593
}
594594
}
595595
}
596+
597+
func testOnes() {
598+
let tensor = Tensor.ones(shape: [2, 3], dataType: .float)
599+
XCTAssertEqual(tensor.shape, [2, 3])
600+
XCTAssertEqual(tensor.count, 6)
601+
tensor.bytes { pointer, count, dataType in
602+
XCTAssertEqual(dataType, .float)
603+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count)
604+
for value in buffer {
605+
XCTAssertEqual(value, 1.0)
606+
}
607+
}
608+
}
609+
610+
func testOnesLike() {
611+
let other = Tensor.empty(shape: [2, 4], dataType: .double)
612+
let tensor = Tensor.ones(like: other)
613+
XCTAssertEqual(tensor.shape, other.shape)
614+
tensor.bytes { pointer, count, dataType in
615+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count)
616+
for value in buffer {
617+
XCTAssertEqual(value, 1.0)
618+
}
619+
}
620+
}
596621
}

0 commit comments

Comments
 (0)