Story of CowHacker

딥러닝 5.0 학습 관련 기술들 본문

공부/딥러닝

딥러닝 5.0 학습 관련 기술들

Cow_Hacker 2020. 8. 19. 17:19
728x90

이번에는 확률적 경사 하강법 , SGD에 대해 알아보겠다.

 

 

 

 

 

그림1

 

그림 1은 SGD를 수식으로 나타낸 것이다.

 

 

W는 갱신할 가중치 매개변수고 뒤에 빼는 분자 분모는 손실 함수의 기울기다.

 

이것을 파이썬으로 구현해보겠다.

 

 

 

SGD 파이썬 코드

class SGD:
    def __init__( self, lr = 0.01 ) :
        self.lr = lr

    def update ( self, params, grads ) :
        for key in params.keys() :
            params [ key ] -= self.lr * grads [ key ]

 

위 코드를 해석해보면

초기화받는 인수 lr은 learninig rate ( 학습률 )를 뜻한다.

이것을 인스턴스 변수로 유지한다.

 

update 메서드는 SGD 과정에서 반복해서 불린다.

인수인 params와 grads는 딕셔너리 변수다.

 

params [ 'W1' ]. grads [ 'W1' ] 등과 같이 각각 가중치 매개변수와 기울기를 저장하고 있다.

 

 

 

 

그림2

 

SGD의 단점이 있다.

그림 2와 같은 함수의 최솟값을 찾는 SGD를 그래프화 시켜보겠다.

 

 

 

 

 

 

 

SGD그래프 코드

 

import sys, os
sys.path.append(os.pardir)  # 부모 디렉터리의 파일을 가져올 수 있도록 설정
import numpy as np
import matplotlib.pyplot as plt
from collections import OrderedDict
from common.optimizer import *


def f(x, y):
    return x**2 / 20.0 + y**2


def df(x, y):
    return x / 10.0, 2.0*y

init_pos = (-7.0, 2.0)
params = {}
params['x'], params['y'] = init_pos[0], init_pos[1]
grads = {}
grads['x'], grads['y'] = 0, 0


optimizers = OrderedDict()
optimizers["SGD"] = SGD(lr=0.95)
idx = 1

for key in optimizers:
    optimizer = optimizers[key]
    x_history = []
    y_history = []
    params['x'], params['y'] = init_pos[0], init_pos[1]
    
    for i in range(30):
        x_history.append(params['x'])
        y_history.append(params['y'])
        
        grads['x'], grads['y'] = df(params['x'], params['y'])
        optimizer.update(params, grads)
    

    x = np.arange(-10, 10, 0.01)
    y = np.arange(-5, 5, 0.01)
    
    X, Y = np.meshgrid(x, y) 
    Z = f(X, Y)
    
    # 외곽선 단순화
    mask = Z > 7
    Z[mask] = 0
    
    # 그래프 그리기
    plt.subplot(1, 1, idx)
    idx += 1
    plt.plot(x_history, y_history, 'o-', color="red")
    plt.contour(X, Y, Z)
    plt.ylim(-10, 10)
    plt.xlim(-10, 10)
    plt.plot(0, 0, '+')
    #colorbar()
    #spring()
    plt.title(key)
    plt.xlabel("x")
    plt.ylabel("y")
    
plt.show()

 

 

 

그림3

그림 3을 보면 지그재그로 찾아가는 모습을 볼 수 있다.

이것은 비효율적이므로 SGD의 단점이라고 볼 수 있다.

 

 

 

 

 

 

 

 

 

 

 

728x90

'공부 > 딥러닝' 카테고리의 다른 글

딥러닝 5.2 학습 관련 기술들  (0) 2020.08.20
딥러닝 5.1 학습 관련 기술들  (0) 2020.08.19
딥러닝 4.5 오차역전파법  (0) 2020.08.19
딥러닝 4.4 오차역전파법  (0) 2020.08.19
딥러닝 4.3 오차역전파법  (0) 2020.08.18
Comments