Skip to content

Commit 06c05d8

Browse files
committed
Threadpool implementation
1 parent 6e76c98 commit 06c05d8

File tree

7 files changed

+324
-44
lines changed

7 files changed

+324
-44
lines changed

source/adapters/native_cpu/device.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,8 @@ UR_APIEXPORT ur_result_t UR_APICALL urDeviceGetInfo(ur_device_handle_t hDevice,
9898
case UR_DEVICE_INFO_LINKER_AVAILABLE:
9999
return ReturnValue(bool{false});
100100
case UR_DEVICE_INFO_MAX_COMPUTE_UNITS:
101-
return ReturnValue(uint32_t{256});
101+
return ReturnValue(static_cast<uint32_t>(
102+
hDevice->tp.num_threads()));
102103
case UR_DEVICE_INFO_PARTITION_MAX_SUB_DEVICES:
103104
return ReturnValue(uint32_t{0});
104105
case UR_DEVICE_INFO_SUPPORTED_PARTITIONS:

source/adapters/native_cpu/device.hpp

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,17 @@
1111
#pragma once
1212

1313
#include <ur/ur.hpp>
14+
#include "threadpool.hpp"
1415

1516
struct ur_device_handle_t_ {
16-
ur_device_handle_t_(ur_platform_handle_t ArgPlt) : Platform(ArgPlt) {}
17+
native_cpu::threadpool_t tp;
18+
ur_device_handle_t_(ur_platform_handle_t ArgPlt) : Platform(ArgPlt) {
19+
tp.start();
20+
}
21+
22+
~ur_device_handle_t_() {
23+
tp.stop();
24+
}
1725

1826
ur_platform_handle_t Platform;
1927
};

source/adapters/native_cpu/enqueue.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
//===----------- enqueue.cpp - NATIVE CPU Adapter -------------------------===//
22
//
3-
// Copyright (C) 2023 Intel Corporation
4-
//
5-
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6-
// Exceptions. See LICENSE.TXT
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
75
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
86
//
97
//===----------------------------------------------------------------------===//
@@ -15,6 +13,8 @@
1513
#include "common.hpp"
1614
#include "kernel.hpp"
1715
#include "memory.hpp"
16+
#include "threadpool.hpp"
17+
#include "queue.hpp"
1818

1919
namespace native_cpu {
2020
struct NDRDescT {
@@ -61,14 +61,27 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
6161

6262
// TODO: add proper error checking
6363
// TODO: add proper event dep management
64-
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize,
65-
pLocalWorkSize);
66-
hKernel->handleLocalArgs();
64+
native_cpu::NDRDescT ndr(workDim, pGlobalWorkOffset, pGlobalWorkSize, pLocalWorkSize);
65+
auto& tp = hQueue->device->tp;
66+
const size_t numParallelThreads = tp.num_threads();
67+
hKernel->updateMemPool(numParallelThreads);
68+
std::vector<std::future<void>> futures;
69+
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
70+
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
71+
auto numWG2 = ndr.GlobalSize[2] / ndr.LocalSize[2];
72+
bool isLocalSizeOne =
73+
ndr.LocalSize[0] == 1 && ndr.LocalSize[1] == 1 && ndr.LocalSize[2] == 1;
74+
6775

6876
native_cpu::state state(ndr.GlobalSize[0], ndr.GlobalSize[1],
6977
ndr.GlobalSize[2], ndr.LocalSize[0], ndr.LocalSize[1],
7078
ndr.LocalSize[2], ndr.GlobalOffset[0],
7179
ndr.GlobalOffset[1], ndr.GlobalOffset[2]);
80+
if (isLocalSizeOne) {
81+
// If the local size is one, we make the assumption that we are running a
82+
// parallel_for over a sycl::range Todo: we could add compiler checks and
83+
// kernel properties for this (e.g. check that no barriers are called, no
84+
// local memory args).
7285

7386
auto numWG0 = ndr.GlobalSize[0] / ndr.LocalSize[0];
7487
auto numWG1 = ndr.GlobalSize[1] / ndr.LocalSize[1];
@@ -92,6 +105,9 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueKernelLaunch(
92105
}
93106
}
94107
}
108+
109+
for (auto &f : futures)
110+
f.get();
95111
// TODO: we should avoid calling clear here by avoiding using push_back
96112
// in setKernelArgs.
97113
hKernel->_args.clear();
@@ -537,3 +553,4 @@ UR_APIEXPORT ur_result_t UR_APICALL urEnqueueWriteHostPipe(
537553

538554
DIE_NO_IMPLEMENTATION;
539555
}
556+

source/adapters/native_cpu/kernel.hpp

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
//===--------------- kernel.hpp - Native CPU Adapter ----------------------===//
22
//
3-
// Copyright (C) 2023 Intel Corporation
4-
//
5-
// Part of the Unified-Runtime Project, under the Apache License v2.0 with LLVM
6-
// Exceptions. See LICENSE.TXT
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
75
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
86
//
97
//===----------------------------------------------------------------------===//
@@ -42,50 +40,53 @@ struct ur_kernel_handle_t_ : RefCounted {
4240
ur_kernel_handle_t_(const char *name, nativecpu_task_t subhandler)
4341
: _name{name}, _subhandler{std::move(subhandler)} {}
4442

45-
const char *_name;
46-
nativecpu_task_t _subhandler;
47-
std::vector<native_cpu::NativeCPUArgDesc> _args;
48-
std::vector<local_arg_info_t> _localArgInfo;
49-
50-
// To be called before enqueing the kernel.
51-
void handleLocalArgs() {
52-
updateMemPool();
53-
size_t offset = 0;
54-
for (auto &entry : _localArgInfo) {
55-
_args[entry.argIndex].MPtr =
56-
reinterpret_cast<char *>(_localMemPool) + offset;
57-
// update offset in the memory pool
58-
// Todo: update this offset computation when we have work-group
59-
// level parallelism.
60-
offset += entry.argSize;
61-
}
43+
ur_kernel_handle_t_(const ur_kernel_handle_t_& other) : _name(other._name), _subhandler(other._subhandler),
44+
_args(other._args), _localArgInfo(other._localArgInfo), _localMemPool(other._localMemPool), _localMemPoolSize(other._localMemPoolSize) {
45+
incrementReferenceCount();
6246
}
6347

6448
~ur_kernel_handle_t_() {
65-
if (_localMemPool) {
49+
decrementReferenceCount();
50+
if (_refCount == 0) {
6651
free(_localMemPool);
6752
}
53+
6854
}
6955

70-
private:
71-
void updateMemPool() {
56+
const char *_name;
57+
nativecpu_task_t _subhandler;
58+
std::vector<native_cpu::NativeCPUArgDesc> _args;
59+
std::vector<local_arg_info_t> _localArgInfo;
60+
61+
// To be called before enqueing the kernel.
62+
void updateMemPool(size_t numParallelThreads) {
7263
// compute requested size.
73-
// Todo: currently we execute only one work-group at a time, so for each
74-
// local arg we can allocate just 1 * argSize local arg. When we implement
75-
// work-group level parallelism we should allocate N * argSize where N is
76-
// the number of work groups being executed in parallel (e.g. number of
77-
// threads in the thread pool).
7864
size_t reqSize = 0;
7965
for (auto &entry : _localArgInfo) {
80-
reqSize += entry.argSize;
66+
reqSize += entry.argSize * numParallelThreads;
8167
}
8268
if (reqSize == 0 || reqSize == _localMemPoolSize) {
8369
return;
8470
}
8571
// realloc handles nullptr case
86-
_localMemPool = realloc(_localMemPool, reqSize);
72+
_localMemPool = (char*)realloc(_localMemPool, reqSize);
8773
_localMemPoolSize = reqSize;
8874
}
89-
void *_localMemPool = nullptr;
75+
76+
// To be called before executing a work group
77+
void handleLocalArgs(size_t numParallelThread, size_t threadId) {
78+
// For each local argument we have size*numthreads
79+
size_t offset = 0;
80+
for (auto &entry : _localArgInfo) {
81+
_args[entry.argIndex].MPtr =
82+
_localMemPool + offset + (entry.argSize * threadId);
83+
// update offset in the memory pool
84+
offset += entry.argSize * numParallelThread;
85+
}
86+
}
87+
88+
private:
89+
char* _localMemPool = nullptr;
9090
size_t _localMemPoolSize = 0;
9191
};
92+

source/adapters/native_cpu/queue.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ UR_APIEXPORT ur_result_t UR_APICALL urQueueCreate(
3535
std::ignore = hDevice;
3636
std::ignore = pProperties;
3737

38-
auto Queue = new ur_queue_handle_t_();
38+
auto Queue = new ur_queue_handle_t_(hDevice);
3939
*phQueue = Queue;
4040

41-
CONTINUE_NO_IMPLEMENTATION;
41+
return UR_RESULT_SUCCESS;
4242
}
4343

4444
UR_APIEXPORT ur_result_t UR_APICALL urQueueRetain(ur_queue_handle_t hQueue) {

source/adapters/native_cpu/queue.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,5 +9,11 @@
99
//===----------------------------------------------------------------------===//
1010
#pragma once
1111
#include "common.hpp"
12+
#include "device.hpp"
1213

13-
struct ur_queue_handle_t_ : RefCounted {};
14+
struct ur_queue_handle_t_ : RefCounted {
15+
ur_device_handle_t_ *device;
16+
17+
ur_queue_handle_t_(ur_device_handle_t_ *device) : device(device) {}
18+
19+
};

0 commit comments

Comments
 (0)