병목 블록은 ResNet50, ResNet101, ResNet152에서 사용되며 1×1 합성곱층, 3×3 합성곱층, 1×1 합성곱층으로 구성됩니다.
코드 6-74 병목 블록 정의
class Bottleneck(nn.Module):
expansion = 4 ------ ResNet에서 병목 블록을 정의하기 위한 하이퍼파라미터입니다.
def __init__(self, in_channels, out_channels, stride=1, downsample=False):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, bias=False) ------ 1×1 합성곱층
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False) ------ 3×3 합성곱층
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, self.expansion*out_channels, kernel_size=1, stride=1, bias=False) ------ 1×1 합성곱층, 또한 다음 계층의 입력 채널 수와 일치하도록 self.expansion*out_channels를 합니다.
self.bn3 = nn.BatchNorm2d(self.expansion*out_channels)
self.relu = nn.ReLU(inplace=True)
if downsample:
conv = nn.Conv2d(in_channels, self.expansion*out_channels, kernel_size=1, stride=stride, bias=False)
bn = nn.BatchNorm2d(self.expansion*out_channels)
downsample = nn.Sequential(conv, bn)
else:
downsample = None
self.downsample = downsample
def forward(self, x):
i = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.conv3(x)
x = self.bn3(x)
if self.downsample is not None:
i = self.downsample(i)
x += i
x = self.relu(x)
return x