Computer/코딩 개꿀팁

벡터의 norm을 구하는 torch.linalg.norm

벼랑끝과학자 2025. 6. 27. 23:32

torch.linalg.norm함수는 N차원의 벡터의 norm값을 구하는 함수다.
다음과 같이 사용한다. 

torch.linalg.norm(input, ord=2, dim, keepdim)

v = torch.tensor([
    [[3.0, 4.0, 5.0], 
     [1.0, 2.0, 4.0]],
    [[0.0, 5.0, 5.0], 
     [6.0, 8.0, 5.0]],
    [[5.0, 3.0, 9.0], 
     [2.0, 1.0, 1.0]]
])  # shape = (3, 2, 3) [BS, L, D]

#### dim=0 (BS) ####
print(torch.linalg.norm(v, ord=2, dim=0, keepdim=True))
# tensor([[[ 5.8310,  7.0711, 11.4455],
#         [ 6.4031,  8.3066,  6.4807]]]) # shape = [1, L, D]

print(torch.linalg.norm(v, ord=2, dim=0, keepdim=False))
# tensor([[ 5.8310,  7.0711, 11.4455],
#        [ 6.4031,  8.3066,  6.4807]]) # shape = [L, D]

#### dim=1 (L) ####
print(torch.linalg.norm(v, ord=2, dim=1, keepdim=True))
# tensor([[[3.1623, 4.4721, 6.4031]],
#         [[6.0000, 9.4340, 7.0711]],
#         [[5.3852, 3.1623, 9.0554]]]) # shape [BS, 1, D]

print(torch.linalg.norm(v, ord=2, dim=1, keepdim=False))
# tensor([[3.1623, 4.4721, 6.4031],
#         [6.0000, 9.4340, 7.0711],
#         [5.3852, 3.1623, 9.0554]]) # shape [BS, D]

#### dim=2 (D) ####
print(torch.linalg.norm(v, ord=2, dim=2, keepdim=True))
# tensor([[[ 7.0711],
#          [ 4.5826]],
#         [[ 7.0711],
#          [11.1803]],
#         [[10.7238],
#          [ 2.4495]]]) # shape [BS, L, 1]
         
print(torch.linalg.norm(v, ord=2, dim=2, keepdim=False))
# tensor([[ 7.0711,  4.5826],
#         [ 7.0711, 11.1803],
#         [10.7238,  2.4495]]) # shape [BS, L]

여기서 input은 N차원의 벡터, ord는 L_ord norm을 구하는 옵션으로 2를 입력해 유클리디안 거리와 동일한 의미를 갖는 L2 Norm 값을 기준으로 계산하는 것이 일반적이다.

dim은 어떤 차원을 따라서 norm을 계산할지를 지정하는 것인데 이게 참 본인이 잘 이해해두지 않았으면 헷갈리고 난해하다.

동영상데이터를 다루는 사람들이 아니면 보통은 3차원 텐서까지만 다루면 되기 때문에 3차원을 기준으로 말해보자면

  • dim=0는 배치 방향으로의 텐서 [BS]
  • dim=1은 각 배치 내 샘플의 세로 방향으로의 텐서 [L]
  • dim=2는 각 배치 내 샘플의 가로 방향으로의 텐서 [D]

이렇게 생각하면 된다. 말로는 이해하기 당연히 어렵고 또 심지어 저자마다 이 순서를 본인 입맛에 맞게 변형해서 사용하는 경우도 간혹 있기에 고급 수준의 텐서를 다루는 딥러닝을 구현하기 위해서는 본인만의 텐서 차원 이해 방식을 반드시 확립해둬야한다.

또한 많은 경우에 결과 텐서를 기존 텐서에 추가로 연산하기 위해 keepdim이라는 옵션을 True로 주는 경우가 있는데, 이것에 대한 이해도 해두는 것이 좋다.

나는 다음처럼 이해하고 있다.




그림이 sum을 예시로 보이고 있지만 모든 텐서를 다루는 dim은 동일한 방향을 의미하므로 그냥 norm의 경우 저 방향대로 norm값을 구하면 될 뿐이다. sum이나 mean이나 norm이나 결과의 shape은 동일하다.

다시한번 말하지만 텐서의 차원은 다른사람의 포스팅은 참고만 하고 반드시 본인이 직접 종이에 그려가면서 dim=0,1,2가 어떤 방향인지 확립을 해둬야한다.