当前位置: 首页 > news >正文

TSPNet代码分析

论文《Realigning Confidence with Temporal Saliency Information for Point-Level Weakly-Supervised Temporal Action Localization》的official code分析

论文解读

代码分析

先看看训练过程,执行main

if __name__ == '__main__':exp = Exp()if exp.config.mode == 'eval':exp.test()else:exp.train()

先实例化EXP

class Exp(object):def __init__(self, exp_type='THUMOS14'):self.config = self._get_config(exp_type)if self.config.seed != -1:self._setup_seed()self.device = self._get_device()def train(self):train_dataset, train_loader = self._get_data(subset='train')test_dataset, test_loader = self._get_data(subset='test')model = self._get_model().to(self.device)criterion = self._get_criterion()optimizer = self._get_optimizer(model)loader = iter(train_loader)for itr in tqdm(range(1, self.config.num_itr + 1), total=self.config.num_itr):if (itr - 1) % (len(train_loader) // self.config.batch_size) == 0:loader = iter(train_loader)train_one_proposal_batch(model, self.device, loader, criterion, optimizer, self.config.batch_size)if itr % self.config.update_fre == 0:update_label(dataset=train_dataset, dataloader=train_loader, model=model, device=self.device, up_threshold=self.config.up_threshold)if itr % 100 == 0:test_proposal(self.config, model, self.device, test_loader, itr)

可以看到获取参数,然后根据mode执行train
首先执行self._get_data,即实例化dataset

    def _get_data(self, subset):dataset = PTAL_Dataset(data_path=self.config.data_path,subset=subset,modality=self.config.modality,num_classes=self.config.num_classes,feature_fps=self.config.feature_fps,soft_value=self.config.soft_value)
class PTAL_Dataset(Dataset):def __init__(self,data_path: str,subset: str = 'test',modality: str = 'both',num_classes: int = 20,feature_fps: int = 25,soft_value: float = 0.4):self.data_path = data_pathself.subset = subsetself.modality = modalityself.feature_fps = feature_fpsself.dataset = self.data_path.split('/')[-1]self.cls_dict = json.load(open('./data/dataset_cls_dict.json', 'rb'))[self.dataset]self.num_classes = num_classesself.soft_value = soft_value# Load label filesself.gt = json.load(open(os.path.join(self.data_path, 'gt.json'), 'rb'))self.p_label = pd.read_csv(os.path.join(self.data_path, 'train_df_ts_in_gt.csv')).groupby('video_id')self.fps_dict = json.load(open(os.path.join(self.data_path, 'fps.json'), 'rb'))self.delta_dict = {}# Get video namesself.vid_names = self._get_vidname()# Get proposalsself.proposals, \self.proposals_point, \self.proposals_center_label, \self.proposals_multi_flag, \self.proposals_point_id = self._get_proposals()

主要看看_get_proposals()函数,这个函数用于初始化和更新proposals

    def _get_proposals(self, delta_point_dict=None):"""get proposals and generate the center labels from the original points or the updated saliency points"""history_points = []proposals_file = json.load(open(f'{self.data_path}/LAC_proposal_{self.dataset}_{self.subset}.json'))['results']proposals = {}proposals_point = {}proposals_center_label = {}proposals_multi_flag = {}proposals_point_id = {}proposals_mask = {}t_factor = self.feature_fps / 16.0act, bg, multi = 0, 0, 0for idx, name in enumerate(self.vi
http://www.lryc.cn/news/408667.html

相关文章:

  • Ubuntu上安装anaconda创建虚拟环境(各种踩坑版)
  • DC-5靶机通关
  • AI学习记录 -使用react开发一个网页,对接chatgpt接口,附带一些英语的学习prompt
  • MongoDB多数据源配置与切换
  • Mongodb入门介绍
  • docker前端部署
  • 指标体系建设的方法论
  • 乐鑫ESP32-H2设备联网芯片,集成多种安全功能方案,启明云端乐鑫代理商
  • C++调用Java接口
  • C# datetimePicker
  • AI有关的学习和python
  • 前端node.js入门
  • 无需标注的数据集
  • C# 抽象工厂模式
  • java中 两个不同类对象list,属性一样,如何copy
  • 文件上传总结
  • 网页突然被恶意跳转或无法打开?DNS污染怎么解决?
  • Matlab进阶绘图第65期—带分组折线段的柱状图
  • EasyMedia转码rtsp视频流flv格式,hls格式,H5页面播放flv流视频
  • FPGA实验6: 有时钟使能两位十进制计数器的设计
  • C# 委托函数 delegate
  • Vue3响应式高阶用法之`shallowReadonly()`
  • Windows系统安全加固方案:快速上手系统加固指南 (下)
  • 记一次因敏感信息泄露而导致的越权+存储型XSS
  • Java笔试面试题AI答之线程Thread(1)
  • 2.5 C#视觉程序开发实例2----图片内存管理
  • Java核心 - 深入理解 Java 枚举类
  • HOW - CSS 定义颜色值
  • Vue3 reactive原理(一)-代理对象及数组
  • 基于联咏 NT98692芯片赋能边缘计算IP摄像机与XVR监控系统解决方案