ML Engineering/Tensorflow and Torch
[Tensorflow] TF 모델 저장방법 2가지 (ckpt, h5)
KeepPersistStay
2020. 6. 20. 15:34
Tensorflow 를 이용한 학습 모델을 저장하는 방법에는 2가지가 있다.
1. ".meta" 파일과 ".data" 파일을 나누어서 저장하는 방식
Tensorflow model을 표현하기 위해서 크게 2가지 컴포넌트(Meta Graph, Checkpoint)가 필요하다.
- Meta graph
- .meta
- tensorflow graph structure 에 대한 정보
- Variables, Collection and Operations
- .meta
- Checkpoint
- A protocol buffer with a list of recent checkpoints
- weight, biases, gradients, all the variables 값
- .data files: training variables 값에 대한 정보; model.ckpt.data-00000-of-00001
- .index files:**checkpoint 에 대한 정보(index)
TF 1.x 버전 tf.Session()을 통해서 모델을 저장하는 코드는 다음과 같다.
with tf.Session() as sess:
# Initializes all the variables.
sess.run(init_all_op)
# Runs to logit.
sess.run(logits)
# Creates a saver.
saver = tf.train.Saver()
# Save both checkpoints and meta-graph
saver.save(sess, 'my-save-dir/my-model-10000')
# Generates MetaGraphDef.
saver.export_meta_graph('my-save-dir/my-model-10000.meta') #change this line
2. 전체 모델을 HDF5 파일 하나에 저장하는 방식
가중치, 모델 구성, 옵티마이저에 지정한 설정까지 파일 하나에 모두 포함된다.
model = create_model()
model.fit(train_images, train_labels, epochs=5)
# 전체 모델을 HDF5 파일로 저장합니다
model.save('my_model.h5')
# 가중치와 옵티마이저를 포함하여 정확히 동일한 모델을 다시 생성합니다
new_model = keras.models.load_model('my_model.h5')
new_model.summary()
고수준 API인 tf.keras 를 통해서 모델을 저장하고 로드하는 것은 이곳에 잘 정리되어 있다.
https://www.tensorflow.org/tutorials/keras/save_and_restore_models
끝.