티스토리 뷰

torch의 nn module 내부의 loss function을 이용해서 loss를 구할 때에는 반드시 input의 type을 float으로 수정해준다.

# 이하 코드는 input의 type이 int(long tensor)였기 때문에 이러한 에러가 난다.

따라서 input에 .float()를 붙여서 float type으로 바꿔주면 문제없이 실행이 된다.

댓글