当前位置: 首页 > news >正文

论文阅读笔记:《Dataset Condensation with Distribution Matching》

论文阅读笔记:《Dataset Condensation with Distribution Matching》

      • 1.解决了什么问题?(Motivation)
      • 2.关键方法与创新点(Key Method & Innovation)
      • 3.实验结果与贡献 (Experiments & Contributions)
      • 4.个人思考与启发
      • 主体代码
      • 算法逻辑总结

WACV23 github

在这里插入图片描述

核心思想一句话总结:

用少量可学习的合成图像,通过多组随机网络上的分布匹配(MMD),高效地“蒸馏”出与原始大数据集等价的训练集。


1.解决了什么问题?(Motivation)

训练大型数据集耗时且昂贵,现有“核心集”只能数据、“蒸馏”常需双层优化都各有局限。本工作旨在:

  • 用少量合成图像(每类几十到几百)
  • 保持模型在测试集上的性能
  • 且避免繁重的bi-level优化

2.关键方法与创新点(Key Method & Innovation)

  • 分布匹配视角:首次用最大均值差异(MMD)在特征空间对齐合成与真实数据分布,而非仅作子集选择或梯度匹配。
  • 随机网络嵌入:不用预训练模型,随机初始化多个同构网络ψθψ_θψθ作为多种“看法”,增强合成数据集的泛化。
  • 单层优化:只对合成图像本身求梯度、SGD更新网络权重固定,省去双层优化开销。
  • 可微分西雅姆增强 (DSA):对真实和合成样本做相同随机变换,提升分布估计稳定性。

3.实验结果与贡献 (Experiments & Contributions)

  • 在 CIFAR-10/100、TinyImageNet、ImageNet-1K 上:
    • 每类仅 10–50 张合成图即可训练出接近原始数据的模型精度(如 CIFAR-10 10 张时 ≈70%+)。
    • 合成速度比 Gradient Matching 提升 ∼45×。
  • 下游任务验证:
    • 持续学习:更小的记忆库即可保持准确率。
    • 神经架构搜索:用代理合成集显著加速搜索且不损失性能。
  • 开源代码与可视化结果:每隔若干迭代保存合成图像演化,便于直观对比。

4.个人思考与启发

  • 高效vs. 代表性:只匹配特征均值简单有效,但或许忽略高阶统计和类内多样性
  • 生成质量 vs. 训练效果:无需最求”图像好看“,只要”训练有用“;但在某些任务中是否要兼顾真实的视觉特征。

主体代码

''' initialize the synthetic data '''
image_syn = torch.randn(size=(num_classes*args.ipc, channel, im_size[0], im_size[1]), dtype=torch.float, requires_grad=True, device=args.device)
label_syn = torch.tensor([np.ones(args.ipc)*i for i in range(num_classes)], dtype=torch.long, requires_grad=False, deviceargs=.device).view(-1) # [0,0,0, 1,1,1, ..., 9,9,9]if args.init == 'real':print('initialize synthetic data from random real images')for c in range(num_classes):image_syn.data[c*args.ipc:(c+1)*args.ipc] = get_images(c, args.ipc).detach().dataelse:print('initialize synthetic data from random noise')''' training '''# 只更新image_synoptimizer_img = torch.optim.SGD([image_syn, ], lr=args.lr_img, momentum=0.5) # optimizer_img for synthetic dataoptimizer_img.zero_grad()print('%s training begins'%get_time())for it in range(args.Iteration+1):''' Evaluate synthetic data '''if it in eval_it_pool:for model_eval in model_eval_pool:print('-------------------------\nEvaluation\nmodel_train = %s, model_eval = %s, iteration = %d'%(args.model, model_eval, it))print('DSA augmentation strategy: \n', args.dsa_strategy)print('DSA augmentation parameters: \n', args.dsa_param.__dict__)accs = []for it_eval in range(args.num_eval):# 每次用新的随机初始化网络来评估合成集net_eval = get_network(model_eval, channel, num_classes, im_size).to(args.device) # get a random model# 深拷贝合成数据集image_syn_eval, label_syn_eval = copy.deepcopy(image_syn.detach()), copy.deepcopy(label_syn.detach()) # avoid any unaware modification# 测试与评估_, acc_train, acc_test = evaluate_synset(it_eval, net_eval, image_syn_eval, label_syn_eval, testloader, args)accs.append(acc_test)print('Evaluate %d random %s, mean = %.4f std = %.4f\n-------------------------'%(len(accs), model_eval, np.mean(accs), np.std(accs)))if it == args.Iteration: # record the final resultsaccs_all_exps[model_eval] += accs''' visualize and save '''save_name = os.path.join(args.save_path, 'vis_%s_%s_%s_%dipc_exp%d_iter%d.png'%(args.method, args.dataset, args.model, args.ipc, exp, it))image_syn_vis = copy.deepcopy(image_syn.detach().cpu())for ch in range(channel):image_syn_vis[:, ch] = image_syn_vis[:, ch]  * std[ch] + mean[ch]image_syn_vis[image_syn_vis<0] = 0.0image_syn_vis[image_syn_vis>1] = 1.0save_image(image_syn_vis, save_name, nrow=args.ipc) # Trying normalize = True/False may get better visual effects.''' Train synthetic data '''# --- 用当前合成数据计算损失并更新(核心:分布匹配) ---# 新的随机网络(视角embedding)net = get_network(args.model, channel, num_classes, im_size).to(args.device) # get a random modelnet.train()# 合成数据训练时冻结网络参数(只优化合成图像)for param in list(net.parameters()):param.requires_grad = False# 多GPU支持,如果使用了DataParallel,embed在module下面embed = net.module.embed if torch.cuda.device_count() > 1 else net.embed # for GPU parallelloss_avg = 0  # 记录各类 loss 平均(后面除法)''' update synthetic data '''# --- 计算合成图像和真实图像在embedding space 上的均值差(即MMD的简化版本)---if 'BN' not in args.model: # for ConvNet 没有batch norm的网络loss = torch.tensor(0.0).to(args.device)for c in range(num_classes):# 每类分别取真实图和合成图img_real = get_images(c, args.batch_real)img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))# 可微分增强(DSA):对real/syn做同样的随机变换以稳定分布估计if args.dsa:seed = int(time.time() * 1000) % 100000  # 保证 real 和 syn 用同样的 seedimg_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)# 投影到embedding空间output_real = embed(img_real).detach() # 真实特征不方向传播output_syn = embed(img_syn)  # 合成特征是要参与梯度的# 均值匹配(特征均值之差平方和)loss += torch.sum((torch.mean(output_real, dim=0) - torch.mean(output_syn, dim=0))**2)else: # for ConvNetBN   BatchNorm 的 ConvNeimages_real_all = []images_syn_all = []loss = torch.tensor(0.0).to(args.device)for c in range(num_classes):img_real = get_images(c, args.batch_real)img_syn = image_syn[c*args.ipc:(c+1)*args.ipc].reshape((args.ipc, channel, im_size[0], im_size[1]))if args.dsa:seed = int(time.time() * 1000) % 100000img_real = DiffAugment(img_real, args.dsa_strategy, seed=seed, param=args.dsa_param)img_syn = DiffAugment(img_syn, args.dsa_strategy, seed=seed, param=args.dsa_param)images_real_all.append(img_real)images_syn_all.append(img_syn)# 把每类真实/合成拼成一个大 batch,送进 embedding 一次得到所有类的特征images_real_all = torch.cat(images_real_all, dim=0)images_syn_all = torch.cat(images_syn_all, dim=0)output_real = embed(images_real_all).detach()output_syn = embed(images_syn_all)# reshape 以便按类计算均值,再做平方差累加loss += torch.sum((torch.mean(output_real.reshape(num_classes, args.batch_real, -1), dim=1) - torch.mean(output_syn.reshape(num_classes, args.ipc, -1), dim=1))**2)# 梯度累积与更新 synthetic imagesoptimizer_img.zero_grad()loss.backward()optimizer_img.step()loss_avg += loss.item()loss_avg /= (num_classes)        # 梯度累积与更新 synthetic imagesif it%10 == 0:print('%s iter = %05d, loss = %.4f' % (get_time(), it, loss_avg))if it == args.Iteration: # only record the final resultsdata_save.append([copy.deepcopy(image_syn.detach().cpu()), copy.deepcopy(label_syn.detach().cpu())])torch.save({'data': data_save, 'accs_all_exps': accs_all_exps, }, os.path.join(args.save_path, 'res_%s_%s_%s_%dipc.pt'%(args.method, args.dataset, args.model, args.ipc)))

算法逻辑总结

  1. 准备:假设要蒸馏一个3类数据集,每类只想保留5张合成图。
  2. 多次”看法“:每次随机初始化一个小网络,把真实图和合成图都送进去提取特征。
  3. 测差异:对每个类别,计算真实图和和冲突在该网络特征空间的平均差距。
  4. 更新合成图:把所有类别的平均差距累加成一个损失,反向梯度作用到图像像素上,轻微调整它们,让下次”看“更像真实图。
  5. 重复:多次切换网络、多次迭代,合成图不断逼近真实数据的”分布“。
  6. 评估:在最终合成图训练几个随机网络,验证它们在测试集上的准确率,确认蒸馏效果
http://www.lryc.cn/news/609064.html

相关文章:

  • 学习方法论
  • React中的Hooks
  • 【深度学习新浪潮】谷歌新推出的AlphaEarth是款什么产品?
  • Unity_数据持久化_XML序列化与反序列化
  • 05.Redis 图形工具RDM
  • 6.1、Redis多级缓存原理和优化、Redis部分参数优化调整
  • Python篇--- Python 的加载、缓存、覆盖机制
  • 01.Redis 概述
  • 复现cacti的RCE
  • ELECTRICAL靶机
  • [硬件电路-145]:模拟电路 - 常见的电源芯片、用途、管脚定义
  • Spring+K8s+AI实战:3全栈开发指南
  • LeetCode Hot 100,快速学习,不断更
  • Steam饥荒联机版多人服务器搭建全解析 -- 阿里云Linux系统构建云服务器
  • kafka创建topic报错解决思路之一
  • 常⻅框架漏洞
  • 计算用户日活:从数据设计到可视化的全流程(高频场景题)
  • sqli-labs:Less-26关卡详细解析
  • C++-二叉树OJ题
  • 外设数据到昇腾310推理卡 之五 3403ATU
  • react native中markdown添加数学公式的支持
  • Java ++i 与 i++ 底层原理
  • 机器翻译入门:定义、发展简史与核心价值
  • [自动化Adapt] 录制引擎
  • MVCC:数据库事务隔离的 “时空魔法”
  • mysql管理
  • 【Linux系统】进程间通信:匿名管道
  • http://localhost:8080/photos/xxx.png的本地图片访问方案
  • 常见的框架漏洞(Thinkphp,spring,Shiro)
  • io_submit系统调用及示例