딥러닝
딥러닝(9) 텐서플로우의 콜백클래스를 이용해서, 원하는 조건이 되면 학습을 멈추게 하는 코드
개발연습자1
2022. 12. 29. 17:52
지난시간에 오버 피팅 문제를 해결하기 위해 텐서 플로우의 Callback 클래스를 사용한다.
우선 텐서 플로우를 이용해 Callback 클래스를 만든다.
#val_accuracy가 88%가 넘으면 멈추도록 하고 싶다.
# Callback 클래스 만들기
class myCallback(tf.keras.callbacks.Callback):
def on_epoch_end(self,epoch, logs={}):
if logs['val_accuracy']>0.88:
print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
self.model.stop_training =True
my_cd = myCallback()
Callback 클래스를 학습할때 넣고 모델로 학습을 시킨다.
def build_model():
model = Sequential()
model.add( Flatten() )
model.add( Dense(128, 'relu') )
model.add( Dense(64, 'relu') )
model.add( Dense(10, 'softmax'))
model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
return model
model = build_model()
# callback 클래스를 학습할때 실행
epoch_history=model.fit(X_train, y_train, epochs = 30,validation_split=0.2,
callbacks = [my_cd])
학습시 오버피팅이 일어나기 전에 학습을 멈춘다.
반응형