티스토리 뷰

 

위와 같이 torch의 DataLoader를 사용해서 데이터를 iteration 할때 매번 뭔가 이상하게 KeyError 값이 변하면서 데이터 확인이 안되고 에러가 나고 있다.

이때 만약 내 데이터가 Pandas의 DataFrame 타입이라면 torch.Dataset으로 train_dataset으로 만들어주는 코드에서 iloc를 빼먹지 않았는지 확인해보자

class BinaryDataset(Dataset):
    def __init__(self, feature, target):
        self.feature = feature
        self.target = target
    
    def __len__(self):
        return len(self.target)
    
    def __getitem__(self, idx):
        return self.feature[idx], self.target[idx] # iloc를 빼먹은 코드
        
 class BinaryDataset(Dataset):
    
    ...
    
    def __getitem__(self, idx):
        return self.feature.iloc[idx], self.target.iloc[idx] # iloc를 정상적으로 추가한 코드
댓글