当前位置: 首页 > news >正文

李宏毅2023机器学习HW15-Few-shot Classification

文章目录

  • Link
  • Task: Few-shot Classification
  • Baseline
    • Simple—transfer learning
    • Medium — FO-MAML
    • Strong — MAML

Link

Kaggle

Task: Few-shot Classification

The Omniglot dataset

  • background set: 30 alphabets
  • evaluation set: 20 alphabets
  • Problem setup: 5-way 1-shot classification

Omniglot数据集

  • 背景集:30个字母
  • 评估集:20个字母
  • 问题设置:5-way 1-shot分类
    Definition of support set and query set

Baseline

Simple—transfer learning

直接把sample code运行即可

  • traing:
    对随机选择的5个任务进行正常分类训练验证/测试
  • validation / testing:
    对五个 Support Images 进行微调,并对Query Images进行推理

Slover首先从训练集中选择5个任务,然后对选择的5个任务进行正常分类训练。在推理中,模型在支持集support set图像上微调inner_train_step步骤,然后在查询集Query Set图像上进行推理。
为了与元学习Slover保持一致,基本Slover具有与元学习Slover完全相同的输入输出格式

def BaseSolver(model,optimizer,x,n_way,k_shot,q_query,loss_fn,inner_train_step=1,inner_lr=0.4,train=True,return_labels=False,
):criterion, task_loss, task_acc = loss_fn, [], []labels = []for meta_batch in x:# Get datasupport_set = meta_batch[: n_way * k_shot]query_set = meta_batch[n_way * k_shot :]if train:""" training loop """# Use the support set to calculate losslabels = create_label(n_way, k_shot).to(device)logits = model.forward(support_set)loss = criterion(logits, labels)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, labels))else:""" validation / testing loop """# First update model with support set images for `inner_train_step` stepsfast_weights = OrderedDict(model.named_parameters())for inner_step in range(inner_train_step):# Simply trainingtrain_label = create_label(n_way, k_shot).to(device)logits = model.functional_forward(support_set, fast_weights)loss = criterion(logits, train_label)grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)# Perform SGDfast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))if not return_labels:""" validation """val_label = create_label(n_way, q_query).to(device)logits = model.functional_forward(query_set, fast_weights)loss = criterion(logits, val_label)task_loss.append(loss)task_acc.append(calculate_accuracy(logits, val_label))else:""" testing """logits = model.functional_forward(query_set, fast_weights)labels.extend(torch.argmax(logits, -1).cpu().numpy())if return_labels:return labelsbatch_loss = torch.stack(task_loss).mean()task_acc = np.mean(task_acc)if train:# Update modelmodel.train()optimizer.zero_grad()batch_loss.backward()optimizer.step()return batch_loss, task_acc

Medium — FO-MAML

FOMAML(First-Order MAML)是MAML(Model-Agnostic Meta-Learning)的一种简化版本。MAML是一种元学习算法,旨在通过训练模型使其能够在少量新数据上快速适应新任务。FOMAML通过忽略二阶导数来简化MAML的计算过程,从而提高计算效率。它在许多情况下表现良好,尤其是在计算资源有限的情况下。然而,它也可能在某些任务上表现不如完整的MAML。

MAML的核心思想是通过在多个任务上进行训练,使得模型能够在面对新任务时,只需少量数据就能快速收敛到一个好的参数配置。具体来说,MAML的训练过程包括两个层次的优化:

  • 内层优化(Inner Loop):在每个任务上进行少量的梯度更新,以适应该任务。

  • 外层优化(Outer Loop):在所有任务上进行梯度更新,以优化模型的初始参数,使得模型在面对新任务时能够快速适应。

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=False) # create_graph=False:这个参数表示在计算梯度时不创建计算图。在FOMAML中,我们只关心一阶导数,因此不需要创建计算图fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

Strong — MAML

""" Inner Loop Update """
grads = torch.autograd.grad(loss, fast_weights.values(), create_graph=True)fast_weights = OrderedDict((name, param - inner_lr * grad)for ((name, param), grad) in zip(fast_weights.items(), grads))""" Outer Loop Update """# TODO: Finish the outer loop update# raise NotimplementedErrormeta_batch_loss.backward()optimizer.step()

相关文章:

  • 北京网站建设多少钱?
  • 辽宁网页制作哪家好_网站建设
  • 高端品牌网站建设_汉中网站制作
  • Python3网络爬虫开发实战(17)爬虫的管理和部署(第一版)
  • 如何重置企业/媒体/组织/个体户类型管理员微信号
  • 吴恩达深度学习笔记:卷积神经网络(Foundations of Convolutional Neural Networks)2.3-2.4
  • 408选择题笔记|自用|随笔记录
  • Python3网络爬虫开发实战(15)Scrapy 框架的使用(第一版)
  • Dockerfile部署xxljob
  • kotlin——设计模式之责任链模式
  • Spring Boot 学习之路 -- 处理 HTTP 请求
  • 20240924软考架构-------软考191-195答案解析
  • 英飞凌TC3xx -- Bootstrap Loader分析
  • 基于FPGA+GPU异构平台的遥感图像切片解决方案
  • 汽车端到端自动驾驶系统的关键技术与发展趋势
  • EasyGBD国标GB28181设备端,支持GB28181-2016、GB28181-2022
  • 基于C#+SQL Server2008实现(CS界面)学生宿舍管理系统
  • 【Docker】深入理解 Docker Compose 文件:构建和管理多容器应用的指南
  • 【Leetcode】101. 对称二叉树
  • AzureCon上微软宣布了哪些容器相关的重磅消息
  • Codepen 每日精选(2018-3-25)
  • java第三方包学习之lombok
  • nodejs调试方法
  • Spark RDD学习: aggregate函数
  • web标准化(下)
  • 分享几个不错的工具
  • 机器学习 vs. 深度学习
  • 基于遗传算法的优化问题求解
  • 深度解析利用ES6进行Promise封装总结
  • 提升用户体验的利器——使用Vue-Occupy实现占位效果
  • 因为阿里,他们成了“杭漂”
  • 转载:[译] 内容加速黑科技趣谈
  • Salesforce和SAP Netweaver里数据库表的元数据设计
  • $.each()与$(selector).each()
  • (52)只出现一次的数字III
  • (done) 声音信号处理基础知识(2) (重点知识:pitch)(Sound Waveforms)
  • (el-Date-Picker)操作(不使用 ts):Element-plus 中 DatePicker 组件的使用及输出想要日期格式需求的解决过程
  • (附源码)ssm本科教学合格评估管理系统 毕业设计 180916
  • (附源码)ssm户外用品商城 毕业设计 112346
  • (附源码)计算机毕业设计ssm高校《大学语文》课程作业在线管理系统
  • (黑马C++)L06 重载与继承
  • (六)DockerCompose安装与配置
  • (十三)Flink SQL
  • (四)【Jmeter】 JMeter的界面布局与组件概述
  • (微服务实战)预付卡平台支付交易系统卡充值业务流程设计
  • (已解决)Bootstrap精美弹出框模态框modal,实现js向modal传递数据
  • .bat批处理(七):PC端从手机内复制文件到本地
  • .net core Redis 使用有序集合实现延迟队列
  • .NET MAUI Sqlite程序应用-数据库配置(一)
  • .NET 给NuGet包添加Readme
  • .NET/C# 项目如何优雅地设置条件编译符号?
  • .net6解除文件上传限制。Multipart body length limit 16384 exceeded
  • .net安装_还在用第三方安装.NET?Win10自带.NET3.5安装
  • .net开源工作流引擎ccflow表单数据返回值Pop分组模式和表格模式对比
  • .NET面试题解析(11)-SQL语言基础及数据库基本原理
  • .net用HTML开发怎么调试,如何使用ASP.NET MVC在调试中查看控制器生成的html?
  • /usr/bin/perl:bad interpreter:No such file or directory 的解决办法
  • @javax.ws.rs Webservice注解