티스토리 뷰

종종 nn.Module을 상속한 class의 __init__에서 self.register_buffer라는 코드를 마주하게 되는데 몇번 공부하긴 했다만 매번 까먹어서 정리한다.

 

self.register_buffer(name, tensor)

self.register_buffer는 학습은 하지 않지만, 모델과 함께 저장/불러오기되는 텐서를 등록하는 코드이다.

register_buffer에 등록된 파라미터들은 모델의 역전파가 되는 과정에서 학습되지 않지만 모델을 저장할 때, 예를들어 state_dict등에는 등록된다. 또한 .to(device)등을 이용해 CPU나 CUDA로 옮기는 기능도 정상적으로 작동한다.

예를들어 class init 내부에 다음과 같은 코드가 있다면

class MyClass(nn.Module):

	def __init__(self,num_funcs=3):
		self.register_buffer('freq_bands', 
        	torch.FloatTensor([i+1 for i in range(num_funcs)] + [1./(i+1) for i in range(num_funcs)]))
            # [i+1 for i in range(3)] → [1, 2, 3]
			# [1./(i+1) for i in range(3)] → [1.0, 0.5, 0.333...]
            # 따라서 self.freq_bands = torch.Tensor([1, 2, 3, 1.0, 0.5, 0.333...])
        ...
        
    def forward(self, x):
    	return x * self.freq_bands

이런 형태로 사용할 수 있으며 한번 buffer에 등록된 self.freq_bands = torch.Tensor([1, 2, 3, 1.0, 0.5, 0.333...])는 backpropagation과정을 거치더라도 업데이트되지 않고 모든 iteration동안 torch.Tensor([1, 2, 3, 1.0, 0.5, 0.333...])값을 유지한 채 사용된다.

보통 다음과 같은 경우의 텐서들을 buffer로 등록해서 사용한다.

 

  • positional encoding
  • mask matrix
  • Fourier freq bands
  • adjacency matrix
  • atom type embedding index 등

 

댓글