Skip to content

Commit d87a716

Browse files
Added support for vector types
1 parent 7e9069f commit d87a716

File tree

1 file changed

+23
-4
lines changed

1 file changed

+23
-4
lines changed

lib/gc/Transforms/GPU/GpuToGpuOcl.cpp

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -415,11 +415,30 @@ struct ConvertLaunch final : ConvertOpPattern<gpu::LaunchFuncOp> {
415415
blockSize.emplace_back(gpuLaunch.getBlockSizeY());
416416
blockSize.emplace_back(gpuLaunch.getBlockSizeZ());
417417

418-
for (auto arg : adaptor.getKernelOperands()) {
418+
for (auto arg : gpuLaunch.getKernelOperands()) {
419419
auto type = arg.getType();
420-
// Assuming, that the value is either an integer or a float or a pointer.
421-
// In the latter case, the size is 0 bytes.
422-
auto size = type.isIntOrFloat() ? type.getIntOrFloatBitWidth() / 8 : 0;
420+
size_t size;
421+
if (isa<MemRefType>(type)) {
422+
size = 0; // A special case for pointers
423+
} else if (type.isIndex()) {
424+
size = helper.idxType.getIntOrFloatBitWidth() / 8;
425+
} else if (type.isIntOrFloat()) {
426+
size = type.getIntOrFloatBitWidth() / 8;
427+
} else if (auto vectorType = dyn_cast<VectorType>(type)) {
428+
type = vectorType.getElementType();
429+
if (type.isIntOrFloat()) {
430+
size = type.getIntOrFloatBitWidth();
431+
} else if (type.isIndex()) {
432+
size = helper.idxType.getIntOrFloatBitWidth();
433+
} else {
434+
llvm::errs() << "Unsupported vector element type: " << type << "\n";
435+
return false;
436+
}
437+
size *= vectorType.getNumElements() / 8;
438+
} else {
439+
llvm::errs() << "Unsupported type: " << type << "\n";
440+
return false;
441+
}
423442
argSize.emplace_back(helper.idxConstant(rewriter, loc, size));
424443
}
425444

0 commit comments

Comments
 (0)