Skip to content

Commit

Permalink
Merge pull request #12 from DeveloperPaul123/fix/spurious-unit-test-f…
Browse files Browse the repository at this point in the history
…ailures
  • Loading branch information
DeveloperPaul123 authored May 14, 2022
2 parents 3a10912 + 7e362b1 commit 17b3cfe
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 16 deletions.
37 changes: 22 additions & 15 deletions include/thread_pool/thread_pool.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <atomic>
#include <concepts>
#include <deque>
#include <functional>
Expand All @@ -24,25 +25,29 @@ namespace dp {
for (std::size_t i = 0; i < number_of_threads; ++i) {
threads_.emplace_back([&, id = i](const std::stop_token &stop_tok) {
do {
// invoke the task
while (auto task = tasks_[id].tasks.pop()) {
try {
std::invoke(std::move(task.value()));
} catch (...) {
}
}
// wait until signaled
tasks_[id].signal.acquire();

// try to steal a task
for (std::size_t j = 1; j < tasks_.size(); ++j) {
const std::size_t index = (id + j) % tasks_.size();
if (auto task = tasks_[index].tasks.steal()) {
std::invoke(std::move(task.value()));
do {
// invoke the task
while (auto task = tasks_[id].tasks.pop()) {
try {
pending_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
} catch (...) {
}
}
}

// no tasks, so we wait instead of spinning
tasks_[id].signal.acquire();
// try to steal a task
for (std::size_t j = 1; j < tasks_.size(); ++j) {
const std::size_t index = (id + j) % tasks_.size();
if (auto task = tasks_[index].tasks.steal()) {
pending_tasks_.fetch_sub(1, std::memory_order_release);
std::invoke(std::move(task.value()));
}
}

} while (pending_tasks_.load(std::memory_order_acquire) > 0);
} while (!stop_tok.stop_requested());
});
}
Expand Down Expand Up @@ -127,6 +132,7 @@ namespace dp {
template <typename Function>
void enqueue_task(Function &&f) {
const std::size_t i = count_++ % tasks_.size();
pending_tasks_.fetch_add(1, std::memory_order_relaxed);
tasks_[i].tasks.push(std::forward<Function>(f));
tasks_[i].signal.release();
}
Expand All @@ -139,6 +145,7 @@ namespace dp {
std::vector<std::jthread> threads_;
std::deque<task_item> tasks_;
std::size_t count_{};
std::atomic_int_fast64_t pending_tasks_{};
};

/**
Expand Down
2 changes: 1 addition & 1 deletion test/source/thread_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ TEST_CASE("Support params of different types") {

TEST_CASE("Ensure work completes upon destruction") {
std::atomic<int> counter;
constexpr auto total_tasks = 20;
constexpr auto total_tasks = 30;
{
dp::thread_pool pool(4);
for (auto i = 0; i < total_tasks; i++) {
Expand Down

0 comments on commit 17b3cfe

Please sign in to comment.