纯CPU场景下C++的分布式模型训练框架设计思路
0. 参数分配
- 稠密参数 → MPI 集合通信(All-Reduce / Broadcast / Reduce-Scatter)。
- 稀疏参数 → brpc Parameter Server 异步推拉。
完全去掉 NCCL/GPU 相关部分。
1. 整体拓扑
┌----------------┐ ┌----------------┐
│ Worker-0 │ │ PS-0 │
│ Worker-1 │◄------► │ PS-1 │
│ ... │ brpc │ ... │
│ Worker-N │ │ PS-M │
└----------------┘ └----------------┘▲│MPI(TCP/InfiniBand)▼
MPI_COMM_WORLD(稠密参数)
- 稠密梯度:通过 MPI 标准集合操作(
MPI_Allreduce
、MPI_Bcast
等)实现同步。 - 稀疏参数:Worker 与 PS 之间用 brpc + protobuf 通信,异步推拉。
2. 关键模块(C++)
cpu_dist/
├── common/
│ ├── tensor.h // 纯 CPU 张量(FP32/FP64)
│ └── mpi_context.h // MPI_Init / Finalize 封装
├── dense/
│ ├── mpi_allreduce.h // MPI All-Reduce 封装
│ └── optimizer.h // 本地 SGD / AdamW
├── sparse/
│ ├── ps_server.h/cc // brpc Parameter Server
│ ├── ps_client.h/cc // brpc Client
│ └── table.h // 稀疏表(unordered_map + 锁)
├── proto/
│ └── message.proto // protobuf 消息
└── launcher.cc // 主进程入口
3. MPI 通信层(稠密参数)
3.1 封装 MPI All-Reduce
// dense/mpi_allreduce.h
class MPIAllReduce {public:explicit MPIAllReduce(MPI_Comm comm) : comm_(comm) {}template <typename T>void AllReduceSum(std::vector<T>& buf) {std::vector<T> recv(buf.size());MPI_Allreduce(buf.data(), recv.data(), buf.size(),GetMPIType<T>(), MPI_SUM, comm_);buf.swap(recv);}private:MPI_Comm comm_;
};
- 支持 float / double / int。
- 支持 In-place All-Reduce(
MPI_IN_PLACE
)。
4. brpc Parameter Server(稀疏参数)
与之前设计一致,仅通信后端为 brpc:
- proto 定义不变(
PullRequest
,PushRequest
)。 - PS 端 实现
brpc::Service
,用brpc::Server
启动。 - Worker 端 用
brpc::Channel
连接 PS,支持 轮询/一致性哈希 负载均衡。
5. 主进程结构(launcher.cc)
int main(int argc, char* argv[]) {MPI_Init(&argc, &argv);int rank, size;MPI_Comm_rank(MPI_COMM_WORLD, &rank);MPI_Comm_size(MPI_COMM_WORLD, &size);bool is_ps = (rank >= FLAGS_worker_num);if (!is_ps) {// WorkerMPIAllReduce ar(MPI_COMM_WORLD);PSClient ps(FLAGS_ps_list);WorkerLoop(ar, ps);} else {// Parameter ServerPSServer server;server.Start(FLAGS_ps_port);}MPI_Finalize();
}
6. Worker 主循环
void WorkerLoop(MPIAllReduce& ar, PSClient& ps) {Model model;DataLoader dl(FLAGS_data_path);for (int step = 0; step < FLAGS_max_step; ++step) {auto batch = dl.Next();std::vector<float> dense_grad;std::vector<int64_t> sparse_keys;std::vector<float> sparse_grad;// 前向 & 反向model.Backward(batch, &dense_grad, &sparse_keys, &sparse_grad);// 1. 稠密梯度 MPI All-Reducear.AllReduceSum(dense_grad);// 2. 稀疏梯度异步 Pushps.PushAsync(0, sparse_keys, sparse_grad);// 3. 稀疏参数 Pullstd::vector<float> sparse_emb;ps.Pull(0, sparse_keys, &sparse_emb);// 4. 参数更新model.Update(dense_grad, sparse_emb);}
}
7. 部署与运行
7.1 启动脚本(OpenMPI)
# 4 worker + 2 ps
mpirun -np 6 \-x LD_LIBRARY_PATH \./launcher \--worker_num 4 \--ps_list "0.0.0.0:8000,0.0.0.0:8001"
- worker_num 前
rank 0~3
为 Worker,后rank 4~5
为 PS。 - MPI 负责稠密通信,brpc 负责稀疏通信,两者互不干扰。
8. 性能调优
项 | 建议 |
---|---|
MPI | 使用 OpenMPI 4.x 或 Intel MPI(CPU 亲和、NUMA 优化)。 |
brpc | 配置 轮询 + 批处理(64~256 key/RPC),开启 8bit 量化压缩。 |
线程 | MPI 与 brpc 线程分离,brpc 用 bthread ,避免与 MPI 线程冲突。 |
至此,“CPU + MPI(稠密) + brpc Parameter Server(稀疏)” 的完整框架已就绪。