@@ -23,6 +23,29 @@ class TensorImplPtrTest : public ::testing::Test {
23
23
}
24
24
};
25
25
26
+ TEST_F (TensorImplPtrTest, ScalarTensorCreation) {
27
+ float scalar_data = 3 .14f ;
28
+ auto tensor_impl =
29
+ make_tensor_impl_ptr (exec_aten::ScalarType::Float, {}, &scalar_data);
30
+
31
+ EXPECT_EQ (tensor_impl->numel (), 1 );
32
+ EXPECT_EQ (tensor_impl->dim (), 0 );
33
+ EXPECT_EQ (tensor_impl->sizes ().size (), 0 );
34
+ EXPECT_EQ (tensor_impl->strides ().size (), 0 );
35
+ EXPECT_EQ ((float *)tensor_impl->data (), &scalar_data);
36
+ EXPECT_EQ (((float *)tensor_impl->data ())[0 ], 3 .14f );
37
+ }
38
+
39
+ TEST_F (TensorImplPtrTest, ScalarTensorOwningData) {
40
+ auto tensor_impl = make_tensor_impl_ptr ({}, {3 .14f });
41
+
42
+ EXPECT_EQ (tensor_impl->numel (), 1 );
43
+ EXPECT_EQ (tensor_impl->dim (), 0 );
44
+ EXPECT_EQ (tensor_impl->sizes ().size (), 0 );
45
+ EXPECT_EQ (tensor_impl->strides ().size (), 0 );
46
+ EXPECT_EQ (((float *)tensor_impl->data ())[0 ], 3 .14f );
47
+ }
48
+
26
49
TEST_F (TensorImplPtrTest, TensorImplCreation) {
27
50
float data[20 ] = {2 };
28
51
auto tensor_impl = make_tensor_impl_ptr (
@@ -34,8 +57,8 @@ TEST_F(TensorImplPtrTest, TensorImplCreation) {
34
57
EXPECT_EQ (tensor_impl->strides ()[0 ], 5 );
35
58
EXPECT_EQ (tensor_impl->strides ()[1 ], 1 );
36
59
EXPECT_EQ (tensor_impl->data (), data);
37
- EXPECT_EQ (tensor_impl->mutable_data (), data);
38
- EXPECT_EQ (((float *)tensor_impl->mutable_data ())[0 ], 2 );
60
+ EXPECT_EQ (tensor_impl->data (), data);
61
+ EXPECT_EQ (((float *)tensor_impl->data ())[0 ], 2 );
39
62
}
40
63
41
64
TEST_F (TensorImplPtrTest, TensorImplSharedOwnership) {
0 commit comments