코드 9-32 네트워크 생성
class binaryClassification(nn.Module):
def __init__(self):
super(binaryClassification, self).__init__()
self.layer_1 = nn.Linear(8, 64, bias=True) ------ 칼럼이 여덟 개이므로 입력 크기는 8을 사용
self.layer_2 = nn.Linear(64, 64, bias=True)
self.layer_out = nn.Linear(64, 1, bias=True) ------ 출력으로는 당뇨인지 아닌지를 나타내는 0과 1의 값만 가지므로 1을 사용
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=0.1)
self.batchnorm1 = nn.BatchNorm1d(64)
self.batchnorm2 = nn.BatchNorm1d(64)
def forward(self, inputs):
x = self.relu(self.layer_1(inputs))
x = self.batchnorm1(x)
x = self.relu(self.layer_2(x))
x = self.batchnorm2(x)
x = self.dropout(x)
x = self.layer_out(x)
return x