텐서플로 2로 ResNet을 구현해 보겠습니다. 이 예제에서는 아이덴티티 블록과 합성곱 블록만 확인해 보겠습니다. 필요한 라이브러리와 데이터를 호출하고, 생성된 모델로 훈련하는 부분은 앞에서 충분히 살펴보았기 때문에 네트워크 부분만 집중해서 살펴보겠습니다.
먼저 아이텐티티 블록에 대한 코드입니다.
코드 6-19 아이덴티티 블록
def res_identity(x, filters):
x_skip = x ------ 레지듀얼 블록을 추가하는 데 사용
f1, f2 = filters
x = Conv2D(f1, kernel_size=(1,1), strides=(1,1), padding='valid',
kernel_regularizer=l2(0.001))(x)
x = BatchNormalization()(x) ------ ①
x = Activation(activations.relu)(x) ------ 첫 번째 블록
x = Conv2D(f1, kernel_size=(3,3), strides=(1,1), padding='same',
kernel_regularizer=l2(0.001))(x)
x = BatchNormalization()(x)
x = Activation(activations.relu)(x) ------ 두 번째 블록
x = Conv2D(f2, kernel_size=(1,1), strides=(1,1), padding='valid',
kernel_regularizer=l2(0.001))(x)
x = BatchNormalization()(x) ------ 세 번째 블록
x = Add()([x, x_skip]) ------ 숏컷
x = Activation(activations.relu)(x)
return x