강화학습
패키지 설치¶
다음 코드는 세가지 패키지가 선행 되어야 합니다.
sudo apt-get install ffmpeg
pip install gym
pip install gym_minigrid
gym.render() 코드가 에러를 발생할 경우, 다음 패키지를 설치하고:
sudo apt-get install xvfb
주피터 노트북을 다음 명령어를 통해 실행합니다:
xvfb-run -s "-screen 0 1400x900x24" jupyter notebook
import warnings; warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import random
import gym
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
%matplotlib inline
환경¶
예제 코드는 그리드월드 예제에서 Deep SARSA 에이전트가 학습하는 코드 입니다.
에이전트가 최종 지점에 도달하면 보상을 받고,
에이전트가 취할 수 있는 행동은 다음 세가지 입니다.
- 왼쪽으로 회전
- 오른쪽으로 회전
- 전진
Deep SARSA 에이전트¶
Deep SARSA 에이전트 클래스를 만들어 줍니다.
Deep SARSA에이전트는 0.001의 학습율($\alpha$), 0.99의 감가율($\gamma$)을 사용합니다.
Deep SARSA 에이전트는 SARSA를 인공신경망을 사용해 함수 근사를 한 에이전트 입니다.
인공신경망의 레이어 수는 세개, 각각 32개의 노드를 사용합니다.
옵티마이저는 Adam을 사용합니다.
class DeepSARSA:
def __init__(self, num_states, num_actions):
self.num_states = num_states
self.num_actions = num_actions
self.alpha = 0.001
self.gamma = 0.99
self.model = nn.Sequential(
nn.Linear(self.num_states, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, self.num_actions)
)
self.optimizer = optim.Adam(self.model.parameters(), lr=self.alpha)
...
입실론 그리디 정책¶
Deep SARSA 에이전트는 SARSA와 동일하게, 입실론 그리디 정책을 사용합니다.
입실론 그리디는 $\epsilon$값에 의해 랜덤한 행동을,
$1 - \epsilon$값에 의해 탐욕적으로 행동을 선택합니다.
(랜덤한 행동은 np.random.choice, 탐욕적 행동은 np.argmax에 의해 선정)
...
def act(self, state):
if np.random.rand() < self.epsilon:
action = np.random.choice(self.actions)
else:
state = self._convert_state(state)
q_values = self.q_values[state]
action = np.argmax(q_values)
return action
...
입실론 감쇄 (Epsilon Decay)¶
입실론($\epsilon$)값은 1.0부터 시작해 0.2까지 감소합니다.
매번 학습(update)을 할때마다, 0.99995씩 곱해지는 방식으로 감소합니다.
여기서 epsilon_min을 0.2로 설정하는 이유는,
비교 분석을 진행할 SARSA 및, Q러닝 에이전트와 동일한 설정을 위해서 입니다.
class DeepSARSA:
def __init__(self, num_states, num_actions):
...
self.epsilon = 1.
self.epsilon_decay = .99995
self.epsilon_min = 0.2
...
def decrease_epsilon(self):
if self.epsilon > self.epsilon_min:
self.epsilon *= self.epsilon_decay
Q값 학습¶
현재상태 (S, state), 현재행동 (A, action), 보상 (R, reward),
다음상태 (S', next_state), 다음행동 (A', next_action)
위 다섯가지 원소를 가지고 시간차(TD)를 학습합니다.
한가지 주의할 점은, 에피소드가 끝나는 시점에서는 미래의 값을 고려하지 않고 학습한다는 점입니다.
class DeepSARSA:
...
def update(self, state, action, reward, next_state, next_action, done):
self.decrease_epsilon()
self.optimizer.zero_grad()
q_value = self.model(state)[action]
next_q_value = self.model(next_state)[next_action].detach()
q_target = reward + (1 - int(done)) * self.gamma * next_q_value
q_error = (q_target - q_value) ** 2
q_error.backward()
self.optimizer.step()
return q_error.item()
...
환경 & 에이전트 초기화¶
환경과 Deep SARSA 에이전트를 초기화 합니다.
env = gen_wrapped_env('MiniGrid-Empty-6x6-v0')
obs = env.reset()
agent = DeepSARSA(obs.shape[0], 3)
에피소드 학습¶
5000번의 에피소드를 통해 학습합니다.
각 10번의 에피소드마다 리워드 값을 출력합니다.
rewards = []
for ep in range(5000):
done = False
obs = torch.FloatTensor(env.reset())
action = agent.act(obs)
ep_rewards = 0
losses = []
while not done:
next_obs, reward, done, info = env.step(action)
next_obs = torch.FloatTensor(next_obs)
next_action = agent.act(next_obs)
loss = agent.update(obs, action, reward, next_obs, next_action, done)
losses.append(loss)
ep_rewards += reward
obs = next_obs
action = next_action
rewards.append(ep_rewards)
ep_loss = sum(losses) / len(losses)
if (ep+1) % 10 == 0:
print("episode: {}, eps: {:.3f}, loss: {:.3f}, rewards: {}".format(ep+1, agent.epsilon, ep_loss, ep_rewards))
env.close()
에피소드 시각화¶
다음은 최종 에피소드의 영상입니다.
show_video()
에피소드 시각화¶
다음은 가장 리워드를 높게 받은 에피소드의 영상입니다.
show_video()
학습 곡선 시각화¶
다음은 학습 시 받은 보상의 이동 평균(Moving Average)을 시각화 한 그래프 입니다.
비교 분석을 위해 SARSA, ExpectedSARSA, 그리고 Q러닝의 결과가 포함되었고,
두 기법 모두5번의 시도 중 가장 좋은 그래프를 시각화 하였습니다.
결과에서 볼 수 있는 것과 같이, DeepSARSA는 SARSA기반의 에이전트들보다는,
학습 속도나 높은 평균 리워드 면에서 더 뛰어난 것을 알 수 있습니다.
하지만 아쉽게도 Q러닝 에이전트 보다는 뒤쳐지는 것을 확인할 수 있습니다.
pd.Series(rewards).to_csv('./logs/rewards_deepsarsa_gridworld.csv')
sarsa_logs = pd.read_csv('./logs/rewards_sarsa.csv', index_col=False).iloc[:, 1]
q_logs = pd.read_csv('./logs/rewards_qlearning.csv', index_col=False).iloc[:, 1]
exp_sarsa_logs = pd.read_csv('./logs/rewards_expectedsarsa.csv', index_col=False).iloc[:, 1]
deepsarsa_logs = pd.read_csv('./logs/rewards_deepsarsa_gridworld.csv', index_col=False).iloc[:, 1]
plt.figure(figsize=(16, 8))
plt.plot(deepsarsa_logs.cumsum() / (pd.Series(np.arange(exp_sarsa_logs.shape[0]))+1), label="DeepSARSA")
plt.plot(q_logs.cumsum() / (pd.Series(np.arange(q_logs.shape[0]))+1), label="QLearning")
plt.plot(sarsa_logs.cumsum() / (pd.Series(np.arange(sarsa_logs.shape[0]))+1), label="SARSA")
plt.plot(exp_sarsa_logs.cumsum() / (pd.Series(np.arange(exp_sarsa_logs.shape[0]))+1), label="ExpectedSARSA")
plt.legend()
전체 코드¶
전체 코드는 다음 깃허브에서 확인할 수 있습니다.