@@ -1094,3 +1094,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
1094
1094
uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1095
1095
Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1096
1096
}
1097
+
1098
+ struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1099
+ virtual void SetUp () override {
1100
+ program_name = " saxpy_usm_local_mem" ;
1101
+ UUR_RETURN_ON_FATAL_FAILURE (
1102
+ urUpdatableCommandBufferExpExecutionTest::SetUp ());
1103
+
1104
+ if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1105
+ GTEST_SKIP ()
1106
+ << " Local memory argument update not supported on Level Zero." ;
1107
+ }
1108
+
1109
+ // HIP has extra args for local memory so we define an offset for arg
1110
+ // indices here for updating
1111
+ hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0 ;
1112
+ ur_device_usm_access_capability_flags_t shared_usm_flags;
1113
+ ASSERT_SUCCESS (
1114
+ uur::GetDeviceUSMSingleSharedSupport (device, shared_usm_flags));
1115
+ if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1116
+ GTEST_SKIP () << " Shared USM is not supported." ;
1117
+ }
1118
+
1119
+ const size_t allocation_size =
1120
+ sizeof (uint32_t ) * global_size * local_size;
1121
+ for (auto &shared_ptr : shared_ptrs) {
1122
+ ASSERT_SUCCESS (urUSMSharedAlloc (context, device, nullptr , nullptr ,
1123
+ allocation_size, &shared_ptr));
1124
+ ASSERT_NE (shared_ptr, nullptr );
1125
+
1126
+ std::vector<uint8_t > pattern (allocation_size);
1127
+ uur::generateMemFillPattern (pattern);
1128
+ std::memcpy (shared_ptr, pattern.data (), allocation_size);
1129
+ }
1130
+
1131
+ std::array<size_t , 12 > index_order{};
1132
+ if (backend != UR_PLATFORM_BACKEND_HIP) {
1133
+ index_order = {3 , 2 , 4 , 5 , 1 , 0 };
1134
+ } else {
1135
+ index_order = {9 , 8 , 10 , 11 , 4 , 5 , 6 , 7 , 0 , 1 , 2 , 3 };
1136
+ }
1137
+ size_t current_index = 0 ;
1138
+
1139
+ // Index 3 is A
1140
+ ASSERT_SUCCESS (urKernelSetArgValue (kernel, index_order[current_index++],
1141
+ sizeof (A), nullptr , &A));
1142
+ // Index 2 is output
1143
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1144
+ kernel, index_order[current_index++], nullptr , shared_ptrs[0 ]));
1145
+
1146
+ // Index 4 is X
1147
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1148
+ kernel, index_order[current_index++], nullptr , shared_ptrs[1 ]));
1149
+ // Index 5 is Y
1150
+ ASSERT_SUCCESS (urKernelSetArgPointer (
1151
+ kernel, index_order[current_index++], nullptr , shared_ptrs[2 ]));
1152
+
1153
+ // Index 1 is local_mem_b arg
1154
+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1155
+ local_mem_b_size, nullptr ));
1156
+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1157
+ ASSERT_SUCCESS (urKernelSetArgValue (
1158
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1159
+ nullptr , &hip_local_offset));
1160
+ ASSERT_SUCCESS (urKernelSetArgValue (
1161
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1162
+ nullptr , &hip_local_offset));
1163
+ ASSERT_SUCCESS (urKernelSetArgValue (
1164
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1165
+ nullptr , &hip_local_offset));
1166
+ }
1167
+
1168
+ // Index 0 is local_mem_a arg
1169
+ ASSERT_SUCCESS (urKernelSetArgLocal (kernel, index_order[current_index++],
1170
+ local_mem_a_size, nullptr ));
1171
+
1172
+ // Hip has extra args for local mem at index 1-3
1173
+ if (backend == UR_PLATFORM_BACKEND_HIP) {
1174
+ ASSERT_SUCCESS (urKernelSetArgValue (
1175
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1176
+ nullptr , &hip_local_offset));
1177
+ ASSERT_SUCCESS (urKernelSetArgValue (
1178
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1179
+ nullptr , &hip_local_offset));
1180
+ ASSERT_SUCCESS (urKernelSetArgValue (
1181
+ kernel, index_order[current_index++], sizeof (hip_local_offset),
1182
+ nullptr , &hip_local_offset));
1183
+ }
1184
+ }
1185
+ };
1186
+
1187
+ struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1188
+ void SetUp () override {
1189
+ UUR_RETURN_ON_FATAL_FAILURE (
1190
+ LocalMemoryUpdateTestBaseOutOfOrder::SetUp ());
1191
+
1192
+ // Append kernel command to command-buffer and close command-buffer
1193
+ ASSERT_SUCCESS (urCommandBufferAppendKernelLaunchExp (
1194
+ updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1195
+ &global_size, &local_size, 0 , nullptr , 0 , nullptr , 0 , nullptr ,
1196
+ nullptr , nullptr , &command_handle));
1197
+ ASSERT_NE (command_handle, nullptr );
1198
+
1199
+ ASSERT_SUCCESS (urCommandBufferFinalizeExp (updatable_cmd_buf_handle));
1200
+ }
1201
+
1202
+ void TearDown () override {
1203
+ if (command_handle) {
1204
+ EXPECT_SUCCESS (urCommandBufferReleaseCommandExp (command_handle));
1205
+ }
1206
+
1207
+ UUR_RETURN_ON_FATAL_FAILURE (
1208
+ LocalMemoryUpdateTestBaseOutOfOrder::TearDown ());
1209
+ }
1210
+
1211
+ ur_exp_command_buffer_command_handle_t command_handle = nullptr ;
1212
+ };
1213
+
1214
+ UUR_INSTANTIATE_DEVICE_TEST_SUITE_P (LocalMemoryUpdateTestOutOfOrder);
1215
+
1216
+ // Test updating A,X,Y parameters to new values and local memory to larger
1217
+ // values when the kernel arguments were added out of order.
1218
+ TEST_P (LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1219
+ // Run command-buffer prior to update and verify output
1220
+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1221
+ nullptr , nullptr ));
1222
+ ASSERT_SUCCESS (urQueueFinish (queue));
1223
+
1224
+ uint32_t *output = (uint32_t *)shared_ptrs[0 ];
1225
+ uint32_t *X = (uint32_t *)shared_ptrs[1 ];
1226
+ uint32_t *Y = (uint32_t *)shared_ptrs[2 ];
1227
+ Validate (output, X, Y, A, global_size, local_size);
1228
+
1229
+ // Update inputs
1230
+ std::array<ur_exp_command_buffer_update_pointer_arg_desc_t , 2 >
1231
+ new_input_descs;
1232
+ std::array<ur_exp_command_buffer_update_value_arg_desc_t , 3 >
1233
+ new_value_descs;
1234
+
1235
+ size_t new_local_size = local_size * 4 ;
1236
+ size_t new_local_mem_a_size = new_local_size * sizeof (uint32_t );
1237
+
1238
+ // New local_mem_a at index 0
1239
+ new_value_descs[0 ] = {
1240
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1241
+ nullptr , // pNext
1242
+ 0 , // argIndex
1243
+ new_local_mem_a_size, // argSize
1244
+ nullptr , // pProperties
1245
+ nullptr , // hArgValue
1246
+ };
1247
+
1248
+ // New local_mem_b at index 1
1249
+ new_value_descs[1 ] = {
1250
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1251
+ nullptr , // pNext
1252
+ 1 + hip_arg_offset, // argIndex
1253
+ local_mem_b_size, // argSize
1254
+ nullptr , // pProperties
1255
+ nullptr , // hArgValue
1256
+ };
1257
+
1258
+ // New A at index 3
1259
+ uint32_t new_A = 33 ;
1260
+ new_value_descs[2 ] = {
1261
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262
+ nullptr , // pNext
1263
+ 3 + (2 * hip_arg_offset), // argIndex
1264
+ sizeof (new_A), // argSize
1265
+ nullptr , // pProperties
1266
+ &new_A, // hArgValue
1267
+ };
1268
+
1269
+ // New X at index 4
1270
+ new_input_descs[0 ] = {
1271
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1272
+ nullptr , // pNext
1273
+ 4 + (2 * hip_arg_offset), // argIndex
1274
+ nullptr , // pProperties
1275
+ &shared_ptrs[3 ], // pArgValue
1276
+ };
1277
+
1278
+ // New Y at index 5
1279
+ new_input_descs[1 ] = {
1280
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1281
+ nullptr , // pNext
1282
+ 5 + (2 * hip_arg_offset), // argIndex
1283
+ nullptr , // pProperties
1284
+ &shared_ptrs[4 ], // pArgValue
1285
+ };
1286
+
1287
+ // Update kernel inputs
1288
+ ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1289
+ UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1290
+ nullptr , // pNext
1291
+ kernel, // hNewKernel
1292
+ 0 , // numNewMemObjArgs
1293
+ new_input_descs.size (), // numNewPointerArgs
1294
+ new_value_descs.size (), // numNewValueArgs
1295
+ n_dimensions, // newWorkDim
1296
+ nullptr , // pNewMemObjArgList
1297
+ new_input_descs.data (), // pNewPointerArgList
1298
+ new_value_descs.data (), // pNewValueArgList
1299
+ nullptr , // pNewGlobalWorkOffset
1300
+ nullptr , // pNewGlobalWorkSize
1301
+ nullptr , // pNewLocalWorkSize
1302
+ };
1303
+
1304
+ // Update kernel and enqueue command-buffer again
1305
+ ASSERT_SUCCESS (
1306
+ urCommandBufferUpdateKernelLaunchExp (command_handle, &update_desc));
1307
+ ASSERT_SUCCESS (urCommandBufferEnqueueExp (updatable_cmd_buf_handle, queue, 0 ,
1308
+ nullptr , nullptr ));
1309
+ ASSERT_SUCCESS (urQueueFinish (queue));
1310
+
1311
+ // Verify that update occurred correctly
1312
+ uint32_t *new_output = (uint32_t *)shared_ptrs[0 ];
1313
+ uint32_t *new_X = (uint32_t *)shared_ptrs[3 ];
1314
+ uint32_t *new_Y = (uint32_t *)shared_ptrs[4 ];
1315
+ Validate (new_output, new_X, new_Y, new_A, global_size, local_size);
1316
+ }
0 commit comments