Skip to content

Commit d3fb566

Browse files
committed
[CUDA][HIP] Fix kernel arguments being overriden when added out of order
In the Cuda and Hip adapter, when kernel arguments are added out of order (e.g. argument at index 1 is added before argument at index 0), the existing arguments are currently being overwritten. This happens because some of the argument sizes might not be known when adding them out of order and the code relies on those sizes to choose where to store the argument. This commit avoids this issue by storing the arguments in the same order that they are added and accessing them using pointer offsets.
1 parent 3087543 commit d3fb566

File tree

4 files changed

+377
-38
lines changed

4 files changed

+377
-38
lines changed

source/adapters/cuda/kernel.hpp

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,8 @@ struct ur_kernel_handle_t_ {
6868
args_size_t ParamSizes;
6969
/// Byte offset into /p Storage allocation for each parameter.
7070
args_index_t Indices;
71+
/// Position in the Storage array where the next argument should added.
72+
size_t InsertPos = 0;
7173
/// Aligned size in bytes for each local memory parameter after padding has
7274
/// been added. Zero if the argument at the index isn't a local memory
7375
/// argument.
@@ -101,6 +103,7 @@ struct ur_kernel_handle_t_ {
101103
/// Implicit offset argument is kept at the back of the indices collection.
102104
void addArg(size_t Index, size_t Size, const void *Arg,
103105
size_t LocalSize = 0) {
106+
// Expand storage to accommodate this Index if needed.
104107
if (Index + 2 > Indices.size()) {
105108
// Move implicit offset argument index with the end
106109
Indices.resize(Index + 2, Indices.back());
@@ -109,14 +112,21 @@ struct ur_kernel_handle_t_ {
109112
AlignedLocalMemSize.resize(Index + 1);
110113
OriginalLocalMemSize.resize(Index + 1);
111114
}
112-
ParamSizes[Index] = Size;
113-
// calculate the insertion point on the array
114-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
115-
std::begin(ParamSizes) + Index, 0);
116-
// Update the stored value for the argument
117-
std::memcpy(&Storage[InsertPos], Arg, Size);
118-
Indices[Index] = &Storage[InsertPos];
119-
AlignedLocalMemSize[Index] = LocalSize;
115+
116+
// Copy new argument to storage if it hasn't been added before.
117+
if (ParamSizes[Index] == 0) {
118+
ParamSizes[Index] = Size;
119+
std::memcpy(&Storage[InsertPos], Arg, Size);
120+
Indices[Index] = &Storage[InsertPos];
121+
AlignedLocalMemSize[Index] = LocalSize;
122+
InsertPos += Size;
123+
}
124+
// Otherwise, update the existing argument.
125+
else {
126+
std::memcpy(Indices[Index], Arg, Size);
127+
AlignedLocalMemSize[Index] = LocalSize;
128+
assert(Size == ParamSizes[Index]);
129+
}
120130
}
121131

122132
/// Returns the padded size and offset of a local memory argument.
@@ -177,10 +187,7 @@ struct ur_kernel_handle_t_ {
177187
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
178188

179189
// Store new offset into local data
180-
const size_t InsertPos =
181-
std::accumulate(std::begin(ParamSizes),
182-
std::begin(ParamSizes) + SuccIndex, size_t{0});
183-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
190+
std::memcpy(Indices[SuccIndex], &SuccAlignedLocalOffset,
184191
sizeof(size_t));
185192
}
186193
}

source/adapters/hip/kernel.hpp

Lines changed: 39 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ struct ur_kernel_handle_t_ {
6363
args_size_t ParamSizes;
6464
/// Byte offset into /p Storage allocation for each parameter.
6565
args_index_t Indices;
66+
/// Position in the Storage array where the next argument should added.
67+
size_t InsertPos = 0;
6668
/// Aligned size in bytes for each local memory parameter after padding has
6769
/// been added. Zero if the argument at the index isn't a local memory
6870
/// argument.
@@ -95,22 +97,30 @@ struct ur_kernel_handle_t_ {
9597
/// Implicit offset argument is kept at the back of the indices collection.
9698
void addArg(size_t Index, size_t Size, const void *Arg,
9799
size_t LocalSize = 0) {
100+
// Expand storage to accommodate this Index if needed.
98101
if (Index + 2 > Indices.size()) {
99-
// Move implicit offset argument Index with the end
102+
// Move implicit offset argument index with the end
100103
Indices.resize(Index + 2, Indices.back());
101104
// Ensure enough space for the new argument
102105
ParamSizes.resize(Index + 1);
103106
AlignedLocalMemSize.resize(Index + 1);
104107
OriginalLocalMemSize.resize(Index + 1);
105108
}
106-
ParamSizes[Index] = Size;
107-
// calculate the insertion point on the array
108-
size_t InsertPos = std::accumulate(std::begin(ParamSizes),
109-
std::begin(ParamSizes) + Index, 0);
110-
// Update the stored value for the argument
111-
std::memcpy(&Storage[InsertPos], Arg, Size);
112-
Indices[Index] = &Storage[InsertPos];
113-
AlignedLocalMemSize[Index] = LocalSize;
109+
110+
// Copy new argument to storage if it hasn't been added before.
111+
if (ParamSizes[Index] == 0) {
112+
ParamSizes[Index] = Size;
113+
std::memcpy(&Storage[InsertPos], Arg, Size);
114+
Indices[Index] = &Storage[InsertPos];
115+
AlignedLocalMemSize[Index] = LocalSize;
116+
InsertPos += Size;
117+
}
118+
// Otherwise, update the existing argument.
119+
else {
120+
std::memcpy(Indices[Index], Arg, Size);
121+
AlignedLocalMemSize[Index] = LocalSize;
122+
assert(Size == ParamSizes[Index]);
123+
}
114124
}
115125

116126
/// Returns the padded size and offset of a local memory argument.
@@ -151,20 +161,11 @@ struct ur_kernel_handle_t_ {
151161
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
152162
}
153163

154-
void addLocalArg(size_t Index, size_t Size) {
155-
// Get the aligned argument size and offset into local data
156-
auto [AlignedLocalSize, AlignedLocalOffset] =
157-
calcAlignedLocalArgument(Index, Size);
158-
159-
// Store argument details
160-
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
161-
AlignedLocalSize);
162-
163-
// For every existing local argument which follows at later argument
164-
// indices, update the offset and pointer into the kernel local memory.
165-
// Required as padding will need to be recalculated.
164+
// Iterate over all existing local argument which follows StartIndex
165+
// index, update the offset and pointer into the kernel local memory.
166+
void updateLocalArgOffset(size_t StartIndex) {
166167
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
167-
for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) {
168+
for (auto SuccIndex = StartIndex; SuccIndex < NumArgs; SuccIndex++) {
168169
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
169170
if (OriginalLocalSize == 0) {
170171
// Skip if successor argument isn't a local memory arg
@@ -179,14 +180,26 @@ struct ur_kernel_handle_t_ {
179180
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
180181

181182
// Store new offset into local data
182-
const size_t InsertPos =
183-
std::accumulate(std::begin(ParamSizes),
184-
std::begin(ParamSizes) + SuccIndex, size_t{0});
185-
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
183+
std::memcpy(Indices[SuccIndex], &SuccAlignedLocalOffset,
186184
sizeof(size_t));
187185
}
188186
}
189187

188+
void addLocalArg(size_t Index, size_t Size) {
189+
// Get the aligned argument size and offset into local data
190+
auto [AlignedLocalSize, AlignedLocalOffset] =
191+
calcAlignedLocalArgument(Index, Size);
192+
193+
// Store argument details
194+
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
195+
AlignedLocalSize);
196+
197+
// For every existing local argument which follows at later argument
198+
// indices, update the offset and pointer into the kernel local memory.
199+
// Required as padding will need to be recalculated.
200+
updateLocalArgOffset(Index + 1);
201+
}
202+
190203
void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
191204
assert(hMem && "Invalid mem handle");
192205
// To avoid redundancy we are not storing mem obj with index i at index

test/conformance/exp_command_buffer/update/local_memory_update.cpp

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1094,3 +1094,223 @@ TEST_P(LocalMemoryMultiUpdateTest, UpdateWithoutBlocking) {
10941094
uint32_t *new_Y = (uint32_t *)shared_ptrs[4];
10951095
Validate(new_output, new_X, new_Y, new_A, global_size, local_size);
10961096
}
1097+
1098+
struct LocalMemoryUpdateTestBaseOutOfOrder : LocalMemoryUpdateTestBase {
1099+
virtual void SetUp() override {
1100+
program_name = "saxpy_usm_local_mem";
1101+
UUR_RETURN_ON_FATAL_FAILURE(
1102+
urUpdatableCommandBufferExpExecutionTest::SetUp());
1103+
1104+
if (backend == UR_PLATFORM_BACKEND_LEVEL_ZERO) {
1105+
GTEST_SKIP()
1106+
<< "Local memory argument update not supported on Level Zero.";
1107+
}
1108+
1109+
// HIP has extra args for local memory so we define an offset for arg
1110+
// indices here for updating
1111+
hip_arg_offset = backend == UR_PLATFORM_BACKEND_HIP ? 3 : 0;
1112+
ur_device_usm_access_capability_flags_t shared_usm_flags;
1113+
ASSERT_SUCCESS(
1114+
uur::GetDeviceUSMSingleSharedSupport(device, shared_usm_flags));
1115+
if (!(shared_usm_flags & UR_DEVICE_USM_ACCESS_CAPABILITY_FLAG_ACCESS)) {
1116+
GTEST_SKIP() << "Shared USM is not supported.";
1117+
}
1118+
1119+
const size_t allocation_size =
1120+
sizeof(uint32_t) * global_size * local_size;
1121+
for (auto &shared_ptr : shared_ptrs) {
1122+
ASSERT_SUCCESS(urUSMSharedAlloc(context, device, nullptr, nullptr,
1123+
allocation_size, &shared_ptr));
1124+
ASSERT_NE(shared_ptr, nullptr);
1125+
1126+
std::vector<uint8_t> pattern(allocation_size);
1127+
uur::generateMemFillPattern(pattern);
1128+
std::memcpy(shared_ptr, pattern.data(), allocation_size);
1129+
}
1130+
1131+
std::array<size_t, 12> index_order{};
1132+
if (backend != UR_PLATFORM_BACKEND_HIP) {
1133+
index_order = {3, 2, 4, 5, 1, 0};
1134+
} else {
1135+
index_order = {9, 8, 10, 11, 4, 5, 6, 7, 0, 1, 2, 3};
1136+
}
1137+
size_t current_index = 0;
1138+
1139+
// Index 3 is A
1140+
ASSERT_SUCCESS(urKernelSetArgValue(kernel, index_order[current_index++],
1141+
sizeof(A), nullptr, &A));
1142+
// Index 2 is output
1143+
ASSERT_SUCCESS(urKernelSetArgPointer(
1144+
kernel, index_order[current_index++], nullptr, shared_ptrs[0]));
1145+
1146+
// Index 4 is X
1147+
ASSERT_SUCCESS(urKernelSetArgPointer(
1148+
kernel, index_order[current_index++], nullptr, shared_ptrs[1]));
1149+
// Index 5 is Y
1150+
ASSERT_SUCCESS(urKernelSetArgPointer(
1151+
kernel, index_order[current_index++], nullptr, shared_ptrs[2]));
1152+
1153+
// Index 1 is local_mem_b arg
1154+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, index_order[current_index++],
1155+
local_mem_b_size, nullptr));
1156+
if (backend == UR_PLATFORM_BACKEND_HIP) {
1157+
ASSERT_SUCCESS(urKernelSetArgValue(
1158+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1159+
nullptr, &hip_local_offset));
1160+
ASSERT_SUCCESS(urKernelSetArgValue(
1161+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1162+
nullptr, &hip_local_offset));
1163+
ASSERT_SUCCESS(urKernelSetArgValue(
1164+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1165+
nullptr, &hip_local_offset));
1166+
}
1167+
1168+
// Index 0 is local_mem_a arg
1169+
ASSERT_SUCCESS(urKernelSetArgLocal(kernel, index_order[current_index++],
1170+
local_mem_a_size, nullptr));
1171+
1172+
// Hip has extra args for local mem at index 1-3
1173+
if (backend == UR_PLATFORM_BACKEND_HIP) {
1174+
ASSERT_SUCCESS(urKernelSetArgValue(
1175+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1176+
nullptr, &hip_local_offset));
1177+
ASSERT_SUCCESS(urKernelSetArgValue(
1178+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1179+
nullptr, &hip_local_offset));
1180+
ASSERT_SUCCESS(urKernelSetArgValue(
1181+
kernel, index_order[current_index++], sizeof(hip_local_offset),
1182+
nullptr, &hip_local_offset));
1183+
}
1184+
}
1185+
};
1186+
1187+
struct LocalMemoryUpdateTestOutOfOrder : LocalMemoryUpdateTestBaseOutOfOrder {
1188+
void SetUp() override {
1189+
UUR_RETURN_ON_FATAL_FAILURE(
1190+
LocalMemoryUpdateTestBaseOutOfOrder::SetUp());
1191+
1192+
// Append kernel command to command-buffer and close command-buffer
1193+
ASSERT_SUCCESS(urCommandBufferAppendKernelLaunchExp(
1194+
updatable_cmd_buf_handle, kernel, n_dimensions, &global_offset,
1195+
&global_size, &local_size, 0, nullptr, 0, nullptr, 0, nullptr,
1196+
nullptr, nullptr, &command_handle));
1197+
ASSERT_NE(command_handle, nullptr);
1198+
1199+
ASSERT_SUCCESS(urCommandBufferFinalizeExp(updatable_cmd_buf_handle));
1200+
}
1201+
1202+
void TearDown() override {
1203+
if (command_handle) {
1204+
EXPECT_SUCCESS(urCommandBufferReleaseCommandExp(command_handle));
1205+
}
1206+
1207+
UUR_RETURN_ON_FATAL_FAILURE(
1208+
LocalMemoryUpdateTestBaseOutOfOrder::TearDown());
1209+
}
1210+
1211+
ur_exp_command_buffer_command_handle_t command_handle = nullptr;
1212+
};
1213+
1214+
UUR_INSTANTIATE_DEVICE_TEST_SUITE_P(LocalMemoryUpdateTestOutOfOrder);
1215+
1216+
// Test updating A,X,Y parameters to new values and local memory to larger
1217+
// values when the kernel arguments were added out of order.
1218+
TEST_P(LocalMemoryUpdateTestOutOfOrder, UpdateAllParameters) {
1219+
// Run command-buffer prior to update and verify output
1220+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
1221+
nullptr, nullptr));
1222+
ASSERT_SUCCESS(urQueueFinish(queue));
1223+
1224+
uint32_t *output = (uint32_t *)shared_ptrs[0];
1225+
uint32_t *X = (uint32_t *)shared_ptrs[1];
1226+
uint32_t *Y = (uint32_t *)shared_ptrs[2];
1227+
Validate(output, X, Y, A, global_size, local_size);
1228+
1229+
// Update inputs
1230+
std::array<ur_exp_command_buffer_update_pointer_arg_desc_t, 2>
1231+
new_input_descs;
1232+
std::array<ur_exp_command_buffer_update_value_arg_desc_t, 3>
1233+
new_value_descs;
1234+
1235+
size_t new_local_size = local_size * 4;
1236+
size_t new_local_mem_a_size = new_local_size * sizeof(uint32_t);
1237+
1238+
// New local_mem_a at index 0
1239+
new_value_descs[0] = {
1240+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1241+
nullptr, // pNext
1242+
0, // argIndex
1243+
new_local_mem_a_size, // argSize
1244+
nullptr, // pProperties
1245+
nullptr, // hArgValue
1246+
};
1247+
1248+
// New local_mem_b at index 1
1249+
new_value_descs[1] = {
1250+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1251+
nullptr, // pNext
1252+
1 + hip_arg_offset, // argIndex
1253+
local_mem_b_size, // argSize
1254+
nullptr, // pProperties
1255+
nullptr, // hArgValue
1256+
};
1257+
1258+
// New A at index 3
1259+
uint32_t new_A = 33;
1260+
new_value_descs[2] = {
1261+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_VALUE_ARG_DESC, // stype
1262+
nullptr, // pNext
1263+
3 + (2 * hip_arg_offset), // argIndex
1264+
sizeof(new_A), // argSize
1265+
nullptr, // pProperties
1266+
&new_A, // hArgValue
1267+
};
1268+
1269+
// New X at index 4
1270+
new_input_descs[0] = {
1271+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1272+
nullptr, // pNext
1273+
4 + (2 * hip_arg_offset), // argIndex
1274+
nullptr, // pProperties
1275+
&shared_ptrs[3], // pArgValue
1276+
};
1277+
1278+
// New Y at index 5
1279+
new_input_descs[1] = {
1280+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_POINTER_ARG_DESC, // stype
1281+
nullptr, // pNext
1282+
5 + (2 * hip_arg_offset), // argIndex
1283+
nullptr, // pProperties
1284+
&shared_ptrs[4], // pArgValue
1285+
};
1286+
1287+
// Update kernel inputs
1288+
ur_exp_command_buffer_update_kernel_launch_desc_t update_desc = {
1289+
UR_STRUCTURE_TYPE_EXP_COMMAND_BUFFER_UPDATE_KERNEL_LAUNCH_DESC, // stype
1290+
nullptr, // pNext
1291+
kernel, // hNewKernel
1292+
0, // numNewMemObjArgs
1293+
new_input_descs.size(), // numNewPointerArgs
1294+
new_value_descs.size(), // numNewValueArgs
1295+
n_dimensions, // newWorkDim
1296+
nullptr, // pNewMemObjArgList
1297+
new_input_descs.data(), // pNewPointerArgList
1298+
new_value_descs.data(), // pNewValueArgList
1299+
nullptr, // pNewGlobalWorkOffset
1300+
nullptr, // pNewGlobalWorkSize
1301+
nullptr, // pNewLocalWorkSize
1302+
};
1303+
1304+
// Update kernel and enqueue command-buffer again
1305+
ASSERT_SUCCESS(
1306+
urCommandBufferUpdateKernelLaunchExp(command_handle, &update_desc));
1307+
ASSERT_SUCCESS(urCommandBufferEnqueueExp(updatable_cmd_buf_handle, queue, 0,
1308+
nullptr, nullptr));
1309+
ASSERT_SUCCESS(urQueueFinish(queue));
1310+
1311+
// Verify that update occurred correctly
1312+
uint32_t *new_output = (uint32_t *)shared_ptrs[0];
1313+
uint32_t *new_X = (uint32_t *)shared_ptrs[3];
1314+
uint32_t *new_Y = (uint32_t *)shared_ptrs[4];
1315+
Validate(new_output, new_X, new_Y, new_A, global_size, local_size);
1316+
}

0 commit comments

Comments
 (0)