티스토리 뷰

nn.Embedding에 대한 이해가 제대로 되어있지 않아 여러가지 문제가 있었다.

다음과 같은 문제들이 맞는지 확인해보자

 

일단 nn.Embedding(n, m)은 (n개의 정수값을 인덱스로 하는, m차원의 테이블)을 만드는 역할을 한다고 생각하면 된다.

 

예를들어 temp_embedding = nn.Embedding(2, 4)라고 하면

0 0.1123 -0.5321 1.1232 0.8737
1 -2.0012 1.2231 0.6653 0.5531

처럼 0~1 (2개)의 인덱스를 가지는 4차원의 테이블이 만들어진다. 이것을 보통 lookup table이라 표현한다.

다음으로 내가 가진 데이터를 저 인덱스를 통해서 vector의 형태로 불러오는 것이다.

예를들어 내가 가지고 있는 텐서 temp = [0, 0 ,1, 0, 1] 이었다면 이 temp텐서를 nn.embedding으로 만들어둔 embedding table에 전달하면 (temp_embedding(temp)을 하면) 다음과 같은 (5, 4)차원의 embedded vector가 얻어진다.

(맨 왼쪽의 column은 temp_embedding(temp)에는 포함되지 않고 이해를 위해 추가로 넣은 열이다.)

0 0.1123 -0.5321 1.1232 0.8737
0 0.1123 -0.5321 1.1232 0.8737
1 -2.0012 1.2231 0.6653 0.5531
0 0.1123 -0.5321 1.1232 0.8737
1 -2.0012 1.2231 0.6653 0.5531

 

1. nn.Embedding은 반드시 정수값을 전달해야한다.

temp = torch.rand(size=(3,4))
print(temp)

"""
tensor([[0.5088, 0.6054, 0.4537, 0.7441],
        [0.5973, 0.3509, 0.2135, 0.6883],
        [0.0675, 0.3632, 0.0307, 0.7911]])
"""
# temp = temp.long() # temp를 LongTensor형태로 전환하면 에러가 해결된다.

temp_emb = nn.Embedding(12, 5)
temp_emb(temp)

"""RuntimeError: Expected tensor for argument #1 'indices'
to have one of the following scalar types: Long, Int; but got torch.FloatTensor instead 
(while checking arguments for embedding)"""

 

위처럼 내가 Embedding으로 전달하는 값들이 LongTensor형태의 정수가 아니라면 런타임 에러를 발생한다. 당연히 정수를 인덱스로 하는 lookup table을 만들어 놓고 0.271번 인덱스의 값을 불러와달라 하는 꼴이니 컴퓨터 입장에서는 에러를 출력할 수 밖에 없다.

비유하자면, 중국집에 가서 사장님 짜장면 0.271 인분 주세요 하는꼴이다. (당연히 안주겠죠?)

 

2. 첫 번째 인자 num_embeddings는 input되는 Tensor의 원소 범위보다 커야한다.

 

temp = torch.randint(5, size=(3,4)) 
print(temp)
"""
tensor([[3, 0, 4, 2],
        [0, 2, 2, 1],
        [2, 2, 2, 3]])
tensor 내부 요소 범위 [0,1,2,3,4] 5개가 존재한다.
그러면 Embedding을 위한 embedding table의 index는 최소한 0~4보다는 커야한다.
"""

# 따라서 num_embeddings는 최소 5보다는 커야한다.

temp_emb = nn.Embedding(num_embeddings=5, embedding_dim=4)
temp_emb(temp)
"""
tensor([[[-0.8807, -1.2626,  1.3101],
         [ 1.6741,  0.5796,  0.9555],
         [-0.8807, -1.2626,  1.3101],
         [-1.4088, -1.1644,  2.2974]],

        [[ 1.6741,  0.5796,  0.9555],
         [-0.8807, -1.2626,  1.3101],
         [ 1.6741,  0.5796,  0.9555],
         [-0.8807, -1.2626,  1.3101]],

        [[-2.0173, -0.8481, -0.6244],
         [-0.8807, -1.2626,  1.3101],
         [-1.3328,  2.9311,  0.6453],
         [-0.8807, -1.2626,  1.3101]]])
"""

##########################################################

temp = torch.randint(5, size=(3,4)) # 0,1,2,3,4 -> 5보다는 커야한다.
print(temp)
"""
tensor([[2, 2, 0, 1],
        [3, 2, 4, 1],
        [1, 0, 3, 4]])
"""

temp_emb = nn.Embedding(num_embeddings=3, embedding_dim=3) 
# 0~2까지만 존재하는 embedding table이 만들어지므로 temp의 3과 4를 embedding 시킬 index가 없다.
# 따라서 index out of range Error가 발생한다.
temp_emb(temp)

# IndexError: index out of range in self

 

이것도 자명하다. 내 lookup table에는 0~3까지밖에 없는데 이 범위를 넘어서는 4라는 정수를 인덱스로 하는 값을 불러오라 하는 꼴이니 컴퓨터 입장에서는 이해를 못한다.

비유하자면 중국집에 가서 숯불닭갈비 주세요 하는꼴이다. (당연히 메뉴에 없으니 안주겠죠?)

댓글