In PyTorch, when defining a nn.Module, attributes can either be trainable parameters (like weights and biases) or non-trainable values (such as constants, buffers, or pre-computed values).

Trainable Parameters vs Non-Trainable Attributes:

  1. Trainable Parameters:
    • These are parameters that the model updates during training via backpropagation (e.g., weights in a neural network).
    • PyTorch stores these parameters using nn.Parameter and registers them to the model.
    • Example: weights and biases in layers like nn.Linear or nn.Conv2d.
    self.weight = nn.Parameter(torch.randn(10, 10))  # Trainable
  2. Non-Trainable Attributes:
    • These are attributes that do not change during training. They are useful for storing constants, lookup tables, pre-initialized matrices, etc.
    • If you don’t want these values to be updated via backpropagation, you typically register them as a buffer or store them as regular attributes of the module.
    • Example: a normalization constant, a precomputed matrix, or a codebook in vector quantization.
    self.constant = torch.randn(10, 10)  # Non-trainable, regular attribute

register_buffer:

  • PyTorch provides register_buffer to store non-trainable tensors in a model. This is useful because buffers will automatically be moved to the correct device (e.g., GPU) when the model is moved, but they won’t be updated during training.
  • However, if you don’t want or need this specific behavior, you can just store non-trainable values as regular attributes.
        def __init__(block_size):
            self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size))
                    .view(1, 1, block_size, block_size)) 
        def forward(x): 
            B, T, C = x.size() 
            self.mask[:,:,:T,:T]

Intro

 

딥러닝 모델을 설계하고 개발할 때 중요한 부분 중 하나인 데이터 로더 만드는 방법에 대해서 정리하고자 한다.

 

특별히, 음성이나 음악 등 연속적인 데이터를 이용하는 모델을 구축하고자 한다.

 

신호를 가지고 할 수 있는 것들이 많이 있지만, 우선 Keyword Spotting 알고리즘을 만드는 것을 목표로 놓고, 그에 맞는 데이터 로더를 만들어 가보도록 하자.

 

먼저 shell 환경에서 다음과 같이 tensorflow speech command dataset을 다운로드 받자. 

!wget https://storage.cloud.google.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz

 

torch.utils.data.Dataset 은 데이터셋을 나타내는 추상 클래스이다. 우리가 만드는 Custom Dataset Class 는 torch.utils.data.Dataset 을 상속하고, 다음 3가지 멤버함수들을 오버라이드 해야 한다. 

import torch

class CustomDataset(torch.utils.data.Dataset):

  def __init__(self, ...):
  
  def __len__(self):
  
  def __getitem__(self, idx):

멤버함수 __init__() 는 클래스 인스턴스 생성시 파라미터로 들어오는 정보로 원하는 데이터셋 정보를 초기화 해야 한다.

len(dataset) 에서 호출되는, 멤버함수 __len__ 은 데이터 셋의 크기를 return 해야 한다.

dataset[i] 에서 호출되는, __getitem__ 은 i 번째 샘플을 찾는데 사용된다.

 

그럼 Custom Dataset Class인 SpeechCommandsDataset 를 만들어 보자.

import torch
import os
import numpy as np
import librosa

CLASSES = 'unknown, silence, yes, no, up, down, left, right, on, off, stop, go'.split(', ')

class SpeechCommandsDataset(torch.utils.data.Dataset):
    """Google speech commands dataset. Only 'yes', 'no', 'up', 'down', 'left',
    'right', 'on', 'off', 'stop' and 'go' are treated as known classes.
    All other classes are used as 'unknown' samples.
    """

    def __init__(self, folder, transform=None, classes=CLASSES, silence_percentage=0.1):
        """
          Args:
          folder (string): Path folder.
          transform (callable, optional): Optional transform to be applied
          on a sample.
          class (string): list

        """
        all_classes = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith('_')]
        
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        for c in all_classes:
            if c not in class_to_idx:
                class_to_idx[c] = 0

        data = []
        for c in all_classes:
            d = os.path.join(folder, c)
            target = class_to_idx[c]
            for f in os.listdir(d):
                path = os.path.join(d, f)
                data.append((path, target))

        # add silence
        target = class_to_idx['silence']
        data += [('', target)] * int(len(data) * silence_percentage)

        self.classes = classes
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path, target = self.data[index]
        data = {'path_wave': path, 'target': target}

        if self.transform is not None:
            data = self.transform(data)
        return data

 

__init__ 을 사용해서 폴더 안에 있는 데이터들의 Path 를  읽지만, __getitem__ 을 이용해서 그 path에 해당하는 waveform 데이터를 읽어드린다 . 이 방법은 모든 파일을 메모리에 저장하지 않고 필요할때마다 읽기 때문에 메모리를 효율적으로 사용한다.

 

데이터셋의 샘플은 {'path': path_wave, 'target': label} 의 사전 형태를 갖는다. 선택적 인자인 transform 을 통해 필요한 전처리 과정을 샘플에 적용할 수 있다. transform 에 대해서는 뒷부분에서 조금 더 자세히 살펴보기로 한다.

 

클래스를 인스턴스화 하고, 데이터 샘플을 통해서 반복한다. 첫번째 4개의 크기를 출력하고, 샘플들의 wave와 target을 보여준 것이다.

path_dataset = "~/data_speech_commands_v0.02"

dataset = SpeechCommandsDataset(path_dataset)
                                         
for i in range(len(dataset)):
    sample = dataset[i]

    print(i, sample['path_wave'], sample['target'])

Out:

0 /Users/Downloads/data_speech_commands_v0.02/right/8e523821_nohash_2.wav 7
1 /Users/Downloads/data_speech_commands_v0.02/right/bb05582b_nohash_3.wav 7
2 /Users/Downloads/data_speech_commands_v0.02/right/988e2f9a_nohash_0.wav 7
3 /Users/Downloads/data_speech_commands_v0.02/right/a69b9b3e_nohash_0.wav 7
...

 

Transform

 

뉴럴 네트워크 학습을 위해서 우리는 다양한 형태의 데이터 변환이 필요할 수 있다. 예를들어, 음성 신호처리에서는 다음과 같은 다양한 transforms이 필요하다.

 

  1. 음성 신호를 time domain 혹은 frequency domain 에서 분석해야 한다.
  2. 모든 wave 파일의 길이가 상이한 특성 때문에 파일들의 길이를 Fix 하여 데이터를 재구성 한 후, 원하는 뉴럴네트워크 모델을 학습해야 한다.
  3. Data augmentation 적용 ( Time Streching / Shift / Add Noise )

모든 transform 은 클래스로 작성하여 클래스가 호출될 때마다 Transform의 매개변수가 전달 되지 않아도 되게 만드는 것이 좋다. 이를 위해 __call__ 함수와 __init__ 함수를 포함한 목적에 맞는 클래스를 구현한다.

 

이 페이지에서는 4가지 Transform 클래스를 구현한다.

  1. LoadAudio : Audio를 librosa library를 사용하여 time domain data로 로드한다.
  2. FixAudioLength : time domain audio 신호를 1초를 기준으로 zero-padding 하거나, truncates 시켜서 fixed length로 변환한다.
  3. ToMelSpectrogram : time domain 신호로부터 freqency domain log Mel filterbank 특징벡터로 변경한다.
  4. ToTensor : numpy 벡터를 torch tensor type 으로 변경한다.
import torch
import librosa
import numpy as np

class LoadAudio(object):
    """Loads an audio into a numpy array."""

    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def __call__(self, data):
        path = data['path']
        if path:
            samples, sample_rate = librosa.load(path, self.sample_rate)
        else:
            # silence
            sample_rate = self.sample_rate
            samples = np.zeros(sample_rate, dtype=np.float32)
        data['samples'] = samples
        data['sample_rate'] = sample_rate
        return data

class FixAudioLength(object):
    """Either pads or truncates an audio into a fixed length."""

    def __init__(self, time=1):
        self.time = time

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        length = int(self.time * sample_rate)
        if length < len(samples):
            data['samples'] = samples[:length]
        elif length > len(samples):
            data['samples'] = np.pad(samples, (0, length - len(samples)), "constant")
        return data
        
class ToMelSpectrogram(object):
    """Creates the mel spectrogram from an audio. The result is a 32x32 matrix."""

    def __init__(self, n_mels=32):
        self.n_mels = n_mels

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        s = librosa.feature.melspectrogram(samples, sr=sample_rate, n_mels=self.n_mels)
        data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
        return data

class ToTensor(object):
    """Converts into a tensor."""

    def __init__(self, np_name, tensor_name, normalize=None):
        self.np_name = np_name
        self.tensor_name = tensor_name
        self.normalize = normalize

    def __call__(self, data):
        tensor = torch.FloatTensor(data[self.np_name])
        if self.normalize is not None:
            mean, std = self.normalize
            tensor -= mean
            tensor /= std
        data[self.tensor_name] = tensor
        return data

 

Transform 구성(Compose)

Trasform 들이 잘 작성되었나 확인해보자.

 

위에 정의한 클래스들을 사용하여, path_wave로부터 raw audio data를 메모리로 load 하고 1초 기준으로 zero-padding/truncate 하고, MelSpectrogram으로 변환 후 Torch Tensor 타입으로 변경한다.

 

torchvision.transforms.Compose 는 위의 클래스에 정의 일을 하도록 호출할 수 있는 클래스이다.

from torchvision import transforms

n_mels=40
batch_size=4
use_gpu=False
num_dataloader_workers_=1

path_dataset = "/Users/Downloads/data_speech_commands_v0.02"

dataset = SpeechCommandsDataset(path_dataset)

_loadaudio = LoadAudio()
_fixaudiolength = FixAudioLength()
_melSpec = ToMelSpectrogram()

composed = transforms.Compose([LoadAudio(),
                                FixAudioLength(),
                                ToMelSpectrogram(n_mels=n_mels), 
                                ToTensor('mel_spectrogram', 'input')])


sample = dataset[0]

transformed_sample = _loadaudio(sample)
print(transformed_sample['samples'])

transformed_sample = _fixaudiolength(sample)
print(transformed_sample['samples'])

transformed_sample = _melSpec(sample)
print(transformed_sample['mel_spectrogram'])

transformed_sample = composed(sample)
print(transformed_sample['path_wave'], transformed_sample['input'].size(), transformed_sample['target'])

Out:

[ 0.0000000e+00  0.0000000e+00 -3.0517578e-05 ... -6.1035156e-05
 -6.1035156e-05 -6.1035156e-05]
[ 0.0000000e+00  0.0000000e+00 -3.0517578e-05 ... -6.1035156e-05
 -6.1035156e-05 -6.1035156e-05]
[[-80.      -80.      -80.      ... -80.      -80.      -80.     ]
 [-76.06169 -78.95192 -80.      ... -80.      -80.      -80.     ]
 [-80.      -80.      -80.      ... -80.      -80.      -80.     ]
 ...
 [-80.      -80.      -80.      ... -80.      -80.      -80.     ]
 [-80.      -80.      -80.      ... -80.      -80.      -80.     ]
 [-80.      -80.      -80.      ... -80.      -80.      -80.     ]]
/Users/Downloads/data_speech_commands_v0.02/right/8e523821_nohash_2.wav torch.Size([40, 32]) 7

 

Custom dataset에 적용

 

이제 Transform 클래스들을 기본 custom dataset 클래스에 적용해보자.

 

from torchvision import transforms
from tqdm import tqdm

n_mels=40
batch_size=4
use_gpu=False
num_dataloader_workers=1

# path_dataset = "~/data_speech_commands_v0.02"
path_dataset = "/Users/Downloads/data_speech_commands_v0.02"

dataset = SpeechCommandsDataset(path_dataset,
                                transforms.Compose([LoadAudio(),
                                         		FixAudioLength(),
                                         		ToMelSpectrogram(n_mels=n_mels), 
                                         		ToTensor('mel_spectrogram', 'input')]))
                                         

dataloader = torch.utils.data.DataLoader(dataset,
			batch_size=batch_size,
            		shuffle=False,
           		pin_memory=use_gpu, 
            		num_workers=num_dataloader_workers)
                    

for batch in tqdm(dataloader, unit="audios", unit_scale=dataloader.batch_size):
    inputs = batch['input']
    targets = batch['target']
    print(inputs.size(), targets.size())

 

요약

wave 파일 전체를 메모리에 올리지 않고 필요할 때마다 로드 시키고, 원하는 형태의 데이터로 transform을 메모리상에서 할 수 있다.

 

전체 코드 공유

import torch
import os
import librosa
import numpy as np
from torchvision import transforms
from tqdm import tqdm

CLASSES = 'unknown, silence, yes, no, up, down, left, right, on, off, stop, go'.split(', ')

class SpeechCommandsDataset(torch.utils.data.Dataset):
    """Google speech commands dataset. Only 'yes', 'no', 'up', 'down', 'left',
    'right', 'on', 'off', 'stop' and 'go' are treated as known classes.
    All other classes are used as 'unknown' samples.
    """

    def __init__(self, folder, transform=None, classes=CLASSES, silence_percentage=0.1):
        """
          Args:
          folder (string): Path folder.
          transform (callable, optional): Optional transform to be applied
          on a sample.
          class (string): list

        """
        all_classes = [d for d in os.listdir(folder) if os.path.isdir(os.path.join(folder, d)) and not d.startswith('_')]
        
        class_to_idx = {classes[i]: i for i in range(len(classes))}
        for c in all_classes:
            if c not in class_to_idx:
                class_to_idx[c] = 0

        data = []
        for c in all_classes:
            d = os.path.join(folder, c)
            target = class_to_idx[c]
            for f in os.listdir(d):
                path = os.path.join(d, f)
                data.append((path, target))

        # add silence
        target = class_to_idx['silence']
        data += [('', target)] * int(len(data) * silence_percentage)

        self.classes = classes
        self.data = data
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        path, target = self.data[index]
        data = {'path_wave': path, 'target': target}

        if self.transform is not None:
            data = self.transform(data)
        return data
  
class LoadAudio(object):
    """Loads an audio into a numpy array."""

    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate

    def __call__(self, data):
        path = data['path_wave']
        if path:
            samples, sample_rate = librosa.load(path, self.sample_rate)
        else:
            # silence
            sample_rate = self.sample_rate
            samples = np.zeros(sample_rate, dtype=np.float32)
        data['samples'] = samples
        data['sample_rate'] = sample_rate
        return data

class FixAudioLength(object):
    """Either pads or truncates an audio into a fixed length."""

    def __init__(self, time=1):
        self.time = time

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        length = int(self.time * sample_rate)
        if length < len(samples):
            data['samples'] = samples[:length]
        elif length > len(samples):
            data['samples'] = np.pad(samples, (0, length - len(samples)), "constant")
        return data
        
class ToMelSpectrogram(object):
    """Creates the mel spectrogram from an audio. The result is a 32x32 matrix."""

    def __init__(self, n_mels=32):
        self.n_mels = n_mels

    def __call__(self, data):
        samples = data['samples']
        sample_rate = data['sample_rate']
        s = librosa.feature.melspectrogram(samples, sr=sample_rate, n_mels=self.n_mels)
        data['mel_spectrogram'] = librosa.power_to_db(s, ref=np.max)
        return data

class ToTensor(object):
    """Converts into a tensor."""

    def __init__(self, np_name, tensor_name, normalize=None):
        self.np_name = np_name
        self.tensor_name = tensor_name
        self.normalize = normalize

    def __call__(self, data):
        tensor = torch.FloatTensor(data[self.np_name])
        if self.normalize is not None:
            mean, std = self.normalize
            tensor -= mean
            tensor /= std
        data[self.tensor_name] = tensor
        return data
        
        
        
        


n_mels=40
batch_size=4
use_gpu=False
num_dataloader_workers=1

path_dataset = "/Users/Downloads/data_speech_commands_v0.02"

dataset = SpeechCommandsDataset(path_dataset,
                                transforms.Compose([LoadAudio(),
                                         		FixAudioLength(),
                                         		ToMelSpectrogram(n_mels=n_mels), 
                                         		ToTensor('mel_spectrogram', 'input')]))
                                         

dataloader = torch.utils.data.DataLoader(dataset,
			batch_size=batch_size,
            		shuffle=False,
           		pin_memory=use_gpu, 
            		num_workers=num_dataloader_workers)
                    

for batch in tqdm(dataloader, unit="audios", unit_scale=dataloader.batch_size):
    inputs = batch['input']
    targets = batch['target']
    print(inputs.size(), targets.size())
    
    

Pytorch 라이브러리 개요

파이토치는 딥러닝 프로젝트를 빌드(build)하는 데 도움을 주는 파이썬 프로그램용 라이브러리/프레임워크 이다.

 

파이토치는 대부분 C++과 CUDA 언어를 기반으로 만들어졌다.

 

파이토치는 수학적 연산을 가속화 하고자 코어 데이터 구조인 텐서(Tensor)를 제공한다. 이는 NumPy 배열(array)과 비슷한 다차원 배열이고, CPU 또는 GPU에서 연산이 가능하다.

 

파이토치에서 모든 연산은 텐서에서 기초적으로 제공되고, torch.autograd에서 정제된다.

 

파이토치에서 신경망을 구성하기 위한 대부분의 모듈은 코어 모듈(torch.nn)의 하위 모듈로 [nn.Conv1d/nn. ReLU/nn.MSELoss] 등과 같이 여러 Layers/Activation/Loss 를 지원한다. 모델의 최적화를 위해서는 torch.optim을 지원한다.

 

효율적인 학습을 위한 데이터 핸들링에 필요한 기능들 torch.utils.data.dataloader 를 통해 지원한다.

 

Multi-Node / Multi-GPU 머신을 활용한 분산/병렬 처리를 데이터 로딩과 학습 연산에 사용할 수 있도록 torch.nn.DataParallel torch.distributed를 지원한다. 관련하여 다른 글에서 정리해보고 싶다!! [관련 내용]

 

파이토치에서는 torch.utils.tensorboard 를 통해 Tensorboard를 지원한다.

 

또한 각각 도메인에서 활용가능한 PyTorch Library들이 존재 한다.

파이토치는 딥러닝 모델 성능 향상을 위해서 여러가지 Quantization 방법을 지원한다.

 

마지막으로 파이토치는 배포 환경을 고려하여 high performance inference 위한 TorchScript를 제공한다. TorchScript는 Python 인터프리터의 비용을 줄이고 Python 런타임으로부터 독립적으로 모델을 실행시키기 위한 방법이다. 효율적 연산을 위해 Just in Time(JIT)을 지원한다.


Rerference

 

[1] www.pytorch.org/

[2] www.yytorch.org/assets/deep-learning/Deep-Learning-with-PyTorch.pdf

Tensorflow framework로 학습한 모델을 C++에서 불러와서 Inference를 하기 위해서는 ckpt 혹은 h5 형식 파일을 pb 형식 파일로 변경을 먼저 해야한다. 

 

다시 말해서 모델을 재학습 하기 위한 다른 메타 데이터는 제외하고, 추론만을 위해 필요한 모델의 graph variable & operation 과 실제 가중치 값들만을 가지고 모델을 protocal buffer 형식으로 변환(freeze)해야 한다.

 

텐서플로우에서는 ckpt 파일을 pb파일로 변환하는 code를 제공한다. 텐서플로우가 이미 빌드되어 있는 상태라면, 다음과 같이 명령어를 입력하면 된다.

 

 freeze_graph	--input_graph=model/dnn.pbtxt \
 			--input_checkpoint=model/dnn.ckpt \
 			--output_graph=model/dnn.pb \
 			--output_node_names=output_name

 

개인적으로 확인하고 싶은 부분이 있다면 공식 레포지토리를 뜯어보시라!

https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/tools/freeze_graph.py

 

 

나는 나만의 freeze_graph.py 함수를 다음과 같이 떼어내서 따로 관리 한다. 

import sys, os, argparse
import tensorflow as tf
# for LSTMBlockFusedCell(), https://github.com/tensorflow/tensorflow/issues/23369
# tf.contrib.rnn
# for QRNN
# ?try: import qrnn
# except: sys.stderr.write('import qrnn, failed\n')

'''
source is from https://gist.github.com/morgangiraud/249505f540a5e53a48b0c1a869d370bf#file-medium-tffreeze-1-py
'''

# The original freeze_graph function
# from tensorflow.python.tools.freeze_graph import freeze_graph 

# dir = os.path.dirname(os.path.realpath(__file__))

def modify_op(graph_def):
    """
    reference : https://github.com/onnx/tensorflow-onnx/issues/77#issuecomment-445066091 
    """
    for node in graph_def.node:
        if node.op == 'Assign':
            node.op = 'Identity'
            if 'use_locking' in node.attr: del node.attr['use_locking']
            if 'validate_shape' in node.attr: del node.attr['validate_shape']
            if len(node.input) == 2:
                # input0: ref: Should be from a Variable node. May be uninitialized.
                # input1: value: The value to be assigned to the variable.
                node.input[0] = node.input[1]
                del node.input[1]
    return graph_def

def freeze_graph(model_dir, output_node_names, frozen_model_name, optimize_graph_def=0):
    """Extract the sub graph defined by the output nodes and convert 
    all its variables into constant 
    Args:
        model_dir: the root folder containing the checkpoint state file
        output_node_names: a string, containing all the output node's names, 
                            comma separated
        frozen_model_name: a string, the name of the frozen model
        optimize_graph_def: int, 1 for optimizing graph_def via tensorRT
    """
    if not tf.gfile.Exists(model_dir):
        raise AssertionError(
            "Export directory doesn't exists. Please specify an export "
            "directory: %s" % model_dir)

    if not output_node_names:
        print("You need to supply the name of a node to --output_node_names.")
        return -1

    # We retrieve our checkpoint fullpath
    checkpoint = tf.train.get_checkpoint_state(model_dir)
    input_checkpoint = checkpoint.model_checkpoint_path
    
    # We precise the file fullname of our freezed graph
    absolute_model_dir = "/".join(input_checkpoint.split('/')[:-1])
    output_graph_path = absolute_model_dir + "/" + frozen_model_name

    # We clear devices to allow TensorFlow to control on which device it will load operations
    clear_devices = True

    # We start a session using a temporary fresh Graph
    with tf.Session(graph=tf.Graph()) as sess:
        # We import the meta graph in the current default Graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=clear_devices)

        # We restore the weights
        saver.restore(sess, input_checkpoint)

        # We use a built-in TF helper to export variables to constants
        output_graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(
            sess, # The session is used to retrieve the weights
            tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes 
            output_node_names.split(',') # The output node names are used to select the usefull nodes
        )

        # Modify for 'float_ref'
        output_graph_def = modify_op(output_graph_def)

        # Optimize graph_def via tensorRT
        if optimize_graph_def:
            from tensorflow.contrib import tensorrt as trt
            # get optimized graph_def
            trt_graph_def = trt.create_inference_graph(
              input_graph_def=output_graph_def,
              outputs=output_node_names.split(','),
              max_batch_size=128,
              max_workspace_size_bytes=1 << 30,
              precision_mode='FP16',  # TRT Engine precision "FP32","FP16" or "INT8"
              minimum_segment_size=3  # minimum number of nodes in an engine
            )
            output_graph_def = trt_graph_def 

        # Finally we serialize and dump the output graph to the filesystem
        with tf.gfile.GFile(output_graph_path, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))
        print("Saved const_graph in %s"%model_dir)
    return output_graph_def
    
    
freeze_graph("model/", "output_name", "dnn.pb")

 

모델을 freeze 할 때, 제일 신경써야할 부분은 "output_node_names" 옵션이다. 모델을 freeze시킨 후, output으로 받고 싶은 출력 노드 이름을 이곳에 명시해줘야 한다. 이 부분을 확인할 수 있는 부분은 ".pbtxt" 파일이다. 이곳에 모든 노드 정보가 기록되므로, 이 파일을 열어서 내가 원하는 노드 네임을 찾아서 넣으면 된다. 

 

개인적으로 노드 이름을 간단하게 확인하기 위해서, 다음 코드를 자주 활용한다.

print([print(n.name) for n in tf.get_default_graph().as_graph_def().node])

 

마지막으로, 또 하나의 팁을 남기자면, 모델 설계 시, 모델의 input node는 꼭 "input"로 이름을 지정해주고, 마지막 노드 또한 해당하는 텐서에 이름을 "output"으로 꼭 설정해주거나, 그렇지 못할 경우에 모델 output 부분에

tf.identity(x, name="output")

다음 코드를 추가해주면, output_node를 "output"으로 일관되게 유지할 수 있다.

 

튜토리얼이 될만한 레포지토리는 다음과 같다. 

https://github.com/JackyTung/tensorgraph

나도 나중에 여유될 때 음성 관련 기본 모델로 튜토리얼 코드 작성 해봐야겠다.

 

끝.

 

Tensorflow 를 이용한 학습 모델을 저장하는 방법에는 2가지가 있다.

 

1. ".meta" 파일과 ".data" 파일을 나누어서 저장하는 방식

 

Tensorflow model을 표현하기 위해서 크게 2가지 컴포넌트(Meta Graph, Checkpoint)가 필요하다.

 

  • Meta graph
    • .meta
      • tensorflow graph structure 에 대한 정보
      • Variables, Collection and Operations
  • Checkpoint
    • A protocol buffer with a list of recent checkpoints
    • weight, biases, gradients, all the variables 값
    • .data files: training variables 값에 대한 정보; model.ckpt.data-00000-of-00001
    • .index files:**checkpoint 에 대한 정보(index)

 

TF 1.x 버전 tf.Session()을 통해서 모델을 저장하는 코드는 다음과 같다.

 

with tf.Session() as sess:
    # Initializes all the variables.
    sess.run(init_all_op)
    # Runs to logit.
    sess.run(logits)
    # Creates a saver.
    saver = tf.train.Saver()
    # Save both checkpoints and meta-graph
    saver.save(sess, 'my-save-dir/my-model-10000')            
    # Generates MetaGraphDef.
    saver.export_meta_graph('my-save-dir/my-model-10000.meta') #change this line

 

 

2. 전체 모델을 HDF5 파일 하나에 저장하는 방식

가중치, 모델 구성, 옵티마이저에 지정한 설정까지 파일 하나에 모두 포함된다.

 

model = create_model()

model.fit(train_images, train_labels, epochs=5)

# 전체 모델을 HDF5 파일로 저장합니다
model.save('my_model.h5')

# 가중치와 옵티마이저를 포함하여 정확히 동일한 모델을 다시 생성합니다
new_model = keras.models.load_model('my_model.h5')
new_model.summary()

 

고수준 API인 tf.keras 를 통해서 모델을 저장하고 로드하는 것은 이곳에 잘 정리되어 있다.

https://www.tensorflow.org/tutorials/keras/save_and_restore_models

 

 

끝.

+ Recent posts