4.1 Tensor 与自动微分

Tensor 和 autograd 是 PyTorch 的两块基石——前者决定了”数据怎么存、怎么算”,后者决定了”梯度怎么来”。本文从 Tensor 的创建、索引、变形、内存布局讲起,再深入 autograd 的计算图机制、梯度累积与常见踩坑,帮你建立扎实的底层认知,为后续的模型搭建、分布式训练和性能优化打好地基。

📑 目录


1. Tensor:一切计算的载体

打个比方:如果深度学习是盖房子,Tensor 就是砖头。模型的权重是砖头,输入数据是砖头,中间的计算结果也是砖头——PyTorch 里一切皆 Tensor。

正式定义:Tensor(张量)是一种多维数组数据结构,可以视为 NumPy ndarray 的超集——不仅支持 GPU 加速运算,还能接入自动微分引擎,是深度学习计算的基本单位。

1.1 创建 Tensor 的常用方式

PyTorch 提供了多种创建 Tensor 的方法,可以分为三大类:

从已有数据创建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import numpy as np

# 从 Python 列表
a = torch.tensor([1.0, 2.0, 3.0])

# 从 NumPy(共享底层内存,修改一个会影响另一个)
np_arr = np.array([[1, 2], [3, 4]], dtype=np.float32)
b = torch.from_numpy(np_arr)
np_arr[0, 0] = 99
print(b[0, 0]) # tensor(99.) ← 内存共享

# 从 NumPy 创建独立副本(不共享内存)
c = torch.tensor(np_arr) # tensor() 总是拷贝数据

用固定值/分布创建

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

zeros = torch.zeros(3, 4) # 3×4 全零
ones = torch.ones(2, 3) # 2×3 全一
full = torch.full((2, 3), fill_value=7) # 2×3 全部填 7
eye = torch.eye(3) # 3×3 单位矩阵

rand_uniform = torch.rand(3, 4) # [0, 1) 均匀分布
rand_normal = torch.randn(3, 4) # 标准正态分布 N(0, 1)
rand_int = torch.randint(0, 10, (3, 4)) # [0, 10) 整数均匀分布

# 序列
seq = torch.arange(0, 10, 2) # tensor([0, 2, 4, 6, 8])
lin = torch.linspace(0, 1, 5) # tensor([0.00, 0.25, 0.50, 0.75, 1.00])

# 未初始化(只分配内存,不赋值,速度最快但值为内存残留的垃圾数据)
empty = torch.empty(3, 4)

按已有 Tensor 的形状创建

1
2
3
4
5
import torch

x = torch.randn(3, 4)
y = torch.zeros_like(x) # 形状、dtype、device 都与 x 一致
z = torch.randn_like(x) # 同形状的正态随机数

1.2 Tensor 的核心属性

每个 Tensor 都有四个核心属性,任何时候拿到一个 Tensor,先看这几样:

属性 含义 示例
shape / size() 各维度的大小 torch.Size([2, 3])
dtype 数据类型 torch.float32
device 所在设备 device(type='cuda', index=0)
requires_grad 是否追踪梯度 True / False
1
2
3
4
5
6
7
8
import torch

x = torch.randn(2, 3, 4, device='cuda' if torch.cuda.is_available() else 'cpu',
requires_grad=True)
print(x.shape) # torch.Size([2, 3, 4]) ← 形状
print(x.dtype) # torch.float32 ← 数据类型
print(x.device) # cuda:0 或 cpu ← 所在设备
print(x.requires_grad) # True ← 是否追踪梯度

💡 提示:在调试模型时,90% 的错误可以通过检查 shape、dtype、device 这三样来定位。形状不匹配、dtype 不一致、设备不同——这是 PyTorch 开发中最常见的三类报错。

1.3 基本运算

Tensor 支持完整的数学运算,和 NumPy 的用法高度一致:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
import torch

a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])

# 逐元素运算
print(a + b) # tensor([5., 7., 9.])
print(a * b) # tensor([4., 10., 18.]) 逐元素乘
print(a ** 2) # tensor([1., 4., 9.])

# 矩阵乘法(AI 中最核心的运算)
A = torch.randn(2, 3)
B = torch.randn(3, 4)
C = A @ B # (2, 3) × (3, 4) → (2, 4)
C = torch.matmul(A, B) # 等价写法

# 聚合运算
x = torch.randn(3, 4)
print(x.sum()) # 所有元素求和 → 标量
print(x.mean(dim=1)) # 沿 dim=1 求均值 → (3,)
print(x.max(dim=0)) # 沿 dim=0 取最大值 → (4,),同时返回索引

AI Infra 视角:深度学习中绝大部分计算归结为矩阵乘法(GEMM)。理解 Tensor 运算的本质,能帮你在后续 CUDA 编程和算子优化中更精准地定位计算瓶颈。


2. 索引与切片

Tensor 的索引方式和 NumPy 几乎相同,但在 AI 场景下有几种高频用法值得单独掌握。

2.1 基础索引

1
2
3
4
5
6
7
8
9
10
11
import torch

x = torch.arange(12).reshape(3, 4)
# tensor([[ 0, 1, 2, 3],
# [ 4, 5, 6, 7],
# [ 8, 9, 10, 11]])

print(x[0]) # tensor([0, 1, 2, 3]) 第 0 行
print(x[1, 2]) # tensor(6) 第 1 行第 2 列
print(x[:, 1]) # tensor([1, 5, 9]) 所有行的第 1 列
print(x[0:2, 1:3]) # tensor([[1, 2], [5, 6]]) 切片

2.2 布尔索引与花式索引

布尔索引在数据过滤和掩码操作中极为常用——Attention 的 mask 本质上就是布尔索引的一种应用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

x = torch.randn(4, 5)
mask = x > 0
print(x[mask]) # 取出所有正数(返回一维 Tensor)

# 花式索引:用索引列表同时取多个位置
indices = torch.tensor([0, 2, 3])
print(x[indices]) # 取出第 0、2、3 行

# gather:按指定维度收集元素(Softmax 取 top-k 时常用)
scores = torch.randn(2, 5) # (batch, vocab_size)
top_indices = scores.topk(3, dim=1).indices # (batch, 3)
selected = scores.gather(1, top_indices) # (batch, 3)

2.3 索引与内存:视图 vs 副本

⚠️ 注意:基础索引(切片)返回的是视图(view),和原 Tensor 共享内存;高级索引(花式索引、布尔索引)返回的是副本(copy)。这个区别在处理大 Tensor 时直接影响显存占用。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

x = torch.arange(6).reshape(2, 3)

# 切片 → 视图(共享内存)
y = x[0]
y[0] = 99
print(x[0, 0]) # tensor(99) ← x 也被改了

# 花式索引 → 副本(独立内存)
x = torch.arange(6).reshape(2, 3)
z = x[[0, 1]]
z[0, 0] = 99
print(x[0, 0]) # tensor(0) ← x 没有变

3. 形状变换

形状变换是 Tensor 操作中最频繁的动作之一。在 Transformer 的实现里,几乎每一步都伴随着 view、permute、reshape——搞不清楚它们的区别,读模型代码就像看天书。

3.1 view 与 reshape:看似相同,内在有别

打个比方:一本 12 页的书,你可以说它是”3 章 × 4 页”,也可以说是”4 章 × 3 页”——内容(数据)没变,只是目录(形状)变了。viewreshape 做的就是这件事。

1
2
3
4
5
6
7
import torch

x = torch.arange(12) # 一维,12 个元素

a = x.view(3, 4) # 改成 3×4
b = x.reshape(3, 4) # 看似等价
c = x.view(-1, 4) # -1 表示自动推断(这里推断为 3)

区别在哪里?关键词是内存连续(contiguous):

  • view 要求底层内存连续,如果不连续会直接报错
  • reshape 在内存连续时等价于 view,不连续时会自动拷贝一份连续内存再变形
1
2
3
4
5
6
7
8
9
import torch

x = torch.arange(12).reshape(3, 4)
y = x.t() # 转置后内存不连续
print(y.is_contiguous()) # False

# y.view(12) # 报错!view 要求连续
z = y.reshape(12) # 正常,reshape 内部做了 contiguous
z = y.contiguous().view(12) # 等价的显式写法

💡 提示:日常开发中,如果你不确定内存是否连续,用 reshape 更安全。但在性能敏感的场景下(比如写 CUDA 内核),需要明确知道 view 不会触发内存拷贝。

3.2 permute 与 transpose:维度换位

permute 是 Transformer 代码里的常客——Multi-Head Attention 需要在 (batch, seq, heads, dim)(batch, heads, seq, dim) 之间反复切换。

1
2
3
4
5
6
7
8
9
10
11
import torch

# Multi-Head Attention 中的典型操作
x = torch.randn(2, 128, 8, 64) # (batch, seq_len, num_heads, head_dim)
x = x.permute(0, 2, 1, 3) # (batch, num_heads, seq_len, head_dim)
# 现在每个 head 有自己独立的 (seq_len, head_dim) 矩阵,可以并行做 Attention

# transpose 只能交换两个维度
y = torch.randn(3, 4)
print(y.transpose(0, 1).shape) # torch.Size([4, 3])
print(y.t().shape) # 等价简写(仅限 2D)

⚠️ 注意permutetranspose 都不拷贝数据,只是改变了 stride(步长),所以返回的 Tensor 通常不是 contiguous 的。如果后续操作要求连续内存,需要显式调用 .contiguous()

3.3 多头注意力中的典型变形

在 Transformer 实现中,最经典的变形操作就是将线性投影结果拆分成多个注意力头。这个 reshape → permute 的组合在 Transformer 代码中几乎无处不在:

1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

batch_size, seq_len, d_model = 4, 128, 512
num_heads, head_dim = 8, 64 # d_model = num_heads × head_dim

# 线性投影输出:(batch, seq_len, d_model)
proj = torch.randn(batch_size, seq_len, d_model)

# 拆分成多头:(batch, seq_len, d_model) → (batch, seq_len, num_heads, head_dim)
# → permute → (batch, num_heads, seq_len, head_dim)
multi_head = proj.reshape(batch_size, seq_len, num_heads, head_dim).permute(0, 2, 1, 3)
print(multi_head.shape) # torch.Size([4, 8, 128, 64])
# 现在每个 head 独立拥有 (seq_len, head_dim) 的矩阵,可以并行做 Attention

3.4 squeeze / unsqueeze / expand

这三兄弟负责维度的”增删扩”:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

# squeeze:去掉大小为 1 的维度
x = torch.randn(1, 3, 1, 4)
print(x.squeeze().shape) # torch.Size([3, 4]) 去掉所有 1 维
print(x.squeeze(0).shape) # torch.Size([3, 1, 4]) 只去掉第 0 维

# unsqueeze:在指定位置插入一个大小为 1 的维度
y = torch.randn(3, 4)
print(y.unsqueeze(0).shape) # torch.Size([1, 3, 4]) 加 batch 维
print(y.unsqueeze(-1).shape) # torch.Size([3, 4, 1]) 加末尾维

# expand:沿大小为 1 的维度"广播"到指定大小(不拷贝内存)
z = torch.randn(1, 4)
print(z.expand(3, 4).shape) # torch.Size([3, 4])
# expand 不分配新内存,只是改变 stride 让同一行数据被"复用"三次

3.5 contiguous 与 stride:理解内存布局

这部分稍微底层一些,但对于 AI Infra 工程师来说很有价值——理解 stride 能帮你判断哪些操作是”零拷贝”的,哪些会偷偷分配新内存。

打个比方:想象一本书的页码。正常翻页时,每翻一页,物理位置跳 1(stride=1)。如果你把书拆成左右两栏来读,”翻一栏”跳的物理位置就不再是 1 了——这就是 stride 的本质:从一个元素到下一个元素,需要在内存中跳多少步

1
2
3
4
5
6
7
8
9
import torch

x = torch.arange(12).reshape(3, 4)
print(x.stride()) # (4, 1) ← 行方向跳 4 个元素,列方向跳 1 个元素
print(x.is_contiguous()) # True

y = x.t() # 转置
print(y.stride()) # (1, 4) ← 行方向跳 1,列方向跳 4
print(y.is_contiguous()) # False ← 行内元素在内存中不连续

📌 关键点viewpermutetransposeexpand 等操作只改变 shape 和 stride,不会拷贝数据。只有当后续操作需要连续内存(如 view)或者你显式调用 .contiguous() 时,才会触发真正的内存拷贝。


4. 广播机制

广播(Broadcasting)是 NumPy 和 PyTorch 都遵循的规则——当两个形状不完全匹配的 Tensor 做运算时,自动将较小的 Tensor “扩展”到兼容的形状。

4.1 广播规则

规则很简单,从右往左逐维比较:

  1. 如果维度大小相同,直接对齐
  2. 如果其中一个大小为 1,将其扩展到另一个的大小
  3. 如果维度个数不同,在较少维度的 Tensor 前面补 1
1
2
3
4
5
6
7
8
9
10
11
12
13
import torch

a = torch.randn(3, 4) # (3, 4)
b = torch.randn(4) # (4,) → 自动视为 (1, 4) → 广播为 (3, 4)
c = a + b # (3, 4) ✅

a = torch.randn(3, 4) # (3, 4)
b = torch.randn(3, 1) # (3, 1) → 广播为 (3, 4)
c = a * b # (3, 4) ✅

a = torch.randn(3, 4) # (3, 4)
b = torch.randn(3) # (3,) → (1, 3) → 无法广播到 (3, 4)
# c = a + b # ❌ 报错!最右维 4 ≠ 3 且都不是 1

4.2 实际应用场景

广播在深度学习中无处不在。Attention 分数加 mask、BatchNorm 减均值、偏置相加——这些都是广播操作:

1
2
3
4
5
6
7
8
9
10
11
12
import torch

# 场景 1:给 Attention 分数矩阵加 mask
scores = torch.randn(2, 8, 128, 128) # (batch, heads, seq, seq)
mask = torch.zeros(1, 1, 128, 128) # (1, 1, seq, seq) → 广播到全部 batch 和 head
masked_scores = scores + mask

# 场景 2:对 batch 数据做标准化
x = torch.randn(64, 768) # (batch, features)
mean = x.mean(dim=0, keepdim=True) # (1, 768)
std = x.std(dim=0, keepdim=True) # (1, 768)
x_normed = (x - mean) / (std + 1e-6) # 广播:(64, 768) 和 (1, 768)

⚠️ 注意:广播不分配新内存,但它可能会掩盖形状错误。如果你不小心传了一个形状略有偏差的 Tensor,广播不会报错,而是默默算出一个形状更大的结果——这种 bug 极其隐蔽。养成习惯:关键运算前 assert 一下形状。


5. 设备管理与数据类型

5.1 CPU ↔ GPU 搬运

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch

# 检查 GPU 是否可用
print(torch.cuda.is_available())
print(torch.cuda.device_count()) # GPU 数量

# 推荐写法:用 device 对象统一管理
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# 三种搬运方式
x = torch.randn(3, 4)
x_gpu = x.to(device) # 推荐:通用写法
x_gpu = x.cuda() # 简写(默认 cuda:0)
x_cpu = x_gpu.cpu() # GPU → CPU

# 直接在 GPU 上创建(省一次搬运)
y = torch.randn(3, 4, device=device)

AI Infra 视角:CPU→GPU 搬运走 PCIe 总线(Gen4 x16 单向 ≈ 32 GB/s),而 GPU 显存带宽远高于此(A100 HBM2e 约 2 TB/s,H100 HBM3 约 3.35 TB/s),两者差距可达 60~100 倍。频繁的 .cpu().cuda() 调用是常见的性能杀手。正确做法:数据搬上 GPU 后尽量留在 GPU 上完成所有计算。

设备间数据传输带宽参考:

路径 典型带宽
GPU 显存(HBM)内部 2-5 TB/s(A100: 2 TB/s, H100: 3.35 TB/s, H200: 4.8 TB/s)
PCIe Gen4 x16 ~32 GB/s
PCIe Gen5 x16 ~64 GB/s
NVLink(GPU 间) 600-1800 GB/s(A100: 600, H100: 900, B200: 1800)

5.2 设备一致性规则

所有参与运算的 Tensor 必须在同一个设备上。 这是 PyTorch 最常见的报错之一:

1
2
3
4
5
6
import torch

a = torch.randn(3, device='cpu')
b = torch.randn(3, device='cuda')

# a + b # ❌ RuntimeError: Expected all tensors to be on the same device

良好的习惯是在模型代码中统一设备管理:

1
2
3
4
5
6
7
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = MyModel().to(device)
for batch in dataloader:
inputs = batch['input'].to(device)
labels = batch['label'].to(device)
output = model(inputs)

5.3 多 GPU 场景下的设备指定

1
2
3
4
5
6
7
8
9
10
11
import torch

# 指定使用第 1 块 GPU
x = torch.randn(3, 4, device='cuda:1')

# 模型放到指定 GPU
model = torch.nn.Linear(100, 10).to('cuda:0')

# 模型和数据必须在同一设备上,否则报错
data = torch.randn(32, 100, device='cuda:0')
output = model(data) # ✅ 同在 cuda:0

5.4 dtype:精度决定显存和速度

不同的数据类型直接影响显存占用和计算吞吐:

dtype 别名 位宽 典型用途
torch.float32 torch.float 32 默认浮点类型,优化器状态
torch.float16 torch.half 16 混合精度训练(需 GradScaler)
torch.bfloat16 16 大模型训练首选(指数范围与 fp32 相同)
torch.float64 torch.double 64 科学计算,深度学习极少用
torch.int64 torch.long 64 默认整型,token ID、索引
torch.int32 torch.int 32 较短的索引
torch.int8 8 量化推理
torch.bool 8 掩码(mask)
torch.float8_e4m3fn 8 FP8 训练/推理(H100+ GPU)

白话理解:fp32 像高清原图,精度好但占空间大;bf16 像智能压缩——体积减半,保留了 fp32 的指数范围(不容易溢出),但有效精度从 24 位降到了 8 位;fp16 有效精度 11 位(比 bf16 高),但指数范围只有 5 位,训练大模型时容易出现梯度下溢,需要配合 GradScaler 使用。

1
2
3
4
5
6
7
8
9
10
11
import torch

# 类型转换
x = torch.randn(1000, 1000) # 默认 fp32
x_bf16 = x.to(torch.bfloat16) # 转 bf16
x_half = x.half() # 转 fp16
x_back = x_bf16.float() # bf16 → fp32

# 显存对比
print(f"fp32: {x.nelement() * x.element_size() / 1024**2:.1f} MB") # 3.8 MB
print(f"bf16: {x_bf16.nelement() * x_bf16.element_size() / 1024**2:.1f} MB") # 1.9 MB

6. 自动微分 autograd

如果说 Tensor 是 PyTorch 的”骨骼”,autograd 就是”神经系统”——它自动帮你算梯度,让训练神经网络从”手动推导每一层的偏导数”变成”调一行 .backward()“。

6.1 计算图:autograd 的核心数据结构

比喻:想象你在烘焙蛋糕。面粉和鸡蛋搅拌成面糊,面糊进烤箱变成蛋糕——如果蛋糕味道不对(loss 太大),你需要反向追溯:是烤箱温度的问题,还是面粉和鸡蛋的比例不对?计算图就是 PyTorch 帮你记下来的”烘焙流程”:它记录每一步操作,这样就能从最终结果(蛋糕/loss)反向推导出每种原料(参数)对结果的影响程度(梯度)。

正式定义:计算图(Computational Graph)是一个有向无环图(DAG)。节点是 Tensor,边是运算操作(如加法、乘法、矩阵乘法)。PyTorch 在前向传播时动态构建这张图,在反向传播时沿着图的反方向计算梯度,然后立即释放整张图。

PyTorch 的特色是动态图(Define-by-Run):图是在每次前向传播时”现场”搭建的,因此天然支持 if/else、for 循环等 Python 控制流——这是它比静态图框架更灵活的根本原因。

6.2 requires_grad 与叶节点

并不是所有 Tensor 都需要计算梯度。requires_grad=True 就像给 Tensor 贴了一张标签:”请记录我参与的所有运算,以便之后算梯度。”

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

# 叶节点:用户直接创建的、需要梯度的 Tensor
w = torch.tensor([2.0, 3.0], requires_grad=True) # 叶节点 ✅
x = torch.tensor([1.0, 4.0]) # 叶节点(但不追踪梯度)

# 非叶节点:由运算产生的 Tensor
y = w * x # y 由 w 和 x 运算得到,是非叶节点
z = y.sum()

print(w.is_leaf) # True ← 用户创建
print(y.is_leaf) # False ← 运算产生
print(y.grad_fn) # <MulBackward0> ← 记录了"我是怎么来的"
print(z.grad_fn) # <SumBackward0>

📌 关键点:只有叶节点.grad 属性才会在 backward() 后被填充。非叶节点的梯度在计算完成后会被释放(除非显式调用 y.retain_grad())。这个设计是出于显存效率考虑——模型可能有数十亿参数,保留所有中间梯度会占用巨大的显存。

上面的例子对应这样一张计算图——前向传播从叶子到输出(从上到下),反向传播从输出到叶子(从下到上),沿着 grad_fn 的链条追溯:


graph TD
W["w (叶子, requires_grad=True)"] --> MUL["× (MulBackward0)"]
X["x (叶子, requires_grad=False)"] --> MUL
MUL --> Y["y (非叶子)"]
Y --> SUM["sum (SumBackward0)"]
SUM --> Z["z (标量)"]

6.3 backward() 与链式法则

backward() 做的事情就是沿着计算图逆向执行链式法则,把梯度从输出一路传回到每个叶节点。

从数学角度看,假设有运算链:

$$
z = f(y), \quad y = g(w)
$$

链式法则告诉我们:

$$
\frac{\partial z}{\partial w} = \frac{\partial z}{\partial y} \cdot \frac{\partial y}{\partial w}
$$

backward() 就是自动完成这个逐层求导的过程。

1
2
3
4
5
6
7
8
import torch

x = torch.tensor([2.0, 3.0], requires_grad=True)
y = x ** 2 # y = [4, 9]
z = y.sum() # z = 13

z.backward() # 反向传播
print(x.grad) # tensor([4., 6.]) ← dz/dx = 2x = [4, 6]

6.4 用数学手动验证 autograd

来做一个稍复杂的验证,确保对 autograd 的理解是准确的:

$$
f(x) = x^3 + 2x^2 - 5x + 1
$$

$$
f’(x) = 3x^2 + 4x - 5
$$

当 $x = 2$ 时,$f’(2) = 3 \times 4 + 4 \times 2 - 5 = 15$。

1
2
3
4
5
6
import torch

x = torch.tensor(2.0, requires_grad=True)
f = x**3 + 2*x**2 - 5*x + 1
f.backward()
print(x.grad) # tensor(15.) ✓ 与手算结果一致

这个例子验证了 autograd 能正确处理多项式的链式求导——无论函数多复杂,backward() 都能自动算出精确的梯度。

几条重要规则

  1. backward() 只能对标量调用。如果 z 不是标量,需要传入一个与 z 同形状的 gradient 参数
  2. 计算图用完即释放。第二次 backward() 会报错——除非创建图时指定 retain_graph=True
  3. 梯度流经 requires_grad=True 的路径。如果某条路径上所有 Tensor 都不需要梯度,这条路径不会被计算
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch

x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 3
z = y.sum()
z.backward()
print(x.grad) # tensor([3., 3.])

# 再调用一次会报错(图已释放)
# z.backward() # RuntimeError: graph already freed

# 如果需要多次反向传播(如某些高阶求导场景)
x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 3
z = y.sum()
z.backward(retain_graph=True) # 保留图
z.backward() # 第二次 backward 可以执行(梯度会累加)
print(x.grad) # tensor([6., 6.]) ← 3+3 累加

6.5 梯度累积与清零

白话理解:PyTorch 的梯度就像一个只进不出的存钱罐——每次 backward() 往里塞钱(累加梯度),但它不会自动清空。标准训练循环里,每个 step 开始前必须手动清零,否则上一轮的”存款”会混进来,搞乱优化方向。

这个设计不是 bug,而是有意为之:梯度累积是显存不够时模拟大 batch 的经典技巧——累积 N 个小 batch 的梯度再统一更新,效果等价于一次性跑一个 N 倍大的 batch。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
import torch

x = torch.tensor([1.0], requires_grad=True)

# 第一次 backward
(x * 2).sum().backward()
print(x.grad) # tensor([2.])

# 第二次 backward(梯度被累加)
(x * 3).sum().backward()
print(x.grad) # tensor([5.]) ← 2 + 3

# 手动清零
x.grad.zero_()
(x * 4).sum().backward()
print(x.grad) # tensor([4.]) ← 清零后干净了

实际训练中用 optimizer.zero_grad() 一次性清零所有参数的梯度:

1
2
3
4
5
# 标准训练循环中的梯度处理
optimizer.zero_grad() # ① 清零梯度
loss = criterion(model(data), labels)
loss.backward() # ② 计算梯度(自动累加到 .grad)
optimizer.step() # ③ 用梯度更新参数

梯度累积实战示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
import torch.nn as nn

model = nn.Linear(10, 2)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
criterion = nn.MSELoss()

accumulation_steps = 4 # 累积 4 个 mini-batch 的梯度

optimizer.zero_grad()
for i, (data, target) in enumerate(dataloader):
loss = criterion(model(data), target)
loss = loss / accumulation_steps # 除以累积步数取平均
loss.backward() # 梯度被累加

if (i + 1) % accumulation_steps == 0:
optimizer.step() # 每 4 步才真正更新一次参数
optimizer.zero_grad() # 更新后清零

6.6 torch.no_grad() 与 inference_mode

推理(inference)和评估(eval)时不需要计算梯度,关掉梯度追踪可以省显存、加速度。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch
import torch.nn as nn

model = nn.Linear(100, 10)
x = torch.randn(32, 100)

# 方式 1:torch.no_grad()(最常用)
with torch.no_grad():
output = model(x) # 不构建计算图,不追踪梯度
print(output.requires_grad) # False

# 方式 2:torch.inference_mode()(PyTorch 1.9+ 推荐)
with torch.inference_mode():
output = model(x) # 更激进的优化,比 no_grad 更快

两者的区别:

特性 torch.no_grad() torch.inference_mode()
禁用梯度计算
可以对结果做 in-place 操作
结果可以在外部参与梯度计算
性能 更快

💡 提示:如果你确定推理结果不会再参与任何梯度计算,用 inference_mode() 能获得更好的性能。在训练循环的验证阶段,用 no_grad() 更稳妥。

6.7 常见踩坑指南

踩坑 1:Device Mismatch——设备不一致

1
2
3
4
5
6
7
8
import torch

# ❌ 常见错误:模型在 GPU,数据在 CPU
model = model.cuda()
output = model(input_tensor) # RuntimeError: 如果 input_tensor 在 CPU

# ✅ 正确做法:统一设备
output = model(input_tensor.to(next(model.parameters()).device))

踩坑 2:in-place 操作破坏计算图

1
2
3
4
5
6
7
import torch

x = torch.tensor([1.0, 2.0], requires_grad=True)
y = x * 2

# y += 1 # ❌ 报错!in-place 操作修改了计算图中的节点
y = y + 1 # ✅ 创建新 Tensor,不影响原计算图

⚠️ 注意:PyTorch 中以 _ 结尾的方法(如 add_mul_zero_)都是 in-place 操作。对需要梯度追踪的 Tensor 使用 in-place 操作,可能导致计算图被破坏、梯度计算出错。规则很简单:如果 Tensor 参与了 autograd,就不要做 in-place 操作zero_grad() 除外,它是特殊处理的)。

踩坑 3:忘记清零梯度导致训练不收敛

1
2
3
4
5
6
7
8
9
10
11
12
13
# 错误写法:忘了 zero_grad
for data, target in dataloader:
loss = criterion(model(data), target)
loss.backward()
optimizer.step()
# ← 梯度在累积,每一步的梯度越来越离谱

# 正确写法
for data, target in dataloader:
optimizer.zero_grad() # 每步清零
loss = criterion(model(data), target)
loss.backward()
optimizer.step()

踩坑 4:.item()print(tensor) 导致隐式同步

在 CUDA 的异步执行模型中,Python 端提交 kernel 后不会等它执行完就继续往下跑。但 .item() 需要把 GPU 上的值搬到 CPU,这迫使 CPU 等待 GPU 完成所有排队的 kernel——相当于在 GPU 流水线上人为插了一道”路障”,打断了异步执行的节奏。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
import torch

# 不好:每步都取标量值,触发 CPU-GPU 同步
for batch in dataloader:
loss = train_step(batch)
total_loss += loss.item() # 每步同步!CPU 等 GPU 完成所有排队 kernel

# 好:用 Tensor 累积,最后再取值
running_loss = torch.tensor(0.0, device='cuda')
for i, batch in enumerate(dataloader):
loss = train_step(batch)
running_loss += loss.detach() # Tensor 加法,不触发同步
if (i + 1) % 100 == 0:
print(f"step {i+1}, avg loss: {running_loss.item() / 100:.4f}")
running_loss.zero_()

不过 .item() 也有正面用途——用它提取标量值可以避免 Tensor 累加导致的计算图泄漏:

1
2
3
4
5
# ❌ 不推荐:用 Tensor 做 Python 累加,会保留整个计算图链条,导致显存泄漏
total_loss += loss # loss 是 Tensor,累加会保留计算图!

# ✅ 推荐:用 .item() 提取 Python 标量,或用 .detach()
total_loss += loss.item() # 提取纯 Python float,不保留计算图

踩坑 5:detach() 的正确使用场景

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import torch

x = torch.tensor([1.0], requires_grad=True)
y = x * 2

# detach() 从计算图中"断开",返回一个共享数据但不追踪梯度的新 Tensor
y_detached = y.detach()
print(y_detached.requires_grad) # False

# 典型场景:记录 loss 但不让它参与梯度计算
loss_history = []
for batch in dataloader:
loss = train_step(batch)
loss_history.append(loss.detach().cpu()) # 断开图 + 搬回 CPU

💡 提示detach() 在需要把中间结果”截断”梯度流时很有用,比如实现 stop-gradient 操作。另一个常见场景是将 GPU Tensor 转为 NumPy 数组——必须先 detach 再 cpu 再 numpy,这三步缺一不可:

1
2
# Tensor → NumPy 的标准操作
np_array = some_tensor.detach().cpu().numpy()

7. 实战:从零实现梯度下降

把上面学到的 Tensor 操作和 autograd 串起来,手动实现一个最简单的线性回归训练,不用任何 nn.Moduleoptim

目标:学习函数 $y = 3x + 1$

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch

# 生成训练数据
torch.manual_seed(42)
x_train = torch.linspace(0, 5, 50) # 50 个均匀分布的 x
y_train = 3 * x_train + 1 + torch.randn(50) * 0.5 # 加噪声

# 初始化参数(需要梯度)
w = torch.tensor([0.0], requires_grad=True)
b = torch.tensor([0.0], requires_grad=True)
lr = 0.01

# 训练循环
for epoch in range(200):
# 前向传播
y_pred = w * x_train + b # 广播:(1,) * (50,) → (50,)
loss = ((y_pred - y_train) ** 2).mean() # MSE Loss

# 反向传播
loss.backward() # 计算 dL/dw 和 dL/db

# 参数更新(不能被 autograd 追踪)
with torch.no_grad():
w -= lr * w.grad
b -= lr * b.grad

# 梯度清零
w.grad.zero_()
b.grad.zero_()

if (epoch + 1) % 50 == 0:
print(f"Epoch {epoch+1:3d} | Loss: {loss.item():.4f} | "
f"w={w.item():.3f}, b={b.item():.3f}")

print(f"\n学到的模型: y = {w.item():.2f}x + {b.item():.2f}")
# 输出接近: y = 3.00x + 1.00

这段代码完整展示了 Tensor 的创建、广播运算、autograd 的前向/反向传播、梯度读取与清零,以及 torch.no_grad() 在参数更新中的使用——把本文的核心知识点串成了一条完整的链路。


📝 总结

本文覆盖了 PyTorch 的两大基石:

模块 核心内容 关键概念
Tensor 创建 多种创建方式、属性三件套 shape, dtype, device
索引与切片 基础索引、布尔索引、花式索引 视图 vs 副本
形状变换 view/reshape/permute/squeeze contiguous, stride
广播机制 自动形状扩展规则 从右往左逐维比较
设备与精度 CPU↔GPU 搬运、dtype 选择 fp32/fp16/bf16
autograd 计算图、backward、梯度累积 叶节点、requires_grad、链式法则

这些是 PyTorch 编程的”内功心法”——后续无论是搭建模型(4.2)、做性能调优(4.3),还是分布式训练,都建立在对这些基础概念的透彻理解之上。


🎯 自我检验清单

完成本文学习后,你应该能够:

  • 能用至少 3 种方式创建指定形状和数据类型的 Tensor
  • 能解释 viewreshape 的区别,以及 contiguous() 何时需要调用
  • 能说明 permute 返回的 Tensor 为什么通常不是 contiguous 的(提示:stride 变了但数据没动)
  • 能手写多头注意力中 reshape → permute 的维度变换过程
  • 能手写广播规则的三条判断逻辑,并判断两个给定形状能否广播
  • 能解释 from_numpy()tensor() 在内存共享上的区别
  • 能画出一段简单 PyTorch 代码的计算图,标出叶节点和 grad_fn
  • 能手动验证 autograd 对多项式函数求导的正确性
  • 能解释为什么训练循环中必须调用 optimizer.zero_grad(),以及梯度累积的工作原理
  • 能写出不使用任何 nn.Module 的手动梯度下降训练代码
  • 能说明 torch.no_grad()torch.inference_mode() 的区别及适用场景
  • 能用 detach().cpu().numpy() 将 GPU Tensor 转为 NumPy 数组
  • 能识别并修复 Device Mismatch、in-place 操作破坏计算图、.item() 导致隐式同步等常见问题

📚 参考资料