강화학습
패키지 설치¶
다음 코드는 세가지 패키지가 선행 되어야 합니다.
sudo apt-get install ffmpeg
pip install gym
pip install gym_minigrid
gym.render() 코드가 에러를 발생할 경우, 다음 패키지를 설치하고:
sudo apt-get install xvfb
주피터 노트북을 다음 명령어를 통해 실행합니다:
xvfb-run -s "-screen 0 1400x900x24" jupyter notebook
import numpy as np
import pandas as pd
import random
from collections import defaultdict
import gym
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
%matplotlib inline
환경¶
예제 코드는 Cart-Pole 예제에서 Expected SARSA 에이전트가 학습하는 코드 입니다.
Cart-Pole이란 막대를 최대한 오래 세워놓는 게임입니다.
Cart-Pole의 상태는 4차원 벡터로 각 값은 다음과 같은 의미를 갖습니다:
- x : Cart의 가로상의 위치
- θ : Pole의 각도
- dx/dt : Cart의 속도
- dθ/dt : θ의 각속도
게임이 끝나는 상황:
- θ가 15˚이상이 되었을 떄,
- 원점으로부터의 x의 거리가 2.4이상이 되었을 떄.
에이전트가 게임을 오래 지속하는 만큼 보상을 받고,
에이전트가 취할 수 있는 행동은 다음 두가지 입니다.
- 왼쪽으로 이동
- 오른쪽으로 이동
Deep SARSA 에이전트¶
Deep SARSA 에이전트 클래스를 만들어 줍니다.
Deep SARSA에이전트는 0.001의 학습율($\alpha$), 0.99의 감가율($\gamma$)을 사용합니다.
Deep SARSA 에이전트는 SARSA를 인공신경망을 사용해 함수 근사를 한 에이전트 입니다.
인공신경망의 레이어 수는 세개, 각각 32개의 노드를 사용합니다.
옵티마이저는 Adam을 사용합니다.
class DeepSARSA:
def __init__(self, num_states, num_actions):
self.num_states = num_states
self.num_actions = num_actions
self.alpha = 0.001
self.gamma = 0.99
self.model = nn.Sequential(
nn.Linear(self.num_states, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions)
)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.alpha)
...
입실론 그리디 정책¶
Deep SARSA 에이전트는 SARSA와 동일하게, 입실론 그리디 정책을 사용합니다.
입실론 그리디는 $\epsilon$값에 의해 랜덤한 행동을,
$1 - \epsilon$값에 의해 탐욕적으로 행동을 선택합니다.
(랜덤한 행동은 np.random.choice, 탐욕적 행동은 np.argmax에 의해 선정)
class DeepSARSA:
...
def act(self, state):
if np.random.rand() < self.epsilon:
action = np.random.choice(self.num_actions)
else:
q_values = self.model(state)
action = torch.argmax(q_values).item()
return action
입실론 감쇄 (Epsilon Decay)¶
입실론($\epsilon$)값은 1.0부터 시작해 0.01까지 감소합니다.
매번 학습(update)을 할때마다, 0.99995씩 곱해지는 방식으로 감소합니다.
class DeepSARSA:
def __init__(self, num_states, num_actions):
...
self.epsilon = 1.
self.epsilon_decay = .99995
self.epsilon_min = 0.01
...
def decrease_epsilon(self):
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
Q값 학습¶
현재상태 (S, state), 현재행동 (A, action), 보상 (R, reward),
다음상태 (S', next_state), 다음행동 (A', next_action)
위 다섯가지 원소를 가지고 시간차(TD)를 학습합니다.
한가지 주의할 점은, 에피소드가 끝나는 시점에서는 미래의 값을 고려하지 않고 학습한다는 점입니다.
class DeepSARSA:
...
def update(self, state, action, reward, next_state, next_action, done):
self.decrease_epsilon()
self.optimizer.zero_grad()
q_value = self.model(state)[action]
next_q_value = self.model(next_state)[next_action].detach()
q_target = reward + (1 - int(done)) * self.gamma * next_q_value
q_error = (q_target - q_value) ** 2
q_error.backward()
self.optimizer.step()
return q_error.item()
...
환경 & 에이전트 초기화¶
환경과 Deep SARSA 에이전트를 초기화 합니다.
import gym
from gym import wrappers
env = gym.make('CartPole-v1')
env = wrappers.Monitor(env, "./video", force=True)
observation = env.reset()
agent = DeepSARSA(4,2)
observation
에피소드 학습¶
500번의 에피소드를 통해 학습합니다.
각 10번의 에피소드마다 리워드 값을 출력합니다.
rewards = []
for ep in range(500):
done = False
obs = torch.FloatTensor(env.reset())
action = agent.act(obs)
ep_rewards = 0
losses = []
while not done:
next_obs, reward, done, info = env.step(action)
next_obs = torch.FloatTensor(next_obs)
next_action = agent.act(next_obs)
loss = agent.update(obs, action, reward, next_obs, next_action, done)
losses.append(loss)
ep_rewards += reward
obs = next_obs
action = next_action
rewards.append(ep_rewards)
ep_loss = sum(losses) / len(losses)
if (ep+1) % 10 == 0:
print("episode: {}, eps: {:.3f}, loss: {:.1f}, rewards: {}".format(ep+1, agent.epsilon, ep_loss, ep_rewards))
env.close()
에피소드 시각화¶
다음은 최종 에피소드의 영상입니다.
from utils import show_video
show_video()
학습 곡선 시각화¶
다음은 학습 시 받은 보상의 이동 평균(Moving Average)을 시각화 한 그래프 입니다.
pd.Series(rewards).to_csv('./logs/rewards_deepsarsa.csv')
plt.figure(figsize=(16, 8))
plt.plot(pd.Series(rewards).cumsum() / (pd.Series(np.arange(len(rewards)))+1))
전체 코드¶
전체 코드는 다음 깃허브에서 확인할 수 있습니다.
참조 사이트¶
'데이터사이언스 > 강화학습' 카테고리의 다른 글
강화학습 - (26) 정책 경사 (0) | 2020.12.16 |
---|---|
강화학습 - (25) 정책 근사 (0) | 2020.12.16 |
강화학습 - (24) 시간차 가치 근사 (0) | 2020.11.21 |
강화학습 - (23) 몬테카를로 가치 근사 (0) | 2020.11.17 |
강화학습 - (22) 가치함수의 근사 (0) | 2020.11.15 |