일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | ||||||
2 | 3 | 4 | 5 | 6 | 7 | 8 |
9 | 10 | 11 | 12 | 13 | 14 | 15 |
16 | 17 | 18 | 19 | 20 | 21 | 22 |
23 | 24 | 25 | 26 | 27 | 28 |
Tags
- 설치
- 책
- flag
- 파이썬
- FastAPI
- 역전파
- 기울기
- Apache2
- PHP
- 코딩
- sgd
- Python Challenge
- AdaGrad
- 소프트맥스 함수
- PICO CTF
- picoCTF
- 아파치
- 신경망 학습
- 딥러닝
- 순전파
- 백준
- 리뷰
- Python
- C언어
- HTML
- 오차역전파법
- 신경망
- PostgreSQL
- 우분투
- CTF
Archives
- Today
- Total
Story of CowHacker
딥러닝 5.0 학습 관련 기술들 본문
728x90
이번에는 확률적 경사 하강법 , SGD에 대해 알아보겠다.
그림 1은 SGD를 수식으로 나타낸 것이다.
W는 갱신할 가중치 매개변수고 뒤에 빼는 분자 분모는 손실 함수의 기울기다.
이것을 파이썬으로 구현해보겠다.
SGD 파이썬 코드
class SGD:
def __init__( self, lr = 0.01 ) :
self.lr = lr
def update ( self, params, grads ) :
for key in params.keys() :
params [ key ] -= self.lr * grads [ key ]
위 코드를 해석해보면
초기화받는 인수 lr은 learninig rate ( 학습률 )를 뜻한다.
이것을 인스턴스 변수로 유지한다.
update 메서드는 SGD 과정에서 반복해서 불린다.
인수인 params와 grads는 딕셔너리 변수다.
params [ 'W1' ]. grads [ 'W1' ] 등과 같이 각각 가중치 매개변수와 기울기를 저장하고 있다.
SGD의 단점이 있다.
그림 2와 같은 함수의 최솟값을 찾는 SGD를 그래프화 시켜보겠다.
SGD그래프 코드
import sys, os
sys.path.append(os.pardir) # 부모 디렉터리의 파일을 가져올 수 있도록 설정
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from common.optimizer import *
def f(x, y):
return x**2 / 20.0 + y**2
def df(x, y):
return x / 10.0, 2.0*y
init_pos = (-7.0, 2.0)
params = {}
params['x'], params['y'] = init_pos[0], init_pos[1]
grads = {}
grads['x'], grads['y'] = 0, 0
optimizers = OrderedDict()
optimizers["SGD"] = SGD(lr=0.95)
idx = 1
for key in optimizers:
optimizer = optimizers[key]
x_history = []
y_history = []
params['x'], params['y'] = init_pos[0], init_pos[1]
for i in range(30):
x_history.append(params['x'])
y_history.append(params['y'])
grads['x'], grads['y'] = df(params['x'], params['y'])
optimizer.update(params, grads)
x = np.arange(-10, 10, 0.01)
y = np.arange(-5, 5, 0.01)
X, Y = np.meshgrid(x, y)
Z = f(X, Y)
# 외곽선 단순화
mask = Z > 7
Z[mask] = 0
# 그래프 그리기
plt.subplot(1, 1, idx)
idx += 1
plt.plot(x_history, y_history, 'o-', color="red")
plt.contour(X, Y, Z)
plt.ylim(-10, 10)
plt.xlim(-10, 10)
plt.plot(0, 0, '+')
#colorbar()
#spring()
plt.title(key)
plt.xlabel("x")
plt.ylabel("y")
plt.show()
그림 3을 보면 지그재그로 찾아가는 모습을 볼 수 있다.
이것은 비효율적이므로 SGD의 단점이라고 볼 수 있다.
728x90
'공부 > 딥러닝' 카테고리의 다른 글
딥러닝 5.2 학습 관련 기술들 (0) | 2020.08.20 |
---|---|
딥러닝 5.1 학습 관련 기술들 (0) | 2020.08.19 |
딥러닝 4.5 오차역전파법 (0) | 2020.08.19 |
딥러닝 4.4 오차역전파법 (0) | 2020.08.19 |
딥러닝 4.3 오차역전파법 (0) | 2020.08.18 |
Comments