Loading [MathJax]/jax/output/CommonHTML/jax.js

카테고리 없음

강화학습 - (19-1) SARSA 코드예제

_금융덕후_ 2020. 11. 1. 23:23
728x90
반응형

 

 

강화학습

SARSA 코드예제

 

 

패키지 설치

다음 코드는 세가지 패키지의 설치가 선행 되어야 합니다.

sudo apt-get install ffmpeg
pip install gym
pip install gym_minigrid

In [1]:
import numpy as np
import pandas as pd
import random
from collections import defaultdict
import gym
import gym_minigrid
import matplotlib.pyplot as plt
%matplotlib inline
 

환경

예제 코드는 그리드월드 예제에서 SARSA에이전트가 학습하는 코드 입니다.
에이전트가 최종 지점에 도달하면 보상을 받고,
에이전트가 취할 수 있는 행동은 다음 세가지 입니다.

  1. 왼쪽으로 회전
  2. 오른쪽으로 회전
  3. 전진
 

SARSA 에이전트

먼저 SARSA 에이전트 클래스를 만들어 줍니다.
SARSA에이전트는 0.01의 학습율(α), 0.9의 감가율(γ),
그리고 0.2의 입실론(ϵ)값을 가지고 학습합니다.
그리고 경험에 의해 학습될 q값을 저장합니다.

In [2]:
class SARSA:
    def __init__(self, actions, agent_indicator=10):
        self.actions = actions
        self.agent_indicator = agent_indicator
        self.alpha = 0.01
        self.gamma = 0.9
        self.epsilon = 0.2
        self.q_values = defaultdict(lambda: [0.0] * actions)
    
    ...
 

입실론 그리디 정책

SARSA 에이전트는 입실론 그리디 정책을 사용합니다.
입실론 그리디는 ϵ값에 의해 랜덤한 행동을,
1ϵ값에 의해 탐욕적으로 행동을 선택합니다.
(랜덤한 행동은 np.random.choice, 탐욕적 행동은 np.argmax에 의해 선정)

In [3]:
class SARSA:
    ...
    
    def act(self, state):
        if np.random.rand() < self.epsilon:
            action = np.random.choice(self.actions)
        else:
            state = self._convert_state(state)
            q_values = self.q_values[state]
            action = np.argmax(q_values)
        return action
    
    ...
 

Q값 학습

현재상태 (S, state), 현재행동 (A, action), 보상 (R, reward),
다음상태 (S', next_state), 다음행동 (A', next_action)
위 다섯가지 원소를 가지고 시간차(TD)를 학습합니다.

In [3]:
class SARSA:
	...
    
    def update(self, state, action, reward, next_state, next_action):
        state = self._convert_state(state)
        next_state = self._convert_state(next_state)
        
        q_value = self.q_values[state][action]
        next_q_value = self.q_values[next_state][action]
        
        td_error = reward + self.gamma * next_q_value - q_value
        self.q_values[state][action] = q_value + self.alpha * td_error
 

환경 & 에이전트 초기화

환경과 에이전트를 초기화 합니다.
환경을 SARSA에이전트가 더 쉽게 학습하게 하기 위해 warpping을 시키는 코드가 추가되었습니다.

In [4]:
from utils import gen_wrapped_env, show_video
env = gen_wrapped_env('MiniGrid-Empty-6x6-v0')
obs = env.reset()

agent_position = obs[0]

agent = SARSA(3, agent_position)
 

에피소드 학습

5000번의 에피소드를 통해 학습합니다.
각 20번의 에피소드마다 리워드 값을 출력합니다.

In [5]:
rewards = []
for ep in range(5000):
    done = False
    obs = env.reset()
    action = agent.act(obs)
    
    ep_rewards = 0
    while not done:
        next_obs, reward, done, info = env.step(action)

        next_action = agent.act(next_obs)

        agent.update(obs, action, reward, next_obs, next_action)
        
        ep_rewards += reward
        obs = next_obs
        action = next_action
    rewards.append(ep_rewards)
    if (ep+1) % 20 == 0:
        print("episode: {}, rewards: {}".format(ep+1, ep_rewards))
env.close()
 
episode: 20, rewards: 0
episode: 40, rewards: 0
episode: 60, rewards: 0
episode: 80, rewards: 0
episode: 100, rewards: 0
episode: 120, rewards: 0
episode: 140, rewards: 0
episode: 160, rewards: 0
episode: 180, rewards: 0
episode: 200, rewards: 0
(생략)...
episode: 4800, rewards: 0
episode: 4820, rewards: 0
episode: 4840, rewards: 0
episode: 4860, rewards: 0.6579999999999999
episode: 4880, rewards: 0.622
episode: 4900, rewards: 0.784
episode: 4920, rewards: 0.17199999999999993
episode: 4940, rewards: 0
episode: 4960, rewards: 0.262
episode: 4980, rewards: 0.694
episode: 5000, rewards: 0.694
 

학습된 Q값 출력

에이전트가 학습한 각 상태/행동별 Q값을 출력해봅니다.
이 때 dict의 key 수는 상태의 수, 즉 grid의 수가 되고,
각 key 별로 행동 3개의 리스트를 가지게 됩니다.

In [6]:
agent.q_values
Out[6]:
defaultdict(<function __main__.SARSA.__init__.<locals>.<lambda>()>,
            {0: [0.0, 0.0, 0.03517484588889315],
             3: [0.0, 0.0, 0.03282974415750312],
             12: [0.0, 0.0, 0.06023929749704925],
             24: [0.0, 0.0, 0.08754691985883382],
             6: [0.0, 0.0, 0.03703603725724147],
             15: [0.0, 0.0, 0.08661513519226519],
             18: [0.0, 0.0, 0.12158279334285689],
             36: [0.0, 0.0, 0.11922002175307364],
             27: [0.0, 0.0, 0.1947269953539534],
             30: [0.0, 0.0, 0.26623815668244954],
             9: [0.0, 0.0, 0.040482184625553294],
             21: [0.0, 0.0, 0.1771600029058558],
             39: [0.0, 0.0, 0.4153081220438313],
             42: [0.0, 0.0, 0.541164906538886],
             33: [0.0, 0.0, 0.32816956124526775],
             45: [0.0, 0.0, 0.0]})
 

에피소드 시각화

다음은 최종 에피소드의 영상입니다.

In [7]:
show_video()
 
 
 

학습 곡선 시각화

다음은 학습 시 받은 보상의 이동 평균(Moving Average)을 시각화 한 그래프 입니다.

In [8]:
plt.figure(figsize=(16, 8))
plt.plot(pd.Series(rewards).cumsum() / (pd.Series(np.arange(len(rewards)))+1))
 
 
Out[8]:
[<matplotlib.lines.Line2D at 0x7f05afa5b400>]
 
728x90
반응형