Skip to content

Commit 22d2558

Browse files
authored
Helpers to create random integer tensor. (#10649)
Summary: #8366 Reviewed By: kirklandsign Differential Revision: D74020940
1 parent 444d7f9 commit 22d2558

File tree

3 files changed

+213
-0
lines changed

3 files changed

+213
-0
lines changed

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.h

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1063,4 +1063,111 @@ __attribute__((deprecated("This API is experimental.")))
10631063

10641064
@end
10651065

1066+
#pragma mark - RandomInteger Category
1067+
1068+
@interface ExecuTorchTensor (RandomInteger)
1069+
1070+
/**
1071+
* Creates a tensor with random integer values in the specified range,
1072+
* with full specification of shape, strides, data type, and shape dynamism.
1073+
*
1074+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1075+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1076+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1077+
* @param strides An NSArray of NSNumber objects representing the desired strides.
1078+
* @param dataType An ExecuTorchDataType value specifying the element type.
1079+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1080+
* @return A new ExecuTorchTensor instance filled with random integer values.
1081+
*/
1082+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1083+
high:(NSInteger)high
1084+
shape:(NSArray<NSNumber *> *)shape
1085+
strides:(NSArray<NSNumber *> *)strides
1086+
dataType:(ExecuTorchDataType)dataType
1087+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1088+
NS_SWIFT_NAME(randint(low:high:shape:strides:dataType:shapeDynamism:));
1089+
1090+
/**
1091+
* Creates a tensor with random integer values in the specified range,
1092+
* with the given shape and data type.
1093+
*
1094+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1095+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1096+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1097+
* @param dataType An ExecuTorchDataType value specifying the element type.
1098+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1099+
* @return A new ExecuTorchTensor instance filled with random integer values.
1100+
*/
1101+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1102+
high:(NSInteger)high
1103+
shape:(NSArray<NSNumber *> *)shape
1104+
dataType:(ExecuTorchDataType)dataType
1105+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1106+
NS_SWIFT_NAME(randint(low:high:shape:dataType:shapeDynamism:));
1107+
1108+
/**
1109+
* Creates a tensor with random integer values in the specified range,
1110+
* with the given shape (using dynamic bound shape) and data type.
1111+
*
1112+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1113+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1114+
* @param shape An NSArray of NSNumber objects representing the desired shape.
1115+
* @param dataType An ExecuTorchDataType value specifying the element type.
1116+
* @return A new ExecuTorchTensor instance filled with random integer values.
1117+
*/
1118+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
1119+
high:(NSInteger)high
1120+
shape:(NSArray<NSNumber *> *)shape
1121+
dataType:(ExecuTorchDataType)dataType
1122+
NS_SWIFT_NAME(randint(low:high:shape:dataType:));
1123+
1124+
/**
1125+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor,
1126+
* with the given data type and shape dynamism.
1127+
*
1128+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1129+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1130+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1131+
* @param dataType An ExecuTorchDataType value specifying the element type.
1132+
* @param shapeDynamism An ExecuTorchShapeDynamism value specifying whether the shape is static or dynamic.
1133+
* @return A new ExecuTorchTensor instance filled with random integer values.
1134+
*/
1135+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1136+
low:(NSInteger)low
1137+
high:(NSInteger)high
1138+
dataType:(ExecuTorchDataType)dataType
1139+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism
1140+
NS_SWIFT_NAME(randint(like:low:high:dataType:shapeDynamism:));
1141+
1142+
/**
1143+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor,
1144+
* with the given data type.
1145+
*
1146+
* @param tensor An existing ExecuTorchTensor instance whose shape and strides are used.
1147+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1148+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1149+
* @param dataType An ExecuTorchDataType value specifying the element type.
1150+
* @return A new ExecuTorchTensor instance filled with random integer values.
1151+
*/
1152+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1153+
low:(NSInteger)low
1154+
high:(NSInteger)high
1155+
dataType:(ExecuTorchDataType)dataType
1156+
NS_SWIFT_NAME(randint(like:low:high:dataType:));
1157+
1158+
/**
1159+
* Creates a tensor with random integer values in the specified range, similar to an existing tensor.
1160+
*
1161+
* @param tensor An existing ExecuTorchTensor instance.
1162+
* @param low An NSInteger specifying the inclusive lower bound of random values.
1163+
* @param high An NSInteger specifying the exclusive upper bound of random values.
1164+
* @return A new ExecuTorchTensor instance filled with random integer values.
1165+
*/
1166+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
1167+
low:(NSInteger)low
1168+
high:(NSInteger)high
1169+
NS_SWIFT_NAME(randint(like:low:high:));
1170+
1171+
@end
1172+
10661173
NS_ASSUME_NONNULL_END

extension/apple/ExecuTorch/Exported/ExecuTorchTensor.mm

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -891,3 +891,85 @@ + (instancetype)randomNormalTensorLikeTensor:(ExecuTorchTensor *)tensor {
891891
}
892892

893893
@end
894+
895+
@implementation ExecuTorchTensor (RandomInteger)
896+
897+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
898+
high:(NSInteger)high
899+
shape:(NSArray<NSNumber *> *)shape
900+
strides:(NSArray<NSNumber *> *)strides
901+
dataType:(ExecuTorchDataType)dataType
902+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
903+
auto tensor = randint_strided(
904+
low,
905+
high,
906+
utils::toVector<SizesType>(shape),
907+
utils::toVector<StridesType>(strides),
908+
static_cast<ScalarType>(dataType),
909+
static_cast<TensorShapeDynamism>(shapeDynamism)
910+
);
911+
return [[self alloc] initWithNativeInstance:&tensor];
912+
}
913+
914+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
915+
high:(NSInteger)high
916+
shape:(NSArray<NSNumber *> *)shape
917+
dataType:(ExecuTorchDataType)dataType
918+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
919+
return [self randomIntegerTensorWithLow:low
920+
high:high
921+
shape:shape
922+
strides:@[]
923+
dataType:dataType
924+
shapeDynamism:shapeDynamism];
925+
}
926+
927+
+ (instancetype)randomIntegerTensorWithLow:(NSInteger)low
928+
high:(NSInteger)high
929+
shape:(NSArray<NSNumber *> *)shape
930+
dataType:(ExecuTorchDataType)dataType {
931+
return [self randomIntegerTensorWithLow:low
932+
high:high
933+
shape:shape
934+
strides:@[]
935+
dataType:dataType
936+
shapeDynamism:ExecuTorchShapeDynamismDynamicBound];
937+
}
938+
939+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
940+
low:(NSInteger)low
941+
high:(NSInteger)high
942+
dataType:(ExecuTorchDataType)dataType
943+
shapeDynamism:(ExecuTorchShapeDynamism)shapeDynamism {
944+
return [self randomIntegerTensorWithLow:low
945+
high:high
946+
shape:tensor.shape
947+
strides:tensor.strides
948+
dataType:dataType
949+
shapeDynamism:shapeDynamism];
950+
}
951+
952+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
953+
low:(NSInteger)low
954+
high:(NSInteger)high
955+
dataType:(ExecuTorchDataType)dataType {
956+
return [self randomIntegerTensorWithLow:low
957+
high:high
958+
shape:tensor.shape
959+
strides:tensor.strides
960+
dataType:dataType
961+
shapeDynamism:tensor.shapeDynamism];
962+
}
963+
964+
+ (instancetype)randomIntegerTensorLikeTensor:(ExecuTorchTensor *)tensor
965+
low:(NSInteger)low
966+
high:(NSInteger)high {
967+
return [self randomIntegerTensorWithLow:low
968+
high:high
969+
shape:tensor.shape
970+
strides:tensor.strides
971+
dataType:tensor.dataType
972+
shapeDynamism:tensor.shapeDynamism];
973+
}
974+
975+
@end

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -655,4 +655,28 @@ class TensorTest: XCTestCase {
655655
XCTAssertEqual(tensor.shape, other.shape)
656656
XCTAssertEqual(tensor.count, other.count)
657657
}
658+
659+
func testRandomInteger() {
660+
let tensor = Tensor.randint(low: 10, high: 20, shape: [5], dataType: .int)
661+
XCTAssertEqual(tensor.shape, [5])
662+
XCTAssertEqual(tensor.count, 5)
663+
tensor.bytes { pointer, count, dataType in
664+
XCTAssertEqual(dataType, .int)
665+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
666+
for value in buffer {
667+
XCTAssertTrue(value >= 10 && value < 20)
668+
}
669+
}
670+
}
671+
672+
func testRandomIntegerLike() {
673+
let other = Tensor.ones(shape: [5], dataType: .int)
674+
let tensor = Tensor.randint(like: other, low: 100, high: 200)
675+
tensor.bytes { pointer, count, dataType in
676+
let buffer = UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count)
677+
for value in buffer {
678+
XCTAssertTrue(value >= 100 && value < 200)
679+
}
680+
}
681+
}
658682
}

0 commit comments

Comments
 (0)