코드 6-40 VGG 모델 정의
class VGG(nn.Module):
def __init__(self, features, output_dim):
super().__init__()
self.features = features ------ VGG 모델에 대한 매개변수에서 받아 온 features 값을 self.features에 넣어 줍니다.
self.avgpool = nn.AdaptiveAvgPool2d(7)
self.classifier = nn.Sequential(
nn.Linear(512*7*7, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Dropout(0.5),
nn.Linear(4096, output_dim)
) ------ 완전연결층과 출력층 정의
def forward(self, x):
x = self.features(x)
x = self.avgpool(x)
h = x.view(x.shape[0], -1)
x = self.classifier(h)
return x, h