지난시간에 오버 피팅 문제를 해결하기 위해 텐서 플로우의 Callback 클래스를 사용한다.
우선 텐서 플로우를 이용해 Callback 클래스를 만든다.
#val_accuracy가 88%가 넘으면 멈추도록 하고 싶다.
# Callback 클래스 만들기
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self,epoch, logs={}):
if logs['val_accuracy']>0.88:
print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
self.model.stop_training =True
my_cd = myCallback()
Callback 클래스를 학습할때 넣고 모델로 학습을 시킨다.
def build_model():
model = Sequential()
model.add( Flatten() )
model.add( Dense(128, 'relu') )
model.add( Dense(64, 'relu') )
model.add( Dense(10, 'softmax'))
model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
return model
model = build_model()
# callback 클래스를 학습할때 실행
epoch_history=model.fit(X_train, y_train, epochs = 30,validation_split=0.2,
callbacks = [my_cd])
학습시 오버피팅이 일어나기 전에 학습을 멈춘다.
반응형
'딥러닝' 카테고리의 다른 글
딥러닝(10) 에포크시마다 테스트를 하는, 벨리데이션 데이터를 처리하는 방법중 validation_data 파라미터 사용법 (0) | 2022.12.29 |
---|---|
딥러닝(9)Flatten 라이브러리 없이, 이미지를 평탄화 하는 방법 (0) | 2022.12.29 |
딥러닝(8)에포크, 학습데이터/벨리데이션데이터와 오버피팅 (0) | 2022.12.29 |
딥러닝(7) flatten 라이브러리 사용하는 이유, 액티베이션함수 소프트맥스, 분류의 문제에서 loss셋팅 (0) | 2022.12.29 |
딥러닝(6) validation split (0) | 2022.12.29 |