无锁队列实现(Michael Scott),伪代码与c++实现
一、Michael & Scoot 原版伪代码实现
structure pointer_t {ptr: pointer to node_t, count: unsigned integer}structure node_t {value: data type, next: pointer_t}structure queue_t {Head: pointer_t, Tail: pointer_t}initialize(Q: pointer to queue_t)node = new_node() // Allocate a free nodenode->next.ptr = NULL // Make it the only node in the linked listQ->Head.ptr = Q->Tail.ptr = node // Both Head and Tail point to itenqueue(Q: pointer to queue_t, value: data type)E1: node = new_node() // Allocate a new node from the free listE2: node->value = value // Copy enqueued value into nodeE3: node->next.ptr = NULL // Set next pointer of node to NULLE4: loop // Keep trying until Enqueue is doneE5: tail = Q->Tail // Read Tail.ptr and Tail.count togetherE6: next = tail.ptr->next // Read next ptr and count fields togetherE7: if tail == Q->Tail // Are tail and next consistent?// Was Tail pointing to the last node?E8: if next.ptr == NULL// Try to link node at the end of the linked listE9: if CAS(&tail.ptr->next, next, <node, next.count+1>)E10: break // Enqueue is done. Exit loopE11: endifE12: else // Tail was not pointing to the last node// Try to swing Tail to the next nodeE13: CAS(&Q->Tail, tail, <next.ptr, tail.count+1>)E14: endifE15: endifE16: endloop// Enqueue is done. Try to swing Tail to the inserted nodeE17: CAS(&Q->Tail, tail, <node, tail.count+1>)dequeue(Q: pointer to queue_t, pvalue: pointer to data type): booleanD1: loop // Keep trying until Dequeue is doneD2: head = Q->Head // Read HeadD3: tail = Q->Tail // Read TailD4: next = head.ptr->next // Read Head.ptr->nextD5: if head == Q->Head // Are head, tail, and next consistent?D6: if head.ptr == tail.ptr // Is queue empty or Tail falling behind?D7: if next.ptr == NULL // Is queue empty?D8: return FALSE // Queue is empty, couldn't dequeueD9: endif// Tail is falling behind. Try to advance itD10: CAS(&Q->Tail, tail, <next.ptr, tail.count+1>)D11: else // No need to deal with Tail// Read value before CAS// Otherwise, another dequeue might free the next nodeD12: *pvalue = next.ptr->value// Try to swing Head to the next nodeD13: if CAS(&Q->Head, head, <next.ptr, head.count+1>)D14: break // Dequeue is done. Exit loopD15: endifD16: endifD17: endifD18: endloopD19: free(head.ptr) // It is safe now to free the old nodeD20: return TRUE // Queue was not empty, dequeue succeeded
二、C++实现
c++的实现就直接看上述的伪代码跟着实现即可,这里的一些atomic操作也可以看我之前写的博客
template<typename T>
class LockFreeQueue {
private:// 队列结构struct Node {std::shared_ptr<T> data;std::atomic<Node*> next;Node() : next(nullptr) {};};std::atomic<Node*> head; // 头节点std::atomic<Node*> tail; // 尾节点Node* dummy; // 用于回收节点的哑节点public:LockFreeQueue() {dummy = new Node();head.store(dummy);tail.store(dummy);}~LockFreeQueue() {T output;while (dequeue(output)) {}delete dummy;}// 禁止拷贝构造和赋值LockFreeQueue(const LockFreeQueue&) = delete;LockFreeQueue& operator=(const LockFreeQueue&) = delete;void enqueue(const T& value) {std::shared_ptr<T> new_data(std::make_shared<T>(value)); // 创建数据Node* new_node = new Node(); // 创建新节点Node* old_tail;while (true) {old_tail = tail.load();Node* next = old_tail->next.load();if (old_tail == tail.load()) { // 此时保证tail还没有被其他线程改变if (next == nullptr) { // 此时保证tail是队列的最后一个节点if (old_tail->next.compare_exchange_weak(next, new_node)) break; // 插入成功,这是一个原语,可以一次操作}else { // 说明tail落后了,推进tail指针tail.compare_exchange_weak(old_tail, next);}}}old_tail->data = new_data;tail.compare_exchange_weak(old_tail, new_node);}bool dequeue(T& value) {Node* old_head;while (true) {old_head = head.load(); // 获取当前头部节点Node* old_tail = tail.load(); // 获取当前尾部节点Node* next = old_head->next.load();if (old_head == head.load()) { // 确保head没有被改变if (old_head == old_tail) { // 队列为空,或者tail落后if (next == nullptr) { // 队列为空return false;}tail.compare_exchange_weak(old_tail, next);} else {// 从队列中移除head,并且读取其值if (next->data) {value = *(next->data);if (head.compare_exchange_weak(old_head, next)) break; // 清除head节点,退出循环}}}}delete old_head;return true;}
};
三、测试
测试代码如下:
int main() {LockFreeQueue<int> queue;std::thread t1([&queue](){for(int i=0; i<100; i++) {queue.enqueue(i);std::this_thread::sleep_for(std::chrono::milliseconds(100));}});std::thread t2([&queue](){for(int i=100; i<200; i++) {queue.enqueue(i);std::this_thread::sleep_for(std::chrono::milliseconds(100));}});std::thread t3([&queue]() {while (true) {int i = 0;if (queue.dequeue(i)) {std::cout << "t3:" << i << std::endl;}}});std::thread t4([&queue]() {while (true) {int i = 0;if (queue.dequeue(i)) {std::cout << "t4:" << i << std::endl;}}});t1.join();t2.join();t3.join();t4.join();
}
可以看到,多线程下可以顺利的插入与找到数据