Tensorflow framework로 학습한 모델을 C++에서 불러와서 Inference를 하기 위해서는 ckpt 혹은 h5 형식 파일을 pb 형식 파일로 변경을 먼저 해야한다. 

 

다시 말해서 모델을 재학습 하기 위한 다른 메타 데이터는 제외하고, 추론만을 위해 필요한 모델의 graph variable & operation 과 실제 가중치 값들만을 가지고 모델을 protocal buffer 형식으로 변환(freeze)해야 한다.

 

텐서플로우에서는 ckpt 파일을 pb파일로 변환하는 code를 제공한다. 텐서플로우가 이미 빌드되어 있는 상태라면, 다음과 같이 명령어를 입력하면 된다.

 

 freeze_graph	--input_graph=model/dnn.pbtxt \
 			--input_checkpoint=model/dnn.ckpt \
 			--output_graph=model/dnn.pb \
 			--output_node_names=output_name

 

개인적으로 확인하고 싶은 부분이 있다면 공식 레포지토리를 뜯어보시라!

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

 

 

나는 나만의 freeze_graph.py 함수를 다음과 같이 떼어내서 따로 관리 한다. 

import sys, os, argparse
import tensorflow as tf
# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
# tf.contrib.rnn
# for QRNN
# ?try: import qrnn
# except: sys.stderr.write('import qrnn, failed\n')

'''
source is from https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py
'''

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

# dir = os.path.dirname(os.path.realpath(__file__))

def modify_op(graph_def):
    """
    reference : https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091 
    """
    for node in graph_def.node:
        if node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr: del node.attr['use_locking']
            if 'validate_shape' in node.attr: del node.attr['validate_shape']
            if len(node.input) == 2:
                # input0: ref: Should be from a Variable node. May be uninitialized.
                # input1: value: The value to be assigned to the variable.
                node.input[0] = node.input[1]
                del node.input[1]
    return graph_def

def freeze_graph(model_dir, output_node_names, frozen_model_name, optimize_graph_def=0):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
        frozen_model_name: a string, the name of the frozen model
        optimize_graph_def: int, 1 for optimizing graph_def via tensorRT
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph_path = absolute_model_dir + "/" + frozen_model_name

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(',') # The output node names are used to select the usefull nodes
        )

        # Modify for 'float_ref'
        output_graph_def = modify_op(output_graph_def)

        # Optimize graph_def via tensorRT
        if optimize_graph_def:
            from tensorflow.contrib import tensorrt as trt
            # get optimized graph_def
            trt_graph_def = trt.create_inference_graph(
              input_graph_def=output_graph_def,
              outputs=output_node_names.split(','),
              max_batch_size=128,
              max_workspace_size_bytes=1 << 30,
              precision_mode='FP16',  # TRT Engine precision "FP32","FP16" or "INT8"
              minimum_segment_size=3  # minimum number of nodes in an engine
            )
            output_graph_def = trt_graph_def 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph_path, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
        print("Saved const_graph in %s"%model_dir)
    return output_graph_def
    
    
freeze_graph("model/", "output_name", "dnn.pb")

 

모델을 freeze 할 때, 제일 신경써야할 부분은 "output_node_names" 옵션이다. 모델을 freeze시킨 후, output으로 받고 싶은 출력 노드 이름을 이곳에 명시해줘야 한다. 이 부분을 확인할 수 있는 부분은 ".pbtxt" 파일이다. 이곳에 모든 노드 정보가 기록되므로, 이 파일을 열어서 내가 원하는 노드 네임을 찾아서 넣으면 된다. 

 

개인적으로 노드 이름을 간단하게 확인하기 위해서, 다음 코드를 자주 활용한다.

print([print(n.name) for n in tf.get_default_graph().as_graph_def().node])

 

마지막으로, 또 하나의 팁을 남기자면, 모델 설계 시, 모델의 input node는 꼭 "input"로 이름을 지정해주고, 마지막 노드 또한 해당하는 텐서에 이름을 "output"으로 꼭 설정해주거나, 그렇지 못할 경우에 모델 output 부분에

tf.identity(x, name="output")

다음 코드를 추가해주면, output_node를 "output"으로 일관되게 유지할 수 있다.

 

튜토리얼이 될만한 레포지토리는 다음과 같다. 

https://github.com/JackyTung/tensorgraph

나도 나중에 여유될 때 음성 관련 기본 모델로 튜토리얼 코드 작성 해봐야겠다.

 

끝.

 

+ Recent posts