深入浅出Pytorch函数——torch.nn.init.sparse_

news/2025/2/23 5:37:46

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据Martens, J等人在《Deep learning via Hessian-free optimization》中描述的方法,将2维的输入张量或变量当做稀疏矩阵填充,其中非零元素生成自 N ( 0 , std 2 ) N(0, \text{std}^2) N(0,std2)

语法

torch.nn.init.sparse_(tensor, sparsity, std=0.01)

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor
  • sparsity:每列中需要被设置成零的元素比例
  • std:用于生成非零值的正态分布的标准差

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 5)
nn.init.sparse_(w, sparsity=0.1)

函数实现

def sparse_(tensor, sparsity, std=0.01):
    r"""Fills the 2D input `Tensor` as a sparse matrix, where the
    non-zero elements will be drawn from the normal distribution
    :math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via
    Hessian-free optimization` - Martens, J. (2010).

    Args:
        tensor: an n-dimensional `torch.Tensor`
        sparsity: The fraction of elements in each column to be set to zero
        std: the standard deviation of the normal distribution used to generate
            the non-zero values

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.sparse_(w, sparsity=0.1)
    """
    if tensor.ndimension() != 2:
        raise ValueError("Only tensors with 2 dimensions are supported")

    rows, cols = tensor.shape
    num_zeros = int(math.ceil(sparsity * rows))

    with torch.no_grad():
        tensor.normal_(0, std)
        for col_idx in range(cols):
            row_indices = torch.randperm(rows)
            zero_indices = row_indices[:num_zeros]
            tensor[zero_indices, col_idx] = 0
    return tensor

http://www.niftyadmin.cn/n/4951951.html

相关文章

LeetCode 160.相交链表

文章目录 💡题目分析💡解题思路🚩步骤一:找尾节点🚩步骤二:判断尾节点是否相等🚩步骤三:找交点🍄思路1🍄思路2 🔔接口源码 题目链接👉…

【Redis】Redis中的布隆过滤器

【Redis】Redis中的布隆过滤器 前言 在实际开发中,会遇到很多要判断一个元素是否在某个集合中的业务场景,类似于垃圾邮件的识别,恶意IP地址的访问,缓存穿透等情况。类似于缓存穿透这种情况,有许多的解决方法&#xf…

LVS-DR模式以及其中ARP问题

目录 LVS_DR LVS_DR数据包流向分析 LVS-DR中ARP问题 问题一 问题二 解决ARP的两个问题的设置方法 LVS-DR特点 LVS-DR优缺点 优点 缺点 LVS-DR集群构建 1.配置负载调度器 2.部署共享存储 3.配置节点服务器 4.测试 LVS 群集 LVS_DR LVS_DR数据包流向分析 客户端…

高效使用ChatGPT之ChatGPT客户端

ChatGPT客户端,支持Mac, Windows, and Linux 下载地址见文章结尾 软件截图 Windows: Mac: 说明 chatgpt桌面版,相比于网页版的chatgpt,最大的特色是支持历史聊天对话记录导出,且支持三种格式:PNG、PDF、…

SpreadJS 16.2 and SpreadJS 16.1.5 Crack

SpreadJS 16.2 SpreadJS SpreadJS 是一个完整的企业 JavaScript 电子表格解决方案,用于创建财务报告和仪表板、预算和预测模型、科学、工程、医疗保健、教育、科学实验室笔记本和其他类似的 JavaScript 业务应用程序。利用高速计算引擎和 19 种语言的 500 多个 Exce…

Vulnhub系列靶机 Hackadmeic.RTB1

系列:Hackademic(此系列共2台) 难度:初级 信息收集 主机发现 netdiscover -r 192.168.80.0/24端口扫描 nmap -A -p- 192.168.80.143访问80端口 使用指纹识别插件查看是WordPress 根据首页显示的内容,点击target 点击…

Linux笔试题(4)

67、在局域网络内的某台主机用ping命令测试网络连接时发现网络内部的主机都可以连同,而不能与公网连通,问题可能是__C_ A.主机ip设置有误 B.没有设置连接局域网的网关 C.局域网的网关或主机的网关设置有误 D.局域网DNS服务器设置有误 解析:在局域网络内的某台主…

过来,我告诉你个秘密:送给程序员男友最好的礼物,快教你对象学习磁盘分区啦!小点声哈,别让其他人学会了!

[原文连接:来自给点知识](过来,我告诉你个秘密:送给程序员男友最好的礼物,快教你对象学习磁盘分区啦!小点声哈,别让其他人学会了!) 再唱不出那样的歌曲 听到都会红着脸躲避 虽然会经常忘了我依然爱着你 …