使用Q-Learning解决二十一点问题

agent-environment-diagram agent-environment-diagram

在本教程中,我们将探索并解决 Blackjack-v1 环境。

Blackjack 是最受欢迎的赌场纸牌游戏之一,同时也以其在特定条件下可被击败而臭名昭著。这个版本的游戲使用无限牌组(我们抽牌后放回),所以在我们的模拟游戏中,记牌将不是一个可行的策略。完整的文档可以在 https://gymnasium.farama.org/environments/toy_text/blackjack 找到。

目标:为了获胜,你的牌面总和应大于庄家,但不超过21。

行动:代理可以在两种行动之间选择:
  • 停牌 (0): 玩家不再要牌

  • hit (1): 玩家将获得另一张牌,然而玩家可能会超过21点并爆牌

方法: 要自己解决这个环境,你可以选择你喜欢的离散强化学习算法。提供的解决方案使用了 *Q-学习*(一种无模型的强化学习算法)。

导入和环境设置

# Author: Till Zemann
# License: MIT License

from __future__ import annotations

from collections import defaultdict

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from matplotlib.patches import Patch
from tqdm import tqdm

import gymnasium as gym


# Let's start by creating the blackjack environment.
# Note: We are going to follow the rules from Sutton & Barto.
# Other versions of the game can be found below for you to experiment.

env = gym.make("Blackjack-v1", sab=True)
# Other possible environment configurations are:

env = gym.make('Blackjack-v1', natural=True, sab=False)
# Whether to give an additional reward for starting with a natural blackjack, i.e. starting with an ace and ten (sum is 21).

env = gym.make('Blackjack-v1', natural=False, sab=False)
# Whether to follow the exact rules outlined in the book by Sutton and Barto. If `sab` is `True`, the keyword argument `natural` will be ignored.

观察环境

首先,我们调用 env.reset() 来开始一个回合。这个函数将环境重置到一个起始位置,并返回一个初始的 observation。我们通常还会设置 done = False。这个变量在后面会很有用,用来检查游戏是否终止(即玩家赢或输)。

# reset the environment to get the first observation
done = False
observation, info = env.reset()

# observation = (16, 9, False)

请注意,我们的观察结果是一个由三个值组成的3元组:

  • 玩家当前的总和

  • 庄家明牌的值

  • 玩家是否持有可用的A(如果A计为11而不爆牌,则A是可用的)

执行一个动作

在接收到我们的第一个观察结果后,我们将只使用 env.step(action) 函数与环境进行交互。此函数以一个动作作为输入并在环境中执行它。由于该动作改变了环境的状态,它返回四个有用的变量给我们。这些是:

  • next_state: 这是智能体在采取行动后将接收到的观察结果。

  • reward: 这是代理在采取行动后将获得的奖励。

  • terminated: 这是一个布尔变量,用于指示环境是否已终止。

  • truncated: 这是一个布尔变量,也表示剧集是否因早期截断而结束,即达到了时间限制。

  • info: 这是一个字典,可能包含有关环境的附加信息。

next_staterewardterminatedtruncated 变量是自解释的,但 info 变量需要一些额外的解释。这个变量包含一个字典,可能包含关于环境的额外信息,但在 Blackjack-v1 环境中你可以忽略它。例如,在 Atari 环境中,info 字典有一个 ale.lives 键,告诉我们代理还剩下多少条命。如果代理有 0 条命,那么这一集就结束了。

请注意,在训练循环中调用 env.render() 不是一个好主意,因为渲染会大大减慢训练速度。相反,尝试在训练后构建一个额外的循环来评估和展示代理。

# sample a random action from all valid actions
action = env.action_space.sample()
# action=1

# execute the action in our environment and receive infos from the environment
observation, reward, terminated, truncated, info = env.step(action)

# observation=(24, 10, False)
# reward=-1.0
# terminated=True
# truncated=False
# info={}

一旦 terminated = Truetruncated=True,我们应该停止当前的片段并使用 env.reset() 开始一个新的片段。如果你继续执行动作而不重置环境,它仍然会响应,但输出对训练没有用处(如果代理在无效数据上学习,甚至可能是有害的)。

构建一个代理

让我们构建一个 Q-learning 代理 来解决 Blackjack-v1!我们需要一些函数来选择动作和更新代理的动作值。为了确保代理探索环境,一种可能的解决方案是 epsilon-greedy 策略,其中我们以 epsilon 百分比选择一个随机动作,并以 1 - epsilon 选择贪婪动作(当前被认为最佳的动作)。

class BlackjackAgent:
    def __init__(
        self,
        env,
        learning_rate: float,
        initial_epsilon: float,
        epsilon_decay: float,
        final_epsilon: float,
        discount_factor: float = 0.95,
    ):
        """Initialize a Reinforcement Learning agent with an empty dictionary
        of state-action values (q_values), a learning rate and an epsilon.

        Args:
            learning_rate: The learning rate
            initial_epsilon: The initial epsilon value
            epsilon_decay: The decay for epsilon
            final_epsilon: The final epsilon value
            discount_factor: The discount factor for computing the Q-value
        """
        self.q_values = defaultdict(lambda: np.zeros(env.action_space.n))

        self.lr = learning_rate
        self.discount_factor = discount_factor

        self.epsilon = initial_epsilon
        self.epsilon_decay = epsilon_decay
        self.final_epsilon = final_epsilon

        self.training_error = []

    def get_action(self, env, obs: tuple[int, int, bool]) -> int:
        """
        Returns the best action with probability (1 - epsilon)
        otherwise a random action with probability epsilon to ensure exploration.
        """
        # with probability epsilon return a random action to explore the environment
        if np.random.random() < self.epsilon:
            return env.action_space.sample()

        # with probability (1 - epsilon) act greedily (exploit)
        else:
            return int(np.argmax(self.q_values[obs]))

    def update(
        self,
        obs: tuple[int, int, bool],
        action: int,
        reward: float,
        terminated: bool,
        next_obs: tuple[int, int, bool],
    ):
        """Updates the Q-value of an action."""
        future_q_value = (not terminated) * np.max(self.q_values[next_obs])
        temporal_difference = (
            reward + self.discount_factor * future_q_value - self.q_values[obs][action]
        )

        self.q_values[obs][action] = (
            self.q_values[obs][action] + self.lr * temporal_difference
        )
        self.training_error.append(temporal_difference)

    def decay_epsilon(self):
        self.epsilon = max(self.final_epsilon, self.epsilon - self.epsilon_decay)

为了训练智能体,我们将让智能体一次玩一个回合(一个完整的游戏称为一个回合),然后在每一步之后更新其Q值(游戏中的一个单独动作称为一步)。

代理需要经历许多个回合以充分探索环境。

现在我们应该准备好构建训练循环了。

# hyperparameters
learning_rate = 0.01
n_episodes = 100_000
start_epsilon = 1.0
epsilon_decay = start_epsilon / (n_episodes / 2)  # reduce the exploration over time
final_epsilon = 0.1

agent = BlackjackAgent(
    env=env,
    learning_rate=learning_rate,
    initial_epsilon=start_epsilon,
    epsilon_decay=epsilon_decay,
    final_epsilon=final_epsilon,
)

太好了,让我们开始训练!

信息:当前的超参数设置为快速训练一个不错的代理。如果你想收敛到最优策略,尝试将 n_episodes 增加 10 倍并降低学习率(例如,降至 0.001)。

env = gym.wrappers.RecordEpisodeStatistics(env, deque_size=n_episodes)
for episode in tqdm(range(n_episodes)):
    obs, info = env.reset()
    done = False

    # play one episode
    while not done:
        action = agent.get_action(env, obs)
        next_obs, reward, terminated, truncated, info = env.step(action)

        # update the agent
        agent.update(obs, action, reward, terminated, next_obs)

        # update if the environment is done and the current obs
        done = terminated or truncated
        obs = next_obs

    agent.decay_epsilon()

可视化训练

rolling_length = 500
fig, axs = plt.subplots(ncols=3, figsize=(12, 5))
axs[0].set_title("Episode rewards")
# compute and assign a rolling average of the data to provide a smoother graph
reward_moving_average = (
    np.convolve(
        np.array(env.return_queue).flatten(), np.ones(rolling_length), mode="valid"
    )
    / rolling_length
)
axs[0].plot(range(len(reward_moving_average)), reward_moving_average)
axs[1].set_title("Episode lengths")
length_moving_average = (
    np.convolve(
        np.array(env.length_queue).flatten(), np.ones(rolling_length), mode="same"
    )
    / rolling_length
)
axs[1].plot(range(len(length_moving_average)), length_moving_average)
axs[2].set_title("Training Error")
training_error_moving_average = (
    np.convolve(np.array(agent.training_error), np.ones(rolling_length), mode="same")
    / rolling_length
)
axs[2].plot(range(len(training_error_moving_average)), training_error_moving_average)
plt.tight_layout()
plt.show()
../../_images/blackjack_training_plots.png

可视化策略

def create_grids(agent, usable_ace=False):
    """Create value and policy grid given an agent."""
    # convert our state-action values to state values
    # and build a policy dictionary that maps observations to actions
    state_value = defaultdict(float)
    policy = defaultdict(int)
    for obs, action_values in agent.q_values.items():
        state_value[obs] = float(np.max(action_values))
        policy[obs] = int(np.argmax(action_values))

    player_count, dealer_count = np.meshgrid(
        # players count, dealers face-up card
        np.arange(12, 22),
        np.arange(1, 11),
    )

    # create the value grid for plotting
    value = np.apply_along_axis(
        lambda obs: state_value[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    value_grid = player_count, dealer_count, value

    # create the policy grid for plotting
    policy_grid = np.apply_along_axis(
        lambda obs: policy[(obs[0], obs[1], usable_ace)],
        axis=2,
        arr=np.dstack([player_count, dealer_count]),
    )
    return value_grid, policy_grid


def create_plots(value_grid, policy_grid, title: str):
    """Creates a plot using a value and policy grid."""
    # create a new figure with 2 subplots (left: state values, right: policy)
    player_count, dealer_count, value = value_grid
    fig = plt.figure(figsize=plt.figaspect(0.4))
    fig.suptitle(title, fontsize=16)

    # plot the state values
    ax1 = fig.add_subplot(1, 2, 1, projection="3d")
    ax1.plot_surface(
        player_count,
        dealer_count,
        value,
        rstride=1,
        cstride=1,
        cmap="viridis",
        edgecolor="none",
    )
    plt.xticks(range(12, 22), range(12, 22))
    plt.yticks(range(1, 11), ["A"] + list(range(2, 11)))
    ax1.set_title(f"State values: {title}")
    ax1.set_xlabel("Player sum")
    ax1.set_ylabel("Dealer showing")
    ax1.zaxis.set_rotate_label(False)
    ax1.set_zlabel("Value", fontsize=14, rotation=90)
    ax1.view_init(20, 220)

    # plot the policy
    fig.add_subplot(1, 2, 2)
    ax2 = sns.heatmap(policy_grid, linewidth=0, annot=True, cmap="Accent_r", cbar=False)
    ax2.set_title(f"Policy: {title}")
    ax2.set_xlabel("Player sum")
    ax2.set_ylabel("Dealer showing")
    ax2.set_xticklabels(range(12, 22))
    ax2.set_yticklabels(["A"] + list(range(2, 11)), fontsize=12)

    # add a legend
    legend_elements = [
        Patch(facecolor="lightgreen", edgecolor="black", label="Hit"),
        Patch(facecolor="grey", edgecolor="black", label="Stick"),
    ]
    ax2.legend(handles=legend_elements, bbox_to_anchor=(1.3, 1))
    return fig


# state values & policy with usable ace (ace counts as 11)
value_grid, policy_grid = create_grids(agent, usable_ace=True)
fig1 = create_plots(value_grid, policy_grid, title="With usable ace")
plt.show()
../../_images/blackjack_with_usable_ace.png
# state values & policy without usable ace (ace counts as 1)
value_grid, policy_grid = create_grids(agent, usable_ace=False)
fig2 = create_plots(value_grid, policy_grid, title="Without usable ace")
plt.show()
../../_images/blackjack_without_usable_ace.png

在脚本末尾调用 env.close() 是一个好的做法,这样环境使用的任何资源都会被关闭。

你觉得自己能做得更好吗?

# You can visualize the environment using the play function
# and try to win a few games.

希望这个教程帮助你掌握了如何与 OpenAI-Gym 环境交互,并为你开启了解决更多强化学习挑战的旅程。

建议你自己解决这个环境(基于项目的学习非常有效!)。你可以应用你喜欢的离散强化学习算法,或者尝试蒙特卡洛ES(在 Sutton & Barto 的第5.3节中有介绍)——这样你可以直接将你的结果与书中的进行比较。

祝你玩得开心!