深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_

news/2025/2/23 5:42:08

分类目录:《深入浅出Pytorch函数》总目录
相关文章:
· 深入浅出Pytorch函数——torch.nn.init.calculate_gain
· 深入浅出Pytorch函数——torch.nn.init.uniform_
· 深入浅出Pytorch函数——torch.nn.init.normal_
· 深入浅出Pytorch函数——torch.nn.init.constant_
· 深入浅出Pytorch函数——torch.nn.init.ones_
· 深入浅出Pytorch函数——torch.nn.init.zeros_
· 深入浅出Pytorch函数——torch.nn.init.eye_
· 深入浅出Pytorch函数——torch.nn.init.dirac_
· 深入浅出Pytorch函数——torch.nn.init.xavier_uniform_
· 深入浅出Pytorch函数——torch.nn.init.xavier_normal_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_uniform_
· 深入浅出Pytorch函数——torch.nn.init.kaiming_normal_
· 深入浅出Pytorch函数——torch.nn.init.trunc_normal_
· 深入浅出Pytorch函数——torch.nn.init.orthogonal_
· 深入浅出Pytorch函数——torch.nn.init.sparse_


torch.nn.init模块中的所有函数都用于初始化神经网络参数,因此它们都在torc.no_grad()模式下运行,autograd不会将其考虑在内。

根据He, K等人于2015年在《Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification》中描述的方法,用一个均匀分布生成值,填充输入的张量或变量。结果张量中的值采样自 U ( − bound , bound ) U(-\text{bound}, \text{bound}) U(bound,bound),其中:
bound = gain × 3 fan_mode \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} bound=gain×fan_mode3

这种方法也被称为He initialisation。

语法

torch.nn.init.kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu')

参数

  • tensor:[Tensor] 一个 N N N维张量torch.Tensor
  • a:[float] 这层之后使用的rectifier的斜率系数(ReLU的默认值为0)
  • mode:[str] 可以为fan_infan_out。若为fan_in则保留前向传播时权值方差的量级,若为fan_out则保留反向传播时的量级,默认值为fan_in
  • nonlinearity:[str] 一个非线性函数,即一个nn.functional的名称,推荐使用relu或者leaky_relu,默认值为leaky_relu

返回值

一个torch.Tensor且参数tensor也会更新

实例

w = torch.empty(3, 5)
nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')

函数实现

def kaiming_uniform_(
    tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
):
    r"""Fills the input `Tensor` with values according to the method
    described in `Delving deep into rectifiers: Surpassing human-level
    performance on ImageNet classification` - He, K. et al. (2015), using a
    uniform distribution. The resulting tensor will have values sampled from
    :math:`\mathcal{U}(-\text{bound}, \text{bound})` where

    .. math::
        \text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}

    Also known as He initialization.

    Args:
        tensor: an n-dimensional `torch.Tensor`
        a: the negative slope of the rectifier used after this layer (only
            used with ``'leaky_relu'``)
        mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
            preserves the magnitude of the variance of the weights in the
            forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
            backwards pass.
        nonlinearity: the non-linear function (`nn.functional` name),
            recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).

    Examples:
        >>> w = torch.empty(3, 5)
        >>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
    """
    if torch.overrides.has_torch_function_variadic(tensor):
        return torch.overrides.handle_torch_function(
            kaiming_uniform_,
            (tensor,),
            tensor=tensor,
            a=a,
            mode=mode,
            nonlinearity=nonlinearity)

    if 0 in tensor.shape:
        warnings.warn("Initializing zero-element tensors is a no-op")
        return tensor
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

http://www.niftyadmin.cn/n/4952157.html

相关文章

PySpark-核心编程

2. PySpark——RDD编程入门 文章目录 2. PySpark——RDD编程入门2.1 程序执行入口SparkContext对象2.2 RDD的创建2.2.1 并行化创建2.2.2 获取RDD分区数2.2.3 读取文件创建 2.3 RDD算子2.4 常用Transformation算子2.4.1 map算子2.4.2 flatMap算子2.4.3 reduceByKey算子2.4.4 Wor…

linux RabbitMQ-3.8.5 安装

软件版本操作系统CentOS Linux release 7.9.2009erlangerlang-23.0.2-1.el7.x86_64rabbitMQrabbitmq-server-3.8.5-1.el7 RabbitMQ的安装首先需要安装Erlang,因为它是基于Erlang的VM运行的。 RabbitMQ安装需要依赖:socat和logrotate,logrotate操作系统已经存在了&…

【机器学习】— 2 图神经网络GNN

一、说明 在本文中,我们探讨了图神经网络(GNN)在推荐系统中的潜力,强调了它们相对于传统矩阵完成方法的优势。GNN为利用图论来改进推荐系统提供了一个强大的框架。在本文中,我们将在推荐系统的背景下概述图论和图神经网…

基于STM32标准库智能风扇设计

目录 一,前言 二,系统方案选择 三,实体展示 工程分类 四,相关代码 PWM.c PWM.h AD.c AD.h 电机驱动程序 舵机驱动 一,前言 当今生活中,风扇已成为人们解暑的重要工具,然而使用风扇缓解…

Spring 整合RabbitMQ,笔记整理

1.创建生产者工程 spring-rabbitmq-producer 2.pom.xml添加依赖 <dependencies><dependency><groupId>org.springframework</groupId><artifactId>spring-context</artifactId><version>5.1.7.RELEASE</version></dep…

ansible(2)-- ansible常用模块

部署ansible&#xff1a;ansible&#xff08;1&#xff09;-- 部署ansible连接被控端_luo_guibin的博客-CSDN博客 目录 一、ansible常用模块 1.1 ping 1.2 command 1.3 raw 1.4 shell 1.5 script 1.6 copy 1.7 template 1.8 yum 11.0.1.13 主控端(ansible)11.0.1.12 被控端(k8s…

iptables防火墙(SNAT与DNAT)

目录 1 SNAT 1.1 SNAT原理与应用 1.2 SNAT工作原理 1.3 SNAT转换前提条件 2 SNAT示例 ​编辑 2.1 网关服务器配置 2.1.1 网关服务器配置网卡 2.1.2 开启SNAT命令 2.2 内网服务器端配置 2.3 外网服务器端配置 2.4 网卡服务器端添加规则 2.5 SNAT 测试 3 DNAT 3.1 网卡…

Seaborn数据可视化(一)

目录 1.seaborn简介 2.Seaborn绘图风格设置 21.参数说明&#xff1a; 2.2 示例&#xff1a; 1.seaborn简介 Seaborn是一个用于数据可视化的Python库&#xff0c;它是建立在Matplotlib之上的高级绘图库。Seaborn的目标是使绘图任务变得简单&#xff0c;同时产生美观且具有信…