当前位置: 首页 > news >正文

从零构建深度学习推理框架-3 手写算子relu

Relu介绍:

f(x) = \left\{\begin{matrix}x , x>thresh & & \\0,x<thresh & & \end{matrix}\right.

 relu是一个非线性激活函数,可以避免梯度消失,过拟合等情况。我们一般将thresh设为0。

operator类:

#ifndef KUIPER_COURSE_INCLUDE_OPS_OP_HPP_
#define KUIPER_COURSE_INCLUDE_OPS_OP_HPP_
namespace kuiper_infer {
enum class OpType {kOperatorUnknown = -1,kOperatorRelu = 0,
};class Operator {public:OpType op_type_ = OpType::kOperatorUnknown; //不是一个具体节点 制定为unknownvirtual ~Operator() = default; //explicit Operator(OpType op_type);
};

这里的  kOperatorUnknown = -1 , kOperatorRelu = 0分别是他们的代号

operator是一个父类,我们的relu就要继承于这个父类

class ReluOperator : public Operator {public:~ReluOperator() override = default;explicit ReluOperator(float thresh);void set_thresh(float thresh);float get_thresh() const;private:// 需要传递到reluLayer中,怎么传递?float thresh_ = 0.f; // 用于过滤tensor<float>值当中大于thresh的部分// relu存的变量只有thresh// stride padding kernel_size 这些是到时候convOperator需要的// operator起到了属性存储、变量的作用// operator所有子类不负责具体运算// 具体运算由另外一个类Layer类负责// y =x  , if x >=0 y = 0 if x < 0};

 operator起到了属性存储、变量的作用
 operator所有子类不负责具体运算
 具体运算由另外一个类Layer类负责

layer类:

class Layer {public:explicit Layer(const std::string &layer_name);virtual void Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,std::vector<std::shared_ptr<Tensor<float>>> &outputs);// reluLayer中 inputs 等于 x , outputs 等于 y= x,if x>0// 计算得到的结果放在y当中,x是输入,放在inputs中virtual ~Layer() = default;private:std::string layer_name_; //relu layer "relu"
};

父类只保留了一个layer_name属性和两个方法。

具体的在relu_layer这个class中

class ReluLayer : public Layer {public:~ReluLayer() override = default;// 通过这里,把relu_op中的thresh告知给relu layer, 因为计算的时候要用到explicit ReluLayer(const std::shared_ptr<Operator> &op);// 执行relu 操作的具体函数Forwardsvoid Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,std::vector<std::shared_ptr<Tensor<float>>> &outputs) override;// 下节的内容,不用管static std::shared_ptr<Layer> CreateInstance(const std::shared_ptr<Operator> &op);private:std::unique_ptr<ReluOperator> op_;
};

具体的方法实现:

ReluLayer::ReluLayer(const std::shared_ptr<Operator> &op) : Layer("Relu") {CHECK(op->op_type_ == OpType::kOperatorRelu) << "Operator has a wrong type: " << int(op->op_type_);// dynamic_cast是什么意思? 就是判断一下op指针是不是指向一个relu_op类的指针// 这边的op不是ReluOperator类型的指针,就报错// 我们这里只接受ReluOperator类型的指针// 父类指针必须指向子类ReluOperator类型的指针// 为什么不讲构造函数设置为const std::shared_ptr<ReluOperator> &op?// 为了接口统一,具体下节会说到ReluOperator *relu_op = dynamic_cast<ReluOperator *>(op.get());CHECK(relu_op != nullptr) << "Relu operator is empty";// 一个op实例和一个layer 一一对应 这里relu op对一个relu layer// 对应关系this->op_ = std::make_unique<ReluOperator>(relu_op->get_thresh());
}void ReluLayer::Forwards(const std::vector<std::shared_ptr<Tensor<float>>> &inputs,std::vector<std::shared_ptr<Tensor<float>>> &outputs) {// relu 操作在哪里,这里!// 我需要该节点信息的时候 直接这么做// 实行了属性存储和运算过程的分离!!!!!!!!!!!!!!!!!!!!!!!!//x就是inputs y = outputsCHECK(this->op_ != nullptr);CHECK(this->op_->op_type_ == OpType::kOperatorRelu);const uint32_t batch_size = inputs.size(); //一批x,放在vec当中,理解为batchsize数量的tensor,需要进行relu操作for (int i = 0; i < batch_size; ++i) {CHECK(!inputs.at(i)->empty());const std::shared_ptr<Tensor<float>> &input_data = inputs.at(i); //取出批次当中的一个张量//对张量中的每一个元素进行运算,进行relu运算input_data->data().transform([&](float value) {// 对张良中的没一个元素进行运算// 从operator中得到存储的属性float thresh = op_->get_thresh();//x >= threshif (value >= thresh) {return value; // return x} else {// x<= thresh return 0.f;return 0.f;}});// 把结果y放在outputs中outputs.push_back(input_data);}
}

http://www.lryc.cn/news/105921.html

相关文章:

  • 想做上位机,学C#还是QT?
  • Ansible —— playbook 剧本
  • ARM寻址方式
  • 【JAVA】String ,StringBuffer 和 StringBuilder 三者有何联系?
  • 关于计数以及Index返回订单号升级版(控制字符长度,控制年月标记)
  • 【计算机网络】11、网桥(bridge)、集线器(hub)、交换机(switch)、路由器(router)、网关(gateway)
  • 第九篇-自我任务数据准备
  • 2023.8.1号论文阅读
  • webpack优化前端框架性能
  • Unity UGUI的Outline(描边)组件的介绍及使用
  • 爆改vue3 setup naiveui可编辑table
  • 功率放大器的种类有哪三种类型
  • HDFS 分布式存储 spark storm HBase
  • Vue3文字实现左右和上下滚动
  • Docker Sybase修改中文编码
  • 【SpringCloud Alibaba】(六)使用 Sentinel 实现服务限流与容错
  • mysql的主从复制
  • 【Golang 接口自动化03】 解析接口返回XML
  • Java+bcprov库实现对称和非对称加密算法
  • 国内最大Llama开源社区发布首个预训练中文版Llama2
  • Qt应用开发(基础篇)——滑块类 QSlider、QScrollBar、QDial
  • 【3-D深度学习:肺肿瘤分割】创建和训练 V-Net 神经网络,并从 3D 医学图像中对肺肿瘤进行语义分割研究(Matlab代码实现)
  • MongoDB文档--架构体系
  • GEE学习03-Geemap配置与安装,arcgis pro自带命令提示符位置等
  • 软件测试面试总结——http协议相关面试题
  • 大数据与okcc呼叫中心融合的几种方式
  • WAF绕过-工具特征-菜刀+冰蝎+哥斯拉
  • 使代码减半的5个Python装饰器
  • 线程池的线程回收问题
  • 盘点那些不想骑车的原因和借口。