@@ -120,3 +120,72 @@ TEST_F(TensorParserTest, TestModuleAddHalf) {
120
120
torch::executor::ScalarType::Half,
121
121
sizeof (torch::executor::Half));
122
122
}
123
+
124
+ TEST_F (TensorParserTest, TestMutableState) {
125
+ // Load the serialized ModuleSimpleTrain data.
126
+ const char * path = std::getenv (" ET_MODULE_SIMPLE_TRAIN_PATH" );
127
+ Result<FileDataLoader> train_loader = FileDataLoader::from (path);
128
+ ASSERT_EQ (train_loader.error (), Error::Ok);
129
+
130
+ Result<Program> program =
131
+ Program::load (&train_loader.get (), Program::Verification::Minimal);
132
+ EXPECT_EQ (program.error (), Error::Ok);
133
+
134
+ ManagedMemoryManager mmm (kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
135
+ ManagedMemoryManager mmm_copy (
136
+ kDefaultNonConstMemBytes , kDefaultRuntimeMemBytes );
137
+
138
+ const executorch_flatbuffer::Program* internal_program =
139
+ ProgramTestFriend::GetInternalProgram (&program.get ());
140
+ executorch_flatbuffer::ExecutionPlan* execution_plan =
141
+ internal_program->execution_plan ()->GetMutableObject (0 );
142
+ auto flatbuffer_values = execution_plan->values ();
143
+
144
+ size_t num_mutable_tensors = 0 ;
145
+ for (size_t i = 0 ; i < flatbuffer_values->size (); ++i) {
146
+ auto serialization_value = flatbuffer_values->Get (i);
147
+ if (serialization_value->val_type () ==
148
+ executorch_flatbuffer::KernelTypes::Tensor &&
149
+ serialization_value->val_as_Tensor ()->allocation_info () != nullptr &&
150
+ serialization_value->val_as_Tensor ()->data_buffer_idx () > 0 ) {
151
+ num_mutable_tensors++;
152
+ Result<torch::executor::Tensor> tensor = parseTensor (
153
+ &program.get (), &mmm.get (), serialization_value->val_as_Tensor ());
154
+ torch::executor::Tensor t = tensor.get ();
155
+ float loaded_value = t.const_data_ptr <float >()[0 ];
156
+ ASSERT_NE (nullptr , t.const_data_ptr ());
157
+ ASSERT_NE (t.mutable_data_ptr <float >()[0 ], 0.5 );
158
+ t.mutable_data_ptr <float >()[0 ] = 0.5 ;
159
+ ASSERT_EQ (
160
+ t.mutable_data_ptr <float >()[0 ],
161
+ 0.5 ); // 0.5 can be represented perfectly by float so EQ and NE work
162
+ // fine here. Any power of 2 rational can be perfectly
163
+ // represented. See dyadic rationals for more info.
164
+
165
+ // Load the same tensor using the same mem manager and show the value is
166
+ // updated again.
167
+ Result<torch::executor::Tensor> tensor1_alias = parseTensor (
168
+ &program.get (), &mmm.get (), serialization_value->val_as_Tensor ());
169
+ torch::executor::Tensor t2 = tensor.get ();
170
+ ASSERT_NE (t2.mutable_data_ptr <float >()[0 ], 0.5 );
171
+
172
+ // Show the tensors are equivalent
173
+ ASSERT_EQ (t.const_data_ptr (), t2.const_data_ptr ());
174
+ // Set mutable tensor value back to 0.5 since it got overwritten by second
175
+ // parse.
176
+ t.mutable_data_ptr <float >()[0 ] = 0.5 ;
177
+
178
+ // Load the same tensor using a different mem manager and show the value
179
+ // is not the same as t.
180
+ Result<torch::executor::Tensor> tensor_new = parseTensor (
181
+ &program.get (),
182
+ &mmm_copy.get (),
183
+ serialization_value->val_as_Tensor ());
184
+ torch::executor::Tensor t3 = tensor_new.get ();
185
+ ASSERT_NE (t3.mutable_data_ptr <float >()[0 ], 0.5 );
186
+ ASSERT_NE (t3.const_data_ptr (), t.const_data_ptr ());
187
+ ASSERT_EQ (loaded_value, t3.const_data_ptr <float >()[0 ]);
188
+ }
189
+ }
190
+ ASSERT_EQ (num_mutable_tensors, 2 );
191
+ }
0 commit comments