더북(TheBook)

코드 6-75 ResNet 모델 네트워크

class ResNet(nn.Module):
    def __init__(self, config, output_dim, zero_init_residual=False):
        super().__init__()

        block, n_blocks, channels = config ------ ResNet을 호출할 때 넘겨준 config 값들을 block, n_blocks, channels에 저장
        self.in_channels = channels[0]
        assert len(n_blocks) == len(channels) == 4 ------ 블록 크기 = 채널 크기 = 4

        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride=2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride=2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, output_dim)

        if zero_init_residual: ------ ①
            for m in self.modules():
                if isinstance(m, Bottleneck):
                    nn.init.constant_(m.bn3.weight, 0)
                elif isinstance(m, BasicBlock):
                    nn.init.constant_(m.bn2.weight, 0)

    def get_resnet_layer(self, block, n_blocks, channels, stride=1): ------ 블록을 추가하기 위한 함수
        layers = []
        if self.in_channels != block.expansion * channels: ------ in_channels와 block.expansion*channels가 다르면 downsample 적용
            downsample = True
        else:
            downsample = False

        layers.append(block(self.in_channels, channels, stride, downsample)) ------ 계층(layer)을 추가할 때 in_channels, channels, stride뿐만 아니라 다운샘플 적용 유무도 함께 전달
        for i in range(1, n_blocks): ------ n_blocks만큼 계층 추가
            layers.append(block(block.expansion*channels, channels))

        self.in_channels = block.expansion * channels
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x) ------ 224×224
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x) ------ 112×112
        x = self.layer1(x) ------ 56×56
        x = self.layer2(x) ------ 28×28
        x = self.layer3(x) ------ 14×14
        x = self.layer4(x) ------ 7×7
        x = self.avgpool(x) ------ 1×1
        h = x.view(x.shape[0], -1)
        x = self.fc(h)
        return x, h
신간 소식 구독하기
뉴스레터에 가입하시고 이메일로 신간 소식을 받아 보세요.