【Transformer】detr之encoder逐行梳理(二)

every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog

0. 前言

detr之encoder逐行梳理

1. 整体

encoder由encoder layer构成

输入进encoder的特征shape:(hw,b,c),后文将给出说明

class Transformer(nn.Module):

    def __init__(self, d_model=512, nhead=8, num_encoder_layers=6,
                 num_decoder_layers=6, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False,
                 return_intermediate_dec=False):
        super().__init__()
        # encoder layer
        encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
                                                dropout, activation, normalize_before)
        encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
        # encoder 部分
        self.encoder = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)

        ... # 略

    def forward(self, src, mask, query_embed, pos_embed):
        # flatten bxCxHxW to HWxbxC
        bs, c, h, w = src.shape
        # (b,c,h,w) ->(b,c,hw) -> (hw,b,c) 
        src = src.flatten(2).permute(2, 0, 1)
        # (b,c,h,w) ->(b,c,hw) -> (hw,b,c) 
        pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
        # (b,h,w) -> (b,hw)
        mask = mask.flatten(1)

        memory = self.encoder(src, src_key_padding_mask=mask, pos=pos_embed)

        ... # 略

2. 部分

2.1 get_clone

用于对指定的层进行复制

def _get_clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for i in range(N)])

2.2 Encoder

串联多个layer,输出作为输入

20240422143301

class TransformerEncoder(nn.Module):

    def __init__(self, encoder_layer, num_layers, norm=None):
        super().__init__()
        # 对指定的层进行复制
        self.layers = _get_clones(encoder_layer, num_layers)
        self.num_layers = num_layers
        self.norm = norm

    def forward(self, src,
                mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        output = src

        for layer in self.layers:
            # 输出作为输入
            output = layer(output, src_mask=mask,
                           src_key_padding_mask=src_key_padding_mask, pos=pos)

        if self.norm is not None:
            output = self.norm(output)

        return output

2.3 EncoderLayer

结构:

最开始的输入是backone的输出,即,src,后续的输入是上一层的输出
20240422150225

其中forward包含forward_post和forward_pre两个函数,主要区别是最开始进行标准化还是最后进行标准化。

由于self.normalize_before默认是False,所以默认是forward_post ,如下方的局部代码所示


q 和 k = backbone输出特征图 + 位置编码

这里对query和key增加位置编码 是因为需要在图像特征中各个位置之间计算相似度/相关性, 而value作为原图像的特征 和 相关性矩阵加权,
从而得到各个位置结合了全局相关性(增强后)的特征表示,所以q 和 k这种计算需要+位置编码 而v代表原图像不需要加位置编码


其中注意力计算主要涉及到两个参数:

  • key_padding_mask: 这部分就是我们在backbone中获取的mask,
    记录backbone生成的特征图中哪些是原始图像pad的部分 这部分是没有意义的
    计算注意力会被填充为-inf,这样最终生成注意力经过softmax时输出就趋向于0,相当于忽略不计。

  • attn_mask: 是在Transformer中用来“防作弊”的,即遮住当前预测位置之后的位置,忽略这些位置,不计算与其相关的注意力权重
    在encoder中通常为None,不使用,因为要计算全局的相关性。 decoder中才使用

forward_post局部代码:

def with_pos_embed(self, tensor, pos: Optional[Tensor]):
    return tensor if pos is None else tensor + pos

def forward_post(self,
                    src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):

    q = k = self.with_pos_embed(src, pos) # q,k都添加位置编码
    # 计算
    # key_padding_mask: 记录backbone生成的特征图中哪些是原始图像pad的部分 这部分是没有意义的
    #                   计算注意力会被填充为-inf,这样最终生成注意力经过softmax时输出就趋向于0,相当于忽略不计
    # attn_mask: 是在Transformer中用来“防作弊”的,即遮住当前预测位置之后的位置,忽略这些位置,不计算与其相关的注意力权重
    #            而在encoder中通常为None 不适用  decoder中才使用
    src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                            key_padding_mask=src_key_padding_mask)[0]
    # 残差连接
    src = src + self.dropout1(src2)
    # 标准化
    src = self.norm1(src)
    # FFN
    src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
    # 残差连接
    src = src + self.dropout2(src2)
    # 最后进行标准化
    src = self.norm2(src) 
    
    return src

默认batch_first = False
20240422152811

所以输入的形式是(l,batch,d),即我们最开始看到的(hw,b,c)
20240422152750

输出两个值,第一个是计算结果,第二个是权重。只需要第一个所以上面用了[0]

20240422153424

EncoderLayer完整代码:

class TransformerEncoderLayer(nn.Module):

    def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
                 activation="relu", normalize_before=False):
        super().__init__()
        self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
        # Implementation of Feedforward model
        self.linear1 = nn.Linear(d_model, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, d_model)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = _get_activation_fn(activation)
        self.normalize_before = normalize_before

    def with_pos_embed(self, tensor, pos: Optional[Tensor]):
        return tensor if pos is None else tensor + pos

    def forward_post(self,
                     src,
                     src_mask: Optional[Tensor] = None,
                     src_key_padding_mask: Optional[Tensor] = None,
                     pos: Optional[Tensor] = None):
        q = k = self.with_pos_embed(src, pos)
        src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src = self.norm1(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src) # 最后进行标准化
        return src

    def forward_pre(self, src,
                    src_mask: Optional[Tensor] = None,
                    src_key_padding_mask: Optional[Tensor] = None,
                    pos: Optional[Tensor] = None):
        
        src2 = self.norm1(src) # 最开始进行标准化
        q = k = self.with_pos_embed(src2, pos)
        src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
                              key_padding_mask=src_key_padding_mask)[0]
        src = src + self.dropout1(src2)
        src2 = self.norm2(src)
        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
        src = src + self.dropout2(src2)
        return src

    def forward(self, src,
                src_mask: Optional[Tensor] = None,
                src_key_padding_mask: Optional[Tensor] = None,
                pos: Optional[Tensor] = None):
        # 默认是False
        if self.normalize_before:
            return self.forward_pre(src, src_mask, src_key_padding_mask, pos)

        return self.forward_post(src, src_mask, src_key_padding_mask, pos)

参考

  1. https://blog.csdn.net/weixin_39190382/article/details/137905915?spm=1001.2014.3001.5502
  2. https://hukai.blog.csdn.net/article/details/127616634

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/567307.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

Nacos分布式配置中心

<?xml version"1.0" encoding"UTF-8"?> <project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance"xsi:schemaLocation"http://maven.apache.org/POM/4.0.0 https://…

常见的数据抽取工具对比

1.什么是ETL? ETL&#xff0c;是英文Extract-Transform-Load的缩写&#xff0c;用来描述将数据从来源端经过抽取&#xff08;extract&#xff09;、转换&#xff08;transform&#xff09;、加载&#xff08;load&#xff09;至目的端的过程&#xff0c;是数据仓库的生命线。 …

C#仿QQ抽屉式窗体的设计方法:创建特殊窗体

目录 1.WindowFromPoint函数 2.GetParent函数 3.实例 &#xff08;1&#xff09; 图片集合编辑器 &#xff08;2&#xff09;Form1.Designer.cs &#xff08;3&#xff09;Form1.cs 4.生成效果 QQ软件对于绝大多数的人来说再熟悉不过了&#xff0c;它以使用方便、界面美…

Scala 05 —— 函数式编程底层逻辑

Scala 05 —— 函数式编程底层逻辑 该文章来自2023/1/14的清华大学交叉信息学院助理教授——袁洋演讲。 文章目录 Scala 05 —— 函数式编程底层逻辑函数式编程假如...副作用是必须的&#xff1f;函数的定义函数是数据的函数&#xff0c;不是数字的函数如何把业务逻辑做成纯函…

多因素不同水平的正交表设计(并列法)

文章目录 一、问题提出二、举例说明 一、问题提出 参考高等教育课本《实验设计与数据处理》 很多时候&#xff0c;我们要考察的因素水平数不尽相同&#xff0c;这时候一般采用混合水平正交表或者对普通的正交表作修改&#xff0c;其中&#xff0c;混合水平正交表由于水平数不规…

JAVA程序设计-对象设计

无论是根据某马还是某谷的适配教程做项目时候,发现了大部分都是重复的crud,大部分只要做好笔记复习即可,但是却往往忘记了编码设计,所以这里开始复习编码设计,对象设计中,长期使用Mp的那一套导致就是Service Mapper,一套梭哈完了,这样很容易忘记基本功夫 POJO&#xff1a; 简单…

Java、Spring、Dubbo三者SPI机制原理与区别

Java、Spring、Dubbo三者SPI机制原理与区别 什么是SPI SPI全称为Service Provider Interface&#xff0c;是一种动态替换发现的机制&#xff0c;一种解耦非常优秀的思想&#xff0c;SPI可以很灵活的让接口和实现分离&#xff0c;让api提供者只提供接口&#xff0c;第三方来实…

刷题训练之二分查找

> 作者&#xff1a;დ旧言~ > 座右铭&#xff1a;松树千年终是朽&#xff0c;槿花一日自为荣。 > 目标&#xff1a;熟练掌握二分查找算法 > 毒鸡汤&#xff1a;学习&#xff0c;学习&#xff0c;再学习 ! 学&#xff0c;然后知不足。 > 专栏选自&#xff1a;刷题…

网卡技术解密:理解网卡背后的原理

✍✍在这个信息爆炸的时代&#xff0c;网卡承载着无数数据的流动&#xff0c;是我们日常生活和工作不可或缺的一部分。但是&#xff0c;您是否曾经好奇过&#xff0c;这些小小的硬件是如何在瞬息万变的网络世界中稳定地发挥作用的呢&#xff1f; 想象一下&#xff0c;每当我们…

2024中国内燃机展-北京汽车发动机零部件展

2024第二十三届中国国际内燃机与零部件展览会 由中国内燃机工业协会主办、中国机床专用技术设备有限公司、汽车工艺装备成套开发集团协办的2024中国国际内燃机及动力装备博览会&#xff08;简称“动博会”&#xff09;将于2024年10月11日-13日在亦创国际会展中心隆重举办。本届…

智能时代 | 合合信息Embedding模型荣获C-MTEB榜单第一

目录 前言 1. MTEB与C-MTEB 2. acge模型的优势 3. Embedding模型应用 4. 大模型发展的关键技术 结语 前言 随着人工智能的不断发展&#xff0c;大语言模型吸引着社会各界的广泛关注&#xff0c;支撑模型应用落地的Embedding模型成为业内的焦点&#xff0c;大模型的发展给…

Electron 30.0.0 发布,升级 Node 和 V8 引擎

近日&#xff0c;Electron 30.0.0 正式发布&#xff01;你可以通过 npm install electronlatest 进行安装&#xff0c;或者从 Electron 的发布网站下载&#xff0c;继续阅读了解此版本的详细信息。 &#x1f525; 主要更新 Windows 上支持 ASAR 完整性融合。如果未正确配置&am…

【后端】python与django的开发环境搭建指南

安装Git 双击Git 客户端安装文件&#xff0c;在安装页面&#xff0c;单击“Next” 在安装路径选择页面&#xff0c;保持默认&#xff0c;单击“Next” 在功能组件选择页面&#xff0c;保持默认&#xff0c;单击“Next” 在开始菜单文件夹设置页面&#xff0c;保持默认&am…

AI交互数字人对教育领域有何优势?

AI交互数字人不仅能够跨越物理距离的限制&#xff0c;以数字人形象为学生提供“面对面”教学互动体验&#xff0c;还能根据学生的具体需求提供个性化的知识解答。如天津大学推出了数字人老师&#xff0c;以刘艳丽教授形象1&#xff1a;1仿真打造的2.5D数字人&#xff0c;能够应…

png图片如何缩小体积?这个方法效果不错

图片压缩是我们生活中经常都会遇到的问题。在日常工作中图片体积过大的话&#xff0c;在使用过程中就会收到影响&#xff0c;比如加载过慢等。那么&#xff0c;当我们想要对png图片进行压缩处理的时候&#xff0c;要怎么操作呢&#xff1f;很简单&#xff0c;使用图片在线压缩&…

单链表逆置(头插法,递归,数据结构栈的应用)

链表逆置就是把最后一个数据提到最前面&#xff0c;倒数第二个放到第二个……依次类推&#xff0c;直到第一个到最后一个。 由于链表没有下标&#xff0c;所以不能借助下标来实行数据的逆置&#xff0c;要靠空间的转移来完成链表的逆置&#xff0c;这里采用没有头节点的链表来实…

Ansible安装基本原理及操作(初识)

作者主页&#xff1a;点击&#xff01; Ansible专栏&#xff1a;点击&#xff01; 创作时间&#xff1a;2024年4月23日15点18分 Ansible 是一款功能强大且易于使用的IT自动化工具&#xff0c;可用于配置管理、应用程序部署和云端管理。它使用无代理模式&#xff08;agentles…

学习笔记:Vue2高级篇

Vue2 学习笔记&#xff1a;Vue2基础篇_ljtxy.love的博客-CSDN博客学习笔记&#xff1a;Vue2中级篇_ljtxy.love的博客-CSDN博客学习笔记&#xff1a;Vue2高级篇_ljtxy.love的博客-CSDN博客 Vue3 学习笔记&#xff1a;Vue3_ljtxy.love的博客&#xff09;-CSDN博客 文章目录 7.…

STM32 HAL库F103系列之DAC实验(一)

DAC输出实验 原理图 DAC数据格式 DAC输出电压 DORX - 数据输出寄存器 Vref 3.3V 实验简要 1&#xff0c;功能描述 通过DAC1通道1(PA4)输出预设电压&#xff0c; 然后由ADC1通道1 (PA1) 采集&#xff0c;最后显示ADC转换的数字量及换算后的电压值 2&#xff0c;关闭通道1…

【已解决】三菱PLC与电脑通信步骤

前言 现场弄了一下一台三菱FX5U的PLC结果试了半天都没有连接上&#xff0c;后来琢磨了一下终于算是连接上了。报错的截图如下图所示&#xff1a; 解决步骤 第一步&#xff1a;先将自己电脑的IP地址设置到与PLC的IP地址在同一个网段下&#xff08;前三个是一样&#xff0c;最…
最新文章