Skip to content

Commit 30d7882

Browse files
More tests for scalar constructor. (#9697)
Summary: #8366 Reviewed By: bsoyluoglu Differential Revision: D71932495 Co-authored-by: Anthony Shoumikhin <[email protected]>
1 parent 25ff7c8 commit 30d7882

File tree

1 file changed

+144
-0
lines changed

1 file changed

+144
-0
lines changed

extension/apple/ExecuTorch/__tests__/TensorTest.swift

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,114 @@ class TensorTest: XCTestCase {
393393
}
394394
}
395395

396+
func testInitInt8() {
397+
let tensor = Tensor(Int8(42))
398+
XCTAssertEqual(tensor.dataType, .char)
399+
XCTAssertEqual(tensor.shape, [])
400+
XCTAssertEqual(tensor.strides, [])
401+
XCTAssertEqual(tensor.dimensionOrder, [])
402+
XCTAssertEqual(tensor.count, 1)
403+
tensor.bytes { pointer, count, dataType in
404+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int8.self), count: count).first, 42)
405+
}
406+
}
407+
408+
func testInitInt16() {
409+
let tensor = Tensor(Int16(42))
410+
XCTAssertEqual(tensor.dataType, .short)
411+
XCTAssertEqual(tensor.shape, [])
412+
XCTAssertEqual(tensor.strides, [])
413+
XCTAssertEqual(tensor.dimensionOrder, [])
414+
XCTAssertEqual(tensor.count, 1)
415+
tensor.bytes { pointer, count, dataType in
416+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int16.self), count: count).first, 42)
417+
}
418+
}
419+
420+
func testInitInt32() {
421+
let tensor = Tensor(Int32(42))
422+
XCTAssertEqual(tensor.dataType, .int)
423+
XCTAssertEqual(tensor.shape, [])
424+
XCTAssertEqual(tensor.strides, [])
425+
XCTAssertEqual(tensor.dimensionOrder, [])
426+
XCTAssertEqual(tensor.count, 1)
427+
tensor.bytes { pointer, count, dataType in
428+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int32.self), count: count).first, 42)
429+
}
430+
}
431+
432+
func testInitInt64() {
433+
let tensor = Tensor(Int64(42))
434+
XCTAssertEqual(tensor.dataType, .long)
435+
XCTAssertEqual(tensor.shape, [])
436+
XCTAssertEqual(tensor.strides, [])
437+
XCTAssertEqual(tensor.dimensionOrder, [])
438+
XCTAssertEqual(tensor.count, 1)
439+
tensor.bytes { pointer, count, dataType in
440+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int64.self), count: count).first, 42)
441+
}
442+
}
443+
444+
func testInitUInt8() {
445+
let tensor = Tensor(UInt8(42))
446+
XCTAssertEqual(tensor.dataType, .byte)
447+
XCTAssertEqual(tensor.shape, [])
448+
XCTAssertEqual(tensor.strides, [])
449+
XCTAssertEqual(tensor.dimensionOrder, [])
450+
XCTAssertEqual(tensor.count, 1)
451+
tensor.bytes { pointer, count, dataType in
452+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt8.self), count: count).first, 42)
453+
}
454+
}
455+
456+
func testInitUInt16() {
457+
let tensor = Tensor(UInt16(42))
458+
XCTAssertEqual(tensor.dataType, .uInt16)
459+
XCTAssertEqual(tensor.shape, [])
460+
XCTAssertEqual(tensor.strides, [])
461+
XCTAssertEqual(tensor.dimensionOrder, [])
462+
XCTAssertEqual(tensor.count, 1)
463+
tensor.bytes { pointer, count, dataType in
464+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt16.self), count: count).first, 42)
465+
}
466+
}
467+
468+
func testInitUInt32() {
469+
let tensor = Tensor(UInt32(42))
470+
XCTAssertEqual(tensor.dataType, .uInt32)
471+
XCTAssertEqual(tensor.shape, [])
472+
XCTAssertEqual(tensor.strides, [])
473+
XCTAssertEqual(tensor.dimensionOrder, [])
474+
XCTAssertEqual(tensor.count, 1)
475+
tensor.bytes { pointer, count, dataType in
476+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt32.self), count: count).first, 42)
477+
}
478+
}
479+
480+
func testInitUInt64() {
481+
let tensor = Tensor(UInt64(42))
482+
XCTAssertEqual(tensor.dataType, .uInt64)
483+
XCTAssertEqual(tensor.shape, [])
484+
XCTAssertEqual(tensor.strides, [])
485+
XCTAssertEqual(tensor.dimensionOrder, [])
486+
XCTAssertEqual(tensor.count, 1)
487+
tensor.bytes { pointer, count, dataType in
488+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt64.self), count: count).first, 42)
489+
}
490+
}
491+
492+
func testInitBool() {
493+
let tensor = Tensor(true)
494+
XCTAssertEqual(tensor.dataType, .bool)
495+
XCTAssertEqual(tensor.shape, [])
496+
XCTAssertEqual(tensor.strides, [])
497+
XCTAssertEqual(tensor.dimensionOrder, [])
498+
XCTAssertEqual(tensor.count, 1)
499+
tensor.bytes { pointer, count, dataType in
500+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Bool.self), count: count).first, true)
501+
}
502+
}
503+
396504
func testInitFloat() {
397505
let tensor = Tensor(Float(42.0))
398506
XCTAssertEqual(tensor.dataType, .float)
@@ -404,4 +512,40 @@ class TensorTest: XCTestCase {
404512
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Float.self), count: count).first, 42.0)
405513
}
406514
}
515+
516+
func testInitDouble() {
517+
let tensor = Tensor(42.0)
518+
XCTAssertEqual(tensor.dataType, .double)
519+
XCTAssertEqual(tensor.shape, [])
520+
XCTAssertEqual(tensor.strides, [])
521+
XCTAssertEqual(tensor.dimensionOrder, [])
522+
XCTAssertEqual(tensor.count, 1)
523+
tensor.bytes { pointer, count, dataType in
524+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Double.self), count: count).first, 42.0)
525+
}
526+
}
527+
528+
func testInitInt() {
529+
let tensor = Tensor(42)
530+
XCTAssertEqual(tensor.dataType, .long)
531+
XCTAssertEqual(tensor.shape, [])
532+
XCTAssertEqual(tensor.strides, [])
533+
XCTAssertEqual(tensor.dimensionOrder, [])
534+
XCTAssertEqual(tensor.count, 1)
535+
tensor.bytes { pointer, count, dataType in
536+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: Int.self), count: count).first, 42)
537+
}
538+
}
539+
540+
func testInitUInt() {
541+
let tensor = Tensor(UInt(42))
542+
XCTAssertEqual(tensor.dataType, .uInt64)
543+
XCTAssertEqual(tensor.shape, [])
544+
XCTAssertEqual(tensor.strides, [])
545+
XCTAssertEqual(tensor.dimensionOrder, [])
546+
XCTAssertEqual(tensor.count, 1)
547+
tensor.bytes { pointer, count, dataType in
548+
XCTAssertEqual(UnsafeBufferPointer(start: pointer.assumingMemoryBound(to: UInt.self), count: count).first, 42)
549+
}
550+
}
407551
}

0 commit comments

Comments
 (0)