Custom method practice
in Pytorch
참고
[1] https://pytorch.org/tutorials/intermediate/custom_function_conv_bn_tutorial.html
- custom method practice
- fusing convolution and batch norm
- convolution backward function
- batch norm backward function
- Fusing convolution and batch norm
- modulize
custom method practice
- custom method 예시 중 하나로 convolution layer와 batch norm layer를 하나의 layer로 합치는 실습을 해본다.
- convolution과 batch norm은 forward 중에 backward 미분 계산을 위해 input 텐서와 weight 텐서를 저장하면서 forward를 진행한다. 이것은 layer가 깊어질 수록, 넓어질 수록 메모리를 많이 잡아먹는다.
- 두 layer가 연속적으로 온다는 것을 가정하여 입력 텐서를 한번만 저장하면 되므로 메모리를 많이 지킬 수 있다.
fusing convolution and batch norm
- convolution 계산을 먼저한 후에 batch norm 계산을 한다.
- 계산의 편리성을 위해 convolution의 파라미터로 bias=False, stride=1, padding=0, dilation=1, and groups=1 로 제한한다.
- 마찬가지로 batch norm의 파라미터로 eps=1e-3, momentum=0.1, affine=False, and track_running_statistics=False 으로 제한한다.
convolution backward function
- convolution의 backward를 정의한다.
- convolution의 backward 함수는 요기를 보자.
def convolution_backward(grad_out, X, weight): grad_w = F.conv2d(X.transpose(0, 1), grad_out.transpose(0, 1)).transpose(0, 1) grad_X = F.conv_transpose2d(grad_out, weight) return grad_X, grad_w
batch norm backward function
- batch norm의 backward를 정의한다.
- batch norm의 backward 함수는 요기를 보자.
def batch_norm_backward(grad_out, X, sum, sqrt_var, N, eps): # We use the formula: out = (X - mean(X)) / (sqrt(var(X)) + eps) tmp = ((X - unsqueeze_all(sum) / N) * grad_out).sum(dim=(0, 2, 3)) tmp *= -1 d_denom = tmp / (sqrt_var + eps)**2 # d_denom = -num / denom**2 d_var = d_denom / (2 * sqrt_var) # denom = torch.sqrt(var) + eps # Compute d_mean_dx before allocating the final NCHW-sized grad_input buffer d_mean_dx = grad_out / unsqueeze_all(sqrt_var + eps) d_mean_dx = unsqueeze_all(-d_mean_dx.sum(dim=(0, 2, 3)) / N) # d_mean_dx has already been reassigned to a C-sized buffer so no need to worry # (1) unbiased_var(x) = ((X - unsqueeze_all(mean))**2).sum(dim=(0, 2, 3)) / (N - 1) grad_input = X * unsqueeze_all(d_var * N) grad_input += unsqueeze_all(-d_var * sum) grad_input *= 2 / ((N - 1) * N) # (2) mean (see above) grad_input += d_mean_dx # (3) Add 'grad_out / <factor>' without allocating an extra buffer grad_input *= unsqueeze_all(sqrt_var + eps) grad_input += grad_out grad_input /= unsqueeze_all(sqrt_var + eps) # sqrt_var + eps > 0! return grad_input
Fusing convolution and batch norm
- convolution과 batch norm을 연속적으로 실행하는 새로운 Function을 만든다.
- backward를 위해 한번의 인풋 데이터만 저장한다.
class FusedConvBN2DFunction(torch.autograd.Function): @staticmethod def forward(ctx, X, conv_weight, eps=1e-3): assert X.ndim == 4 # N, C, H, W # Only need to save this single buffer for backward! ctx.save_for_backward(X, conv_weight) # Exact same Conv2D forward from example above X = F.conv2d(X, conv_weight) # Exact same BatchNorm2D forward from example above sum = X.sum(dim=(0, 2, 3)) var = X.var(unbiased=True, dim=(0, 2, 3)) N = X.numel() / X.size(1) sqrt_var = torch.sqrt(var) ctx.eps = eps ctx.sum = sum ctx.N = N ctx.sqrt_var = sqrt_var mean = sum / N denom = sqrt_var + eps # Try to do as many things in-place as possible # Instead of `out = (X - a) / b`, doing `out = X - a; out /= b` # avoids allocating one extra NCHW-sized buffer here out = X - unsqueeze_all(mean) out /= unsqueeze_all(denom) return out @staticmethod def backward(ctx, grad_out): X, conv_weight, = ctx.saved_tensors # (4) Batch norm backward # (5) We need to recompute conv X_conv_out = F.conv2d(X, conv_weight) grad_out = batch_norm_backward(grad_out, X_conv_out, ctx.sum, ctx.sqrt_var, ctx.N, ctx.eps) # (6) Conv2d backward grad_X, grad_input = convolution_backward(grad_out, X, conv_weight) return grad_X, grad_input, None, None, None, None, None
modulize
import torch.nn as nn import math class FusedConvBN(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, exp_avg_factor=0.1, eps=1e-3, device=None, dtype=None): super(FusedConvBN, self).__init__() # Conv parameters weight_shape = (out_channels, in_channels, kernel_size, kernel_size) self.conv_weight = nn.Parameter(torch.empty(*weight_shape, **factory_kwargs)) # Batch norm parameters num_features = out_channels self.num_features = num_features self.eps = eps # Initialize self.reset_parameters() def forward(self, X): return FusedConvBN2DFunction.apply(X, self.conv_weight, self.eps) def reset_parameters(self) -> None: nn.init.kaiming_uniform_(self.conv_weight, a=math.sqrt(5))