Skip to content
This repository was archived by the owner on Mar 28, 2023. It is now read-only.

[LIT][SYCL] Avoid fp64 support requirement #1410

Merged
merged 1 commit into from
Nov 23, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 73 additions & 68 deletions SYCL/Plugin/interop-level-zero-thread.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,96 +178,101 @@ sycl::event operation(sycl::queue q) {
}

int main(int argc, char *argv[]) {
size_t count = 100;
try {
size_t count = 100;

int size = 0;
int rank = 0;
int size = 0;
int rank = 0;

size_t num_iters = 20;
size_t kernel_num = 3;
size_t num_iters = 20;
size_t kernel_num = 3;

if (argc > 1)
kernel_num = atoi(argv[1]);
if (argc > 2)
count = atoi(argv[2]);
if (argc > 3)
num_iters = atoi(argv[3]);
if (argc > 1)
kernel_num = atoi(argv[1]);
if (argc > 2)
count = atoi(argv[2]);
if (argc > 3)
num_iters = atoi(argv[3]);

size_t byte_count = count * 4;
size_t byte_count = count * 4;

sycl::property_list props{sycl::property::queue::in_order{},
sycl::property::queue::enable_profiling{}};
sycl::queue q{props};
sycl::property_list props{sycl::property::queue::in_order{},
sycl::property::queue::enable_profiling{}};
sycl::queue q{props};

init();
init();

// Store allocated mem ptrs to free them later
std::vector<std::pair<float *, float *>> ptrs(kernel_num);
// allocate all the buffers
for (size_t i = 0; i < kernel_num; i++) {
float *weight_buf = (float *)sycl::malloc_device(byte_count, q);
float *weight_allreduce_buf = (float *)sycl::malloc_device(byte_count, q);
ptrs[i] = {weight_buf, weight_allreduce_buf};
}
// Store allocated mem ptrs to free them later
std::vector<std::pair<float *, float *>> ptrs(kernel_num);
// allocate all the buffers
for (size_t i = 0; i < kernel_num; i++) {
float *weight_buf = (float *)sycl::malloc_device(byte_count, q);
float *weight_allreduce_buf = (float *)sycl::malloc_device(byte_count, q);
ptrs[i] = {weight_buf, weight_allreduce_buf};
}

std::vector<std::tuple<sycl::event, sycl::event>> kernel_events(num_iters *
kernel_num);
std::vector<std::tuple<sycl::event, sycl::event>> kernel_events(num_iters *
kernel_num);

std::vector<sycl::event> barrier_events;
std::vector<sycl::event> barrier_events;

std::thread worker_thread(worker);
std::thread worker_thread(worker);

for (size_t i = 0; i < num_iters; ++i) {
std::cout << "Running iteration " << i << std::endl;
for (size_t i = 0; i < num_iters; ++i) {
std::cout << "Running iteration " << i << std::endl;

for (size_t j = 0; j < kernel_num; j++) {
size_t num = i * kernel_num + j;
float *weight_buf = ptrs[j].first;
float *weight_allreduce_buf = ptrs[j].second;
for (size_t j = 0; j < kernel_num; j++) {
size_t num = i * kernel_num + j;
float *weight_buf = ptrs[j].first;
float *weight_allreduce_buf = ptrs[j].second;

// Step1: FWK kernel submission
sycl::event submit_event;
if (i == 0) {
submit_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Initial weight in first iteration
weight_buf[id] = j * (rank + 1);
// Step1: FWK kernel submission
sycl::event submit_event;
if (i == 0) {
submit_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Initial weight in first iteration
weight_buf[id] = j * (rank + 1);
});
});
});
} else {
submit_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Make weight differ in each iteration
weight_buf[id] = weight_buf[id] + (j * (rank + 1));
} else {
submit_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Make weight differ in each iteration
weight_buf[id] = weight_buf[id] + (j * (rank + 1));
});
});
});
}
}

barrier_events.push_back(operation(q));
barrier_events.push_back(operation(q));

// Step3: Weight update
auto update_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Update weight in each iteration
weight_buf[id] = weight_allreduce_buf[id] * 0.5;
// Step3: Weight update
auto update_event = q.submit([&](auto &h) {
h.parallel_for(count, [=](auto id) {
// Update weight in each iteration
weight_buf[id] = weight_allreduce_buf[id] * 0.5f;
});
});
});

kernel_events[num] = {submit_event, update_event};
kernel_events[num] = {submit_event, update_event};
}
q.wait();
}
q.wait();
}

// Make sure there is no exceptions in the queue
q.wait_and_throw();
// Make sure there is no exceptions in the queue
q.wait_and_throw();

for (auto p : ptrs) {
sycl::free(p.first, q);
sycl::free(p.second, q);
}
for (auto p : ptrs) {
sycl::free(p.first, q);
sycl::free(p.second, q);
}

stop_worker = true;
cv.notify_all();
worker_thread.join();
stop_worker = true;
cv.notify_all();
worker_thread.join();
} catch (std::exception &E) {
std::cout << E.what() << std::endl;
return 1;
}
return 0;
}