当前位置: 首页 > news >正文

从零构建深度学习推理框架-8 卷积算子实现

其实这一次课还蛮好理解的:

 首先将kernel展平:

for (uint32_t g = 0; g < groups; ++g) {std::vector<arma::fmat> kernel_matrix_arr(kernel_count_group);arma::fmat kernel_matrix_c(1, row_len * input_c_group);for (uint32_t k = 0; k < kernel_count_group; ++k) {const std::shared_ptr<Tensor<float>> &kernel =weights.at(k + g * kernel_count_group);for (uint32_t ic = 0; ic < input_c_group; ++ic) {memcpy(kernel_matrix_c.memptr() + row_len * ic,kernel->at(ic).memptr(), row_len * sizeof(float));}LOG(INFO) << "kernel展开后: " << "\n" << kernel_matrix_c;kernel_matrix_arr.at(k) = kernel_matrix_c;}

将原来的kernel放到kernel_matrix_c里面,之后如果是多个channel,也就是input_c有多个,那就按照rowlen*ic依次存放到里面。

将输入input展平:

//按照上面的图就是input = 3*9 ,4的这样一个空间arma::fmat input_matrix(input_c_group * row_len, col_len);for (uint32_t ic = 0; ic < input_c_group; ++ic) {const arma::fmat &input_channel = input_->at(ic + g * input_c_group);int current_col = 0;
//下面是以窗口滑动的顺序选取for (uint32_t w = 0; w < input_w - kernel_w + 1; w += stride_w) {for (uint32_t r = 0; r < input_h - kernel_h + 1; r += stride_h) {float *input_matrix_c_ptr =input_matrix.colptr(current_col) + ic * row_len;//对准窗口位置,比如对第一个就是对准红色, 黄色, 绿色current_col += 1;for (uint32_t kw = 0; kw < kernel_w; ++kw) {const float *region_ptr = input_channel.colptr(w + kw) + r;memcpy(input_matrix_c_ptr, region_ptr, kernel_h * sizeof(float));input_matrix_c_ptr += kernel_h;}}}}LOG(INFO)  << "input展开后: " << "\n"  << input_matrix;

对于:

 for (uint32_t kw = 0; kw < kernel_w; ++kw) {const float *region_ptr = input_channel.colptr(w + kw) + r;memcpy(input_matrix_c_ptr, region_ptr, kernel_h * sizeof(float));input_matrix_c_ptr += kernel_h;}

w+kw指向的是窗口的列,r指向的是窗口的行

然后对于每个窗口的以kernel的列为标准复制过去。

最后两个矩阵相乘就可以得到结果

http://www.lryc.cn/news/129728.html

相关文章:

  • 【Spring Boot】JdbcTemplate数据连接模板 — JdbcTemplate入门
  • 视频汇聚集中存储EasyCVR平台调用iframe地址视频无法播放,该如何解决?
  • 从今天起,重新出发
  • Java多态详解(1)
  • optee读取Arm系统寄存器的模板
  • VSCode 使用总结
  • GuLi商城-前端基础Vue-使用Vue脚手架进行模块化开发
  • LeetCode450. 删除二叉搜索树中的节点
  • Java动态调试技术原理及实践
  • Lua + Redis 实战代码
  • 类的访问限定符,实例化,对象存储方式,this指针
  • 《Linux从练气到飞升》No.15 Linux 环境变量
  • Spring Boot 重启命令
  • pdf怎么合并在一起?这几个合并方法了解一下
  • 【仿写tomcat】七、项目结构优化以及代码开源
  • 泛微E8配置自定义触发流程失败
  • Springboot整合Mybatis调用Oracle存储过程
  • 【java安全】Log4j反序列化漏洞
  • [mars3d 打包]vue3+vite,打包后mars3d找不到
  • STM32——SPI外设总线
  • BOXTRADE-天启量化分析平台 主要功能介绍
  • kaggle注册不显示验证码
  • python爬虫7:实战1
  • uniApp引入vant2
  • 如何大幅提高遥感影像分辨率(Python+MATLAB)
  • nginx php-fpm安装配置
  • 通过ip获取地理位置信息
  • 数据库索引优化策略与性能提升实践
  • 【ARM 嵌入式 编译系列 11.1 -- GCC __attribute__((aligned(x)))详细介绍】
  • 【计算机视觉|生成对抗】逐步增长的生成对抗网络(GAN)以提升质量、稳定性和变化