OpenPCDet系列 | 4.2 DataAugmentor点云数据增强模块解析
文章目录
- DataAugmentor模块解析
- 1. gt_sampling
- 2. random_world_flip
- 3. random_world_rotation
- 4. random_world_scaling
- 5. limit_period
DataAugmentor模块解析
在pointpillars算法中,具体的数据增强方法配置是在yaml中的DATA_CONFIG.DATA_AUGMENTOR进行配置,pointpillars的参考配置如下所示:
DATA_AUGMENTOR:DISABLE_AUG_LIST: ['placeholder'] # 禁用该数据增强AUG_CONFIG_LIST: # 启用的数据增强列表- NAME: gt_samplingUSE_ROAD_PLANE: True # 启用道路平面信息,将gt中心位置移动到道路平面上(abcd表示一个平面方程)DB_INFO_PATH:- kitti_dbinfos_train.pkl # 存储训练集每个类别的对象信息,每个类别用列表来存储,应该与INFO_PATH对应PREPARE: { # gt的具体采样操作filter_by_min_points: ['Car:5', 'Pedestrian:5', 'Cyclist:5'], # 表示过滤低于5个点的gt,需要int型赋值filter_by_difficulty: [-1], # 过滤的困难等级列表}# USE_SHARED_MEMORY: False # 默认不设置共享内存,若开启需要设置DB_DATA_PATHSAMPLE_GROUPS: ['Car:15','Pedestrian:15', 'Cyclist:15']NUM_POINT_FEATURES: 4DATABASE_WITH_FAKELIDAR: FalseREMOVE_EXTRA_WIDTH: [0.0, 0.0, 0.0] # 将采样box扩张大小长度LIMIT_WHOLE_SCENE: False- NAME: random_world_flipALONG_AXIS_LIST: ['x']- NAME: random_world_rotationWORLD_ROT_ANGLE: [-0.78539816, 0.78539816]- NAME: random_world_scalingWORLD_SCALE_RANGE: [0.95, 1.05]
在data_augmentor模块初始化后,在进行前向传播的数据准备时,也就是batch = next(dataloader_iter)的时候,会在kitti_dataset的__getitem__函数进行数据的准备,随机获取一个点云帧场景索引,准备好当前点云帧的点特征、gt boxes、gt类别名称、坐标系转换类calib、道路信息road_plane等数据,构建成一个data_dict传入基类的prepare_data函数中,进行数据的进一步处理,形成训练的batch数据。此外,在__getitem__函数中还会困难进行FOV点云视角的范围裁剪,也就是将全方位的点云场景只保留前视图的点云,在这一步之后当前点云帧场景的点云数量会大大降低。
在基类的prepare_data函数中,即会依次进行data_augmentor部分、point_feature_encoder部分、以及data_processor部分三大部分的数据处理,最后形成最终的batch数据输入到模型中进行训练,下面分别对着data_augmentor部分进行记录介绍。在pointpillars配置文件中,数据增强部分就包含了gt_sampling、random_world_flip、random_world_rotation、random_world_scaling四种方法。下面对其进行分别记录。
在数据增强基类DataAugmentor中提供了一些列的数据增强的方法,在初始化阶段会依次添加配置文件中制定的数据增强方式进入队列中,随后在forward函数中对data_dict数据进行依次处理。由于队列先进先出的特性,所以在配置文件中靠前的部分会按顺序优先处理。
def forward(self, data_dict): # 在prepare_data中进行"""Args:data_dict:points: (N, 3 + C_in)gt_boxes: optional, (N, 7) [x, y, z, dx, dy, dz, heading]gt_names: optional, (N), string...Returns:"""# 遍历增强队列,逐个增强器做数据增强for cur_augmentor in self.data_augmentor_queue:data_dict = cur_augmentor(data_dict=data_dict)......
1. gt_sampling
在pointpillars的数据增强方法中首要的就是gt_sampling方法。对于gt的信息格式存储在db_infos字典中,具体是在database_sampler中进行处理,对于每个类别的每个内容如下所示:
在pcdet.datasets.augmentor.database_sampler.py文件中提供了DataBaseSampler类来完成这一模块,在__init__初始化类阶段首先会对训练集数据的全部gt样本进行读取,对每个gt的点云数量较少或者是困难等级的gt进行过滤。随后对每个类别的gt重新统计信息,包含采样数量、过滤后的gt数量(变化)、以及重新分配索引(从0开始分配)。这部分的gt过滤操作通过DATA_AUGMENTOR.PREPARE部分来进行设置。以上即设置了两个过滤操作。分别是:filter_by_min_points、filter_by_difficulty。
随后,在具体的数据准备过程中具体调用的DataBaseSampler的__call__方法,这相当于是继承nn.Module时调用的forward方法,即运行对象。在训练集各类别中随机采样与当前帧点云不重叠的gt,作为当前帧额外的独立采样gt,并添加到当前的点云帧场景中,相当于是一个copy paste操作。而挑选不重叠的gt过程的挑选相当于是一个碰撞测试,避免影响到当前点云帧的原始gt,从而实现gt样本的增多,即数据增强的效果。
具体将采样的gt与当前点云帧场景进行结合,并将gt移动到道路平面上具体操作是有DataBaseSampler.add_sampled_boxes_to_scene方法实现的。主要作用就是将gt移动到道路平面上,首先去除原始的gt(这里还会将原始的gt扩张)内的点云,然后将采样的gt点云与背景点点云拼接构造成新的场景。
在对训练集gt样本进行具体体的随机才样操作是通过sample_with_fixed_number函数来具体实现,主要的实现方法是获取打乱顺序后的indices前 sample_num 个db_infos。所以,总的来说gt_sampling这一部分就是围绕着gt的采样来实现。比如对gt进行限制过滤,然后采样其他场景的gt来添加到当前点云帧场景,前提是通过碰撞测试,从而实现数据增强,类似的一个copy paste的思路。具体实现上还有FOV的范围筛选,道路平面的转移等操作细节。核心代码如下所示:
# 将采样的box扩大,sampler_cfg.REMOVE_EXTRA_WIDTH即为dx,dy和dz的放大长度
large_sampled_gt_boxes = box_utils.enlarge_box3d(sampled_gt_boxes[:, 0:7], extra_width=self.sampler_cfg.REMOVE_EXTRA_WIDTH
)# 核心代码:更新当前点云场景信息
points = box_utils.remove_points_in_boxes3d(points, large_sampled_gt_boxes) # 只保留背景点,去除前景点
points = np.concatenate([obj_points[:, :points.shape[-1]], points], axis=0) # 将采样+原始gt点云与放大后保留的背景点拼接,组成新的点云
gt_names = np.concatenate([gt_names, sampled_gt_names], axis=0) # 将类别拼接
gt_boxes = np.concatenate([gt_boxes, sampled_gt_boxes], axis=0) # 将box拼接
在进行gt‘采样后的数据增强时,data_dict数据字典中只保留了4个有效键值对,如下所示。分别是当前点云帧索引,gt类别名称,gt boxes以及加入采样gt后的新点云场景points
2. random_world_flip
对点云和gt沿x轴或者y轴按一定分布概率进行随机反转,具体的实现函数是在augmentor_utils中。
- 沿x轴进行反转的核心代码:
if enable is None:enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) # 一半的概率选择是否翻转
if enable:gt_boxes[:, 1] = -gt_boxes[:, 1] # y坐标翻转gt_boxes[:, 6] = -gt_boxes[:, 6] # 方位角翻转,直接取负数,因为方位角定义为与x轴的夹角(这里按照顺时针的方向取角度)points[:, 1] = -points[:, 1] # 点云y坐标翻转
- 沿y轴进行反转的核心代码:
if enable is None:enable = np.random.choice([False, True], replace=False, p=[0.5, 0.5]) # 一半的概率选择是否翻转
if enable:gt_boxes[:, 0] = -gt_boxes[:, 0] # x坐标翻转gt_boxes[:, 6] = -(gt_boxes[:, 6] + np.pi) # 方位角加pi后,取负数(这里按照顺时针的方向取角度)points[:, 0] = -points[:, 0] # 点云x坐标取反
通过设置ALONG_AXIS_LIST来进行xy轴的反转设置,如果确定反转,会在data_dict中对反转的轴进行保留。
3. random_world_rotation
对点云和gt进行随机旋转,具体的实现函数是在augmentor_utils中,其会调用common_utils.rotate_points_along_z函数来实现沿z轴进行点云场景的旋转。
对于点云的旋转来说,一般都是利用旋转角度构建成沿z轴旋转的旋转矩阵来与坐标进行相乘,随后将旋转后的点云坐标再与原始的反射特征强度进行拼接,拼接成原来的点云特征维度(dim=4)。在函数实现中,这里是实现对batch数据进行数据增强操作,batch内的每帧数据都会分配一个随机选择angle,然后构建成一个batch的旋转矩阵。
- 沿z轴进行旋转的核心代码:
cosa = torch.cos(angle)
sina = torch.sin(angle)
zeros = angle.new_zeros(points.shape[0]) # [0]
ones = angle.new_ones(points.shape[0]) # [1]
rot_matrix = torch.stack(( # 根据旋转角构造沿z轴旋转的旋转矩阵cosa, sina, zeros,-sina, cosa, zeros,zeros, zeros, ones
), dim=1).view(-1, 3, 3).float()
points_rot = torch.matmul(points[:, :, 0:3], rot_matrix) # 对点云坐标进行旋转 (B, N, 3)
points_rot = torch.cat((points_rot, points[:, :, 3:]), dim=-1)
此外,这里的旋转角度是从给点的阈值范围内进行均匀分布而产生,增加了随机性。随机产生的旋转角度也会记录在data_dict中。
4. random_world_scaling
对点云和gt进行随机缩放,具体实现函数是在augmentor_utils中。
具体的缩放操作比较简单,直接将点云坐标与gt坐标和尺寸与缩放因子进行相乘即可。随机产生的缩放大小也会记录在data_dict中。核心代码如下所示:
if scale_range[1] - scale_range[0] < 1e-3: # 如果缩放的尺度过小,则直接返回原来的box和点云return gt_boxes, points
noise_scale = np.random.uniform(scale_range[0], scale_range[1]) # 在缩放范围内随机产生缩放尺度
points[:, :3] *= noise_scale
gt_boxes[:, :6] *= noise_scale # [:, :6]表示xyz,dxdydz均进行缩放
在进行了以上多种数据增强方式后,相关的记录都会记录在data_dict字典中,其中的gt boxes和points点云特征会不断进行更新,处理后的键值对如下所示:
5. limit_period
由于在gt的方向中存在大于π的偏移,这里额外将方位角限制在[-pi, pi],方位角定义为与x轴的夹角(这里按照顺时针的方向取角度)
# 功能: 将方位角限制在[-pi, pi],方位角定义为与x轴的夹角(这里按照顺时针的方向取角度)
# common_utils.py
def limit_period(val, offset=0.5, period=np.pi):val, is_numpy = check_numpy_to_torch(val) # 格式转换,其中val表示的角度ans = val - torch.floor(val / period + offset) * period # 将方位角限制在[-pi, pi]return ans.numpy() if is_numpy else ans# 将方位角限制在[-pi,pi]
# data_augmentor.py
data_dict['gt_boxes'][:, 6] = common_utils.limit_period(data_dict['gt_boxes'][:, 6], offset=0.5, period=2 * np.pi
)
在保存限制方位角后的gt,此时的data_dict即更新完毕,即将送入下一个模块中。