c++

Thead Pool

Posted by keming on August 16, 2021

线程池的实现

  • 头文件
#ifndef B_H
#define B_H

#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>

using namespace std;

enum PriortyLevel { kLevelLow, kLevelMid, kLevelHigh };

class ThreadPool {
  public:
    using Task = function<void()>;
    using TaskPair = pair<Task, PriortyLevel>;
    struct TaskPairCmp {
        bool operator()(const TaskPair &tp1, const TaskPair &tp2) const { return tp1.second < tp2.second; }
    };

    void Start(int thread_num);
    void Stop();

    void AddTask(const Task &task, PriortyLevel level);
    void Loop();    // real excute code

  private:
    Task GetOneTask();

    vector<thread *> threads_;
    bool is_running_;
    priority_queue<TaskPair, vector<TaskPair>, TaskPairCmp> tasks_;
    mutex m;
    condition_variable cv;
};

#endif
  • 实现
#include "b.h"
#include "assert.h"

void ThreadPool::Start(int thread_num) {
    assert(thread_num > 0);
    for (int i = 0; i < thread_num; i++) {
        threads_.emplace_back(new thread(bind(&ThreadPool::Loop, this)));   // avoid static
    }
}

void ThreadPool::Stop() {
    unique_lock<mutex> ul(m);
    is_running_ = false;
    cv.notify_all();

    for (auto t : threads_) {
        t->join();
        delete t;
    }
    threads_.clear();
}

ThreadPool::Task ThreadPool::GetOneTask() {
    unique_lock<mutex> ul(m);
    while (is_running_ && tasks_.empty()) { // for spurious wakeup
        cv.wait(ul);
    }
    auto ans = tasks_.top().first;
    tasks_.pop();
    return ans;
}

void ThreadPool::Loop() {
    while (is_running_) {
        auto task = GetOneTask();
        task();
    }
}

void ThreadPool::AddTask(const ThreadPool::Task &task, PriortyLevel level) {
    unique_lock<mutex> ul(m);
    tasks_.push({task, level});
    cv.notify_one();
}