코드 6-75 ResNet 모델 네트워크
class ResNet(nn.Module):
def __init__(self, config, output_dim, zero_init_residual=False):
super().__init__()
block, n_blocks, channels = config ------ ResNet을 호출할 때 넘겨준 config 값들을 block, n_blocks, channels에 저장
self.in_channels = channels[0]
assert len(n_blocks) == len(channels) == 4 ------ 블록 크기 = 채널 크기 = 4
self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(self.in_channels)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)
self.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.fc = nn.Linear(self.in_channels, output_dim)
if zero_init_residual: ------ ①
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def get_resnet_layer(self, block, n_blocks, channels, stride=1): ------ 블록을 추가하기 위한 함수
layers = []
if self.in_channels != block.expansion * channels: ------ in_channels와 block.expansion*channels가 다르면 downsample 적용
downsample = True
else:
downsample = False
layers.append(block(self.in_channels, channels, stride, downsample)) ------ 계층(layer)을 추가할 때 in_channels, channels, stride뿐만 아니라 다운샘플 적용 유무도 함께 전달
for i in range(1, n_blocks): ------ n_blocks만큼 계층 추가
layers.append(block(block.expansion*channels, channels))
self.in_channels = block.expansion * channels
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x) ------ 224×224
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x) ------ 112×112
x = self.layer1(x) ------ 56×56
x = self.layer2(x) ------ 28×28
x = self.layer3(x) ------ 14×14
x = self.layer4(x) ------ 7×7
x = self.avgpool(x) ------ 1×1
h = x.view(x.shape[0], -1)
x = self.fc(h)
return x, h