티스토리 뷰

class ToDataset(Dataset):
    def __init__(self, df):
        self.df = df
        # self.src = df['src']
        # self.tgt = df['tgt']
        # self.label = df['pKa']

    def __len__(self):
        return len(self.df)

    def __getitem__(self, index):
        src = self.df.iloc[index]['src']
        src = src_embedder(torch.tensor(src, requires_grad=False))

        tgt = self.df.iloc[index]['tgt']
        tgt = tgt_embedder(torch.tensor(tgt, requires_grad=False))

        label = self.df.iloc[index]['pKa']

        return src, tgt, label

 

이렇게 Dataset으로 변환하는 class를 만들고 제대로 변환되었는지 확인하기 위해 아래와 같이 확인하려는데

for i, (src, tgt, label) in enumerate(train_generator):
    # print(src, tgt, label)
    break

 

심상찮은 에러가 뜬다.

 

뭐라는지 모르겠고 중요한건 맨 마지막 줄이겠지, 읽어봤지만 생소하다. 위에서 코드를 보면 알겠지만, requires grad가 붙어있어서 문제가 되는것 같아서 requires_grad 옵션을 False로 줬음에도 불구하고 같은 에러가 떴다.

 

해결법은, 아예 grad에 대한 기록 자체를 detach()시킨 데이터를 return해야 불러올 수 있었다.

 

즉 class ToDataset에서 return할 데이터는 다음과 같이 tensor에서 grad는 detach시켜야한다.

class ToDataset(Dataset):
	
	...

	def __getitem__(self, index):
			
            ...
        	
            return src.deatch(), tgt.detach(), label

 

댓글