딥러닝

딥러닝(9) 텐서플로우의 콜백클래스를 이용해서, 원하는 조건이 되면 학습을 멈추게 하는 코드

개발연습자1 2022. 12. 29. 17:52

지난시간에 오버 피팅 문제를 해결하기 위해 텐서 플로우의 Callback 클래스를 사용한다. 

 

우선 텐서 플로우를 이용해 Callback 클래스를 만든다.

#val_accuracy가 88%가 넘으면 멈추도록 하고 싶다.

# Callback 클래스 만들기
class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self,epoch, logs={}):
    if logs['val_accuracy']>0.88:
      print('\n내가 정한 정확도에 도달했으니, 학습을 멈춘다.')
      self.model.stop_training =True
 
my_cd = myCallback()

 

Callback 클래스를 학습할때 넣고 모델로 학습을 시킨다.

 

def build_model():
  model = Sequential()
  model.add( Flatten()  )
  model.add( Dense(128, 'relu') )
  model.add( Dense(64, 'relu') )
  model.add( Dense(10, 'softmax'))
  model.compile('adam', 'sparse_categorical_crossentropy', ['accuracy'])
  return model
  
  model = build_model()
  
  # callback 클래스를 학습할때 실행
  epoch_history=model.fit(X_train, y_train, epochs = 30,validation_split=0.2,
                        callbacks = [my_cd])

 

학습시 오버피팅이 일어나기 전에 학습을 멈춘다.

반응형