Skip to content

Commit 65aab9e

Browse files
authored
[mlir][gpu] Generate multiple rank-specializations for tensor map cre… (#74082)
…ation The previous code was technically incorrect in that the type indicated that the memref only has 1 dimension, while the code below was happily dereferencing the size array out of bounds. Now, if the compiler doesn't get too smart about optimizations, this code *might even work*. But, if the compiler realizes that the array has 1 element it might starrt doing silly things. This generates a specialization per each supported rank, making sure we don't do any UB.
1 parent 9363658 commit 65aab9e

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

mlir/lib/ExecutionEngine/CudaRuntimeWrappers.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,9 +423,24 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void mgpuTensorMapEncodeTiled(
423423
elementStrides[4], interleave, swizzle, l2Promotion, oobFill);
424424
}
425425

426+
namespace {
427+
428+
template <int rank>
429+
void mgpuGetMemRefDataAndShape(void *raw_descriptor, char **addr,
430+
uint64_t *globalDim) {
431+
auto descriptor =
432+
reinterpret_cast<StridedMemRefType<char, rank> *>(raw_descriptor);
433+
*addr = descriptor->data;
434+
for (int i = 0; i < rank; ++i) {
435+
globalDim[i] = static_cast<uint64_t>(descriptor->sizes[rank - i - 1]);
436+
}
437+
}
438+
439+
} // namespace
440+
426441
extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
427442
int64_t tensorRank, // Dimensionality of tensor
428-
StridedMemRefType<char, 1> *descriptor, // Starting address
443+
void *ranked_descriptor, // Ranked MemRef descriptor
429444
const CUtensorMapDataType tensorDataType, // Stride size (in bytes)
430445
CUtensorMapInterleave interleave, // Type of interleaved layout
431446
CUtensorMapSwizzle swizzle, // Bank swizzling pattern
@@ -435,17 +450,39 @@ extern "C" MLIR_CUDA_WRAPPERS_EXPORT void *mgpuTensorMapEncodeTiledMemref(
435450
) {
436451
CUtensorMap tensorMap;
437452

438-
auto *globalAddress = descriptor->data;
439453
uint32_t boxDim[5] = {1, 1, 1, 1, 1}, elementStrides[5] = {1, 1, 1, 1, 1};
440454
uint64_t globalDim[5] = {1, 1, 1, 1, 1}, globalStrides[5] = {0};
441455
uint32_t tensorRank32 = uint32_t(tensorRank);
442456

457+
char *globalAddress = nullptr;
458+
switch (tensorRank) {
459+
case 1:
460+
mgpuGetMemRefDataAndShape<1>(ranked_descriptor, &globalAddress, globalDim);
461+
break;
462+
case 2:
463+
mgpuGetMemRefDataAndShape<2>(ranked_descriptor, &globalAddress, globalDim);
464+
break;
465+
case 3:
466+
mgpuGetMemRefDataAndShape<3>(ranked_descriptor, &globalAddress, globalDim);
467+
break;
468+
case 4:
469+
mgpuGetMemRefDataAndShape<4>(ranked_descriptor, &globalAddress, globalDim);
470+
break;
471+
case 5:
472+
mgpuGetMemRefDataAndShape<5>(ranked_descriptor, &globalAddress, globalDim);
473+
break;
474+
default:
475+
fprintf(
476+
stderr,
477+
"'mgpuTensorMapEncodeTiledMemref' failed with 'rank is too high'\n");
478+
return NULL;
479+
}
480+
443481
static const int elementSizeInBytes[] = {1, 2, 4, 4, 8, 8, 2,
444482
4, 8, 2, 4, 4, 4};
445483
for (int64_t r = 0; r < tensorRank; ++r) {
446484
elementStrides[r] = uint32_t(1);
447485
boxDim[r] = static_cast<uint32_t>(inputBoxDims[tensorRank - r - 1]);
448-
globalDim[r] = static_cast<uint64_t>(descriptor->sizes[tensorRank - r - 1]);
449486
}
450487

451488
globalStrides[0] = globalDim[0] * elementSizeInBytes[tensorDataType];

0 commit comments

Comments
 (0)