Skip to content

Commit 6f9d5c5

Browse files
committed
Fix incorrect outputs and improve performance of commonMemSetLargePattern
Change the implementation of commonMemSetLargePattern to use the largest pattern word size supported by the backend into which the pattern can be divided. That is, use 4-byte words if the pattern size is a multiple of 4, 2-byte words for even sizes and 1-byte words for odd sizes. Keep the idea of filling the entire destination region with the first word, and only start strided fill from the second, but implement it correctly. The previous implementation produced incorrect results for any pattern size which wasn't a multiple of 4. For HIP, the strided fill remains to be always in 1-byte increments because HIP API doesn't provide strided multi-byte memset functions like CUDA does. For CUDA, both the initial memset and the strided ones use the largest possible word size. Add a new optimisation skipping the strided fills completely if the pattern is equal to the first word repeated throughout. This is most commonly the case for a pattern of all zeros, but other cases are possible. This optimisation is implemented in both CUDA and HIP adapters.
1 parent 3a5b23c commit 6f9d5c5

File tree

2 files changed

+123
-44
lines changed

2 files changed

+123
-44
lines changed

source/adapters/cuda/enqueue.cpp

Lines changed: 61 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -961,35 +961,71 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
961961

962962
// CUDA has no memset functions that allow setting values more than 4 bytes. UR
963963
// API lets you pass an arbitrary "pattern" to the buffer fill, which can be
964-
// more than 4 bytes. We must break up the pattern into 1 byte values, and set
965-
// the buffer using multiple strided calls. The first 4 patterns are set using
966-
// cuMemsetD32Async then all subsequent 1 byte patterns are set using
967-
// cuMemset2DAsync which is called for each pattern.
964+
// more than 4 bytes. We must break up the pattern into 1, 2 or 4-byte values
965+
// and set the buffer using multiple strided calls.
968966
ur_result_t commonMemSetLargePattern(CUstream Stream, uint32_t PatternSize,
969967
size_t Size, const void *pPattern,
970968
CUdeviceptr Ptr) {
971-
// Calculate the number of patterns, stride, number of times the pattern
972-
// needs to be applied, and the number of times the first 32 bit pattern
973-
// needs to be applied.
974-
auto NumberOfSteps = PatternSize / sizeof(uint8_t);
975-
auto Pitch = NumberOfSteps * sizeof(uint8_t);
976-
auto Height = Size / NumberOfSteps;
977-
auto Count32 = Size / sizeof(uint32_t);
978-
979-
// Get 4-byte chunk of the pattern and call cuMemsetD32Async
980-
auto Value = *(static_cast<const uint32_t *>(pPattern));
981-
UR_CHECK_ERROR(cuMemsetD32Async(Ptr, Value, Count32, Stream));
982-
for (auto step = 4u; step < NumberOfSteps; ++step) {
983-
// take 1 byte of the pattern
984-
Value = *(static_cast<const uint8_t *>(pPattern) + step);
985-
986-
// offset the pointer to the part of the buffer we want to write to
987-
auto OffsetPtr = Ptr + (step * sizeof(uint8_t));
988-
989-
// set all of the pattern chunks
990-
UR_CHECK_ERROR(cuMemsetD2D8Async(OffsetPtr, Pitch, Value, sizeof(uint8_t),
991-
Height, Stream));
969+
// Find the largest supported word size into which the pattern can be divided
970+
auto BackendWordSize = PatternSize % 4u == 0u ? 4u
971+
: PatternSize % 2u == 0u ? 2u
972+
: 1u;
973+
974+
// Calculate the number of words in the pattern, the stride, and the number of
975+
// times the pattern needs to be applied
976+
auto NumberOfSteps = PatternSize / BackendWordSize;
977+
auto Pitch = NumberOfSteps * BackendWordSize;
978+
auto Height = Size / PatternSize;
979+
980+
// Same implementation works for any pattern word type (uint8_t, uint16_t,
981+
// uint32_t)
982+
auto memsetImpl = [BackendWordSize, NumberOfSteps, Pitch, Height, Size, Ptr,
983+
&Stream](const auto *pPatternWords,
984+
auto &&continuousMemset, auto &&stridedMemset) {
985+
// If the pattern is 1 word or the first word is repeated throughout, a fast
986+
// continuous fill can be used without the need for slower strided fills
987+
bool UseOnlyFirstValue{true};
988+
for (auto Step{1u}; (Step < NumberOfSteps) && UseOnlyFirstValue; ++Step) {
989+
if (*(pPatternWords + Step) != *pPatternWords) {
990+
UseOnlyFirstValue = false;
991+
}
992+
}
993+
auto OptimizedNumberOfSteps{UseOnlyFirstValue ? 1u : NumberOfSteps};
994+
995+
// Fill the pattern in steps of BackendWordSize bytes. Use a continuous
996+
// fill in the first step because it's faster than a strided fill. Then,
997+
// overwrite the other values in subsequent steps.
998+
for (auto Step{0u}; Step < OptimizedNumberOfSteps; ++Step) {
999+
if (Step == 0) {
1000+
UR_CHECK_ERROR(continuousMemset(Ptr, *(pPatternWords),
1001+
Size / BackendWordSize, Stream));
1002+
} else {
1003+
UR_CHECK_ERROR(stridedMemset(Ptr + Step * BackendWordSize, Pitch,
1004+
*(pPatternWords + Step), 1u, Height,
1005+
Stream));
1006+
}
1007+
}
1008+
};
1009+
1010+
// Apply the implementation to the chosen pattern word type
1011+
switch (BackendWordSize) {
1012+
case 4u: {
1013+
memsetImpl(static_cast<const uint32_t *>(pPattern), cuMemsetD32Async,
1014+
cuMemsetD2D32Async);
1015+
break;
1016+
}
1017+
case 2u: {
1018+
memsetImpl(static_cast<const uint16_t *>(pPattern), cuMemsetD16Async,
1019+
cuMemsetD2D16Async);
1020+
break;
1021+
}
1022+
default: {
1023+
memsetImpl(static_cast<const uint8_t *>(pPattern), cuMemsetD8Async,
1024+
cuMemsetD2D8Async);
1025+
break;
9921026
}
1027+
}
1028+
9931029
return UR_RESULT_SUCCESS;
9941030
}
9951031

source/adapters/hip/enqueue.cpp

Lines changed: 62 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -712,25 +712,22 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueMemBufferCopyRect(
712712

713713
static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
714714
size_t Size, const void *pPattern,
715-
hipDeviceptr_t Ptr) {
715+
hipDeviceptr_t Ptr,
716+
uint32_t StartOffset) {
717+
// Calculate the number of times the pattern needs to be applied
718+
auto Height = Size / PatternSize;
716719

717-
// Calculate the number of patterns, stride and the number of times the
718-
// pattern needs to be applied.
719-
auto NumberOfSteps = PatternSize / sizeof(uint8_t);
720-
auto Pitch = NumberOfSteps * sizeof(uint8_t);
721-
auto Height = Size / NumberOfSteps;
722-
723-
for (auto step = 4u; step < NumberOfSteps; ++step) {
720+
for (auto step = StartOffset; step < PatternSize; ++step) {
724721
// take 1 byte of the pattern
725722
auto Value = *(static_cast<const uint8_t *>(pPattern) + step);
726723

727724
// offset the pointer to the part of the buffer we want to write to
728-
auto OffsetPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(Ptr) +
729-
(step * sizeof(uint8_t)));
725+
auto OffsetPtr =
726+
reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(Ptr) + step);
730727

731728
// set all of the pattern chunks
732-
UR_CHECK_ERROR(hipMemset2DAsync(OffsetPtr, Pitch, Value, sizeof(uint8_t),
733-
Height, Stream));
729+
UR_CHECK_ERROR(
730+
hipMemset2DAsync(OffsetPtr, PatternSize, Value, 1u, Height, Stream));
734731
}
735732
}
736733

@@ -743,11 +740,55 @@ static inline void memsetRemainPattern(hipStream_t Stream, uint32_t PatternSize,
743740
ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
744741
size_t Size, const void *pPattern,
745742
hipDeviceptr_t Ptr) {
743+
// Find the largest supported word size into which the pattern can be divided
744+
auto BackendWordSize = PatternSize % 4u == 0u ? 4u
745+
: PatternSize % 2u == 0u ? 2u
746+
: 1u;
747+
748+
// Calculate the number of patterns
749+
auto NumberOfSteps = PatternSize / BackendWordSize;
750+
751+
// If the pattern is 1 word or the first word is repeated throughout, a fast
752+
// continuous fill can be used without the need for slower strided fills
753+
bool UseOnlyFirstValue{true};
754+
auto checkIfFirstWordRepeats = [&UseOnlyFirstValue,
755+
NumberOfSteps](const auto *pPatternWords) {
756+
for (auto Step{1u}; (Step < NumberOfSteps) && UseOnlyFirstValue; ++Step) {
757+
if (*(pPatternWords + Step) != *pPatternWords) {
758+
UseOnlyFirstValue = false;
759+
}
760+
}
761+
};
746762

747-
// Get 4-byte chunk of the pattern and call hipMemsetD32Async
748-
auto Count32 = Size / sizeof(uint32_t);
749-
auto Value = *(static_cast<const uint32_t *>(pPattern));
750-
UR_CHECK_ERROR(hipMemsetD32Async(Ptr, Value, Count32, Stream));
763+
// Use a continuous fill for the first word in the pattern because it's faster
764+
// than a strided fill. Then, overwrite the other values in subsequent steps.
765+
switch (BackendWordSize) {
766+
case 4u: {
767+
auto *pPatternWords = static_cast<const uint32_t *>(pPattern);
768+
checkIfFirstWordRepeats(pPatternWords);
769+
UR_CHECK_ERROR(
770+
hipMemsetD32Async(Ptr, *pPatternWords, Size / BackendWordSize, Stream));
771+
break;
772+
}
773+
case 2u: {
774+
auto *pPatternWords = static_cast<const uint16_t *>(pPattern);
775+
checkIfFirstWordRepeats(pPatternWords);
776+
UR_CHECK_ERROR(
777+
hipMemsetD16Async(Ptr, *pPatternWords, Size / BackendWordSize, Stream));
778+
break;
779+
}
780+
default: {
781+
auto *pPatternWords = static_cast<const uint8_t *>(pPattern);
782+
checkIfFirstWordRepeats(pPatternWords);
783+
UR_CHECK_ERROR(
784+
hipMemsetD8Async(Ptr, *pPatternWords, Size / BackendWordSize, Stream));
785+
break;
786+
}
787+
}
788+
789+
if (UseOnlyFirstValue) {
790+
return UR_RESULT_SUCCESS;
791+
}
751792

752793
// There is a bug in ROCm prior to 6.0.0 version which causes hipMemset2D
753794
// to behave incorrectly when acting on host pinned memory.
@@ -761,7 +802,7 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
761802
// we need to check that isManaged attribute is false.
762803
if (ptrAttribs.hostPointer && !ptrAttribs.isManaged) {
763804
const auto NumOfCopySteps = Size / PatternSize;
764-
const auto Offset = sizeof(uint32_t);
805+
const auto Offset = BackendWordSize;
765806
const auto LeftPatternSize = PatternSize - Offset;
766807
const auto OffsetPatternPtr = reinterpret_cast<const void *>(
767808
reinterpret_cast<const uint8_t *>(pPattern) + Offset);
@@ -776,10 +817,12 @@ ur_result_t commonMemSetLargePattern(hipStream_t Stream, uint32_t PatternSize,
776817
Stream));
777818
}
778819
} else {
779-
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr);
820+
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr,
821+
BackendWordSize);
780822
}
781823
#else
782-
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr);
824+
memsetRemainPattern(Stream, PatternSize, Size, pPattern, Ptr,
825+
BackendWordSize);
783826
#endif
784827
return UR_RESULT_SUCCESS;
785828
}

0 commit comments

Comments
 (0)