데이터사이언스/강화학습

강화학습 - (24-1) Deep SARSA 코드예제

Johnny Yoon 2020. 12. 7. 23:13
728x90
반응형

 

 

 

 

강화학습

 

패키지 설치

다음 코드는 세가지 패키지가 선행 되어야 합니다.

sudo apt-get install ffmpeg
pip install gym
pip install gym_minigrid

gym.render() 코드가 에러를 발생할 경우, 다음 패키지를 설치하고:
sudo apt-get install xvfb

주피터 노트북을 다음 명령어를 통해 실행합니다:
xvfb-run -s "-screen 0 1400x900x24" jupyter notebook

In [1]:
import numpy as np
import pandas as pd
import random
from collections import defaultdict
import gym

import torch
import torch.nn as nn
import torch.optim as optim

import matplotlib.pyplot as plt
%matplotlib inline
 

환경

예제 코드는 Cart-Pole 예제에서 Expected SARSA 에이전트가 학습하는 코드 입니다.
Cart-Pole이란 막대를 최대한 오래 세워놓는 게임입니다.
Cart-Pole의 상태는 4차원 벡터로 각 값은 다음과 같은 의미를 갖습니다:

  • x : Cart의 가로상의 위치
  • θ : Pole의 각도
  • dx/dt : Cart의 속도
  • dθ/dt : θ의 각속도

게임이 끝나는 상황:

  1. θ가 15˚이상이 되었을 떄,
  2. 원점으로부터의 x의 거리가 2.4이상이 되었을 떄.

에이전트가 게임을 오래 지속하는 만큼 보상을 받고,
에이전트가 취할 수 있는 행동은 다음 두가지 입니다.

  1. 왼쪽으로 이동
  2. 오른쪽으로 이동
 

Deep SARSA 에이전트

Deep SARSA 에이전트 클래스를 만들어 줍니다.
Deep SARSA에이전트는 0.001의 학습율($\alpha$), 0.99의 감가율($\gamma$)을 사용합니다.
Deep SARSA 에이전트는 SARSA를 인공신경망을 사용해 함수 근사를 한 에이전트 입니다.
인공신경망의 레이어 수는 세개, 각각 32개의 노드를 사용합니다.
옵티마이저는 Adam을 사용합니다.

In [2]:
class DeepSARSA:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        self.alpha = 0.001
        self.gamma = 0.99

        self.model = nn.Sequential(
            nn.Linear(self.num_states, 32),
            nn.ReLU(),
            nn.Linear(32, 32),
            nn.ReLU(),
            nn.Linear(32, self.num_actions)
        )
        self.optimizer = optim.Adam(self.model.parameters(), lr=self.alpha)
        
        ...
 

입실론 그리디 정책

Deep SARSA 에이전트는 SARSA와 동일하게, 입실론 그리디 정책을 사용합니다.
입실론 그리디는 $\epsilon$값에 의해 랜덤한 행동을,
$1 - \epsilon$값에 의해 탐욕적으로 행동을 선택합니다.
(랜덤한 행동은 np.random.choice, 탐욕적 행동은 np.argmax에 의해 선정)

In [2]:
class DeepSARSA:
    ...
    
    def act(self, state):
        if np.random.rand() < self.epsilon:
            action = np.random.choice(self.num_actions)
        else:
            q_values = self.model(state)
            action = torch.argmax(q_values).item()
            
        return action
 

입실론 감쇄 (Epsilon Decay)

입실론($\epsilon$)값은 1.0부터 시작해 0.01까지 감소합니다.
매번 학습(update)을 할때마다, 0.99995씩 곱해지는 방식으로 감소합니다.

In [2]:
class DeepSARSA:
    def __init__(self, num_states, num_actions):
        ...

        self.epsilon = 1.
        self.epsilon_decay = .99995
        self.epsilon_min = 0.01
        
        ...
        
    def decrease_epsilon(self):
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
 

Q값 학습

현재상태 (S, state), 현재행동 (A, action), 보상 (R, reward),
다음상태 (S', next_state), 다음행동 (A', next_action)
위 다섯가지 원소를 가지고 시간차(TD)를 학습합니다.
한가지 주의할 점은, 에피소드가 끝나는 시점에서는 미래의 값을 고려하지 않고 학습한다는 점입니다.

In [2]:
class DeepSARSA:
    ...
        
    def update(self, state, action, reward, next_state, next_action, done):
        self.decrease_epsilon()
        self.optimizer.zero_grad()
        
        q_value = self.model(state)[action]
        next_q_value = self.model(next_state)[next_action].detach()
        
        q_target = reward + (1 - int(done)) * self.gamma * next_q_value
        q_error = (q_target - q_value) ** 2
        
        q_error.backward()
        self.optimizer.step()
        
        return q_error.item()
    
    ...
 

환경 & 에이전트 초기화

환경과 Deep SARSA 에이전트를 초기화 합니다.

In [3]:
import gym
from gym import wrappers
env = gym.make('CartPole-v1')
env = wrappers.Monitor(env, "./video", force=True)
observation = env.reset()
agent = DeepSARSA(4,2)

observation
 
/home/johnny/anaconda3/lib/python3.7/site-packages/gym/logger.py:30: UserWarning: WARN: Box bound precision lowered by casting to float32
  warnings.warn(colorize('%s: %s'%('WARN', msg % args), 'yellow'))
Out[3]:
array([ 0.02680645,  0.00791713, -0.03561815,  0.00689243])
 

에피소드 학습

500번의 에피소드를 통해 학습합니다.
각 10번의 에피소드마다 리워드 값을 출력합니다.

In [5]:
rewards = []
for ep in range(500):
    done = False
    obs = torch.FloatTensor(env.reset())
    action = agent.act(obs)

    ep_rewards = 0
    losses = []
    while not done:
        next_obs, reward, done, info = env.step(action)
        next_obs = torch.FloatTensor(next_obs)

        next_action = agent.act(next_obs)

        loss = agent.update(obs, action, reward, next_obs, next_action, done)
        losses.append(loss)
        
        ep_rewards += reward
        obs = next_obs
        action = next_action
    rewards.append(ep_rewards)
    ep_loss = sum(losses) / len(losses)
    if (ep+1) % 10 == 0:
        print("episode: {}, eps: {:.3f}, loss: {:.1f}, rewards: {}".format(ep+1, agent.epsilon, ep_loss, ep_rewards))
env.close()
 
episode: 10, eps: 0.991, loss: 2.2, rewards: 17.0
episode: 20, eps: 0.976, loss: 2.7, rewards: 68.0
episode: 30, eps: 0.964, loss: 11.4, rewards: 27.0
episode: 40, eps: 0.952, loss: 27.2, rewards: 12.0
episode: 50, eps: 0.942, loss: 5.3, rewards: 68.0
episode: 60, eps: 0.934, loss: 10.1, rewards: 16.0
episode: 70, eps: 0.922, loss: 11.9, rewards: 26.0
episode: 80, eps: 0.913, loss: 9.6, rewards: 21.0
episode: 90, eps: 0.905, loss: 3.8, rewards: 32.0
episode: 100, eps: 0.895, loss: 13.9, rewards: 17.0
...
episode: 400, eps: 0.369, loss: 22.9, rewards: 228.0
episode: 410, eps: 0.340, loss: 3.1, rewards: 149.0
episode: 420, eps: 0.302, loss: 11.2, rewards: 500.0
episode: 430, eps: 0.283, loss: 11.5, rewards: 156.0
episode: 440, eps: 0.258, loss: 6.6, rewards: 198.0
episode: 450, eps: 0.240, loss: 1.8, rewards: 163.0
episode: 460, eps: 0.220, loss: 3.3, rewards: 229.0
episode: 470, eps: 0.181, loss: 13.3, rewards: 500.0
episode: 480, eps: 0.149, loss: 15.5, rewards: 500.0
episode: 490, eps: 0.127, loss: 66.8, rewards: 76.0
episode: 500, eps: 0.099, loss: 14.3, rewards: 500.0
 

에피소드 시각화

다음은 최종 에피소드의 영상입니다.

In [6]:
from utils import show_video

show_video()
 
 

학습 곡선 시각화

다음은 학습 시 받은 보상의 이동 평균(Moving Average)을 시각화 한 그래프 입니다.

In [7]:
pd.Series(rewards).to_csv('./logs/rewards_deepsarsa.csv')
 
 
In [8]:
plt.figure(figsize=(16, 8))
plt.plot(pd.Series(rewards).cumsum() / (pd.Series(np.arange(len(rewards)))+1))
Out[8]:
[<matplotlib.lines.Line2D at 0x7f0038125048>]
 
 

전체 코드

전체 코드는 다음 깃허브에서 확인할 수 있습니다.

참조 사이트

728x90
반응형