pytorch有默认的初始化函数,但如果想自定义超参数的初始值,改如何操作呢?

一、默认初始化函数

可以在pycharm找nn.Conv2d的定义文件,然后搜索‘reset_parameters’就可以看到它的初始化函数, image.png 同理可以查看nn.BatchNorm2d,nn.Linear的初始化函数。

二、如何增加自定义初始化函数?

下面是简单的VGG网络示例

import torch
import torch.nn.functional as F
import troch.nn as nn
class VGG(nn.Module):
    def __init__(self):
        super(VGG, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, 1, 1,bias=False)
        self.conv2_1 = nn.Conv2d(32, 64, 3, 1, 1,bias=False)
        self.conv2_2 = nn.Conv2d(64, 64, 3, 1, 1,bias=False)
        self.conv3_1 = nn.Conv2d(64, 128, 3, 1, 1,bias=False)
        self.conv3_2 = nn.Conv2d(128, 128, 3, 1, 1,bias=False)
        self.conv3_3 = nn.Conv2d(128, 128, 3, 1, 1,bias=False)
        self.fc1   = nn.Linear(2048, 128,bias=False)
        self.fc2   = nn.Linear(128, 10,bias=False)
 
    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2_1(out))
        out = F.relu(self.conv2_2(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv3_1(out))
        out = F.relu(self.conv3_2(out))
        out = F.relu(self.conv3_3(out))
        out = F.max_pool2d(out, 2)
        out = out.view(out.size(0), -1)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

#定义初始化函数
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.constant_(m.weight, 0.5)
        #nn.init.constant_(m.bias, 0)  
        #nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 
        #nn.init.xavier_normal_(m.weight)
     elif  isinstance(m, nn.Linear): 
        nn.init.xavier_uniform_(m.weight)
        #nn.init.constant_(m.bias,0)

三、如何实施初始化?

net = VGG()
#apply函数会递归地搜索网络内的所有module并把参数表示的函数应用到所有的module上。
net.apply(weights_init)

四、如何查看初始化的权重?

# 显示网络参数字典的key
net.state_dict().keys()
# 输出conv1层卷积权重
net.state_dict()['conv1.weight']

标签: 初始化, pytorch

添加新评论