深入浅出Pytorch函数——torch.nn.init.xavier_uniform_

news/2025/2/23 6:02:40

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据Glorot, X.和Bengio, Y.在《Understanding the difficulty of training deep feedforward neural networks》中描述的方法,用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自 U ( − a , a ) U(-a, a) U(a,a),其中:
a = gain × 6 fan_in + fan_put a=\text{gain}\times\sqrt{\frac{6}{\text{fan\_in}+\text{fan\_put}}} a=gain×fan_in+fan_put6

这种方法也被称为Glorot initialization。

语法

torch.nn.init.xavier_uniform_(tensor, gain=1)

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor
  • gain :[float] 可选的缩放因子

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 5)
nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))

函数实现

def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
    r"""Fills the input `Tensor` with values according to the method
    described in `Understanding the difficulty of training deep feedforward
    neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
    distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-a, a)` where

    .. math::
        a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}}

    Also known as Glorot initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        gain: an optional scaling factor

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
    """
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
    a = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation

    return _no_grad_uniform_(tensor, -a, a)

http://www.niftyadmin.cn/n/4952504.html

相关文章

大数据课程K1——Spark概述

文章作者邮箱:yugongshiye@sina.cn 地址:广东惠州 ▲ 本章节目的 ⚪ 了解Spark的背景; ⚪ 了解Spark的特点; ⚪ 掌握Spark的生态系统模块、使用模式; ⚪ 掌握Spark的单机模式安装; 一、简介 1. 背景 Spark是UC Berkeley AMP lab (加州大学伯克利分校…

第十二章MyBatis动态SQL

if标签与where标签 if标签 test如果为true就会拼接查询条件&#xff0c;否则不会 当没有使用Param&#xff0c;test出现arg0/param1当使用Param&#xff0c;test为Param指定的值当使用Pojo&#xff0c;test为对象的属性名 select * from car where <if test"name!n…

如何快速制作解决方案PPT

如何快速制作解决方案PPT 理解客户的需求 在开始制作解决方案PPT之前&#xff0c;需要对客户的需求进行深入了解和分析。这包括客户需要解决的问题、目标、预算和时间限制等。 需求分析 客户需要解决的问题客户的目标预算限制时间限制 确定解决方案 基于客户的需求&#x…

redis 存储结构原理 2

咱们接着上一部分来进行分享&#xff0c;我们可以在如下地址下载 redis 的源码&#xff1a; https://redis.io/download 此处我下载的是 redis-6.2.5 版本的&#xff0c;xdm 可以直接下载上图中的 **redis-6.2.6 **版本&#xff0c; redis 中 hash 表的数据结构 redis hash …

unity 之Transform组件(汇总)

文章目录 理论指导结合例子 理论指导 当在Unity中处理3D场景中的游戏对象时&#xff0c;Transform 组件是至关重要的组件之一。它管理了游戏对象的位置、旋转和缩放&#xff0c;并提供了许多方法来操纵和操作这些属性。以下是关于Transform 组件的详细介绍&#xff1a; 位置&a…

tomcat结构目录有哪些?

bin 启动&#xff0c;关闭和其他脚本。这些 .sh文件&#xff08;对于Unix系统&#xff09;是这些.bat文件的功能副本&#xff08;对于 Windows系统&#xff09;。由于Win32命令行缺少某些功能&#xff0c;因此此处包含一些其他文件。 比如说&#xff1a;windows下启动tomcat用的…

SpringBoot复习:(55)在service类中的方法上加上@Transactional注解后,Spring底层是怎么生成代理对象的?

SpringBoot run方法代码如下&#xff1a; 可以看到它会调用refreshContext方法来刷新Spring容器&#xff0c;这个refreshContext方法最终会调用AbstractApplicationContext的refresh方法&#xff0c;代码如下 如上图&#xff0c;refresh方法最终会调用finisheBeanFactoryInit…

OpenCV基础知识(5)— 几何变换

前言&#xff1a;Hello大家好&#xff0c;我是小哥谈。OpenCV中的几何变换是指改变图像的几何结构&#xff0c;例如大小、角度和形状等&#xff0c;让图像呈现出缩放、翻转、旋转和透视效果。这些几何变换操作都涉及复杂、精密的计算。OpenCV将这些计算过程都封装成了非常灵活的…