티스토리 뷰
https://biomadscientist.tistory.com/61
지난 포스팅에서는 RBM의 기본적인 구조와 RBM에서 각 입력벡터 P(x)의 발생 확률분포를 얻는 방법을 공부하였다. 이번 포스팅에서는 얻어진 P(x)값을 이용해서 어떻게 RBM을 학습시킬지에 대해서 포스팅해보려 한다.
2. RBM 학습
RBM의 목적은 훈련집합 X = {x1, x2, ... , xn}에 속한 특징벡터는 높은 확률로 발생시키고 X에 없는 샘플은 낮은 확률로 발생시키는 것이다. 따라서 RBM은 레이블 정보가 없어도 학습되는 비지도 학습이다.
훈련집합에 속한 특징 벡터를 높은 확률로 발생시킨다는게 무슨말인지 이해가 안갈 것이다. 조금 더 심도있게 RBM의 학습이 어떻게 이뤄지는지 다뤄보자
목적함수
이전 포스팅에서 다뤘던 예제를 다시한번 생각해보자
예제를 통해 우리는 현재 상태의 RBM에서 각각의 가시노드 발생 확률P(x)를 다음과 같이 구할 수 있었다.
이때 만약 우리가 가지고있는 훈련집합이 X = {(1, 0, 0)T, (1, 0, 1)T, (1, 1, 0)T, (1, 1, 1)T}라고 해보자. 이때 예제의 RBM에서는 훈련집합에 속하지 않는 데이터들도 발생확률이 비슷하다. 따라서 예제문제에서 제시된 RBM의 파라미터는 우리가 가지고있는 훈련집합의 데이터를 높은 확률로 발생시키기에는 부적절하다.
그럼 이런 RBM은 어떤가?
구조는 완전히 동일하지만, 각각의 노드를 연결한 가중치의 값과 bias값들이 바뀐 새로운 RBM구조를 이용해 계산해보면 다음과 같다.
훈련집합에 속하는 벡터들의 발생확률이 훨씬 높은 RBM임을 확인할 수 있다. 이렇게 훈련집합에 속한 데이터를 얻을 확률이 더 높은 RBM을 얻는것이 우리의 목표이며 가중치 집합 Θ= {W, a, b} 임을 이전 포스팅에서 공부하였다.
결론적으로 RBM의 목적함수는 다음과 같다.
거대한 파이(π) 기호는 product로써 훈련집합 X에 속하는 모든 input vector x들의 각각의 발생확률의 곱을 뜻한다. 그러나 확률의 정의에 의해 각각의 P(x)값은 반드시 0 ~ 1 사이의 값을 가지며, 대체로 훈련집합에 속하는 데이터의 수 n은 최소 수십개에서 수천 수만개까지 존재할 수 있기 때문에 이것을 모두 곱하여 나가면 그 수가 너무 작아져서 훈련이 불가능한 경우가 많다. 따라서 주로 product 연산을 그대로 사용하기보다는 로그를 취해 곱연산을 합연산으로 바꿔 훈련한다.
RBM의 학습 알고리즘은 목적함수 J(Θ)를 이용해 최적의 파라미터 집합을 찾아내야 하는데 수식으로 나타내면 다음과 같다.
대조 발산 알고리즘(Contrastive Diverence algorithm, CD)
RBM의 목적함수 학습 알고리즘 10.19에는 치명적인 약점이 있다. 다시한번 예제에 제시된 RBM의 구조를 살펴보자 노드가 겨우 5개인데 우리는 partition function(Z) 값을 구하기 위해 총 25 = 32개의 energy 값을 계산했어야 했다. 그럼 이제 노드가 50개로 증가한 경우를 살펴보자 partition function(Z)값을 구하기 위해 우리는 250 = 1,125,899,906,842,624개의 energy값을 계산해야한다. time complexity가 exponentially 증가하므로 O(2n)이다.
알고리즘에 대한 지식이 조금 있다면 이해가 쉽겠지만 시간복잡도에 대한 개념이 없다면 그냥 내가 짠 모델의 시간복잡도가 빨간색 영역에 속하는 알고리즘은 사실상 사용이 불가능하다고 생각하면 된다.
따라서 이러한 학습 알고리즘을 통해서는 최적의 RBM 파라미터 집합을 학습하는건 불가능하므로 근사해를 구하는 알고리즘을 힌튼교수님이 제시하는데 그것이 바로 대조 발산 알고리즘(CD)이다.
대조 발산 알고리즘도 MLP를 학습하는 Gradient Descent 방식과 유사하긴 하지만 목적함수 자체가 Loss여서 이를 최소화 해야하는 GD 방식과는 달리 RBM의 목적함수는 P(x)값이고 이를 최대화 하는것이 목적이므로 Gradient Ascent(GA)방식을 사용한다는 차이가 존재한다.
RBM의 가중치 변화량은 다음과 같이 대조 발산을 통해 계산한다.
여기서 < >기호는 KL divergence를 의미한다. 즉, 대조 발산은 data와 model분포의 KLD값의 차이를 의미한다. 식 10.20을 이용해서 가중치를 업데이트하는 경사 상승 알고리즘을 이용하면 가중치는 다음과 같이 업데이트된다.
식은 알겠는데 정작 대조발산값을 어떻게 구하는가? <>data는 뭐고 <>model은 무엇인가?
대조발산의 값을 구하기 위해 RBM의 샘플 발생 능력을 이용하며 특히 깁스 샘플링(Gibbs sampling)방식을 이용한다.
깁스 샘플링(Gibbs sampling)
깁스 샘플링은 일단 RBM의 가시층에 샘플하나를 무작위로 입력하면서 시작된다. 지난 포스팅에서 다루었던 간단한 예제를 이용해 깁스샘플링과 대조발산값에 대한 이해를 해보자
먼저 깁스 샘플링에 대해서 이해하기 위해서 그림에서 가시노드와 은닉 노드가 각자 3개와 2개인것처럼 보이지만 각각의 노드가 d개, m개 있다고 가정하자.
1. 입력 노드에 랜덤 샘플 x = (x1, x2 ... ,xd)T가 입력되었다고 생각하자. 가시 노드의 벡터가 고정되었기 때문에 은닉노드의 값을 결정할 수 있다. 깁스 샘플링 방식에 의해 은닉노드 hj의 값은 다음과 같이 결정된다.
식 10.22에서 구한 p(hj=1|x)값이 0 ~ 1사이의 랜덤 난수보다 크면 hj=1이되고 그렇지 않으면 0이된다.
2. 식 10.22~10.23을 이용해 h1, h2, ... hm까지의 은닉노드를 샘플링한다.
3. 모든 은닉 노드가 결정(샘플링)되었으면 반대로 샘플링된 은닉노드를 이용해 이번에는 반대로 가시노드를 샘플링한다.
4. 식 10.24~10.25를 이용해 x1, x2, ... xd까지의 가시노드를 샘플링한다.
5. 입력된 샘플로부터 새로운 샘플을 얻었다. 깁스샘플링에 의해 샘플링되는 모든 노드들은 서로 독립적으로 샘플링된다.
이번엔 실제로 예제에서 제공한 RBM을 이용해 깁스샘플링을 해보자.
1. x = (1, 0, 0)T가 입력되었다고 가정하자 먼저 각각의 은닉노드 hj가 1일 확률을 계산하면 다음과 같다.
P(h1 = 1|x) = sigmoid( 0.1 + (1 x 0.1 + 0 x 0.0 + 0 x 0.2) ) = 0.549834
P(h2 = 1|x) = sigmoid( 0.2 + (1 x (-0.2) + 0 x (-0.1) + 0 x 0.1) ) = 0.5
2. 첫번째 난수로 0.48, 두번째 난수로 0.64가 얻어졌다면 h1 = 1 , h2 = 0이 된다.
3. 얻어진 은닉노드로부터 반대로 가시노드 xi를 샘플링한다
P(x1 = 1|h) = sigmoid( 0.2 + (1 x 0.1 + 0 x (-0.2)) ) = 0.574443
P(x2 = 1|h) = sigmoid( -0.1 + (1 x 0.0 + 0 x (-0.1)) ) = 0.524979
P(x3 = 1|h) = sigmoid( 0.0 + (1 x 0.2 + 0 x 0.1) ) = 0.549834
4. 3개의 난수값이 0.24, 0.61, 0.52가 나왔다고 하자, x1 = 1, x2 = 0, x3 = 1가 된다. 즉 x = (1, 0, 1)T값이 샘플링되었다. 다시말해 해당 예제의 RBM이 x = (1, 0, 0)T을 입력받은 결과 깁스샘플링을 통해 (1, 0, 1)T가 샘플링되었다.
마지막으로 RBM 학습을 위한 대조 발산 알고리즘은 다음과 같다.
입력 : 훈련집합 X = {x1, x2 ... ,xn}
출력 : 학습된 RBM
repeat
for x_k in X:
sampling hidden vector h with 10.22 ~ 10.23
sampling visible vector x` with 10.24 ~ 10.25
sampling hidden vector h` with 10.22 ~ 10.23
calc <x_ih_j>_data with x_k and h
calc <x_ih_j>_model with x` and h`
update w_ij with 10.21
until stop signal
다음 포스팅에서는 RBM 구조를 쌓아올린 DBN의 개념에 대해서 설명할 것이다. 추가로 KL Divergence 개념에 대해서도 나중에 포스팅을 해볼 예정에 있다.
https://biomadscientist.tistory.com/67
↓ 내용이 혹시나 도움되셨다면 좋아요 눌러주세요 꾸준한 포스팅에 매우 큰 응원이 됩니다 🥰
'Background > Math' 카테고리의 다른 글
[오일석 기계학습] 10.4 - 확률 그래피컬 모델 RBM과 DBN (3) (0) | 2023.05.06 |
---|---|
[오일석 기계학습] 2.2 수학 - 가우시안 분포/ 베르누이/ 이항분포 (1) | 2023.04.29 |
[오일석 기계학습] 10.4 - 확률 그래피컬 모델 RBM과 DBN (1) (0) | 2023.04.26 |
[오일석 기계학습] 2.2 수학 - 확률과 통계 - 최대우도법 (0) | 2023.04.26 |
Xavier initializer (0) | 2023.04.16 |
- Total
- Today
- Yesterday
- kl divergence
- kld
- Matrix algebra
- 선형대수
- 인공지능
- 최대우도추정
- 파이썬
- manim
- variational autoencoder
- ai인공지능
- Manimlibrary
- manimtutorial
- ai신약개발
- 3b1b
- eigenvector
- marginal likelihood
- 제한볼츠만머신
- 오일석기계학습
- vae
- 백준
- 기계학습
- 3B1B따라잡기
- MorganCircularfingerprint
- MatrixAlgebra
- elementry matrix
- MLE
- eigenvalue
- manim library
- 이왜안
- 베이즈정리
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |