티스토리 뷰

해당 오류는 pytorch 1.12 이하 버전에서는 제공하지 않는 Transformer Masking기능을 사용하려 했기 때문에 발생한다.

pytorch 버전을 1.13버전 이상으로 업데이트 해주면 해결된다.

아래 링크를 참고하고, 본인은 해당 코드로 토치를 싸그리 업데이트해주니 해결되었다.

conda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia

 

https://discuss.pytorch.org/t/runtimeerror-mask-shape-should-match-input-shape-transformer-mask-is-not-supported-in-the-fallback-case/169690

댓글