传知代码-基于多尺度动态卷积的图像分类

代码以及视频讲解

本文所涉及所有资源均在传知代码平台可获取

概述

在计算机视觉领域,图像分类是非常重要的任务之一。近年来,深度学习的兴起极大提升了图像分类的精度和效率。本文将介绍一种基于动态卷积网络(Dynamic Convolutional Networks)、多尺度特征融合网络(Multi-scale Feature Fusion Networks)和自适应损失函数(Adaptive Loss Functions)的智能图像分类模型,采用了PyTorch框架进行实现,并通过PyQt构建了简洁的用户图像分类界面。该模型能够处理多分类任务,并且提供了良好的可扩展性和轻量化设计,使其适用于多种不同的图像分类场景。

效果可视化

在这里插入图片描述

模型原理解读

动态卷积

传统卷积网络通常使用固定的卷积核,而动态卷积则是通过引入多个可学习的卷积核动态选择不同的卷积核进行操作。这样可以在不同的输入图像上实现不同的卷积操作,从而提高模型的表达能力。通过加入Attention模块,能对输入图像的不同特征进行加权处理,进一步增强了网络对特征的自适应能力。

常规的卷积层使用单个静态卷积核,应用于所有输入样本。而动态卷积层则通过注意力机制动态加权n个卷积核的线性组合,使得卷积操作依赖于输入样本。动态卷积操作可以定义为:
在这里插入图片描述

在这里插入图片描述

其中动态卷积的线性组合可以用这个图表示:
在这里插入图片描述

在 ODConv 中,对于卷积核 W i W_i Wi:

  1. α s i \alpha_{s_i} αsi 为 k × \times × k 空间位置的每个卷积参数(每个滤波器)分配不同的注意力标量;下图a
  2. α c i \alpha_{c_i} αci 为每个卷积滤波器 W i m W_i^m Wim c in c_{\text{in}} cin 个通道分配不同的注意力标量;下图b
  3. α f i \alpha_{f_i} αfi c out c_{\text{out}} cout 个卷积滤波器分配不同的注意力标量;下图c
  4. α w i \alpha_{w_i} αwi 为整个卷积核分配一个注意力标量。下图d

在下图中,展示了将这四种类型的注意力逐步乘以 n n n 个卷积核的过程。原则上,这四种类型的注意力是相互补充的,通过按位置、通道、滤波器和卷积核的顺序逐步将它们乘以卷积核 W i W_i Wi,使卷积操作在所有空间位置、所有输入通道、所有卷积核中都不同,针对输入 x x x 捕获丰富的上下文信息,从而提供性能保证。
在这里插入图片描述

原则上来讲,这四种类型的注意力是互补的,通过渐进式对卷积沿位置、通道、滤波器以及核等维度乘以不同的注意力将使得卷积操作对于输入存在各个维度的差异性,提供更好的性能以捕获丰富上下文信息。因此,ODCOnv可以大幅提升卷积的特征提取能力;更重要的是,采用更少卷积核的ODConv可以取得更优的性能。
代码实现

class ODConv2d(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1,
                 reduction=0.0625, kernel_num=4):
        super(ODConv2d, self).__init__()
        self.in_planes = in_planes
        self.out_planes = out_planes
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.groups = groups
        self.kernel_num = kernel_num
        self.attention = Attention(in_planes, out_planes, kernel_size, groups=groups,
                                   reduction=reduction, kernel_num=kernel_num)
        self.weight = nn.Parameter(torch.randn(kernel_num, out_planes, in_planes//groups, kernel_size, kernel_size),
                                   requires_grad=True)
        self._initialize_weights()
 
        if self.kernel_size == 1 and self.kernel_num == 1:
            self._forward_impl = self._forward_impl_pw1x
        else:
            self._forward_impl = self._forward_impl_common
 
    def _initialize_weights(self):
        for i in range(self.kernel_num):
            nn.init.kaiming_normal_(self.weight[i], mode='fan_out', nonlinearity='relu')
 
    def update_temperature(self, temperature):
        self.attention.update_temperature(temperature)
 
    def _forward_impl_common(self, x):
        # Multiplying channel attention (or filter attention) to weights and feature maps are equivalent,
        # while we observe that when using the latter method the models will run faster with less gpu memory cost.
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        batch_size, in_planes, height, width = x.size()
        x = x * channel_attention
        x = x.reshape(1, -1, height, width)
        aggregate_weight = spatial_attention * kernel_attention * self.weight.unsqueeze(dim=0)
        aggregate_weight = torch.sum(aggregate_weight, dim=1).view(
            [-1, self.in_planes // self.groups, self.kernel_size, self.kernel_size])
        output = F.conv2d(x, weight=aggregate_weight, bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups * batch_size)
        output = output.view(batch_size, self.out_planes, output.size(-2), output.size(-1))
        output = output * filter_attention
        return output
 
    def _forward_impl_pw1x(self, x):
        channel_attention, filter_attention, spatial_attention, kernel_attention = self.attention(x)
        x = x * channel_attention
        output = F.conv2d(x, weight=self.weight.squeeze(dim=0), bias=None, stride=self.stride, padding=self.padding,
                          dilation=self.dilation, groups=self.groups)
        output = output * filter_attention
        return output
 
    def forward(self, x):
        return self._forward_impl(x)

多尺度特征融合网络

多尺度特征是指从图像中提取不同尺度、不同分辨率下的特征。这些特征可以捕捉图像中的局部细节信息(如纹理、边缘等)和全局结构信息(如物体形状和轮廓)。传统的卷积神经网络(CNN)一般通过逐层下采样提取深层特征,但在这个过程中,高层的语义信息虽然丰富,却丢失了低层的细节信息。多尺度特征融合通过结合不同层次的特征,弥补了这一不足。

在这里插入图片描述

如上图所示,在本文的网络设计中,多尺度特征融合通过以下几个步骤实现:

特征提取模块:模型通过不同的卷积核(例如3x3、5x5、7x7)对输入图像进行多层次的卷积操作,提取出不同尺度的特征。

特征拼接与加权融合:在融合阶段,来自不同卷积层的特征图会进行拼接或加权求和,确保网络能够根据不同的任务需求自适应地调整特征权重。例如,在分类任务中,局部信息可能对小物体的识别更有帮助,而全局信息则适用于大物体的分类。
代码实现

class MultiScaleFeatureFusion(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(MultiScaleFeatureFusion, self).__init__()
        self.conv1x1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv3x3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
        self.conv5x5 = nn.Conv2d(in_channels, out_channels, kernel_size=5, padding=2)
        self.conv7x7 = nn.Conv2d(in_channels, out_channels, kernel_size=7, padding=3)

    def forward(self, x):
        out1 = self.conv1x1(x)
        out2 = self.conv3x3(x)
        out3 = self.conv5x5(x)
        out4 = self.conv7x7(x)
        return out1 + out2 + out3 + out4  # 多尺度特征融合

自适应损失函数

在深度学习的图像分类任务中,损失函数的选择直接影响模型的训练效果。本文所设计的网络引入了自适应损失函数(Adaptive Loss Functions),这是提升分类性能的重要创新之一。传统的损失函数通常具有固定的形式和权重,不能根据数据分布和训练阶段的不同自动调整。而自适应损失函数通过动态调整损失权重和形式,能够更有效地优化模型,提升其对复杂问题的学习能力。
代码实现

class AdaptiveLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, balance_factor=0.999):
        super(AdaptiveLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.balance_factor = balance_factor
    
    def forward(self, logits, targets):
        # 计算交叉熵损失
        ce_loss = F.cross_entropy(logits, targets, reduction='none')

模型整体结构

本文使用的模型整体结构如下图所示:
在这里插入图片描述

数据集简介

德国交通标志识别基准(GTSRB)包含43类交通标志,分为39,209个训练图像12,630个测试图像。图像具有不同的光线条件和丰富的背景。如下图所示:
在这里插入图片描述
在这里插入图片描述

实验结果

在经过动态卷积和多尺度特征提取以及自适应损失函数后在验证集上能够取得0.944的准确率。
在这里插入图片描述

其loss曲线和准确率曲线如下图所示:
在这里插入图片描述

并且本文与其他文章结果进行了比较:

模型准确率差异
ASSC[1]82.8%+11.6%
DAN[2]91.1%+3.3%
SRDA[3]93.6%+0.8%
OURS94.4%-

混淆矩阵结果

在这里插入图片描述

实现过程

版本

PyQt5                     5.15.11
seaborn                   0.13.2
torch                     2.4.0
PyQt5-Qt5                 5.15.2
numpy                     1.26.4
pandas                    1.5.0
  1. 首先对模型进行训练,保存最佳模型
python main.py
  1. 加载最佳模型进行可视化预测
python predict_gui.py

参考文献

[1] Haeusser, Philip, et al. “Associative domain adaptation.” Proceedings of the IEEE international conference on computer vision. 2017.
[2] Long, Mingsheng, et al. “Learning transferable features with deep adaptation networks.” International conference on machine learning. PMLR, 2015.
[3] Cai, Guanyu, et al. “Learning smooth representation for unsupervised domain adaptation.” IEEE Transactions on Neural Networks and Learning Systems 34.8 (2021): 4181-4195.
[4] Li, Chao, Aojun Zhou, and Anbang Yao. “Omni-dimensional dynamic convolution.” arXiv preprint arXiv:2209.07947 (2022).

源码下载

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/882325.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

计算机网络17——IM聊天系统——客户端核心处理类框架搭建

目的 拆开客户端和服务端,使用Qt实现客户端,VS实现服务端 Qt创建项目 Qt文件类型 .pro文件:配置文件,决定了哪些文件参与编译,怎样参与编译 .h .cpp .ui:画图文件 Qt编码方式 Qt使用utf-8作为编码方…

【delphi】正则判断windows完整合法文件名,包括路径

在 Delphi 中&#xff0c;可以使用正则表达式来检查 Windows 文件名称或路径是否合法。合法的文件名和路径要求符合以下几点&#xff1a; 禁止的字符&#xff1a;文件名和路径不能包含以下字符&#xff1a;<, >, :, ", /, \, |, ?, *。文件名不能以空格或点结束。…

idea多模块启动

文章目录 idea多模块启动2018版本的idea2019版本的idea idea多模块启动 2018版本的idea 1.首先看一下view> Tool Windows下有没有Run Dashboard 如果有&#xff0c;点击一下底部的窗口就会出现 如果不存在&#xff0c;执行下一步 2.查看自己项目的工作空间位置 点击 File&…

Java中的事件(动作监听-ActionListener)

&#xff08;一&#xff09;、ActionListener接口 ActionListener接口用于处理用户界面上的动作事件&#xff0c;例如&#xff1a;按钮点击、菜单选择等。实现ActionListener接口需要重写actionPerformed(ActionEvent e)方法&#xff0c;该方法会在动作发生时被调用。 &#…

Android WebView H5 Hybrid 混和开发

对于故乡&#xff0c;我忽然有了新的理解&#xff1a;人的故乡&#xff0c;并不止于一块特定的土地&#xff0c;而是一种辽阔无比的心情&#xff0c;不受空间和时间的限制&#xff1b;这心情一经唤起&#xff0c;就是你已经回到了故乡。——《记忆与印象》 前言 移动互联网发展…

Python | Leetcode Python题解之第415题字符串相加

题目&#xff1a; 题解&#xff1a; class Solution:def addStrings(self, num1: str, num2: str) -> str:res ""i, j, carry len(num1) - 1, len(num2) - 1, 0while i > 0 or j > 0:n1 int(num1[i]) if i > 0 else 0n2 int(num2[j]) if j > 0 e…

Dify创建自定义工具,调用ASP.NET Core WebAPI时的注意事项(出现错误:Reached maximum retries (3) for URL ...)

1、要配置Swagger using Microsoft.AspNetCore.Mvc; using Microsoft.OpenApi.Models;var builder WebApplication.CreateBuilder(args);builder.Services.AddCors(options > {options.AddPolicy("AllowSpecificOrigin",builder > builder.WithOrigins("…

SpringSecurity6.x整合手机短信登录授权

前言&#xff1a;如果没有看过我的这篇文章的Springboot3.x.x使用SpringSecurity6(一文包搞定)_springboot3整合springsecurity6-CSDN博客需要看下&#xff0c;大部分多是基于这篇文章的基础上实现的。 明确点我们的业务流程&#xff1a; 需要有一个发送短信的接口&#xff0…

【C++】10道经典面试题带你玩转二叉树

&#x1f984;个人主页:修修修也 &#x1f38f;所属专栏:C ⚙️操作环境:Leetcode/牛客网 目录 一.根据二叉树创建字符串 二.二叉树的层序遍历 三.二叉树的层序遍历 II 四.二叉树的最近公共祖先 五.二叉搜索树与双向链表 六.从前序与中序遍历序列构造二叉树 七.从中序与后序遍历…

基于yolov8的无人机检测系统python源码+onnx模型+评估指标曲线+精美GUI界面

【算法介绍】 基于YOLOv8的无人机检测系统是一项前沿技术&#xff0c;结合了YOLOv8深度学习模型的强大目标检测能力与无人机的灵活性。YOLOv8作为YOLO系列的最新版本&#xff0c;在检测精度和速度上均有显著提升&#xff0c;特别适用于复杂和高动态的场景。 该系统通过捕获实…

【QML 基础】QML ——描述性脚本语言,用于用户界面的编写

文章目录 1. QML 定义2. QML 1. QML 定义 &#x1f427; QML全称为Qt Meta-Object Language&#xff0c;QML是一种描述性的脚本语言&#xff0c;文件格式以.qml结尾。支持javascript形式的编程控制。QML是Qt推出的Qt Quick技术当中的一部分&#xff0c;Qt Quick是 Qt5中用户界…

C++笔记---set和map

1. 序列式容器与关联式容器 前面我们已经接触过STL中的部分容器如&#xff1a;string、vector、list、deque、array、forward_list等&#xff0c;这些容器统称为序列式容器&#xff0c;因为逻辑结构为线性序列的数据结构&#xff0c;两个位置存储的值之间一般没有紧密的关联关…

U盘格式化了怎么办?这4个工具能帮你恢复数据。

如果你思维U盘被格式化了&#xff0c;也不用太过担心&#xff0c;其实里面的数据并没有被删除&#xff0c;只是被标记为了可覆盖的状态。只要我们及时采取正确的数据恢复措施&#xff0c;就有很大的机会可以将数据找回。比如使用专业得的数据恢复软件&#xff0c;我也可以跟大家…

缓存的思考与总结

缓存的思考与总结 什么是缓存缓存命中率数据一致性旁路模式 Cache aside双写模式直写模式 write through异步写 Write Behind 旁路和双写 案例 新技术或中间的引入&#xff0c;一定是解决了亟待解决的问题或是显著提升了系统性能&#xff0c;并且这种改变所带来的增幅&#xff…

python新手的五个练习题

代码 # 1. 定义一个变量my_Number,将其设置为你的学号&#xff0c;然后输出到终端。 my_Number "20240001" # 假设你的学号是20240001 print("学号:", my_Number) # 2. 计算并输出到终端:两个数(例如3和5)的和、差、乘积和商。 num1 3 num2 5 print(&…

nodejs基于vue电子产品商城销售网站的设计与实现 _bugfu

目录 技术栈具体实现截图系统设计思路技术可行性nodejs类核心代码部分展示可行性论证研究方法解决的思路Express框架介绍源码获取/联系我 技术栈 该系统将采用B/S结构模式&#xff0c;开发软件有很多种可以用&#xff0c;本次开发用到的软件是vscode&#xff0c;用到的数据库是…

论文集搜索网站-dblp 详细使用方法

分享在dblp论文集中的两种论文搜索方式&#xff1a;关键字搜索&#xff0c;指定会议/期刊搜索。 关键字搜索 进入dblp官方网址dblp: computer science bibliography&#xff0c;直接在上方搜索栏&#xff0c;搜索关键字&#xff0c;底下会列出相关论文。 指定会议/期刊搜索 …

三菱FX5U PLC故障处理(各种出错的内容、原因及处理方法进行说明。)

对使用系统时发生的各种出错的内容、原因及处理方法进行说明。 故障排除的步骤 发生故障时&#xff0c;按以下顺序实施故障排除。 1.确认各模块是否正确安装或正确配线。 2、确认CPU模块的LED。 3.确认各智能功能模块的LED。(各模块的用户手册) 4、连接工程工具&#xff0c;启…

从数据仓库到数据中台再到数据飞轮:我了解的数据技术进化史

这里写目录标题 前言数据仓库&#xff1a;数据整合的起点数据中台&#xff1a;数据共享的桥梁数据飞轮&#xff1a;业务与数据的双向驱动结语 前言 在当今这个数据驱动的时代&#xff0c;企业发展离不开对数据的深度挖掘和高效利用。从最初的数据仓库&#xff0c;到后来的数据…

828华为云征文|华为Flexus云服务器搭建Cloudreve私人网盘

一、华为云 Flexus X 实例&#xff1a;开启高效云服务新篇&#x1f31f; 在云计算的广阔领域中&#xff0c;资源的灵活配置与卓越性能犹如璀璨星辰般闪耀。华为云 Flexus X 实例恰似一颗最为耀眼的新星&#xff0c;将云服务器技术推向了崭新的高度。 华为云 Flexus X 实例基于…