Computer/이게 왜 안되지?

[이왜안?] [RuntimeError]: Expected floating point type for target with class probabilities, got Long

벼랑끝과학자 2023. 7. 10. 18:06

torch의 nn module 내부의 loss function을 이용해서 loss를 구할 때에는 반드시 input의 type을 float으로 수정해준다.

# 이하 코드는 input의 type이 int(long tensor)였기 때문에 이러한 에러가 난다.

따라서 input에 .float()를 붙여서 float type으로 바꿔주면 문제없이 실행이 된다.