티스토리 뷰

Computer/코딩 개꿀팁

inf_iterator는 왜 쓰는걸까?

벼랑끝과학자 2025. 6. 24. 16:32

가끔보면 다음처럼 torch의 DataLoader를 inf_iterator라는 함수를 따로 만들어서 태우는 경우가 있다.

def inf_iterator(iterable):
    iterator = iterable.__iter__()
    while True:
        try:
            yield iterator.__next__()
        except StopIteration:
            iterator = iterable.__iter__()

train_iterator = inf_iterator(DataLoader(
        train_dataset, 
        batch_size=config.train.batch_size, 
        collate_fn=PaddingCollate(), 
        shuffle=True,
        num_workers=args.num_workers
    ))

근데 Prediction 모델만 학습해본 경험이 있는 나로서는 DataLoader만 적용해도 알아서 epoch를 잘 돌던 기억이 있는데 왜 굳이 inf_iterator라는걸 구현해서 사용하는건지? 이해가 가지 않았는데 굳이 찾아보지는 않았음

오늘 찾아보니까 예측모델과 달리 생성모델처럼 epoch를 설정하지 않고 100,000 step처럼 배치 데이터 단위로 학습을 설정하는 경우에는 아무래도 귀찮게 step/batch_size를 계산해서 epoch를 설정하는 것보다 그냥 배치 데이터 단위로 100,000번 학습하도록 하는 것이 더 직관적이기 때문에 inf_iterator 함수를 적용해주는 편이 더 좋다고 한다.

특히 DDPM 기반의 생성모델들은 T step이 보통 1000의 배수를 가지기 때문에 모든 스텝 t별로 학습되는 횟수의 기대값을 계산하기 더 편하지 않을까 싶다.

댓글