Computer/이게 왜 안되지?
[이왜안?] [RuntimeError]: Expected floating point type for target with class probabilities, got Long
벼랑끝과학자
2023. 7. 10. 18:06
torch의 nn module 내부의 loss function을 이용해서 loss를 구할 때에는 반드시 input의 type을 float으로 수정해준다.
# 이하 코드는 input의 type이 int(long tensor)였기 때문에 이러한 에러가 난다.
따라서 input에 .float()를 붙여서 float type으로 바꿔주면 문제없이 실행이 된다.