Masked_Filled随机置列为零

news/2025/1/15 15:29:57 标签: 深度学习, pytorch, 人工智能

文章目录

  • 1. softmax
  • 2. python 方法

1. softmax

在计算损失函数的时候,我们需要将我们填充为0的地方概率置为0,以免参与损失计算,我们一般会将需要置为0的位置上面通过masked_filled函数将为True的位置置为一个非常小的值1e-9,这样经过F.softmax函数后,其值为0。这里用到两个函数,

  • 第一个是F.softmax,主要负责归一化处理,将值转换为0-1内,并且其和为1,转换成概率值。
  • 第二个是Masked_fill 函数,可以通过提供一个同等大小的BOOL矩阵,将为True的地方,填充为自己喜欢的值。
  • 第三个是填充的方式,在transformer中,我们把为0的位置的值填充为负无穷,这样经过为softmax后为零,但是transofrmer中填充的方式为在一个行向量中的末尾填充零,以行向量作为样本向量,列向量为特征向量,根据MIT麻神理工的思路,矩阵A以列向量表示更适合参数学习,所以我们希望通过随机掩码不同位置的列向量,这样通过学习样本的特征维来表示矩阵,所以我们引入一种列向量掩码方式。

2. python 方法

import torch
import torch.nn as nn
import torch.nn.functional as F

torch.set_printoptions(precision=3, sci_mode=False)
torch.manual_seed(333512)

if __name__ == "__main__":
    run_code = 0
    row = 4
    column = 5
    scores = torch.randn(row, column)
    masked = torch.randint(0, 2, (1, column)).to(torch.bool)
    masked_scores = scores.masked_fill(masked, -1e9)
    scores_softmax = F.softmax(masked_scores, dim=-1)
    print(f"scores=\n{scores}")
    print(f"masked=\n{masked}")
    print(f"masked_scores=\n{masked_scores}")
    print(f"scores_softmax=\n{scores_softmax}")
  • 结果:
scores=
tensor([[-0.786,  1.136,  1.624,  0.417,  1.366],
        [-0.520, -0.127, -0.219, -0.489,  0.276],
        [-0.937, -0.734,  1.221, -0.305,  1.020],
        [ 2.252, -0.042, -1.098,  1.135, -0.075]])
masked=
tensor([[False,  True,  True, False,  True]])
masked_scores=
tensor([[    -0.786, -1000000000.000, -1000000000.000,      0.417, -1000000000.000],
        [    -0.520, -1000000000.000, -1000000000.000,     -0.489, -1000000000.000],
        [    -0.937, -1000000000.000, -1000000000.000,     -0.305, -1000000000.000],
        [     2.252, -1000000000.000, -1000000000.000,      1.135, -1000000000.000]])
scores_softmax=
tensor([[0.231, 0.000, 0.000, 0.769, 0.000],
        [0.492, 0.000, 0.000, 0.508, 0.000],
        [0.347, 0.000, 0.000, 0.653, 0.000],
        [0.754, 0.000, 0.000, 0.246, 0.000]])

http://www.niftyadmin.cn/n/5824191.html

相关文章

API接口技术开发小红书笔记详情api采集笔记图片视频参数解析

小红书笔记详情 API 采集笔记图片视频参数解析如下: 获取 API 访问权限 注册账号:填写用户名、邮箱、密码等必要信息完成注册,并登录进入开发者控制台。创建应用并申请接口权限:填写应用名称、描述、应用类型等信息并提交审核。审…

大语言模型训练的基本步骤解析

一、引言 大语言模型(LLMs)在当今人工智能领域取得了令人瞩目的成就,从智能聊天机器人到文本生成、语言翻译等广泛应用,深刻改变着人们与信息交互的方式。这些模型展现出强大的语言理解和生成能力背后,是一套复杂而精妙…

STM32从零开始深入学习

STM32项目创建 1.新建项目文件夹1.1 Drivers1.2 Middlewares1.3 Output1.4 Projects1.5 User 2.新建项目工程2.1 项目新建2.2 项目文件添加 3.魔术棒设置3.1 Target3.2 Output3.3 Listing3.4 C/C3.5 Debug 4.下载调试4.1 主程序写入4.2 编译和下载烧录 1.新建项目文件夹 新建一…

重新定义数据分析:LLM如何让人专注真正的思考

重新定义数据分析:LLM如何让人专注真正的思考 LLM重塑智能数据分析:从DIKW到智能Agent的演进智能数据分析的技术突破智能数据分析的未来图景 还记得第一次用Excel做数据分析的场景吗?选数据、找公式、画图表…每一步都像在破解密码。 现在&am…

python项目结构,PyCharm 调试Debug模式配置

经常使用java开发转到python项目有些差异。在 Python 中,项目的组织结构和 Java 有一些不同。Java 在创建项目时通常会先定义包(package),然后在包下创建源代码文件(.java)。而在 Python 中,虽然…

【大厂面试AI算法题中的知识点】方向涉及:ML/DL/CV/NLP/大数据...本篇介绍自动驾驶检测模型如何针对corner case 优化?

【大厂面试AI算法题中的知识点】方向涉及:ML/DL/CV/NLP/大数据…本篇介绍自动驾驶检测模型如何针对corner case 优化? 【大厂面试AI算法题中的知识点】方向涉及:ML/DL/CV/NLP/大数据…本篇介绍自动驾驶检测模型如何针对corner case 优化&…

小游戏前端地区获取

目前前端获取除了太平洋,没有其它的了。 //在JS中都是使用的UTF-8,然而requst请求后显示GBK却是乱码,对传入的GBK字符串,要用数据流接收,responseType: "arraybuffer" tt.request({url: "https://whoi…

常见的php框架有哪几个?

一直以来,PHP作为一种广泛使用的编程语言,拥有着许多优秀的框架来帮助开发人员快速构建稳定的Web应用程序。本文降为大家介绍几种常见的PHP的主流框架,以及它们相关的特点和使用场景。如有问题,欢迎指正! 1.Laravel&a…