Skip to content

Commit b881c0b

Browse files
committed
Improve solution
Iterate on previous solution so that the local argument offsets at following inidices are updated when an earlier local argument is updated
1 parent dc65a88 commit b881c0b

File tree

9 files changed

+411
-103
lines changed

9 files changed

+411
-103
lines changed

source/adapters/cuda/command_buffer.cpp

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -522,9 +522,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
522522
DepsList.data(), DepsList.size(),
523523
&NodeParams));
524524

525-
if (LocalSize != 0)
526-
hKernel->clearLocalSize();
527-
528525
// Add signal node if external return event is used.
529526
CUgraphNode SignalNode = nullptr;
530527
if (phEvent) {
@@ -1396,22 +1393,14 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferUpdateKernelLaunchExp(
13961393

13971394
CUDA_KERNEL_NODE_PARAMS &Params = KernelCommandHandle->Params;
13981395

1399-
const auto LocalSize = KernelCommandHandle->Kernel->getLocalSize();
1400-
if (LocalSize != 0) {
1401-
// Clean the local size, otherwise calling updateKernelArguments() in
1402-
// future updates with local arguments will incorrectly increase the
1403-
// size further.
1404-
KernelCommandHandle->Kernel->clearLocalSize();
1405-
}
1406-
14071396
Params.func = CuFunc;
1408-
Params.gridDimX = static_cast<unsigned int>(BlocksPerGrid[0]);
1409-
Params.gridDimY = static_cast<unsigned int>(BlocksPerGrid[1]);
1410-
Params.gridDimZ = static_cast<unsigned int>(BlocksPerGrid[2]);
1411-
Params.blockDimX = static_cast<unsigned int>(ThreadsPerBlock[0]);
1412-
Params.blockDimY = static_cast<unsigned int>(ThreadsPerBlock[1]);
1413-
Params.blockDimZ = static_cast<unsigned int>(ThreadsPerBlock[2]);
1414-
Params.sharedMemBytes = LocalSize;
1397+
Params.gridDimX = BlocksPerGrid[0];
1398+
Params.gridDimY = BlocksPerGrid[1];
1399+
Params.gridDimZ = BlocksPerGrid[2];
1400+
Params.blockDimX = ThreadsPerBlock[0];
1401+
Params.blockDimY = ThreadsPerBlock[1];
1402+
Params.blockDimZ = ThreadsPerBlock[2];
1403+
Params.sharedMemBytes = KernelCommandHandle->Kernel->getLocalSize();
14151404
Params.kernelParams =
14161405
const_cast<void **>(KernelCommandHandle->Kernel->getArgIndices().data());
14171406

source/adapters/cuda/enqueue.cpp

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -493,9 +493,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
493493
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2], LocalSize,
494494
CuStream, const_cast<void **>(ArgIndices.data()), nullptr));
495495

496-
if (LocalSize != 0)
497-
hKernel->clearLocalSize();
498-
499496
if (phEvent) {
500497
UR_CHECK_ERROR(RetImplEvent->record());
501498
*phEvent = RetImplEvent.release();
@@ -673,9 +670,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunchCustomExp(
673670
const_cast<void **>(ArgIndices.data()),
674671
nullptr));
675672

676-
if (LocalSize != 0)
677-
hKernel->clearLocalSize();
678-
679673
if (phEvent) {
680674
UR_CHECK_ERROR(RetImplEvent->record());
681675
*phEvent = RetImplEvent.release();

source/adapters/cuda/kernel.hpp

Lines changed: 80 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,22 @@ struct ur_kernel_handle_t_ {
6161
using args_t = std::array<char, MaxParamBytes>;
6262
using args_size_t = std::vector<size_t>;
6363
using args_index_t = std::vector<void *>;
64+
/// Storage shared by all args which is mem copied into when adding a new
65+
/// argument.
6466
args_t Storage;
67+
/// Aligned size of each parameter, including padding.
6568
args_size_t ParamSizes;
69+
/// Byte offset into /p Storage allocation for each parameter.
6670
args_index_t Indices;
67-
args_size_t OffsetPerIndex;
71+
/// Aligned size in bytes for each local memory parameter after padding has
72+
/// been added. Zero if the argument at the index isn't a local memory
73+
/// argument.
74+
args_size_t AlignedLocalMemSize;
75+
/// Original size in bytes for each local memory parameter, prior to being
76+
/// padded to appropriate alignment. Zero if the argument at the index
77+
/// isn't a local memory argument.
78+
args_size_t OriginalLocalMemSize;
79+
6880
// A struct to keep track of memargs so that we can do dependency analysis
6981
// at urEnqueueKernelLaunch
7082
struct mem_obj_arg {
@@ -93,7 +105,8 @@ struct ur_kernel_handle_t_ {
93105
Indices.resize(Index + 2, Indices.back());
94106
// Ensure enough space for the new argument
95107
ParamSizes.resize(Index + 1);
96-
OffsetPerIndex.resize(Index + 1);
108+
AlignedLocalMemSize.resize(Index + 1);
109+
OriginalLocalMemSize.resize(Index + 1);
97110
}
98111
ParamSizes[Index] = Size;
99112
// calculate the insertion point on the array
@@ -102,28 +115,83 @@ struct ur_kernel_handle_t_ {
102115
// Update the stored value for the argument
103116
std::memcpy(&Storage[InsertPos], Arg, Size);
104117
Indices[Index] = &Storage[InsertPos];
105-
OffsetPerIndex[Index] = LocalSize;
118+
AlignedLocalMemSize[Index] = LocalSize;
106119
}
107120

108-
void addLocalArg(size_t Index, size_t Size) {
109-
size_t LocalOffset = this->getLocalSize();
121+
/// Returns the padded size and offset of a local memory argument.
122+
/// Local memory arguments need to be padded if the alignment for the size
123+
/// doesn't match the current offset into the kernel local data.
124+
/// @param Index Kernel arg index.
125+
/// @param Size User passed size of local parameter.
126+
/// @return Tuple of (Aligned size, Aligned offset into local data).
127+
std::pair<size_t, size_t> calcAlignedLocalArgument(size_t Index,
128+
size_t Size) {
129+
// Store the unpadded size of the local argument
130+
if (Index + 2 > Indices.size()) {
131+
AlignedLocalMemSize.resize(Index + 1);
132+
OriginalLocalMemSize.resize(Index + 1);
133+
}
134+
OriginalLocalMemSize[Index] = Size;
135+
136+
// Calculate the current starting offset into local data
137+
const size_t LocalOffset = std::accumulate(
138+
std::begin(AlignedLocalMemSize),
139+
std::next(std::begin(AlignedLocalMemSize), Index), size_t{0});
110140

111-
// maximum required alignment is the size of the largest vector type
141+
// Maximum required alignment is the size of the largest vector type
112142
const size_t MaxAlignment = sizeof(double) * 16;
113143

114-
// for arguments smaller than the maximum alignment simply align to the
144+
// For arguments smaller than the maximum alignment simply align to the
115145
// size of the argument
116146
const size_t Alignment = std::min(MaxAlignment, Size);
117147

118-
// align the argument
148+
// Align the argument
119149
size_t AlignedLocalOffset = LocalOffset;
120-
size_t Pad = LocalOffset % Alignment;
150+
const size_t Pad = LocalOffset % Alignment;
121151
if (Pad != 0) {
122152
AlignedLocalOffset += Alignment - Pad;
123153
}
124154

155+
const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
156+
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
157+
}
158+
159+
void addLocalArg(size_t Index, size_t Size) {
160+
// Get the aligned argument size and offset into local data
161+
size_t AlignedLocalSize, AlignedLocalOffset;
162+
std::tie(AlignedLocalSize, AlignedLocalOffset) =
163+
calcAlignedLocalArgument(Index, Size);
164+
165+
// Store argument details
125166
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
126-
Size + (AlignedLocalOffset - LocalOffset));
167+
AlignedLocalSize);
168+
169+
// For every existing local argument which follows at later argument
170+
// indices, updated the offset and pointer into the kernel local memory.
171+
// Required as padding will need to be recalculated.
172+
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
173+
for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) {
174+
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
175+
if (OriginalLocalSize == 0) {
176+
// Skip if successor argument isn't a local memory arg
177+
continue;
178+
}
179+
180+
// Recalculate alignment
181+
size_t SuccAlignedLocalSize, SuccAlignedLocalOffset;
182+
std::tie(SuccAlignedLocalSize, SuccAlignedLocalOffset) =
183+
calcAlignedLocalArgument(SuccIndex, OriginalLocalSize);
184+
185+
// Store new local memory size
186+
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
187+
188+
// Store new offset into local data
189+
const size_t InsertPos =
190+
std::accumulate(std::begin(ParamSizes),
191+
std::begin(ParamSizes) + SuccIndex, size_t{0});
192+
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
193+
sizeof(size_t));
194+
}
127195
}
128196

129197
void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -145,15 +213,11 @@ struct ur_kernel_handle_t_ {
145213
std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size);
146214
}
147215

148-
void clearLocalSize() {
149-
std::fill(std::begin(OffsetPerIndex), std::end(OffsetPerIndex), 0);
150-
}
151-
152216
const args_index_t &getIndices() const noexcept { return Indices; }
153217

154218
uint32_t getLocalSize() const {
155-
return std::accumulate(std::begin(OffsetPerIndex),
156-
std::end(OffsetPerIndex), 0);
219+
return std::accumulate(std::begin(AlignedLocalMemSize),
220+
std::end(AlignedLocalMemSize), 0);
157221
}
158222
} Args;
159223

@@ -240,7 +304,5 @@ struct ur_kernel_handle_t_ {
240304

241305
uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); }
242306

243-
void clearLocalSize() { Args.clearLocalSize(); }
244-
245307
size_t getRegsPerThread() const noexcept { return RegsPerThread; };
246308
};

source/adapters/hip/command_buffer.cpp

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -396,9 +396,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urCommandBufferAppendKernelLaunchExp(
396396
DepsList.data(), DepsList.size(),
397397
&NodeParams));
398398

399-
if (LocalSize != 0)
400-
hKernel->clearLocalSize();
401-
402399
// Get sync point and register the node with it.
403400
auto SyncPoint = hCommandBuffer->addSyncPoint(GraphNode);
404401
if (pSyncPoint) {

source/adapters/hip/enqueue.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,6 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
324324
ThreadsPerBlock[0], ThreadsPerBlock[1], ThreadsPerBlock[2],
325325
hKernel->getLocalSize(), HIPStream, ArgIndices.data(), nullptr));
326326

327-
hKernel->clearLocalSize();
328-
329327
if (phEvent) {
330328
UR_CHECK_ERROR(RetImplEvent->record());
331329
*phEvent = RetImplEvent.release();

source/adapters/hip/kernel.hpp

Lines changed: 81 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,22 @@ struct ur_kernel_handle_t_ {
5656
using args_t = std::array<char, MAX_PARAM_BYTES>;
5757
using args_size_t = std::vector<size_t>;
5858
using args_index_t = std::vector<void *>;
59+
/// Storage shared by all args which is mem copied into when adding a new
60+
/// argument.
5961
args_t Storage;
62+
/// Aligned size of each parameter, including padding.
6063
args_size_t ParamSizes;
64+
/// Byte offset into /p Storage allocation for each parameter.
6165
args_index_t Indices;
62-
args_size_t OffsetPerIndex;
66+
/// Aligned size in bytes for each local memory parameter after padding has
67+
/// been added. Zero if the argument at the index isn't a local memory
68+
/// argument.
69+
args_size_t AlignedLocalMemSize;
70+
/// Original size in bytes for each local memory parameter, prior to being
71+
/// padded to appropriate alignment. Zero if the argument at the index
72+
/// isn't a local memory argument.
73+
args_size_t OriginalLocalMemSize;
74+
6375
// A struct to keep track of memargs so that we can do dependency analysis
6476
// at urEnqueueKernelLaunch
6577
struct mem_obj_arg {
@@ -88,7 +100,8 @@ struct ur_kernel_handle_t_ {
88100
Indices.resize(Index + 2, Indices.back());
89101
// Ensure enough space for the new argument
90102
ParamSizes.resize(Index + 1);
91-
OffsetPerIndex.resize(Index + 1);
103+
AlignedLocalMemSize.resize(Index + 1);
104+
OriginalLocalMemSize.resize(Index + 1);
92105
}
93106
ParamSizes[Index] = Size;
94107
// calculate the insertion point on the array
@@ -97,28 +110,83 @@ struct ur_kernel_handle_t_ {
97110
// Update the stored value for the argument
98111
std::memcpy(&Storage[InsertPos], Arg, Size);
99112
Indices[Index] = &Storage[InsertPos];
100-
OffsetPerIndex[Index] = LocalSize;
113+
AlignedLocalMemSize[Index] = LocalSize;
101114
}
102115

103-
void addLocalArg(size_t Index, size_t Size) {
104-
size_t LocalOffset = this->getLocalSize();
116+
/// Returns the padded size and offset of a local memory argument.
117+
/// Local memory arguments need to be padded if the alignment for the size
118+
/// doesn't match the current offset into the kernel local data.
119+
/// @param Index Kernel arg index.
120+
/// @param Size User passed size of local parameter.
121+
/// @return Tuple of (Aligned size, Aligned offset into local data).
122+
std::pair<size_t, size_t> calcAlignedLocalArgument(size_t Index,
123+
size_t Size) {
124+
// Store the unpadded size of the local argument
125+
if (Index + 2 > Indices.size()) {
126+
AlignedLocalMemSize.resize(Index + 1);
127+
OriginalLocalMemSize.resize(Index + 1);
128+
}
129+
OriginalLocalMemSize[Index] = Size;
105130

106-
// maximum required alignment is the size of the largest vector type
131+
// Calculate the current starting offset into local data
132+
const size_t LocalOffset = std::accumulate(
133+
std::begin(AlignedLocalMemSize),
134+
std::next(std::begin(AlignedLocalMemSize), Index), size_t{0});
135+
136+
// Maximum required alignment is the size of the largest vector type
107137
const size_t MaxAlignment = sizeof(double) * 16;
108138

109-
// for arguments smaller than the maximum alignment simply align to the
139+
// For arguments smaller than the maximum alignment simply align to the
110140
// size of the argument
111141
const size_t Alignment = std::min(MaxAlignment, Size);
112142

113-
// align the argument
143+
// Align the argument
114144
size_t AlignedLocalOffset = LocalOffset;
115-
size_t Pad = LocalOffset % Alignment;
145+
const size_t Pad = LocalOffset % Alignment;
116146
if (Pad != 0) {
117147
AlignedLocalOffset += Alignment - Pad;
118148
}
119149

120-
addArg(Index, sizeof(size_t), (const void *)&AlignedLocalOffset,
121-
Size + AlignedLocalOffset - LocalOffset);
150+
const size_t AlignedLocalSize = Size + (AlignedLocalOffset - LocalOffset);
151+
return std::make_pair(AlignedLocalSize, AlignedLocalOffset);
152+
}
153+
154+
void addLocalArg(size_t Index, size_t Size) {
155+
// Get the aligned argument size and offset into local data
156+
size_t AlignedLocalSize, AlignedLocalOffset;
157+
std::tie(AlignedLocalSize, AlignedLocalOffset) =
158+
calcAlignedLocalArgument(Index, Size);
159+
160+
// Store argument details
161+
addArg(Index, sizeof(size_t), (const void *)&(AlignedLocalOffset),
162+
AlignedLocalSize);
163+
164+
// For every existing local argument which follows at later argument
165+
// indices, updated the offset and pointer into the kernel local memory.
166+
// Required as padding will need to be recalculated.
167+
const size_t NumArgs = Indices.size() - 1; // Accounts for implicit arg
168+
for (auto SuccIndex = Index + 1; SuccIndex < NumArgs; SuccIndex++) {
169+
const size_t OriginalLocalSize = OriginalLocalMemSize[SuccIndex];
170+
if (OriginalLocalSize == 0) {
171+
// Skip if successor argument isn't a local memory arg
172+
continue;
173+
}
174+
175+
// Recalculate alignment
176+
size_t SuccAlignedLocalSize, SuccAlignedLocalOffset;
177+
std::tie(SuccAlignedLocalSize, SuccAlignedLocalOffset) =
178+
calcAlignedLocalArgument(SuccIndex, OriginalLocalSize);
179+
180+
// Store new local memory size
181+
AlignedLocalMemSize[SuccIndex] = SuccAlignedLocalSize;
182+
183+
// Store new offset into local data
184+
const size_t InsertPos =
185+
std::accumulate(std::begin(ParamSizes),
186+
std::begin(ParamSizes) + SuccIndex, size_t{0});
187+
std::memcpy(&Storage[InsertPos], &SuccAlignedLocalOffset,
188+
sizeof(size_t));
189+
}
122190
}
123191

124192
void addMemObjArg(int Index, ur_mem_handle_t hMem, ur_mem_flags_t Flags) {
@@ -140,15 +208,11 @@ struct ur_kernel_handle_t_ {
140208
std::memcpy(ImplicitOffsetArgs, ImplicitOffset, Size);
141209
}
142210

143-
void clearLocalSize() {
144-
std::fill(std::begin(OffsetPerIndex), std::end(OffsetPerIndex), 0);
145-
}
146-
147211
const args_index_t &getIndices() const noexcept { return Indices; }
148212

149213
uint32_t getLocalSize() const {
150-
return std::accumulate(std::begin(OffsetPerIndex),
151-
std::end(OffsetPerIndex), 0);
214+
return std::accumulate(std::begin(AlignedLocalMemSize),
215+
std::end(AlignedLocalMemSize), 0);
152216
}
153217
} Args;
154218

@@ -220,6 +284,4 @@ struct ur_kernel_handle_t_ {
220284
}
221285

222286
uint32_t getLocalSize() const noexcept { return Args.getLocalSize(); }
223-
224-
void clearLocalSize() { Args.clearLocalSize(); }
225287
};

0 commit comments

Comments
 (0)