当前位置: 首页 > news >正文

NeRF基础代码解析

embedders

对position和view direction做embedding。

class FreqEmbedder(nn.Module):def __init__(self, in_dim=3, multi_res=10, use_log_bands=True, include_input=True):super().__init__()self.in_dim = in_dimself.num_freqs = multi_resself.max_freq_log2 = multi_resself.use_log_bands = use_log_bandsself.periodic_fns = [torch.sin, torch.cos]self.include_input = include_inputself.embed_fns = Noneself.out_dim = Noneself.num_embed_fns = Noneself.create_embedding_fn()def create_embedding_fn(self):self.embed_fns = []# 10 * 2 * 3 = 60self.out_dim = self.num_freqs * len(self.periodic_fns) * self.in_dim)if self.include_input:self.embed_fns.append(lambda x: x)self.out_dim += self.in_dim	# 63if self.use_log_lands:freq_bands = 2. ** torch.linspace(0., self.max_freq_log2, steps=self.num_freqs)else:freq_bands = torch.linspace(2.**0, 2.**self.max_freq_log2, steps=self.num_freqs)for freq in freq_bands:for p_fn in self.periodic_fns:self.embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x*freq))self.num_embed_fns = len(self.embed_fns)def forward(self, x):"""x: [..., in_dim], xyz or view direction.embedding: [..., out_dim], corresponding frequency encoding."""embed_lst = [embed_fn(x) for embed_fn in self.embed_fns]# [[x, sin(x), cos(x), sin(2x), cos(2x),...,sin(512x), cos(512x)]]embedding = torch.cat(embed_lst, dim=-1)return embedding

NeRFBackbone

position和view经过embedding后,得到特征向量。再输入到NeRFBackbone网络中,得到sigma和color输出。

class NeRFBackbone(nn.Module):def __init__(self, pos_dim=3, cond_dim=64, view_dim=3, hid_dim=128, num_density_linears=8, num_color_linears=3, skip_layer_indices=[4]):self.pos_dim = pos_dimself.cond_dim = cond_dimself.view_dim = view_dimself.hid_dim = hid_dimself.out_dim = 4	# rgb + sigmaself.num_density_linears = num_density_linearsself.num_color_linears = num_color_linearsself.skip_layer_indices = skip_layer_indicesdensity_input_dim = pos_dim + cond_dimself.density_linears = nn.ModuleList([nn.Linear(density_input_dim, hid_dim)] +[nn.Linear(hid_dim, hid_dim) if i not in self.skip_layer_indices else nn.Linear(hid_dim + density_input_dim, hid_dim) for i in range(num_density_linears - 1)])self.density_out_linear = nn.Linear(hid_dim, 1)color_input_dim = view_dim + hid_dimself.color_linears = nn.ModuleList([nn.Linear(color_input_dim, hid_dim//2)] +[nn.Linear(hid_dim//2, hid_dim//2) for _ in range(num_color_linears - 1)])self.color_out_linear = nn.Linear(hid_dim//2, 3)def forward(self, pos, view, view):"""pos: [bs, n_sample, pos_dim], encoding of position.cond: [cond_dim,], condition features.view: [bs, view_dim], encoding of view direction."""bs, n_sample, _ = pos.shapeif cond.dim == 1:	# [cond_dim]cond = cond.squeeze()[None, None, :].expand([bs, n_sample, self.cond_dim])elif cond_dim == 2:	# [batch, cond_dim]cond = cond[:, None, :].expand([bs, n_sample, self.cond_dim])view = view[:, None, :].expand([bs, n_sample, self.view_dim])density_linear_input = torch.cat([pos, cond], dim=-1)h = density_linear_inputfor i in range(len(self.density_linears)):h = self.density_linears[i](h)h = F.relu(h)if i in self.skip_layer_indices:h = torch.cat([density_linear_input, h], -1)sigma = self.density_out_linear(h)h = torch.cat([h, view], -1)for i in range(len(self.color_linears)):h = self.color_linears[i](h)h = F.relu(h)rgb = self.color_out_linear(h)outputs = torch.cat([rgb, sigma], -1)return outputs

Ray Sampler

一张图的height = 1280, width = 720, 对这张图采样4096条从相机原点发出的光线ray。

def get_rays(H, W, focal, c2w, cx=None, cy=None):"""Get the rays emitted from camera to all pixels.The ray is represented in world coordinate.input:H: height of the image in pixel.W: width of the image in pixel.focal: focal length of the camera in pixel.c2w: 3x4 camera-to-world matrix, it should be something like this:[[r11, r12, r13, t1],[r21, r22, r23, t2],[r31, r32, r33, t3]]cx: center of camera in width axis.cy: center of camera in height axis.return:rays_o: start point of the ray.rays_d: direction of the ray. so you can sample the point in the ray with: xyz = rays_o + rays_d * z_val, where z_val is the distance."""j_pixels, i_pixels = torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W))if cx is None:cx = W * 0.5if cy is None:cy = H * 0.5directions = torch.stack([(i_pixels - cx)/focal, -(j_pixels - cy)/focal, -torch.ones_like(i_pixels)], dim=-1)	# [W, H, 3]# Rotate ray directions from camera to the world frame.rays_d = torch.sum(directions[..., None, :] * c2w[:3, :3], dim=-1)# origin point of all ray, camera center in world coodinate.rays_o = c2w[:3, -1].expand(rays_d.shape)return rays_o, rays_dclass BaseRaySampler:def __init__(self, N_rays):super(BaseRaySampler, self).__init__()self.N_rays = N_raysdef __call__(self, H, W, focal, c2w):rays_o, rays_d = get_rays(H, W, focal, c2w)selected_coords = self.sample_rays(H, W)rays_o = rays_o[select_coords[:, 0], select_coords[:, 1]]	# [N_rand, 3]rays_d = rays_d[select_coords[:, 0], select_coords[:, 1]]	# [N_rand, 3]return rays_o, rays_d, select_coordsdef sample_rays(self, H, W, **kwargs):raise NotImplementedErrorclass UniformRaySampler(BaseRaySampler):def __init__(self, N_rays=None):super().__init__(N_rays=N_rays)def sample_ray(self, H, W, n_rays=None, rect=None, in_rect_percent=0.9, **kwargs):if n_rays is None:n_rays = self.N_rayscoords = torch.stack(torch.meshgrid(torch.linspace(0, H-1, H), torch.linspace(0, W-1, W)), -1)	# [H, W, 2]coords = torch.reshape(coords, [-1, 2])	# [H * W, 2]if rect is None:# uniformly sample the whole imageselected_inds = np.random.choice(coords.shape[0], size=[n_rays], replace=False)selected_coords = coords[selected_inds].long()else:# uniformly sample from rect region and out-rect, respectively.......return seleced_coordsdef __call__(self, H, W, focal, c2w, n_rays=None, selected_coords=None, rect=None, in_rect_percent=0.9, **kwargs):rays_o, rays_d = get_rays(H, W, focal, c2w)if select_coords s None:select_coords = self.sample_rays(H, W, n_rays, rect, in_rect_percent)rays_o = rays_o[selected_coords[:, 0], selected_coords[:, 1]]rays_d = rays_d[selected_coords[:, 0], selected_coords[:, 1]]return rays_o, rays_d, selected_coordsdef sample_pixels_from_img_with_select_coords(self, img, select_coords):return img[selected_coords[:, 0], select_coords[:, 1]]
http://www.lryc.cn/news/118456.html

相关文章:

  • 职场新星:Java面试干货让你笑傲求职路(三)
  • 获取指定收获地址的信息
  • 突破笔试:力扣全排列(medium)
  • gitlab 503 错误的解决方案
  • 智能离子风棒联网监控静电消除器的主要功能和特点
  • matplotlib 设置legend的位置在轴最上方,长度与图的长度相同
  • Docker-Compose 安装rabbitmq
  • leetcode357- 2812. 找出最安全路径
  • Oracle连接数据库提示 ORA-12638:身份证明检索失败
  • 在 Linux 中使用 systemd 注册服务
  • (03)Unity HTC VRTK 基于 URP 开发记录
  • .bit域名调研
  • Vue数组变更方法和替换方法
  • Centos-6.3安装使用MongoDB
  • Mysql 复杂查询丨联表查询
  • C语言进阶第二课-----------指针的进阶----------升级版
  • 若依vue -【 111 ~ 更 ~ 127 完 】
  • vue-pc端实现按钮防抖处理-自定义指令
  • python解决8皇后问题
  • xcode打包导出ipa
  • 更优雅地调试SwiftUI—借助LLDB
  • 2.4 网络安全新技术
  • 人生天地之间,若白驹之过隙,忽然而已
  • MySQL — MVCC
  • Android模板设计模式之 - 构建整个应用的BaseActivity
  • 浏览器缓存技术--localStorage和sessionStorage原理与使用
  • 无涯教程-Perl - endservent函数
  • MRO工业品采购过程中,采购人员要注意哪些事项
  • Jaeger 教程,OpenTelemetry 教程
  • P1597 语句解析