Computer/코딩 개꿀팁
[꿀팁] DataLoader Batch결과의 모양이 이상하면 이것부터 체크해보자!
벼랑끝과학자
2023. 7. 24. 05:50
한줄 요약 : 이상한 형태로 batchify가 되는 데이터의 타입을 확인해보고 nd.array가 아니라 list라면 nd.array로 변경하자.
class ProteinSequenceDataset(Dataset):
def __init__(self, df, tokenizer):
self.df = df
self.tokenizer = tokenizer
def __len__(self):
return len(self.df)
def __getitem__(self, item):
p = self.df.iloc[item]['Target Sequence']
s = self.df.iloc[item]['SPS']
d = self.df.iloc[item]['SMILES']
p_v = protein2emb_encoder(p)
# print(type(p_v))
s_v, s_l = SPS2emb_encoder(s)
# print(type(s_v))
d_v, d_l = drug2emb_encoder(d)
# print(type(d_v))
y = self.df.iloc[item]['Label']
return p_v, s_v, d_v, y
# return {
# 'protein_sequence': sequence,
# 'input_ids': encoding['input_ids'],
# 'attention_mask': encoding['attention_mask'],
# 'targets': torch.tensor(target, dtype=torch.long)
# }
Batch size를 4를 준 상태로 tokenization을 하는 DataLoader를 만들고 만들어진 Data를 확인해보았다. 대충 보면 알겠지만, 지금 batch의 방향이 엉뚱하게 되어있다. [21, 8, 10 ,20 ... ] 이렇게 4개가 되어야 할 것이 잘못된 방향으로 batchify를 하고 있다.
다른 친구들은 batchify가 잘 되고 있는데 유독 하나만 그렇길래 왜그런지 한참을 고민했는데 이럴때는 엉뚱하게 batchify 되어 출력되는 데이터의 type을 확인해보자 혹시 nd.array가 아니라 list형태로 되어있다면 이런 잘못된 방향으로의 batchify가 나타날 수 있다.
해당 문제를 일으키는 데이터의 타입을 nd.array로 변경한 뒤 다시 시도해보니 정상적으로 batchify가 되는 것을 확인할 수 있었다.