为什么需要权重初始化

在深度学习过程中,良好的权重初始化,会加速模型收敛。我们在定义例如 nn.Conv2d 之后,会自动进行权重初始化。 但是,有些情况下,我们需要进行自定义的权重初始化。例如, DCGan (生成对抗网络)指出,所有模型权重应该随机初始化为一个正则分布(Normal distribution),均值为0,方差为 0.02.

权重初始化的思路

  1. 首先构建你的网络类。例如 Generate类
  2. 定义各个模型初始化的函数 weights_init(m) 。这个函数具有通用性,对于其他的网络自定义初始化也可以使用。
  3. 将 Generate类实例化之后,应用apply方法。

自定义模型初始化函数具体代码剖析

定义自定义模型初始化函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
# 初始化模型权重
def weights_init(m):
"""
应用到 netG 和 netD 上的自定义权重初始化
:param m: 例如:netG.apply 调用时,就是 netG 中各层的实例化
:return:
"""
classname = m.__class__.__name__ # classname是一个 string 类型的变量,表示实例化 m 的类名
if classname.find("Conv") != -1:
# .find 是针对字符串的方法,如果找到 Conv 就返回找到的第一个索引
# 如果没找到,就返回 -1
nn.init.normal_(m.weight.data, 0.0, 0.002)
elif classname.find("BatchNorm") != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0) # 将偏置归零

nn.init 中提供很多用于权重初始化的方法。

1. nn.init.constant_

1
torch.nn.init.constant_(tensor, val)

用 val 值 来填充 tensor 张量。

举例:

以下例子是对网络的某一层参数进行初始化

1
2
3
4
5
6
import torch

m = torch.empty(3,2) # 生成一个 (3,2) 大小的tensor数据。
print("m: ", m)
torch.nn.init.constant_(m, 0.5) # 将 m 的值都填充为 0.5
print("m: ", m)

输出

1
2
3
4
5
6
7
m:  tensor([[5.7241e-06, 1.2738e-42],
[0.0000e+00, 0.0000e+00],
[0.0000e+00, 0.0000e+00]])

m: tensor([[0.5000, 0.5000],
[0.5000, 0.5000],
[0.5000, 0.5000]])

2. nn.init.normal_

1
torch.nn.init.normal_(tensor, mean=0.0, std=1.0, generator=None)

使用正态分布得到的值填充张量。

注: std 是标准差! 不是方差

举例

1
2
3
4
5
6
import torch

m = torch.empty(3,2)
print("m: ", m)
torch.nn.init.normal_(m, 0, 0.02)
print("m: ", m)

输出

1
2
3
4
5
6
7
m:  tensor([[0., 0.],
[0., 0.],
[0., 0.]])

m: tensor([[-0.0108, -0.0127],
[ 0.0045, 0.0057],
[-0.0214, 0.0267]])

3. m.weight.data, m.bias.data

m.weight.data, m.bias.data 获得模型的权重和偏置数据,类型是 tensor

1
2
3
4
5
6
7
8
9
10
11
12
import torch

m = torch.nn.Conv2d(3, 64, 3, 2, 1)
print("m.weight ->", m.weight)
print("type m.weight ->", type(m.weight))
print("m.weight.data -> ", m.weight.data)
print("type m.weight.data -> ", type(m.weight.data))

print("m.bias -> ", m.bias)
print("type m.bias -> ", type(m.bias))
print("m.bias.data -> ", m.bias.data)
print("type m.bias.data -> ", type(m.bias.data))

输出

1
2
3
4
5
6
7
8
m.weight -> Parameter containing:
tensor([[[[-0.1054, 0.0425, -0.1113],
[-0.0216, 0.0641, -0.0979],
[-0.0539, -0.1248, -0.0053]],
...
[[-0.0481, -0.0629, 0.0049],
[ 0.1842, -0.0309, -0.1135],
[ 0.0841, 0.1651, 0.0954]]]], requires_grad=True)
1
type m.weight -> <class 'torch.nn.parameter.Parameter'>
1
2
3
4
5
6
7
m.weight.data ->  tensor([[[[-0.1054,  0.0425, -0.1113],
[-0.0216, 0.0641, -0.0979],
[-0.0539, -0.1248, -0.0053]],
...
[[-0.0481, -0.0629, 0.0049],
[ 0.1842, -0.0309, -0.1135],
[ 0.0841, 0.1651, 0.0954]]]])
1
type m.weight.data ->  <class 'torch.Tensor'>
1
2
3
4
5
6
7
8
9
10
11
type m.weight.data ->  <class 'torch.Tensor'>
m.bias -> Parameter containing:
tensor([-0.0249, -0.0707, 0.0400, -0.1284, -0.1467, -0.0915, 0.0905, -0.0056,
0.0740, -0.0885, 0.0225, -0.0648, 0.1001, 0.1633, 0.1573, -0.0152,
0.0349, -0.1705, 0.0416, -0.0819, -0.1141, 0.1561, -0.1807, -0.0014,
-0.1340, 0.1697, 0.0537, -0.1431, -0.0957, 0.1814, -0.0127, 0.0360,
-0.1791, -0.1035, 0.1016, 0.0863, 0.1119, -0.0765, 0.0651, 0.1523,
0.0911, 0.1534, 0.0462, -0.1833, -0.0205, -0.0581, -0.1704, 0.1281,
-0.1156, -0.0847, 0.0631, -0.0384, -0.0543, 0.0445, -0.1642, -0.0424,
0.1099, 0.0096, -0.1660, 0.0808, 0.1525, -0.0996, -0.0960, 0.0144],
requires_grad=True)
1
type m.bias ->  <class 'torch.nn.parameter.Parameter'>
1
2
3
4
5
6
7
8
m.bias.data ->  tensor([-0.0249, -0.0707,  0.0400, -0.1284, -0.1467, -0.0915,  0.0905, -0.0056,
0.0740, -0.0885, 0.0225, -0.0648, 0.1001, 0.1633, 0.1573, -0.0152,
0.0349, -0.1705, 0.0416, -0.0819, -0.1141, 0.1561, -0.1807, -0.0014,
-0.1340, 0.1697, 0.0537, -0.1431, -0.0957, 0.1814, -0.0127, 0.0360,
-0.1791, -0.1035, 0.1016, 0.0863, 0.1119, -0.0765, 0.0651, 0.1523,
0.0911, 0.1534, 0.0462, -0.1833, -0.0205, -0.0581, -0.1704, 0.1281,
-0.1156, -0.0847, 0.0631, -0.0384, -0.0543, 0.0445, -0.1642, -0.0424,
0.1099, 0.0096, -0.1660, 0.0808, 0.1525, -0.0996, -0.0960, 0.0144])
1
type m.bias.data ->  <class 'torch.Tensor'>

Generator类代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__() # 初始化父类属性
self.conv1 = nn.Sequential(
nn.Conv2d(3, 64, 1, 1, ),
nn.ReLU(),
nn.BatchNorm2d(64),
)
self.conv2 = nn.Sequential(
nn.Conv2d(64, 128, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.conv3 = nn.Sequential(
nn.Conv2d(128, 258, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(258),
)
self.conv4 = nn.Sequential(
nn.Conv2d(258, 512, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(512),
)
self.conv5 = nn.Sequential(
nn.Conv2d(512, 1024, 3, 2, 1),
nn.ReLU(),
nn.BatchNorm2d(1024),
)

self.dconv1 = nn.Sequential(
nn.ConvTranspose2d(1024, 512, 2, 2, ),
nn.ReLU(),
nn.BatchNorm2d(512),
)
self.dconv2 = nn.Sequential(
nn.ConvTranspose2d(512, 258, 2, 2),
nn.ReLU(),
nn.BatchNorm2d(258),
)
self.dconv3 = nn.Sequential(
nn.ConvTranspose2d(258, 128, 2, 2),
nn.ReLU(),
nn.BatchNorm2d(128),
)
self.dconv4 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 2, 2),
nn.ReLU(),
nn.BatchNorm2d(64),
)

self.conv6 = nn.Sequential(
nn.Conv2d(64, 3, 1, 1),
nn.ReLU(),
nn.BatchNorm2d(3),
)

def forward(self, x):
y1 = self.conv1(x)
y2 = self.conv2(y1) # 2
y3 = self.conv3(y2) # 4
y4 = self.conv4(y3) # 8
y5 = self.conv5(y4) # 16

out1 = self.dconv1(y5) + y4
out2 = self.dconv2(out1) + y3
out3 = self.dconv3(out2) + y2
out4 = self.dconv4(out3) + y1
output = self.conv6(out4)

return output

应用实例化对象的 apply 方法

对整个网络的参数进行初始化定制

apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。

torch.nn.Module中的方法 .apply(fn)

1
apply(fn)

递归的将fn用于每一个子模块。

应用:

1
2
3
4
netG = Generator().to(device)

netG.apply(weight_init)
print(netG)