目录

融合卷积和批量标准化的原理

PyTorch 的实现

def fuse_conv_bn_eval(conv, bn, transpose=False):
    assert(not (conv.training or bn.training)), "Fusion only for eval!"
    fused_conv = copy.deepcopy(conv)

    fused_conv.weight, fused_conv.bias = \
        fuse_conv_bn_weights(fused_conv.weight, fused_conv.bias,
                             bn.running_mean, bn.running_var, bn.eps, bn.weight, bn.bias, transpose)

    return fused_conv

def fuse_conv_bn_weights(conv_w, conv_b, bn_rm, bn_rv, bn_eps, bn_w, bn_b, transpose=False):
    if conv_b is None:
        conv_b = torch.zeros_like(bn_rm)
    if bn_w is None:
        bn_w = torch.ones_like(bn_rm)
    if bn_b is None:
        bn_b = torch.zeros_like(bn_rm)
    bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)

    if transpose:
        shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
    else:
        shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)

    conv_w = conv_w * (bn_w * bn_var_rsqrt).reshape(shape)
    conv_b = (conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b

    return torch.nn.Parameter(conv_w), torch.nn.Parameter(conv_b)

这里使用预训练模型 ResNet18 的两个层测试

import torch
import torchvision

torch.set_grad_enabled(False)
x = torch.randn(16, 3, 256, 256)

rn18 = torchvision.models.resnet18(pretrained=True)
rn18.eval()
net = torch.nn.Sequential(
	rn18.conv1,
	rn18.bn1
)

y1 = net.forward(x)

fused_conv = torch.nn.utils.fusion.fuse_conv_bn_eval(net[0], net[1])
y2 = fused_conv.forward(x)

d = (y1 - y2).norm().div(y1.norm()).item()
print("error: %.8f" % d)
error: 0.00000022

性能测量(🚀25%)

import timeit

starttime = timeit.default_timer()
[net.forward(x) for _ in range(100)]
print("融合前推理100次的时间 :", timeit.default_timer() - starttime)


starttime = timeit.default_timer()
[fused_conv.forward(x) for _ in range(100)]
print("融合后推理100次的时间 :", timeit.default_timer() - starttime)
融合前推理100次的时间 : 2.601980792125687
融合后推理100次的时间 : 2.0703182069119066

参考资料