• 作者:老汪软件技巧
  • 发表时间:2024-10-06 11:01
  • 浏览量:

前情提要

本文是传知代码平台中的相关前沿知识与技术的分享~

接下来我们即将进入一个全新的空间,对技术有一个全新的视角~

本文所涉及所有资源均在传知代码平台可获取

以下的内容一定会让你对AI 赋能时代有一个颠覆性的认识哦!!!

以下内容干货满满,跟上步伐吧~

本章重点一. 论文概述

本文复现论文 Revisiting Consistency Regularization for Deep Partial Label Learning[1] 提出的偏标记学习方法。程序基于Pytorch,会保存完整的训练日志,并生成损失变化图和准确度变化图。

偏标记学习(Partial Label Learning)是一个经典的弱监督问题。在偏标记学习中,每个样例的监督信息为一个包含多个标签的候选标签集合。目前的偏标记方法大多基于自监督或者对比学习范式,或多或少地会遇到低性能或低效率的问题。该论文基于一致性正则化的思想,改进基于自监督的偏标记学习方法。具体地,该论文所提出的方法设计了两个训练目标。其中第一个训练目标为最小化非候选标签的预测输出,第二个目标最大化不同视图的预测输出之间的一致性。

总的来说,该论文所提出的方法着眼于将模型对同一图像不同增强视图的预测输出对齐,以提升模型输出的可靠性和对标签的消歧能力,这一方法同样能给其他弱监督学习任务带来提升。

二. 算法原理

首先,论文所提出方法的第一项损失(监督损失)如下:

其中,当事件 A 为真时,I(A)= 1 否则 I(A)= 0,f(.)表示模型的输出概率。

然后,论文所提出方法的第二项损失(一致性损失)如下:

其在训练过程中通过所有增强视图预测结果的几何平均来更新标签分布:

由于数据增强的不稳定性,该论文通过叠加 K 个不同的增强视图的一致性损失来提升方法性能。

最后,考虑到训练初期模型的预测准确率较低,一致性损失的权重被设置为从零开始随着训练轮数的增加逐渐提高:

综上所述,模型的总损失函数如下:

三.核心逻辑

具体的核心逻辑如下所示:

def dpll_sup_loss(probs, partial_labels):
    loss = -torch.sum(torch.log(1 + 1e-6 - probs) * (1 - partial_labels), dim=-1)
    loss_avg = torch.mean(loss)
    return loss_avg
def dpll_cont_loss(logits, targets):
    logits_log = torch.log_softmax(logits, dim=-1)
    loss = F.kl_div(logits_log, targets, reduction='batchmean')
    return loss
def train():
    # main loops
    for epoch_id in range(total_epochs):
        # train
        model.train()
        for batch in train_dataloader:
            optimizer.zero_grad()
            ids = batch['ids']
            data1 = batch['data1'].to(device)
            data2 = batch['data2'].to(device)
            data3 = batch['data3'].to(device)
            partial_labels = batch['partial_labels'].to(device)
            targets = train_targets[ids].to(device)
            logits1 = model(data1)
            logits2 = model(data2)
            logits3 = model(data3)
            probs1 = F.softmax(logits1, dim=-1)
            # update targets
            with torch.no_grad():
                probs2 = F.softmax(logits2.detach(), dim=-1)
                probs3 = F.softmax(logits3.detach(), dim=-1)
                new_targets = torch.pow(probs1.detach() * probs2 * probs3, 1 / 3)
                new_targets = F.normalize(new_targets * partial_labels, p=1, dim=-1)
                train_targets[ids] = new_targets.cpu()
            # dynamic weight
            balancing_weight = max_weight * (epoch_id + 1) / max_weight_epoch
            balancing_weight = min(max_weight, balancing_weight)
			# supervised loss
            loss_sup = dpll_sup_loss(probs1, partial_labels)
            # consistency regularization loss
            loss_cont1 = dpll_cont_loss(logits1, targets)
            loss_cont2 = dpll_cont_loss(logits2, targets)
            loss_cont3 = dpll_cont_loss(logits3, targets)
            # all loss
            loss = loss_sup + balancing_weight * (loss_cont1 + loss_cont2 + loss_cont3)
            loss.backward()
            optimizer.step()
        if epoch_id in lr_decay_epochs:
            lr_scheduler.step()

四.效果演示

本文基于网络 Wide-ResNet[2] 和数据集 CIFAR-10[3] 进行实验,偏标记的随机翻转概率为0.1。当然,本文所提供的程序不仅仅提供了上述的实验设置,同时也可以直接基于CIFAR-100(100类图像分类数据集),SVHN(数字号牌识别数据集),Fashion-MNIST(时装识别数据集),Kuzushiji-MNIST(日本古草体识别数据集)进行实验。仅仅需要替换运行命令的对应部分即可(使用说明见下文)

总结

综上,我们基本了解了“一项全新的技术啦” :lollipop: ~~

恭喜你的内功又双叒叕得到了提高!!!

感谢你们的阅读:satisfied:

后续还会继续更新:heartbeat:,欢迎持续关注:pushpin:哟~

:dizzy:如果有错误❌,欢迎指正呀:dizzy:

:sparkles:如果觉得收获满满,可以点点赞支持一下哟~:sparkles: