텐서플로우 모델을 만들시 구글코랩에서 만든 모델이면 임시로 구글에서 서버를 잠깐 빌려주는 식이라 빌려준 서버와의 연결이 끊기면 모델은 저장되지 않고 날아간다. 이러한 사태를 방지 하기 위해 모델을 저장하는 방법을 알아보자.
첫번째 방법은 딥러닝 했던 내용을 폴더로 저장하는 방법이다.
# 전체 네트워크와 웨이트를 통으로 저장하고 불러오기
# 폴더 구조로 저장.
model.save('fashion_mnist_model')
# 저장된 인공지능을 불러오는 코드.
tf.keras.models.load_model('fashion_mnist_model')
코드를 실행하면 구글서버에 폴더 형태로 저장되며 세션이 끊기면 사라지지만 세션이 끊기기 전에 폴더를 다운로드하여 저장할 수 있다.
두번째 방법은 h5파일로 저장하는 방법이다.
# 모델을,파일 하나로 저장하는 방법
#파일을 저장하는 코드
model.save('fashion_mnist_model.h5')
#파일을 불러오는 코드
model3=tf.keras.models.load_model('fashion_mnist_model.h5')
코드를 실행하면 구글서버에 h5파일 형태로 저장되며 세션이 끊기면 사라지지만 세션이 끊기기 전에 파일을 다운로드하여 저장할 수 있다.
세번째 방법은 모델의 네트워크와 웨이트 파일을 각각 따로 저장하는 방법이다.
이방법으로 저장할경우 네트워크와 웨이트이 2개가 있어야 모델이 구성이되고 작동할수 있다.
# 네트워크만 저장하고 불러오기
model.to_json()
fashion_mnist_network=model.to_json()
# 네트워크를 json 파일로 저장하는 코드
fashion_mnist_network=model.to_json()
with open('fashion_mnist_network.json','w')as file:
file.write(fashion_mnist_network)
# 저장된 네트워크를 읽어오는 코드
with open('fashion_mnist_network.json','r')as file:
fashion_net = file.read()
# 위의 네트워크로부터 모델을 만들고 싶으면
model4=tf.keras.models.model_from_json(fashion_net)
# 웨이트를 저장하는 코드
model.save_weights('fashion_mnist_weight.h5')
# 웨이트를 읽어오는 코드
model4.load_weights('fashion_mnist_weight.h5')
위와 마찬가지로 네트워크,웨이트 둘 다 다운로드 받을수 있다.
반응형
'딥러닝' 카테고리의 다른 글
딥러닝(13) pooling 이란 (0) | 2023.01.02 |
---|---|
딥러닝(12)CNN의 구조 (0) | 2022.12.30 |
딥러닝(10) 에포크시마다 테스트를 하는, 벨리데이션 데이터를 처리하는 방법중 validation_data 파라미터 사용법 (0) | 2022.12.29 |
딥러닝(9)Flatten 라이브러리 없이, 이미지를 평탄화 하는 방법 (0) | 2022.12.29 |
딥러닝(9) 텐서플로우의 콜백클래스를 이용해서, 원하는 조건이 되면 학습을 멈추게 하는 코드 (0) | 2022.12.29 |