Reinforcement Learning for Board Games

Reinforcement Learning for Board Games

RL for board games is a project under the McGill AI Lab, dedicated to developing and advancing reinforcement learning (RL) models aimed at mastering board games. By tackling state representation, action space modeling, and multi-agent RL, the team aims to create RL agents capable learning to play common board games. All the source code for the project can be found in the GitHub repository epicgamer17/rl-research

Project Focus

The project's central goal is to create RL agents that can master both simple and complex board games. While early successes include solving classic control environments like Cartpole and Mountain Car, the primary focus now is on board game environments such as Tic-Tac-Toe, Checkers, Connect Four, with future ambitions include creating RL bots to pioneer solutions for more complex games like Scrabble, Risk, Monopoly, and Catan.

Key challenges involve designing efficient state representations, modeling large action spaces, and training multi-agent systems. The team actively reads and recreates RL papers, encouraging collaboration and iterative learning, with an emphasis on adapting and expanding advanced models such as AlphaZero, MuZero, and NFSP to these board game challenges.

Implemented Algorithms

The repository includes a range of RL algorithms, such as DQN-based architectures (Double DQN, Dueling DQN, Rainbow DQN) and Actor-Critic methods like A2C and PPO. More advanced models, including AlphaZero, MuZero, and NFSP, provide a foundation for tackling complex strategy games. The modular design allows for experimentation with different configurations, fostering innovation and learning.

Custom Environments

To support RL agent development, the project provides custom OpenAI Gym environments tailored to board games:

  • Tic Tac Toe: A classic two-player game for testing basic RL models.
  • Connect 4: A strategic game that introduces complexity and planning depth.

Furthermore, to investigate imperfect information and multi-agent settings, we have implementations of the following environments

  • LeDuc Hold'em: A simplified poker variant for studying imperfect information games.
  • Mississippi Marbles: A custom environment with unique dynamics and strategies.

These environments allow for iterative training and evaluation, ensuring that agents can generalize across different strategic challenges.

Example: RainbowDQN on CartPole, training episodes 4, 79, and 154

Example: RainbowDQN on Acrobot, training episodes 4, 79, and 154

Example: RainbowDQN on LunarLander, training episodes 4, 79, and 154

Example: Train RainbowDQN on Pac-Man

import gymnasium as gym
import sys
import random
from collections import defaultdict
import copy
import math
from operator import itemgetter
import os
import matplotlib
from torch.optim.sgd import SGD
from torch.optim.adam import Adam
matplotlib.use("Agg")
from matplotlib import pyplot as plt
import scipy
import pickle
from typing import Iterable, Tuple
from datetime import datetime
import torch
from torch import nn, Tensor
import numpy as np
import numpy.typing as npt
import itertools
from hyperopt import space_eval
import pandas as pd
# from ....replay_buffers.base_replay_buffer import Game
# from replay_buffers.segment_tree import SumSegmentTree
def normalize_policies(policies: torch.float32):
    # print(policies)
    policy_sums = policies.sum(axis=-1, keepdims=True)
    # print(policy_sums)
    policies = policies / policy_sums
    return policies
def action_mask(
    actions: Tensor, legal_moves, mask_value: float = 0, device="cpu"
) -> Tensor:
    """
    Mask actions that are not legal moves
    actions: Tensor, probabilities of actions or q-values
    """
    assert isinstance(
        legal_moves, list
    ), "Legal moves should be a list got {} of type {}".format(
        legal_moves, type(legal_moves)
    )
    # add a dimension if the legal moves are not a list of lists
    # if len(legal_moves) != actions.shape[0]:
    #     legal_moves = [legal_moves]
    assert (
        len(legal_moves) == actions.shape[0]
    ), "Legal moves should be the same length as the batch size"
    mask = torch.zeros_like(actions, dtype=torch.bool).to(device)
    for i, legal in enumerate(legal_moves):
        mask[i, legal] = True
    # print(mask)
    # print(actions)
    # actions[mask == 0] = mask_value
    actions = torch.where(mask, actions, torch.tensor(mask_value).to(device)).to(device)
    # print(mask)
    return actions
def clip_low_prob_actions(actions: Tensor, low_prob: float = 0.01) -> Tensor:
    """
    Clip actions with probability lower than low_prob to 0
    actions: Tensor, probabilities of actions
    """
    # print("Actions in low prob func", actions)
    if low_prob == 0:
        return actions
    mask = actions < low_prob
    # print("Mask", mask)
    actions = torch.where(mask, 0.0, actions)
    # print("Actions after clipping", actions)
    return actions
def get_legal_moves(info: dict | list[dict]):
    # print(info)
    if isinstance(info, dict):
        return [info["legal_moves"] if "legal_moves" in info else None]
    else:
        return [(i["legal_moves"] if "legal_moves" in i else None) for i in info]
def normalize_images(image: Tensor) -> Tensor:
    """Preprocessing step to normalize image with 8-bit (0-255) color inplace.
    Modifys the original tensor
    Args:
        image (Tensor): An 8-bit color image
    Returns:
        Tensor: The tensor divided by 255
    """
    # Return a copy of the tensor divided by 255
    normalized_image = image.div_(255)
    return normalized_image
def make_stack(item: Tensor) -> Tensor:
    """Convert a tensor of shape (*) to (1, *). Does not copy the data; instead,
    returns a view of the original tensor.
    Args:
        item (Tensor):
    Returns:
        Tensor: A view of the original tensor.
    """
    #
    return item.view(1, *item.shape)
def update_per_beta(
    per_beta: float, per_beta_final: float, per_beta_steps: int, initial_per_beta: int
):
    # could also use an initial per_beta instead of current (multiply below equation by current step)
    if per_beta < per_beta_final:
        clamp_func = min
    else:
        clamp_func = max
    per_beta = clamp_func(
        per_beta_final,
        per_beta + (per_beta_final - initial_per_beta) / (per_beta_steps),
    )
    return per_beta
def update_linear_schedule(
    final_value: float,
    total_steps: int,
    initial_value: float,
    current_step: int,
):
    # learning_rate = initial_value
    if initial_value < final_value:
        clamp_func = min
    else:
        clamp_func = max
    value = clamp_func(
        final_value,
        initial_value + ((final_value - initial_value) * (current_step / total_steps)),
    )
    return value
def update_inverse_sqrt_schedule(
    initial_value: float = None,
    current_step: int = None,
):
    return initial_value / math.sqrt(current_step + 1)
def default_plot_func(
    axs, key: str, values: list[dict], targets: dict, row: int, col: int
):
    axs[row][col].set_title("{} | rolling average: {}".format(key, np.mean(values[-5:])))
    x = np.arange(1, len(values) + 1)
    axs[row][col].plot(x, values)
    if key in targets and targets[key] is not None:
        axs[row][col].axhline(y=targets[key], color="r", linestyle="--")
def plot_scores(axs, key: str, values: list[dict], targets: dict, row: int, col: int):
    if len(values) == 0:
        return
    print(values)
    scores = [value["score"] for value in values]
    x = np.arange(1, len(values) + 1)
    axs[row][col].plot(x, scores)
    has_max_scores = "max_score" in values[0]
    has_min_scores = "min_score" in values[0]
    assert (
        has_max_scores == has_min_scores
    ), "Both max_scores and min_scores must be provided or not provided"
    if has_max_scores:
        max_scores = [value["max_score"] for value in values]
        min_scores = [value["min_score"] for value in values]
        axs[row][col].fill_between(x, min_scores, max_scores, alpha=0.5)
    has_target_model_updates = "target_model_updated" in values[0]
    has_model_updates = "model_updated" in values[0]
    if has_target_model_updates:
        weight_updates = [value["target_model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="black",
                    linestyle="dotted",
                    # label="Target Model Weight Update",
                )
    if has_model_updates:
        weight_updates = [value["model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="gray",
                    linestyle="dotted",
                    # label="Model Weight Update",
                )
    axs[row][col].set_title(
        f"{key} | rolling average: {np.mean(scores[-5:])} | latest: {scores[-1]}"
    )
    axs[row][col].set_xlabel("Game")
    axs[row][col].set_ylabel("Score")
    axs[row][col].set_xlim(1, len(values))
    if len(scores) > 1:
        best_fit_x, best_fit_y = np.polyfit(x, scores, 1)
        axs[row][col].plot(
            x,
            best_fit_x * x + best_fit_y,
            color="g",
            label="Best Fit Line",
            linestyle="dotted",
        )
    if key in targets and targets[key] is not None:
        axs[row][col].axhline(
            y=targets[key],
            color="r",
            linestyle="dashed",
            label="Target Score: {}".format(targets[key]),
        )
    axs[row][col].legend()
def plot_loss(axs, key: str, values: list[dict], targets: dict, row: int, col: int):
    loss = [value["loss"] for value in values]
    x = np.arange(1, len(values) + 1)
    axs[row][col].plot(x, loss)
    has_target_model_updates = "target_model_updated" in values[0]
    has_model_updates = "model_updated" in values[0]
    if has_target_model_updates:
        weight_updates = [value["target_model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="black",
                    linestyle="dotted",
                    # label="Target Model Weight Update",
                )
    if has_model_updates:
        weight_updates = [value["model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="gray",
                    linestyle="dotted",
                    # label="Model Weight Update",
                )
    axs[row][col].set_title(
        f"{key} | rolling average: {np.mean(loss[-5:])} | latest: {loss[-1]}"
    )
    axs[row][col].set_xlabel("Time Step")
    axs[row][col].set_ylabel("Loss")
    axs[row][col].set_xlim(1, len(values))
    if key in targets and targets[key] is not None:
        axs[row][col].axhline(
            y=targets[key],
            color="r",
            linestyle="dashed",
            label="Target Score: {}".format(targets[key]),
        )
    axs[row][col].legend()
def plot_exploitability(
    axs, key: str, values: list[dict], targets: dict, row: int, col: int
):
    if len(values) == 0:
        return
    exploitability = [abs(value["exploitability"]) for value in values]
    print(values)
    rolling_averages = [
        np.mean(exploitability[max(0, i - 5) : i])
        for i in range(1, len(exploitability) + 1)
    ]
    # print(rolling_averages)
    x = np.arange(1, len(values) + 1)
    axs[row][col].plot(x, rolling_averages)
    axs[row][col].plot(x, exploitability)
    has_target_model_updates = "target_model_updated" in values[0]
    has_model_updates = "model_updated" in values[0]
    if has_target_model_updates:
        weight_updates = [value["target_model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="black",
                    linestyle="dotted",
                    # label="Target Model Weight Update",
                )
    if has_model_updates:
        weight_updates = [value["model_updated"] for value in values]
        for i, weight_update in enumerate(weight_updates):
            if weight_update:
                axs[row][col].axvline(
                    x=i,
                    color="gray",
                    linestyle="dotted",
                    # label="Model Weight Update",
                )
    if len(rolling_averages) > 1:
        best_fit_x, best_fit_y = np.polyfit(x, rolling_averages, 1)
        axs[row][col].plot(
            x,
            best_fit_x * x + best_fit_y,
            color="g",
            label="Best Fit Line",
            linestyle="dotted",
        )
    axs[row][col].set_title(
        f"{key} | rolling average: {np.mean(exploitability[-5:])} | latest: {exploitability[-1]}"
    )
    axs[row][col].set_xlabel("Game")
    axs[row][col].set_ylabel("Exploitability (rolling average)")
    axs[row][col].set_xscale("log")
    axs[row][col].set_yscale("log")
    axs[row][col].set_xlim(1, len(values))
    # axs[row][col].set_ylim(0.01, 10)
    # axs[row][col].set_ylim(
    #     -(10 ** math.ceil(math.log10(abs(min_exploitability)))),
    #     10 ** math.ceil(math.log10(max_exploitability)),
    # )
    # axs[row][col].set_yticks(
    #     [
    #         -(10**i)
    #         for i in range(
    #             math.ceil(math.log10(abs(min_exploitability))),
    #             math.floor(math.log10(abs(min_exploitability))) - 1,
    #             -1,
    #         )
    #         if -(10**i) < min_exploitability
    #     ]
    #     + [0]
    #     + [
    #         10**i
    #         for i in range(
    #             math.ceil(math.log10(max_exploitability)),
    #             math.floor(math.log10(max_exploitability)) + 1,
    #         )
    #         if 10**i > max_exploitability
    #     ]
    # )
    if key in targets and targets[key] is not None:
        axs[row][col].axhline(
            y=targets[key],
            color="r",
            linestyle="dashed",
            label="Target Exploitability: {}".format(targets[key]),
        )
    axs[row][col].legend()
def plot_trials(scores: list, file_name: str, final_trial: int = 0):
    fig, axs = plt.subplots(
        1,
        1,
        figsize=(10, 5),
        squeeze=False,
    )
    if final_trial > 0:
        x = np.arange(1, final_trial + 1)
        scores = scores[:final_trial]
    else:
        x = np.arange(1, len(scores) + 1)
    axs[0][0].scatter(x, scores)
    best_fit_x, best_fit_y = np.polyfit(x, scores, 1)
    axs[0][0].plot(
        x,
        best_fit_x * x + best_fit_y,
        color="g",
        label="Best Fit Line",
        linestyle="dotted",
    )
    fig.suptitle("Score of Hyperopt trials over time for Rainbow DQN on CartPole-v1")
    axs[0][0].set_xlabel("Trial")
    axs[0][0].set_ylabel("Score")
    plt.savefig(f"./graphs/{file_name}.png")
    plt.show()
    plt.close(fig)
stat_keys_to_plot_funcs = {
    "test_score": plot_scores,
    "score": plot_scores,
    "policy_loss": plot_loss,
    "value_loss": plot_loss,
    "l2_loss": plot_loss,
    "loss": plot_loss,
    "rl_loss": plot_loss,
    "sl_loss": plot_loss,
    "exploitability": plot_exploitability,  # should this be plot_scores?
}
def plot_graphs(
    stats: dict,
    targets: dict,
    step: int,
    frames_seen: int,
    time_taken: float,
    model_name: str,
    dir: str = "./checkpoints/graphs",
):
    num_plots = len(stats)
    sqrt_num_plots = math.ceil(np.sqrt(num_plots))
    fig, axs = plt.subplots(
        sqrt_num_plots,
        sqrt_num_plots,
        figsize=(10 * sqrt_num_plots, 5 * sqrt_num_plots),
        squeeze=False,
    )
    hours = int(time_taken // 3600)
    minutes = int((time_taken % 3600) // 60)
    seconds = int(time_taken % 60)
    fig.suptitle(
        "training stats | training step {} | frames seen {} | time taken {} hours {} minutes {} seconds".format(
            step, frames_seen, hours, minutes, seconds
        )
    )
    for i, (key, values) in enumerate(stats.items()):
        row = i // sqrt_num_plots
        col = i % sqrt_num_plots
        if key in stat_keys_to_plot_funcs:
            stat_keys_to_plot_funcs[key](axs, key, values, targets, row, col)
        else:
            default_plot_func(axs, key, values, targets, row, col)
    for i in range(num_plots, sqrt_num_plots**2):
        row = i // sqrt_num_plots
        col = i % sqrt_num_plots
        fig.delaxes(axs[row][col])
    # plt.show()
    assert os.path.exists(dir), f"Directory {dir} does not exist"
    plt.savefig("{}/{}.png".format(dir, model_name))
    plt.close(fig)
def plot_comparisons(
    stats: list[dict],
    model_name: str,
    dir: str = "./checkpoints/graphs",
):
    num_plots = len(stats[0])
    sqrt_num_plots = math.ceil(np.sqrt(num_plots))
    fig, axs = plt.subplots(
        sqrt_num_plots,
        sqrt_num_plots,
        figsize=(10 * sqrt_num_plots, 5 * sqrt_num_plots),
        squeeze=False,
    )
    fig.suptitle("Comparison of training stats")
    for i, (key, _) in enumerate(stats[0].items()):
        row = i // sqrt_num_plots
        col = i % sqrt_num_plots
        # max_value = float("-inf")
        # min_value = float("inf")
        max_len = 0
        for s in stats:
            values = s[key]
            # print(values)
            max_len = max(max_len, len(values))
            print(max_len)
            # max_value = max(max_value, max(values))
            # min_value = min(min_value, min(values))
            if key in stat_keys_to_plot_funcs:
                stat_keys_to_plot_funcs[key](axs, key, values, {}, row, col)
                axs[row][col].set_xlim(0, max_len)
            else:
                default_plot_func(axs, key, values, {}, row, col)
        # axs[row][col].set_ylim(min_value, max_value)
    for i in range(num_plots, sqrt_num_plots**2):
        row = i // sqrt_num_plots
        col = i % sqrt_num_plots
        fig.delaxes(axs[row][col])
    # plt.show()
    os.makedirs(dir, exist_ok=True)
    plt.savefig("{}/{}.png".format(dir, model_name))
    plt.close(fig)
def prepare_kernel_initializers(kernel_initializer: str, output_layer: bool = False):
    if kernel_initializer == "pytorch_default":
        return None
    if kernel_initializer == "glorot_uniform":
        return nn.init.xavier_uniform_
    elif kernel_initializer == "glorot_normal":
        return nn.init.xavier_normal_
    elif kernel_initializer == "he_uniform":
        return nn.init.kaiming_uniform_
    elif kernel_initializer == "he_normal":
        return nn.init.kaiming_normal_
    elif kernel_initializer == "variance_baseline":
        return VarianceScaling()
    elif kernel_initializer == "variance_0.1":
        return VarianceScaling(scale=0.1)
    elif kernel_initializer == "variance_0.3":
        return VarianceScaling(scale=0.3)
    elif kernel_initializer == "variance_0.8":
        return VarianceScaling(scale=0.8)
    elif kernel_initializer == "variance_3":
        return VarianceScaling(scale=3)
    elif kernel_initializer == "variance_5":
        return VarianceScaling(scale=5)
    elif kernel_initializer == "variance_10":
        return VarianceScaling(scale=10)
    # TODO
    # elif kernel_initializer == "lecun_uniform":
    #     return LecunUniform(seed=np.random.seed())
    # elif kernel_initializer == "lecun_normal":
    #     return LecunNormal(seed=np.random.seed())
    elif kernel_initializer == "orthogonal":
        return nn.init.orthogonal_
    raise ValueError(f"Invalid kernel initializer: {kernel_initializer}")
def prepare_activations(activation: str):
    # print("Activation to prase: ", activation)
    if activation == "linear":
        return nn.Identity()
    elif activation == "relu":
        return nn.ReLU()
    elif activation == "relu6":
        return nn.ReLU6()
    elif activation == "sigmoid":
        return nn.Sigmoid()
    elif activation == "softplus":
        return nn.Softplus()
    elif activation == "soft_sign":
        return nn.Softsign()
    elif activation == "silu" or activation == "swish":
        return nn.SiLU()
    elif activation == "tanh":
        return nn.Tanh()
    # elif activation == "log_sigmoid":
    #     return nn.LogSigmoid()
    elif activation == "hard_sigmoid":
        return nn.Hardsigmoid()
    # elif activation == "hard_silu" or activation == "hard_swish":
    #     return nn.Hardswish()
    # elif activation == "hard_tanh":
    #     return nn.Hardtanh()
    elif activation == "elu":
        return nn.ELU()
    # elif activation == "celu":
    #     return nn.CELU()
    elif activation == "selu":
        return nn.SELU()
    elif activation == "gelu":
        return nn.GELU()
    # elif activation == "glu":
    #     return nn.GLU()
    raise ValueError(f"Activation {activation} not recognized")
def epsilon_greedy_policy(
    q_values: list[float], info: dict, epsilon: float, wrapper=np.argmax
):
    if np.random.rand() < epsilon:
        # print("selecting a random move")
        if "legal_moves" in info:
            # print("using legal moves")
            return random.choice(info["legal_moves"])
        else:
            q_values = q_values.reshape(-1)
            return random.choice(range(len(q_values)))
    else:
        # try:
        # print("using provided wrapper to select action")
        return wrapper(q_values, info)
    # except:
    #     return wrapper(q_values)
def add_dirichlet_noise(
    policy: list[float], dirichlet_alpha: float, exploration_fraction: float
):
    # MAKE ALPHAZERO USE THIS
    noise = np.random.dirichlet([dirichlet_alpha] * len(policy))
    frac = exploration_fraction
    for i, n in enumerate(noise):
        policy[i] = (1 - frac) * policy[i] + frac * n
    return policy
def augment_game(game, flip_y: bool = False, flip_x: bool = False, rot90: bool = False):
    # augmented_games[0] = rotate 90
    # augmented_games[1] = rotate 180
    # augmented_games[2] = rotate 270
    # augmented_games[3] = flip y (rotate 180 and flip x)
    # augmented_games[4] = rotate 90 and flip y (rotate 270 and flip x)
    # augmented_games[5] = rotate 180 and flip y (flip x)
    # augmented_games[6] = flip y and rotate 90 (rotate 270 and flip y) (rotate 90 and flip x)
    # augmented_games[7] = normal
    if (rot90 and flip_y) or (rot90 and flip_x):
        augemented_games = [copy.deepcopy(game) for _ in range(7)]
        for i in range(len(game.observation_history)):
            board = game.observation_history[i]
            policy = game.policy_history[i]
            augemented_games[0].observation_history[i] = np.rot90(board)
            augemented_games[0].policy_history[i] = np.rot90(policy)
            augemented_games[1].observation_history[i] = np.rot90(np.rot90(board))
            augemented_games[1].policy_history[i] = np.rot90(np.rot90(policy))
            augemented_games[2].observation_history[i] = np.rot90(
                np.rot90(np.rot90(board))
            )
            augemented_games[2].policy_history[i] = np.rot90(np.rot90(np.rot90(policy)))
            augemented_games[3].observation_history[i] = np.flipud(board)
            augemented_games[3].policy_history[i] = np.flipud(policy)
            augemented_games[4].observation_history[i] = np.flipud(np.rot90(board))
            augemented_games[4].policy_history[i] = np.flipud(np.rot90(policy))
            augemented_games[5].observation_history[i] = np.flipud(
                np.rot90(np.rot90(board))
            )
            augemented_games[5].policy_history[i] = np.flipud(np.rot90(np.rot90(policy)))
            augemented_games[6].observation_history[i] = np.rot90(np.flipud(board))
            augemented_games[6].policy_history[i] = np.rot90(np.flipud(policy))
    elif rot90 and not flip_y and not flip_x:
        augemented_games = [copy.deepcopy(game) for _ in range(3)]
        augemented_games[0].observation_history = [
            np.rot90(board) for board in game.observation_history
        ]
        augemented_games[0].policy_history = [
            np.rot90(policy) for policy in game.policy_history
        ]
        augemented_games[1].observation_history = [
            np.rot90(np.rot90(board)) for board in game.observation_history
        ]
        augemented_games[1].policy_history = [
            np.rot90(np.rot90(policy)) for policy in game.policy_history
        ]
        augemented_games[2].observation_history = [
            np.rot90(np.rot90(np.rot90(board))) for board in game.observation_history
        ]
        augemented_games[2].policy_history = [
            np.rot90(np.rot90(np.rot90(policy)) for policy in game.policy_history)
        ]
    elif flip_y and not rot90 and not flip_x:
        augemented_games = [copy.deepcopy(game)]
        augemented_games[0].observation_history = [
            np.flipud(board) for board in game.observation_history
        ]
        augemented_games[0].policy_history = [
            np.flipud(policy) for policy in game.policy_history
        ]
    elif flip_x and not rot90 and not flip_y:
        augemented_games = [copy.deepcopy(game) for _ in range(1)]
        augemented_games[0].observation_history = [
            np.fliplr(board) for board in game.observation_history
        ]
        augemented_games[0].policy_history = [
            np.fliplr(policy) for policy in game.policy_history
        ]
    augemented_games.append(game)
    return augemented_games
def augment_board(
    board, policy, flip_y: bool = False, flip_x: bool = False, rot90: bool = False
):
    if (rot90 and flip_y) or (rot90 and flip_x):
        augemented_boards = [copy.deepcopy(board) for _ in range(7)]
        augmented_policies = [copy.deepcopy(policy) for _ in range(7)]
        augemented_boards[0] = np.rot90(board)
        augmented_policies[0] = np.rot90(policy)
        augemented_boards[1] = np.rot90(np.rot90(board))
        augmented_policies[1] = np.rot90(np.rot90(policy))
        augemented_boards[2] = np.rot90(np.rot90(np.rot90(board)))
        augmented_policies[2] = np.rot90(np.rot90(np.rot90(policy)))
        augemented_boards[3] = np.flipud(board)
        augmented_policies[3] = np.flipud(policy)
        augemented_boards[4] = np.flipud(np.rot90(board))
        augmented_policies[4] = np.flipud(np.rot90(policy))
        augemented_boards[5] = np.flipud(np.rot90(np.rot90(board)))
        augmented_policies[5] = np.flipud(np.rot90(np.rot90(policy)))
        augemented_boards[6] = np.rot90(np.flipud(board))
        augmented_policies[6] = np.rot90(np.flipud(policy))
    elif rot90 and not flip_y and not flip_x:
        augemented_boards = [copy.deepcopy(board) for _ in range(3)]
        augmented_policies = [copy.deepcopy(policy) for _ in range(3)]
        augemented_boards[0] = np.rot90(board)
        augmented_policies[0] = np.rot90(policy)
        augemented_boards[1] = np.rot90(np.rot90(board))
        augmented_policies[1] = np.rot90(np.rot90(policy))
        augemented_boards[2] = np.rot90(np.rot90(np.rot90(board)))
        augmented_policies[2] = np.rot90(np.rot90(np.rot90(policy)))
    elif flip_y and not rot90 and not flip_x:
        augemented_boards = [copy.deepcopy(board)]
        augmented_policies = [copy.deepcopy(policy)]
        augemented_boards[0] = np.flipud(board)
        augmented_policies[0] = np.flipud(policy)
    elif flip_x and not rot90 and not flip_y:
        augemented_boards = [copy.deepcopy(board)]
        augmented_policies = [copy.deepcopy(policy)]
        augemented_boards[0] = np.fliplr(board)
        augmented_policies[0] = np.fliplr(policy)
    augemented_boards.append(board)
    augmented_policies.append(policy)
    return augemented_boards, augmented_policies
def sample_by_random_indices(
    max_index_or_1darray, batch_size: int, with_replacement=False
) -> npt.NDArray[np.int64]:
    """
    Sample from a numpy array using indices
    """
    return np.random.choice(max_index_or_1darray, batch_size, replace=with_replacement)
def sample_by_indices_probability(
    max_index_or_1darray, batch_size: int, probabilities: npt.NDArray[np.float64]
) -> npt.NDArray[np.int64]:
    """
    Sample from a numpy array using indices
    """
    return np.random.choice(max_index_or_1darray, batch_size, p=probabilities)
def sample_tree_proportional(
    tree, batch_size: int, max_size: int
) -> npt.NDArray[np.int64]:
    """
    tree: SumSegmentTree
    Sample proportionally from a sum segment tree. Used in prioritized experience replay
    """
    indices = np.zeros(batch_size, dtype=np.int64)
    total_priority = tree.sum(0, max_size - 1)
    priority_segment = total_priority / batch_size
    for i in range(batch_size):
        l = priority_segment * i
        h = priority_segment * (i + 1)
        upperbound = np.random.uniform(l, h)
        indices[i] = tree.retrieve(upperbound)
        # print(tree[indices[i]])
    return indices
def reward_clipping(reward: float, lower_bound: float = -1, upper_bound: float = 1):
    if reward < lower_bound:
        return lower_bound
    elif reward > upper_bound:
        return upper_bound
    return reward
def discounted_cumulative_sums(x, discount):
    # Discounted cumulative sums of vectors for computing rewards-to-go and advantage estimates
    return scipy.signal.lfilter([1], [1, float(-discount)], x[::-1], axis=0)[::-1]
def to_lists(l: list[Iterable]) -> list[Tuple]:
    """Convert a list of iterables to a zip of tuples
    Args:
        list (list[Iterable]): A list of iterables, e.g. [(1,1,1),(2,2,2),(3,3,3)]
    Returns:
        list[Tuple]: A list of tuples, i.e. [(1,2,3), (1,2,3), (1,2,3)]
    """
    return list(zip(*l))
def current_timestamp():
    return datetime.now().timestamp()
_epsilon = 1e-7
def categorical_crossentropy(predicted: torch.Tensor, target: torch.Tensor, axis=-1):
    # print(predicted)
    predicted = predicted / torch.sum(predicted, dim=axis, keepdim=True)
    # print(predicted)
    predicted = torch.clamp(predicted, _epsilon, 1.0 - _epsilon)
    # print(predicted)
    log_prob = torch.log(predicted)
    return -torch.sum(log_prob * target, axis=axis)
class CategoricalCrossentropyLoss:
    def __init__(self, from_logits=False, axis=-1):
        self.from_logits = from_logits
        self.axis = axis
    def __call__(self, predicted, target):
        return categorical_crossentropy(predicted, target, self.axis)
def kl_divergence(predicted: torch.Tensor, target: torch.Tensor, axis=-1):
    predicted = predicted / torch.sum(predicted, dim=axis, keepdim=True)
    predicted = torch.clamp(predicted, _epsilon, 1.0)
    target = torch.clamp(target, _epsilon, 1.0)
    return torch.sum(target * torch.log(target / predicted), axis=axis)
class KLDivergenceLoss:
    def __init__(self, from_logits=False, axis=-1):
        self.from_logits = from_logits
        self.axis = axis
    def __call__(self, predicted, target):
        return kl_divergence(predicted, target, self.axis)
def huber(predicted: torch.Tensor, target: torch.Tensor, axis=-1, delta: float = 1.0):
    diff = torch.abs(predicted - target)
    return torch.where(diff < delta, 0.5 * diff**2, delta * (diff - 0.5 * delta)).view(-1)
class HuberLoss:
    def __init__(self, axis=-1, delta: float = 1.0):
        self.axis = axis
        self.delta = delta
    def __call__(self, predicted, target):
        return huber(predicted, target, axis=self.axis, delta=self.delta)
def mse(predicted: torch.Tensor, target: torch.Tensor):
    # print(predicted)
    # print(target)
    return (predicted - target) ** 2
class MSELoss:
    def __init__(self):
        pass
    def __call__(self, predicted, target):
        return mse(predicted, target)
from typing import Callable
Loss = Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
def calculate_padding(i: int, k: int, s: int) -> Tuple[int, int]:
    """Calculate both padding sizes along 1 dimension for a given input length, kernel length, and stride
    Args:
        i (int): input length
        k (int): kernel length
        s (int): stride
    Returns:
        (p_1, p_2): where p_1 = p_2 - 1 for uneven padding and p_1 == p_2 for even padding
    """
    p = (i - 1) * s - i + k
    p_1 = p // 2
    p_2 = (p + 1) // 2
    return (p_1, p_2)
def generate_layer_widths(widths: list[int], max_num_layers: int) -> list[Tuple[int]]:
    """Create all possible combinations of widths for a given number of layers"""
    width_combinations = []
    for i in range(0, max_num_layers):
        width_combinations.extend(itertools.combinations_with_replacement(widths, i))
    return width_combinations
def hyperopt_analysis(
    data_dir: str,
    file_name: str,
    viable_trial_threshold: int,
    step: int,
    final_trial: int = 0,
    eval_method: str = "final_score",
):
    trials = pickle.load(open(f"{data_dir}/{file_name}.p", "rb"))
    if final_trial > 0:
        print("Number of trials: {}".format(final_trial))
    else:
        print("Number of trials: {}".format(len(trials.trials)))
    # losses.sort()
    # print(len(os.listdir(f"{data_dir}/checkpoints")) - 1)
    # print(len(trials.trials))
    checkpoints = os.listdir(f"{data_dir}/checkpoints")
    checkpoints.remove("videos") if "videos" in checkpoints else None
    checkpoints.remove(".DS_Store") if ".DS_Store" in checkpoints else None
    checkpoints.sort(key=lambda x: int(x.split("_")[-1]))
    if final_trial > 0:
        checkpoints = checkpoints[:final_trial]
    viable_throughout_trials = []
    final_rolling_averages = []
    final_std_devs = []
    scores = []
    losses = []
    failed_trials = 0
    for i, trial in enumerate(trials.trials):
        losses.append(trial["result"]["loss"])
        if final_trial > 0 and i >= final_trial:
            break
        # print(trial["result"]["status"])
        if trial["result"]["status"] == "fail":
            failed_trials += 1
            final_rolling_averages.append(trial["result"]["loss"])
            scores.append(trial["result"]["loss"])
            final_std_devs.append(trial["result"]["loss"])
        else:
            # print(checkpoints[i - failed_trials])
            # print(failed_trials)
            # if os.path.exists(
            #     f"{data_dir}/checkpoints/{checkpoints[i - failed_trials]}/step_{step}/graphs_stats/stats.pkl"
            # ):
            stats = pickle.load(
                open(
                    f"{data_dir}/checkpoints/{checkpoints[i - failed_trials]}/step_{step}/graphs_stats/stats.pkl",
                    "rb",
                )
            )
            max_score = 0
            # print([stat_dict["score"] for stat_dict in stats["test_score"][-5:]])
            final_rolling_averages.append(
                np.around(
                    np.mean(
                        [stat_dict["score"] for stat_dict in stats["test_score"][-5:]]
                    ),
                    1,
                )
            )
            final_std_devs.append(
                np.around(
                    np.std(
                        [stat_dict["score"] for stat_dict in stats["test_score"][-5:]]
                    ),
                    1,
                )
            )
            for stat_dict in stats["test_score"]:
                if stat_dict["max_score"] > max_score:
                    max_score = stat_dict["max_score"]
            if max_score > viable_trial_threshold:
                viable_throughout_trials.append(max_score)
            if eval_method == "final_score":
                score = -trial["result"]["loss"]
            elif (
                eval_method == "rolling_average"
                or eval_method == "final_score_rolling_average"
            ):
                score = stats["test_score"][-1]["score"]
            scores.append(score)
    plot_trials(
        scores,
        file_name,
        final_trial=final_trial,
    )
    res = [
        list(x)
        for x in zip(
            *sorted(
                zip(losses, scores, final_rolling_averages, final_std_devs),
                key=itemgetter(0),
            )
        )
    ]
    losses = res[0]
    scores = res[1]
    final_rolling_averages = res[2]
    final_std_devs = res[3]
    viable_trials = [score for score in scores if score > viable_trial_threshold]
    print("Failed trials: ~{}%".format(round(failed_trials / len(scores) * 100)))
    print(
        "Viable trials (based on final score): ~{}%".format(
            round(len(viable_trials) / len(scores) * 100)
        )
    )
    print(
        "Viable trials (throughout training): ~{}%".format(
            round(len(viable_throughout_trials) / len(scores) * 100)
        )
    )
    print("Losses: {}".format(losses))
    print("Scores: {}".format(scores))
    print("Final rolling averages: {}".format(final_rolling_averages))
    print("Final standard deviations: {}".format(final_std_devs))
    print("Max loss: {}".format(max(losses)))
    print("Max score: {}".format(max(scores)))
    print("Max final rolling average: {}".format(max(final_rolling_averages)))
    print("Max final standard deviation: {}".format(max(final_std_devs)))
    print("Average loss: {}".format(np.mean(losses)))
    print("Average score: {}".format(np.mean(scores)))
    print("Average final rolling average: {}".format(np.mean(final_rolling_averages)))
    print("Average final standard deviation: {}".format(np.mean(final_std_devs)))
    viable_final_rolling_averages = [
        final_rolling_averages[i]
        for i, loss in enumerate(scores)
        if loss > viable_trial_threshold
    ]
    viable_std_devs = [
        final_std_devs[i]
        for i, loss in enumerate(scores)
        if loss > viable_trial_threshold
    ]
    print(
        "Average score of viable trials (based on final score): {}".format(
            np.mean(viable_trials)
        )
    )
    print(
        "Average final rolling average of viable trials (based on final score): {}".format(
            np.mean(viable_final_rolling_averages)
        )
    )
    print(
        "Average final standard deviation of viable trials (based on final score): {}".format(
            np.mean(viable_std_devs)
        )
    )
def graph_hyperparameter_importance(
    data_dir: str, trials_file: str, search_space_file: str, viable_trial_threshold: int
):
    with open(f"{data_dir}/{trials_file}", "rb") as f:
        trials = pickle.load(f)
    print(trials)
    search_space = pickle.load(open(f"./search_spaces/{search_space_file}", "rb"))
    values_dict = defaultdict(list)
    scores = []
    for trial in trials.trials:
        for key, value in space_eval(trial["misc"]["vals"], search_space).items():
            values_dict[key].append(value[0])
        scores.append(-trial["result"]["loss"])
    df = pd.DataFrame(values_dict)
    x_cols = df.columns
    df["scores"] = scores
    # print(df)
    df = df[df["scores"] > viable_trial_threshold]
    for col in x_cols:
        if col == "loss_function":
            continue
        plt = df.plot(x=col, y="scores", kind="scatter")
        grouped = df.groupby(col)["scores"]
        medians = grouped.median()
        means = grouped.mean()
        stddev = grouped.std()
        if not (col == "kernel_initializer" or col == "activation"):
            # plt.fill_between(medians.index, medians.values-stddev, medians.values+stddev, color="#00F0F0")
            plt.plot(means.index, means.values, color="#00FFFF")
        else:
            plt.scatter(means.index, means.values, c="#00FFFF")
        # plt.add_line
def calc_units(shape):
    shape = tuple(shape)
    if len(shape) == 1:
        return shape + shape
    if len(shape) == 2:
        # dense layer -> (in_channels, out_channels)
        return shape
    else:
        # conv_layer (Assuming convolution kernels (2D, 3D, or more).
        # kernel shape: (input_depth, depth, ...)
        in_units = shape[1]
        out_units = shape[0]
        c = 1
        for dim in shape[2:]:
            c *= dim
        return (c * in_units, c * out_units)
class VarianceScaling:
    def __init__(self, scale=0.1, mode="fan_in", distribution="uniform"):
        self.scale = scale
        self.mode = mode
        self.distribution = distribution
        assert mode == "fan_in" or mode == "fan_out" or mode == "fan_avg"
        assert distribution == "uniform", "only uniform distribution is supported"
    def __call__(self, tensor: Tensor) -> None:
        with torch.no_grad():
            scale = self.scale
            shape = tensor.shape
            in_units, out_units = calc_units(shape)
            if self.mode == "fan_in":
                scale /= in_units
            elif self.mode == "fan_out":
                scale /= out_units
            else:
                scale /= (in_units + out_units) / 2
            limit = math.sqrt(3.0 * scale)
            return tensor.uniform_(-limit, limit)
def isiterable(o):
    try:
        it = iter(o)
    except TypeError:
        return False
    return True
def tointlists(list):
    ret = []
    for x in list:
        if isiterable(x):
            ret.append(tointlists(x))
        else:
            ret.append(int(x))
    return ret
import time
from collections import deque
class StoppingCriteria:
    def __init__(self):
        pass
    def should_stop(self, details: dict) -> bool:
        return False
class TimeStoppingCriteria(StoppingCriteria):
    def __init__(self, max_runtime_sec=60 * 10):
        self.stop_time = time.time() + max_runtime_sec
    def should_stop(self, details: dict) -> bool:
        return time.time() > self.stop_time
class TrainingStepStoppingCritiera(StoppingCriteria):
    def __init__(self, max_training_steps=100000):
        self.max_training_steps = max_training_steps
    def should_stop(self, details: dict) -> bool:
        return details["training_step"] > self.max_training_steps
class EpisodesStoppingCriteria(StoppingCriteria):
    def __init__(self, max_episodes=100000):
        self.max_episodes = max_episodes
    def should_stop(self, details: dict) -> bool:
        return details["max_episodes"] > self.max_episodes
class AverageScoreStoppingCritera(StoppingCriteria):
    def __init__(self, min_avg_score: float, last_scores_length: int):
        self.min_avg_score = min_avg_score
        self.last_scores_length = last_scores_length
        self.last_scores = deque(maxlen=last_scores_length)
    def add_score(self, score: float):
        self.last_scores.append(score)
    def should_stop(self, details: dict) -> bool:
        if len(self.last_scores) < self.last_scores_length:
            return False
        return np.average(self.last_scores) < self.min_avg_score
class ApexLearnerStoppingCriteria(StoppingCriteria):
    def __init__(self):
        self.criterias: dict[str, StoppingCriteria] = {
            "time": TimeStoppingCriteria(max_runtime_sec=1.5 * 60 * 60),
            "training_step": TrainingStepStoppingCritiera(max_training_steps=10000),
            "avg_score": AverageScoreStoppingCritera(
                min_avg_score=15, last_scores_length=10
            ),
        }
    def should_stop(self, details: dict) -> bool:
        if self.criterias["time"].should_stop(details):
            return True
        if details["training_step"] < 10000:
            return False
        return self.criterias["training_step"].should_stop(details) or self.criterias[
            "avg_score"
        ].should_stop(details)
    def add_score(self, score: float):
        tc: AverageScoreStoppingCritera = self.criterias["avg_score"]
        tc.add_score(score)
import gc
import os
from pathlib import Path
import numpy as np
import torch
import gymnasium as gym
import copy
import pickle
from torch.optim import Optimizer
from torch.nn import Module
from utils import make_stack, plot_graphs
# Every model should have:
# 1. A network
# 2. An optimizer
# 3. A loss function
# 4. A training method
#       this method should have training iterations, minibatches, and training steps
# 6. A select_action method
# 7. A predict method
import torch
import yaml
class ConfigBase:
    def parse_field(
        self, field_name, default=None, wrapper=None, required=True, dtype=None
    ):
        if field_name in self.config_dict:
            val = self.config_dict[field_name]
            # print("value: ", val)
            print(f"Using         {field_name:30}: {val}")
            if wrapper is not None:
                return wrapper(val)
            return self.config_dict[field_name]
        if default is not None:
            print(f"Using default {field_name:30}: {default}")
            if wrapper is not None:
                return wrapper(default)
            return default
        if required:
            raise ValueError(
                f"Missing required field without default value: {field_name}"
            )
        else:
            print(f"Using         {field_name:30}: {default}")
        if field_name in self._parsed_fields:
            print("warning: duplicate field: ", field_name)
        self._parsed_fields.add(field_name)
    def __init__(self, config_dict: dict):
        self.config_dict = config_dict
        self._parsed_fields = set()
    @classmethod
    def load(cls, filepath: str):
        with open(filepath, "r") as f:
            o = yaml.load(f, yaml.Loader)
            print(o)
            a = cls(config_dict=o["config_dict"])
        return a
    def dump(self, filepath: str):
        to_dump = dict(config_dict=self.config_dict)
        with open(filepath, "w") as f:
            yaml.dump(to_dump, f, yaml.Dumper)
class GameConfig:
    def __init__(
        self,
        max_score,
        min_score,
        is_discrete,
        is_image,
        is_deterministic,
        has_legal_moves,
        perfect_information,
        multi_agent,
        num_players,
    ):
        self.max_score = max_score
        self.min_score = min_score
        self.is_discrete = is_discrete  # can just check the action space type instead of setting manually if the env is passed in (ALSO COULD DO THIS IN THE BASE GAME CONFIG)
        # self.num_actions = num_actions
        # self.observation_space = observation_space
        self.is_image = is_image
        self.is_deterministic = is_deterministic
        # self.num_players = num_players (might not need this idk) <- it would likely be for muzero but could also be for rainbow and stuff when they play multiplayer games (like connect 4)
        self.has_legal_moves = has_legal_moves
        self.perfect_information = perfect_information
        self.multi_agent = multi_agent
        self.num_players = num_players
    def __eq__(self, o: object) -> bool:
        if not isinstance(o, GameConfig):
            return False
        return (
            self.max_score == o.max_score
            and self.min_score == o.min_score
            and self.is_discrete == o.is_discrete
            and self.is_image == o.is_image
            and self.is_deterministic == o.is_deterministic
            and self.has_legal_moves == o.has_legal_moves
            and self.perfect_information == o.perfect_information
            and self.multi_agent == o.multi_agent
            and self.num_players == o.num_players
        )
class AtariConfig(GameConfig):
    def __init__(self):
        super(AtariConfig, self).__init__(
            max_score=10,  # FROM CATEGORICAL DQN PAPER
            min_score=-10,
            is_discrete=True,
            is_image=True,
            is_deterministic=False,  # if no frameskip, then deterministic
            has_legal_moves=False,
            perfect_information=True,  # although it is not deterministic, it is so close to it that it is considered perfect information
            multi_agent=False,
            num_players=1,
        )
class Config(ConfigBase):
    @classmethod
    def load(cls, filepath: str):
        with open(filepath, "r") as f:
            o = yaml.load(f, yaml.Loader)
            print(o)
            a = cls(config_dict=o["config_dict"], game_config=o["game"])
        return a
    def dump(self, filepath: str):
        to_dump = dict(config_dict=self.config_dict, game=self.game)
        with open(filepath, "w") as f:
            yaml.dump(to_dump, f, yaml.Dumper)
    def __init__(self, config_dict: dict, game_config: GameConfig) -> None:
        super().__init__(config_dict)
        # could take in a game config and set an action space and observation shape here
        # OR DO THAT IN BASE AGENT?
        self.game = game_config
        self._verify_game()
        # not hyperparameters but utility things
        self.save_intermediate_weights: bool = self.parse_field(
            "save_intermediate_weights", False
        )
        # ADD LEARNING RATE SCHEDULES
        self.training_steps: int = self.parse_field("training_steps", 10000, wrapper=int)
        self.adam_epsilon: float = self.parse_field("adam_epsilon", 1e-6)
        self.momentum = self.parse_field("momentum", 0.9)
        self.learning_rate: float = self.parse_field("learning_rate", 0.001)
        self.clipnorm: int = self.parse_field("clipnorm", 0)
        self.optimizer: torch.optim.Optimizer = self.parse_field(
            "optimizer", torch.optim.Adam
        )
        self.weight_decay: float = self.parse_field("weight_decay", 0.0)
        self.loss_function: Loss = self.parse_field("loss_function", required=True)
        self.activation = self.parse_field(
            "activation", "relu", wrapper=prepare_activations
        )
        self.kernel_initializer = self.parse_field(
            "kernel_initializer",
            None,
            required=False,
            wrapper=kernel_initializer_wrapper,
        )
        self.minibatch_size: int = self.parse_field("minibatch_size", 64, wrapper=int)
        self.replay_buffer_size: int = self.parse_field(
            "replay_buffer_size", 5000, wrapper=int
        )
        self.min_replay_buffer_size: int = self.parse_field(
            "min_replay_buffer_size", self.minibatch_size, wrapper=int
        )
        self.num_minibatches: int = self.parse_field("num_minibatches", 1, wrapper=int)
        self.training_iterations: int = self.parse_field(
            "training_iterations", 1, wrapper=int
        )
        self.print_interval: int = self.parse_field("print_interval", 100, wrapper=int)
    def _verify_game(self):
        raise NotImplementedError
class BaseAgent:
    def __init__(
        self,
        env: gym.Env,
        config: Config,
        name,
        device: torch.device = (
            torch.device("cuda")
            if torch.cuda.is_available()
            # MPS is sometimes useful for M2 instances, but only for large models/matrix multiplications otherwise CPU is faster
            else (
                torch.device("mps")
                if torch.backends.mps.is_available() and torch.backends.mps.is_built()
                else torch.device("cpu")
            )
        ),
        from_checkpoint=False,
    ):
        if from_checkpoint:
            self.from_checkpoint = True
        self.model: Module = None
        self.optimizer: Optimizer = None
        self.model_name = name
        self.config = config
        self.device = device
        self.training_time = 0
        self.training_step = 0
        self.total_environment_steps = 0
        self.training_steps = self.config.training_steps
        self.checkpoint_interval = max(self.training_steps // 30, 1)
        self.checkpoint_trials = 5
        self.env = env
        self.test_env = self.make_test_env(env)
        self.observation_dimensions = self.determine_observation_dimensions(env)
        print("observation_dimensions: ", self.observation_dimensions)
        if isinstance(env.action_space, gym.spaces.Discrete):
            self.num_actions = env.action_space.n
            self.discrete_action_space = True
        else:
            self.num_actions = env.action_space.shape[0]
            self.discrete_action_space = False
        print("num_actions: ", self.num_actions)
    def make_test_env(self, env: gym.Env):
        # self.test_env = copy.deepcopy(env)
        if hasattr(env, "render_mode") and env.render_mode == "rgb_array":
            # assert (
            #     self.env.render_mode == "rgb_array"
            # ), "Video recording for test_env requires render_mode to be 'rgb_array'"
            return gym.wrappers.RecordVideo(
                copy.deepcopy(env),
                ".",
                name_prefix="{}".format(self.model_name),
            )
        else:
            print(
                "Warning: test_env will not record videos as render_mode is not 'rgb_array'"
            )
            return copy.deepcopy(env)
    def determine_observation_dimensions(self, env: gym.Env):
        if isinstance(env.observation_space, gym.spaces.Box):
            return env.observation_space.shape
        elif isinstance(env.observation_space, gym.spaces.Discrete):
            return (1,)
        elif isinstance(env.observation_space, gym.spaces.Tuple):
            return (len(env.observation_space.spaces),)  # for tuple of discretes
        else:
            raise ValueError("Observation space not supported")
    def train(self):
        if self.training_steps != 0:
            self.print_resume_training()
        pass
    def preprocess(self, states) -> torch.Tensor:
        """Applies necessary preprocessing steps to a batch of environment observations or a single environment observation
        Does not alter the input state parameter, instead creating a new Tensor on the inputted device (default cpu)
        Args:
            state (Any): A or a list of state returned from self.env.step
        Returns:
            Tensor: The preprocessed state, a tensor of floats. If the input was a single environment step,
                    the returned tensor is returned as outputed as if a batch of states with a length of a batch size of 1
        """
        # always convert to np.array first for performance, recoommnded by pytorchx
        # special case: list of compressed images (which are LazyFrames)
        if isinstance(states[0], gym.wrappers.frame_stack.LazyFrames):
            np_states = np.array([np.array(state) for state in states])
        else:
            # single observation, could be compressed or not compressed
            # print("Single state")
            np_states = np.array(states)
        # print("Numpyified States", np_states)
        prepared_state = (
            torch.from_numpy(
                np_states,
            )
            .to(torch.float32)
            .to(self.device)
        )
        # if self.config.game.is_image:
        # normalize_images(prepared_state)
        # if the state is a single number, add a dimension (not the batch dimension!, just wrapping it in []s basically)
        if prepared_state.shape == torch.Size([]):
            prepared_state = prepared_state.unsqueeze(0)
        if prepared_state.shape == self.observation_dimensions:
            prepared_state = make_stack(prepared_state)
        return prepared_state
    def predict(
        self, state: torch.Tensor, *args
    ) -> torch.Tensor:  # args is for info for player counts or legal move masks
        """Run inference on 1 or a batch of environment states, applying necessary preprocessing steps
        Returns:
            Tensor: The predicted values, e.g. Q values for DQN or Q distributions for Categorical DQN
        """
        raise NotImplementedError
    def select_actions(self, predicted, info, mask_actions=False) -> torch.Tensor:
        """Return actions determined from the model output, appling postprocessing steps such as masking beforehand
        Args:
            state (_type_): _description_
            legal_moves (_type_, optional): _description_. Defaults to None.
        Raises:
            NotImplementedError: _description_
        Returns:
            Tensor: _description_
        """
        raise NotImplementedError
    def learn(self):
        # raise NotImplementedError, "Every agent should have a learn method. (Previously experience_replay)"
        pass
    def load_optimizer_state(self, checkpoint):
        self.optimizer.load_state_dict(checkpoint["optimizer"])
    def load_replay_buffers(self, checkpoint):
        self.replay_buffer = checkpoint["replay_buffer"]
    def load_model_weights(self, checkpoint):
        self.model.load_state_dict(checkpoint["model"])
    def checkpoint_base(self, checkpoint):
        checkpoint["training_time"] = self.training_time
        checkpoint["training_step"] = self.training_step
        checkpoint["total_environment_steps"] = self.total_environment_steps
        return checkpoint
    def checkpoint_environment(self, checkpoint):
        checkpoint["enviroment"] = self.env
        return checkpoint
    def checkpoint_optimizer_state(self, checkpoint):
        checkpoint["optimizer"] = self.optimizer.state_dict()
        return checkpoint
    def checkpoint_replay_buffers(self, checkpoint):
        checkpoint["replay_buffer"] = self.replay_buffer
        return checkpoint
    def checkpoint_model_weights(self, checkpoint):
        checkpoint["model"] = self.model.state_dict()
        return checkpoint
    def checkpoint_extra(self, checkpoint) -> dict:
        return checkpoint
    @classmethod
    def load(cls, *args, **kwargs):
        cls.loaded_from_checkpoint = True
        return cls.load_from_checkpoint(*args, **kwargs)
    def load_from_checkpoint(agent_class, config_class, dir: str, training_step):
        # load the config and checkpoint
        training_step_dir = Path(dir, f"step_{training_step}")
        weights_dir = Path(training_step_dir, "model_weights")
        weights_path = str(Path(training_step_dir, f"model_weights/weights.keras"))
        config = config_class.load(Path(dir, "configs/config.yaml"))
        checkpoint = torch.load(weights_path)
        env = checkpoint["enviroment"]
        model_name = checkpoint["model_name"]
        # construct the agent
        agent = agent_class(env, config, model_name, from_checkpoint=True)
        # load the model state (weights, optimizer, replay buffer, training time, training step, total environment steps)
        os.makedirs(weights_dir, exist_ok=True)
        agent.training_time = checkpoint["training_time"]
        agent.training_step = checkpoint["training_step"]
        agent.total_environment_steps = checkpoint["total_environment_steps"]
        agent.load_model_weights(checkpoint)
        agent.load_optimizer_state(checkpoint)
        agent.load_replay_buffers(checkpoint)
        # load the graph stats and targets
        with open(Path(training_step_dir, f"graphs_stats/stats.pkl"), "rb") as f:
            agent.stats = pickle.load(f)
        with open(Path(training_step_dir, f"graphs_stats/targets.pkl"), "rb") as f:
            agent.targets = pickle.load(f)
        return agent
    def save_checkpoint(
        self,
        frames_seen=None,
        training_step=None,
        time_taken=None,
    ):
        if not frames_seen is None:
            print(
                "warning: frames_seen option is deprecated, update self.total_environment_steps instead"
            )
        if not time_taken is None:
            print(
                "warning: time_taken option is deprecated, update self.training_time instead"
            )
        if not training_step is None:
            print(
                "warning: training_step option is deprecated, update self.training_step instead"
            )
        dir = Path("checkpoints", self.model_name)
        training_step_dir = Path(dir, f"step_{self.training_step}")
        os.makedirs(dir, exist_ok=True)
        # save the model state
        if self.config.save_intermediate_weights:
            weights_path = str(Path(training_step_dir, f"model_weights/weights.keras"))
            os.makedirs(Path(training_step_dir, "model_weights"), exist_ok=True)
            checkpoint = self.make_checkpoint_dict(checkpoint)
            torch.save(checkpoint, weights_path)
        if self.env.render_mode == "rgb_array":
            os.makedirs(Path(training_step_dir, "videos"), exist_ok=True)
        # save config
        os.makedirs(Path(dir, "configs"), exist_ok=True)
        self.config.dump(f"{dir}/configs/config.yaml")
        # test model
        test_score = self.test(
            self.checkpoint_trials, self.training_step, training_step_dir
        )
        self.stats["test_score"].append(test_score)
        # save the graph stats and targets
        os.makedirs(
            Path(training_step_dir, f"graphs_stats", exist_ok=True), exist_ok=True
        )
        with open(Path(training_step_dir, f"graphs_stats/stats.pkl"), "wb") as f:
            pickle.dump(self.stats, f)
        with open(Path(training_step_dir, f"graphs_stats/targets.pkl"), "wb") as f:
            pickle.dump(self.targets, f)
        # to periodically clear uneeded memory, if it is drastically slowing down training you can comment this out, checkpoint less often, or do less trials
        gc.collect()
        # plot the graphs (and save the graph)
        print(self.stats)
        print(self.targets)
        os.makedirs(Path(dir, "graphs"), exist_ok=True)
        plot_graphs(
            self.stats,
            self.targets,
            self.training_step if training_step is None else training_step,
            self.total_environment_steps if frames_seen is None else frames_seen,
            self.training_time if time_taken is None else time_taken,
            self.model_name,
            f"{dir}/graphs",
        )
    def make_checkpoint_dict(self):
        checkpoint = self.checkpoint_base({})
        checkpoint = self.checkpoint_environment(checkpoint)
        checkpoint = self.checkpoint_optimizer_state(checkpoint)
        checkpoint = self.checkpoint_replay_buffers(checkpoint)
        checkpoint = self.checkpoint_model_weights(checkpoint)
        checkpoint = self.checkpoint_extra(checkpoint)
        return checkpoint
    def test(self, num_trials, step, dir="./checkpoints") -> None:
        if num_trials == 0:
            return
        with torch.no_grad():
            """Test the agent."""
            average_score = 0
            max_score = float("-inf")
            min_score = float("inf")
            # self.test_env.reset()
            if self.test_env.render_mode == "rgb_array":
                self.test_env.episode_trigger = lambda x: (x + 1) % num_trials == 0
                self.test_env.video_folder = "{}/videos/{}/{}".format(
                    dir, self.model_name, step
                )
                if not os.path.exists(self.test_env.video_folder):
                    os.makedirs(self.test_env.video_folder)
            for trials in range(num_trials):
                state, info = self.test_env.reset()
                done = False
                score = 0
                while not done:
                    prediction = self.predict(
                        state, info, env=self.test_env
                    )  # env = self.test_env is there for alpha_zero which needs to use the test env here instead of the normal env for the tree search (might be able to just use the regular env still)
                    action = self.select_actions(
                        prediction, info, self.config.game.has_legal_moves
                    ).item()
                    next_state, reward, terminated, truncated, info = self.test_env.step(
                        action
                    )
                    # self.test_env.render()
                    done = terminated or truncated
                    state = next_state
                    score += reward[0] if isinstance(reward, list) else reward
                average_score += score
                max_score = max(max_score, score)
                min_score = min(min_score, score)
                print("score: ", score)
            # reset
            # if self.test_env.render_mode != "rgb_array":
            #     self.test_env.render()
            # self.test_env.close()
            average_score /= num_trials
            return {
                "score": average_score,
                "max_score": max_score,
                "min_score": min_score,
            }
    def print_training_progress(self):
        print(f"Training step: {self.training_step + 1}/{self.training_steps}")
    def print_resume_training(self):
        print(
            f"Resuming training at step {self.training_step + 1} / {self.training_steps}"
        )
    def print_stats(self):
        print(f"")
def unpack(x: int | Tuple):
    if isinstance(x, Tuple):
        assert len(x) == 2
        return x
    else:
        try:
            x = int(x)
            return x, x
        except Exception as e:
            print(f"error converting {x} to int: ", e)
class Conv2dStack(nn.Module):
    @staticmethod
    def calculate_same_padding(i, k, s) -> Tuple[None | Tuple[int], None | str | Tuple]:
        """Calculate pytorch inputs for same padding
        Args:
            i (int, int) or int: (h, w) or (w, w)
            k (int, int) or int: (k_h, k_w) or (k, k)
            s (int, int) or int: (s_h, s_w) or (s, s)
        Returns:
            Tuple[manual_pad_padding, torch_conv2d_padding_input]: Either the manual padding that must be applied (first element of tuple) or the input to the torch padding argument of the Conv2d layer
        """
        if s == 1:
            return None, "same"
        h, w = unpack(i)
        k_h, k_w = unpack(k)
        s_h, s_w = unpack(s)
        p_h = calculate_padding(h, k_h, s_h)
        p_w = calculate_padding(w, k_w, s_w)
        if p_h[0] == p_h[1] and p_w[0] == p_w[1]:
            return None, (p_h[0], p_w[0])
        else:
            # not torch compatiable, manually pad with torch.nn.functional.pad
            return (*p_w, *p_h), None
    def __init__(
        self,
        input_shape: tuple[int],
        filters: list[int],
        kernel_sizes: list[int | Tuple[int, int]],
        strides: list[int | Tuple[int, int]],
        activation: nn.Module = nn.ReLU(),
        noisy_sigma: float = 0,
    ):
        """A sequence of convolution layers with the activation function applied after each layer.
        Always applies the minimum zero-padding that ensures the output shape is equal to the input shape.
        Input shape in "BCHW" form, i.e. (batch_size, input_channels, height, width)
        """
        super(Conv2dStack, self).__init__()
        self.conv_layers = nn.ModuleList()
        self.activation = activation
        # [B, C_in, H, W]
        assert len(input_shape) == 4
        assert len(filters) == len(kernel_sizes) == len(strides)
        assert len(filters) > 0
        self.noisy = noisy_sigma != 0
        if self.noisy:
            print("warning: Noisy convolutions not implemented yet")
            # raise NotImplementedError("")
        current_input_channels = input_shape[1]
        for i in range(len(filters)):
            h, w = input_shape[2], input_shape[3]
            manual_padding, torch_padding = self.calculate_same_padding(
                (h, w), kernel_sizes[i], strides[i]
            )
            if not torch_padding is None:
                layer = nn.Conv2d(
                    in_channels=current_input_channels,
                    out_channels=filters[i],
                    kernel_size=kernel_sizes[i],
                    stride=strides[i],
                    padding=torch_padding,
                )
            else:
                layer = nn.Sequential(
                    nn.ZeroPad2d(manual_padding),
                    nn.Conv2d(
                        in_channels=current_input_channels,
                        out_channels=filters[i],
                        kernel_size=kernel_sizes[i],
                        stride=strides[i],
                    ),
                )
            self.conv_layers.append(layer)
            current_input_channels = filters[i]
        self._output_len = current_input_channels
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        def initialize_if_conv(m: nn.Module):
            if isinstance(m, nn.Conv2d):
                initializer(m.weight)
        self.apply(initialize_if_conv)
    def forward(self, inputs):
        x = inputs
        for layer in self.conv_layers:
            x = self.activation(layer(x))
        return x
    def reset_noise(self):
        assert self.noisy
        # noisy not implemented
        # for layer in self.conv_layers:
        #     # layer.reset_noise()
        # return
    def remove_noise(self):
        assert self.noisy
        # noisy not implemented
        # for layer in self.conv_layers:
        #     # layer.reset_noise()
        # return
    @property
    def output_channels(self):
        return self._output_len
from torch import nn, Tensor, functional
class Dense(nn.Module):
    def __init__(
        self, in_features: int, out_features: int, bias: bool = True, *args, **kwargs
    ):
        super(Dense, self).__init__(*args, **kwargs)
        self.layer = nn.Linear(
            in_features=in_features, out_features=out_features, bias=bias
        )
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        initializer(self.layer.weight)
    def forward(self, inputs: Tensor) -> Tensor:
        return self.layer(inputs)
    def extra_repr(self) -> str:
        return self.layer.extra_repr()
class NoisyDense(nn.Module):
    """See https://arxiv.org/pdf/1706.10295."""
    @staticmethod
    def f(x: Tensor):
        return x.sgn() * (x.abs().sqrt())
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        initial_sigma: float = 0.5,
        use_factorized: bool = True,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.initial_sigma = initial_sigma
        self.use_factorized = use_factorized
        self.use_bias = bias
        self.mu_w = nn.Parameter(torch.empty(out_features, in_features))
        self.sigma_w = nn.Parameter(torch.empty(out_features, in_features))
        self.eps_w = self.register_buffer("eps_w", torch.empty(out_features, in_features))
        if self.use_bias:
            self.mu_b = nn.Parameter(torch.empty(out_features))
            self.sigma_b = nn.Parameter(torch.empty(out_features))
            self.eps_b = self.register_buffer("eps_b", torch.empty(out_features))
        else:
            self.register_parameter("mu_b", None)
            self.register_parameter("sigma_b", None)
            self.eps_b = self.register_buffer("eps_b", None)
        self.reset_parameters()
        self.reset_noise()
    def reset_noise(self) -> None:
        if self.use_factorized:
            eps_i = torch.randn(1, self.in_features).to(self.mu_w.device)
            eps_j = torch.randn(self.out_features, 1).to(self.mu_w.device)
            self.eps_w = self.f(eps_j) @ self.f(eps_i)
            self.eps_b = self.f(eps_j).reshape(self.out_features)
        else:
            self.eps_w = self.f(torch.randn(self.mu_w.shape)).to(self.mu_w.device)
            if self.use_bias:
                self.eps_b = self.f(torch.randn(size=self.mu_b.shape)).to(
                    self.mu_w.device
                )
    def remove_noise(self) -> None:
        self.eps_w = torch.zeros_like(self.mu_w).to(self.mu_w.device)
        if self.use_bias:
            self.eps_b = torch.zeros_like(self.mu_b).to(self.mu_w.device)
    def reset_parameters(self) -> None:
        p = self.in_features
        if self.use_factorized:
            mu_init = 1.0 / (p**0.5)
            sigma_init = self.initial_sigma / (p**0.5)
        else:
            mu_init = (3.0 / p) ** 0.5
            sigma_init = 0.017
        nn.init.constant_(self.sigma_w, sigma_init)
        nn.init.uniform_(self.mu_w, -mu_init, mu_init)
        if self.use_bias:
            nn.init.constant_(self.sigma_b, sigma_init)
            nn.init.uniform_(self.mu_b, -mu_init, mu_init)
    @property
    def weight(self):
        return self.mu_w + self.sigma_w * self.eps_w
    @property
    def bias(self):
        if self.use_bias:
            return self.mu_b + self.sigma_b * self.eps_b
        else:
            return None
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        pass
    def forward(self, input: Tensor) -> Tensor:
        return functional.F.linear(input, self.weight, self.bias)
    def extra_repr(self) -> str:
        return f"in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}, initial_sigma={self.initial_sigma}, use_factorized={self.use_factorized}"
def build_dense(in_features: int, out_features: int, sigma: float = 0):
    if sigma == 0:
        return Dense(in_features, out_features)
    else:
        return NoisyDense(in_features, out_features)
class DenseStack(nn.Module):
    def __init__(
        self,
        initial_width: int,
        widths: list[int],
        activation: nn.Module = nn.ReLU(),
        noisy_sigma: float = 0,
    ):
        super(DenseStack, self).__init__()
        self.dense_layers: nn.ModuleList = nn.ModuleList()
        self.activation = activation
        assert len(widths) > 0
        self.noisy = noisy_sigma != 0
        current_input_width = initial_width
        for i in range(len(widths)):
            layer = build_dense(
                in_features=current_input_width,
                out_features=widths[i],
                sigma=noisy_sigma,
            )
            self.dense_layers.append(layer)
            current_input_width = widths[i]
        self.initial_width = initial_width
        self._output_len = current_input_width
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        for layer in self.dense_layers:
            layer.initialize(initializer)
    def forward(self, inputs: Tensor) -> Tensor:
        x = inputs
        for layer in self.dense_layers:
            x = self.activation(layer(x))
        return x
    def reset_noise(self) -> None:
        assert self.noisy
        for layer in self.dense_layers:
            layer.reset_noise()
        return
    def remove_noise(self) -> None:
        assert self.noisy
        for layer in self.dense_layers:
            layer.remove_noise()
        return
    def extra_repr(self) -> str:
        return f"in_features={self.initial_width}, out_width={self.output_width}, noisy={self.noisy}"
    @property
    def output_width(self):
        return self._output_len
class ResidualStack(nn.Module):
    def __init__(
        self,
        input_shape: tuple[int],
        filters: list[int],
        kernel_sizes: list[int | Tuple[int, int]],
        strides: list[int | Tuple[int, int]],
        activation: nn.Module = nn.ReLU(),
        noisy_sigma: float = 0,
    ):
        """A sequence of residual layers with the activation function applied after each layer.
        Always applies the minimum zero-padding that ensures the output shape is equal to the input shape.
        Input shape in "BCHW" form, i.e. (batch_size, input_channels, height, width)
        """
        super(ResidualStack, self).__init__()
        self.residual_layers = nn.ModuleList()
        self.activation = activation
        # [B, C_in, H, W]
        assert (
            len(input_shape) == 4
            and len(filters) == len(kernel_sizes) == len(strides)
            and len(filters) > 0
        )
        self.noisy = noisy_sigma != 0
        if self.noisy:
            print("warning: Noisy convolutions not implemented yet")
            # raise NotImplementedError("")
        current_input_channels = input_shape[1]
        for i in range(len(filters)):
            print(current_input_channels)
            layer = Residual(
                in_channels=current_input_channels,
                out_channels=filters[i],
                kernel_size=kernel_sizes[i],
                stride=strides[i],
            )
            self.residual_layers.append(layer)
            current_input_channels = filters[i]
        self._output_len = current_input_channels
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        def initialize_if_conv(m: nn.Module):
            if isinstance(m, nn.Conv2d):
                initializer(m.weight)
        self.apply(initialize_if_conv)
    def forward(self, inputs):
        x = inputs
        for layer in self.residual_layers:
            x = self.activation(layer(x))
        return x
    def reset_noise(self):
        assert self.noisy
        # noisy not implemented
        # for layer in self.conv_layers:
        #     # layer.reset_noise()
        # return
    def remove_noise(self):
        assert self.noisy
        # noisy not implemented
        # for layer in self.conv_layers:
        #     # layer.reset_noise()
        # return
    @property
    def output_channels(self):
        return self._output_len
class Residual(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        stride,
    ):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding="same",
        )
        # REGULARIZATION?
        self.bn1 = nn.BatchNorm2d(
            num_features=out_channels,
        )
        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            padding="same",
        )
        # REGULARIZATION?
        self.bn2 = nn.BatchNorm2d(
            num_features=out_channels,
        )
        self.relu = nn.ReLU()
        self.downsample = None
        if in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(
                    in_channels=in_channels,
                    out_channels=out_channels,
                    kernel_size=kernel_size,
                    padding="same",
                    bias=False,
                ),
                nn.BatchNorm2d(out_channels),
            )
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        def initialize_if_conv(m: nn.Module):
            if isinstance(m, nn.Conv2d):
                initializer(m.weight)
        self.apply(initialize_if_conv)
    def forward(self, inputs):
        residual = self.downsample(inputs) if self.downsample else inputs
        x = self.conv1(inputs)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x + residual)
        return x
def kernel_initializer_wrapper(x):
    if x is None:
        return x
    elif isinstance(x, str):
        return prepare_kernel_initializers(x)
    else:
        assert callable(x)
        return x
class RainbowConfig(Config):
    def __init__(self, config_dict: dict, game_config):
        super(RainbowConfig, self).__init__(config_dict, game_config)
        print("RainbowConfig")
        self.residual_layers: list = self.parse_field("residual_layers", [])
        self.conv_layers: list = self.parse_field("conv_layers", [])
        self.dense_layer_widths: int = self.parse_field(
            "dense_layer_widths", [128], tointlists
        )
        self.value_hidden_layer_widths = self.parse_field(
            "value_hidden_layer_widths", [], tointlists
        )
        self.advantage_hidden_layer_widths: int = self.parse_field(
            "advantage_hidden_layer_widths", [], tointlists
        )
        self.noisy_sigma: float = self.parse_field("noisy_sigma", 0.5)
        self.eg_epsilon: float = self.parse_field("eg_epsilon", 0.00)
        self.eg_epsilon_final: float = self.parse_field("eg_epsilon_final", 0.00)
        self.eg_epsilon_decay_type: str = self.parse_field(
            "eg_epsilon_decay_type", "linear"
        )
        self.eg_epsilon_final_step: int = self.parse_field(
            "eg_epsilon_final_step", self.training_steps
        )
        self.dueling: bool = self.parse_field("dueling", True)
        self.discount_factor: float = self.parse_field("discount_factor", 0.99)
        self.soft_update: bool = self.parse_field("soft_update", False)
        self.transfer_interval: int = self.parse_field(
            "transfer_interval", 512, wrapper=int
        )
        self.ema_beta: float = self.parse_field("ema_beta", 0.99)
        self.replay_interval: int = self.parse_field("replay_interval", 1, wrapper=int)
        self.per_alpha: float = self.parse_field("per_alpha", 0.6)
        self.per_beta: float = self.parse_field("per_beta", 0.5)
        self.per_beta_final: float = self.parse_field("per_beta_final", 1.0)
        self.per_epsilon: float = self.parse_field("per_epsilon", 1e-6)
        self.n_step: int = self.parse_field("n_step", 3)
        self.atom_size: int = self.parse_field("atom_size", 51, wrapper=int)
        # assert (
        #     self.atom_size > 1
        # ), "Atom size must be greater than 1, as softmax and Q distribution to Q value calculation requires more than 1 atom"
        # assert not (
        #     self.game.is_image
        #     and len(self.conv_layers) == 0
        #     and len(self.residual_layers) == 0
        # ), "Convolutional layers must be defined for image based games"
        if len(self.conv_layers) > 0:
            assert len(self.conv_layers[0]) == 3
        # maybe don't use a game config, since if tuning for multiple games this should be the same regardless of the game <- (it is really a hyper parameter if you are tuning for multiple games or a game with unknown bounds)
        # could use a MuZero min-max config and just constantly update the suport size (would this break the model?) <- might mean this is not in the config but just a part of the model
        self.v_min = game_config.min_score
        self.v_max = game_config.max_score
        if self.atom_size != 1:
            assert self.v_min != None and self.v_max != None
    def _verify_game(self):
        assert self.game.is_discrete, "Rainbow only supports discrete action spaces"
class RainbowNetwork(nn.Module):
    def __init__(
        self,
        config: RainbowConfig,
        output_size: int,
        input_shape: Tuple[int],
        *args,
        **kwargs,
    ):
        super().__init__(*args, **kwargs)
        self.config = config
        self.has_residual_layers = len(config.residual_layers) > 0
        self.has_conv_layers = len(config.conv_layers) > 0
        self.has_dense_layers = len(config.dense_layer_widths) > 0
        assert (
            self.has_conv_layers or self.has_dense_layers or self.has_residual_layers
        ), "At least one of the layers should be present."
        self.has_value_hidden_layers = len(config.value_hidden_layer_widths) > 0
        self.has_advantage_hidden_layers = len(config.advantage_hidden_layer_widths) > 0
        if not self.config.dueling:
            assert not (
                self.has_value_hidden_layers or self.has_advantage_hidden_layers
            ), "Value or Advantage hidden layers are only used in dueling networks"
        self.output_size = output_size
        current_shape = input_shape
        B = current_shape[0]
        if self.has_residual_layers:
            assert (
                len(input_shape) == 4
            ), "Input shape should be (B, C, H, W), got {}".format(input_shape)
            filters, kernel_sizes, strides = to_lists(config.residual_layers)
            # (B, C_in, H, W) -> (B, C_out H, W)
            self.residual_layers = ResidualStack(
                input_shape=input_shape,
                filters=filters,
                kernel_sizes=kernel_sizes,
                strides=strides,
                activation=self.config.activation,
                noisy_sigma=config.noisy_sigma,
            )
            current_shape = (
                B,
                self.residual_layers.output_channels,
                current_shape[2],
                current_shape[3],
            )
        if self.has_conv_layers:
            assert (
                len(input_shape) == 4
            ), "Input shape should be (B, C, H, W), got {}".format(input_shape)
            filters, kernel_sizes, strides = to_lists(config.conv_layers)
            # (B, C_in, H, W) -> (B, C_out H, W)
            self.conv_layers = Conv2dStack(
                input_shape=input_shape,
                filters=filters,
                kernel_sizes=kernel_sizes,
                strides=strides,
                activation=self.config.activation,
                noisy_sigma=config.noisy_sigma,
            )
            current_shape = (
                B,
                self.conv_layers.output_channels,
                current_shape[2],
                current_shape[3],
            )
        if self.has_dense_layers:
            if len(current_shape) == 4:
                initial_width = current_shape[1] * current_shape[2] * current_shape[3]
            else:
                assert len(current_shape) == 2
                initial_width = current_shape[1]
            # (B, width_in) -> (B, width_out)
            self.dense_layers = DenseStack(
                initial_width=initial_width,
                widths=self.config.dense_layer_widths,
                activation=self.config.activation,
                noisy_sigma=self.config.noisy_sigma,
            )
            current_shape = (
                B,
                self.dense_layers.output_width,
            )
        if len(current_shape) == 4:
            initial_width = current_shape[1] * current_shape[2] * current_shape[3]
        else:
            assert (
                len(current_shape) == 2
            ), "Input shape should be (B, width), got {}".format(current_shape)
            initial_width = current_shape[1]
        if self.config.dueling:
            if self.has_value_hidden_layers:
                # (B, width_in) -> (B, value_in_features) -> (B, atom_size)
                self.value_hidden_layers = DenseStack(
                    initial_width=initial_width,
                    widths=self.config.value_hidden_layer_widths,
                    activation=self.config.activation,
                    noisy_sigma=self.config.noisy_sigma,
                )
                value_in_features = self.value_hidden_layers.output_width
            else:
                value_in_features = initial_width
            # (B, value_in_features) -> (B, atom_size)
            self.value_layer = build_dense(
                in_features=value_in_features,
                out_features=config.atom_size,
                sigma=config.noisy_sigma,
            )
            if self.has_advantage_hidden_layers:
                # (B, width_in) -> (B, advantage_in_features)
                self.advantage_hidden_layers = DenseStack(
                    initial_width=initial_width,
                    widths=self.config.advantage_hidden_layer_widths,
                    activation=self.config.activation,
                    noisy_sigma=self.config.noisy_sigma,
                )
                advantage_in_features = self.advantage_hidden_layers.output_width
            else:
                advantage_in_features = initial_width
            # (B, advantage_in_features) -> (B, output_size * atom_size)
            self.advantage_layer = build_dense(
                in_features=advantage_in_features,
                out_features=output_size * config.atom_size,
                sigma=self.config.noisy_sigma,
            )
        else:
            self.distribution_layer = build_dense(
                in_features=initial_width,
                out_features=self.output_size * self.config.atom_size,
                sigma=self.config.noisy_sigma,
            )
    def initialize(self, initializer: Callable[[Tensor], None]) -> None:
        if self.has_residual_layers:
            self.residual_layers.initialize(initializer)
        if self.has_conv_layers:
            self.conv_layers.initialize(initializer)
        if self.has_dense_layers:
            self.dense_layers.initialize(initializer)
        if self.has_value_hidden_layers:
            self.value_hidden_layers.initialize(initializer)
        if self.has_advantage_hidden_layers:
            self.advantage_hidden_layers.initialize(initializer)
        if self.config.dueling:
            self.value_layer.initialize(initializer)
            self.advantage_layer.initialize(initializer)
    def forward(self, inputs: Tensor) -> Tensor:
        if self.has_conv_layers:
            assert inputs.dim() == 4
        # (B, *)
        S = inputs
        # (B, C_in, H, W) -> (B, C_out, H, W)
        if self.has_residual_layers:
            S = self.residual_layers(S)
        # (B, C_in, H, W) -> (B, C_out, H, W)
        if self.has_conv_layers:
            S = self.conv_layers(S)
        # (B, *) -> (B, dense_features_in)
        S = S.flatten(1, -1)
        # (B, dense_features_in) -> (B, dense_features_out)
        if self.has_dense_layers:
            S = self.dense_layers(S)
        if self.config.dueling:
            # (B, value_hidden_in) -> (B, value_hidden_out)
            if self.has_value_hidden_layers:
                v = self.value_hidden_layers(S)
            else:
                v = S
            # (B, value_hidden_in || dense_features_out) -> (B, atom_size) -> (B, 1, atom_size)
            v: Tensor = self.value_layer(v).view(-1, 1, self.config.atom_size)
            # (B, adv_hidden_in) -> (B, adv_hidden_out)
            if self.has_advantage_hidden_layers:
                A = self.advantage_hidden_layers(S)
            else:
                A = S
            # (B, adv_hidden_out || dense_features_out) -> (B, output_size * atom_size) -> (B, output_size, atom_size)
            A: Tensor = self.advantage_layer(A).view(
                -1, self.output_size, self.config.atom_size
            )
            # (B, output_size, atom_size) -[mean(1)]-> (B, 1, atom_size)
            a_mean = A.mean(1, keepdim=True)
            # (B, 1, atom_size) +
            # (B, output_size, atom_size) +
            # (B, 1, atom_size)
            # is valid broadcasting operation
            Q = v + A - a_mean
            # -[softmax(2)]-> turns the atom dimension into a valid p.d.f.
            # ONLY CLIP FOR CATEGORICAL CROSS ENTROPY LOSS TO PREVENT NAN
            # MIGHT BE ABLE TO REMOVE CLIPPING ENTIRELY SINCE I DONT THINK THE TENSORFLOW LOSSES CAN RETURN NaN
            # q.clip(1e-3, 1)
        else:
            # (B, dense_features_out) -> (B, output_size, atom_size)
            Q = self.distribution_layer(S).view(
                -1, self.output_size, self.config.atom_size
            )
        if self.config.atom_size == 1:
            return Q.squeeze(-1)
        else:
            return Q.softmax(dim=-1)
    def reset_noise(self):
        if self.config.noisy_sigma != 0:
            if self.has_residual_layers:
                self.residual_layers.reset_noise()
            if self.has_conv_layers:
                self.conv_layers.reset_noise()
            if self.has_dense_layers:
                self.dense_layers.reset_noise()
            if self.has_value_hidden_layers:
                self.value_hidden_layers.reset_noise()
            if self.has_advantage_hidden_layers:
                self.advantage_hidden_layers.reset_noise()
            if self.config.dueling:
                self.value_layer.reset_noise()
                self.advantage_layer.reset_noise()
    def remove_noise(self):
        if self.config.noisy_sigma != 0:
            if self.has_residual_layers:
                self.residual_layers.remove_noise()
            if self.has_conv_layers:
                self.conv_layers.remove_noise()
            if self.has_dense_layers:
                self.dense_layers.remove_noise()
            if self.has_value_hidden_layers:
                self.value_hidden_layers.remove_noise()
            if self.has_advantage_hidden_layers:
                self.advantage_hidden_layers.remove_noise()
            if self.config.dueling:
                self.value_layer.remove_noise()
                self.advantage_layer.remove_noise()
from time import time
import numpy as np
import operator
from typing import Callable
class SegmentTree:
    """Create SegmentTree.
    Taken from OpenAI baselines github repository:
    https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
    Attributes:
        capacity (int)
        tree (list)
        operation (function)
    """
    def __init__(self, capacity: int, operation: Callable, init_value: float):
        """Initialization.
        Args:
            capacity (int)
            operation (function)
            init_value (float)
        """
        assert (
            capacity > 0 and capacity & (capacity - 1) == 0
        ), "capacity must be positive and a power of 2."
        self.capacity = capacity
        self.tree = [init_value for _ in range(2 * capacity)]
        self.operation = operation
    def _operate_helper(
        self, start: int, end: int, node: int, node_start: int, node_end: int
    ) -> float:
        """Returns result of operation in segment."""
        if start == node_start and end == node_end:
            return self.tree[node]
        mid = (node_start + node_end) // 2
        if end <= mid:
            return self._operate_helper(start, end, 2 * node, node_start, mid)
        else:
            if mid + 1 <= start:
                return self._operate_helper(start, end, 2 * node + 1, mid + 1, node_end)
            else:
                return self.operation(
                    self._operate_helper(start, mid, 2 * node, node_start, mid),
                    self._operate_helper(mid + 1, end, 2 * node + 1, mid + 1, node_end),
                )
    def operate(self, start: int = 0, end: int = 0) -> float:
        """Returns result of applying `self.operation`."""
        if end <= 0:
            end += self.capacity
        end -= 1
        return self._operate_helper(start, end, 1, 0, self.capacity - 1)
    def __setitem__(self, idx: int, val: float):
        """Set value in tree."""
        idx += self.capacity
        self.tree[idx] = val
        idx //= 2
        while idx >= 1:
            self.tree[idx] = self.operation(self.tree[2 * idx], self.tree[2 * idx + 1])
            idx //= 2
    def __getitem__(self, idx: int) -> float:
        """Get real value in leaf node of tree."""
        assert 0 <= idx < self.capacity
        return self.tree[self.capacity + idx]
class SumSegmentTree(SegmentTree):
    """Create SumSegmentTree.
    Taken from OpenAI baselines github repository:
    https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
    """
    def __init__(self, capacity: int):
        """Initialization.
        Args:
            capacity (int)
        """
        super(SumSegmentTree, self).__init__(
            capacity=capacity, operation=operator.add, init_value=0.0
        )
    def sum(self, start: int = 0, end: int = 0) -> float:
        """Returns arr[start] + ... + arr[end]."""
        return super(SumSegmentTree, self).operate(start, end)
    def retrieve(self, upperbound: float) -> int:
        """Find the highest index `i` about upper bound in the tree"""
        # TODO: Check assert case and fix bug
        assert 0 <= upperbound <= self.sum() + 1e-5, "upperbound: {}".format(upperbound)
        idx = 1
        while idx < self.capacity:  # while non-leaf
            left = 2 * idx
            right = left + 1
            if self.tree[left] > upperbound:
                idx = 2 * idx
            else:
                upperbound -= self.tree[left]
                idx = right
        return idx - self.capacity
class MinSegmentTree(SegmentTree):
    """Create SegmentTree.
    Taken from OpenAI baselines github repository:
    https://github.com/openai/baselines/blob/master/baselines/common/segment_tree.py
    """
    def __init__(self, capacity: int):
        """Initialization.
        Args:
            capacity (int)
        """
        super(MinSegmentTree, self).__init__(
            capacity=capacity, operation=min, init_value=float("inf")
        )
    def min(self, start: int = 0, end: int = 0) -> float:
        """Returns min(arr[start], ...,  arr[end])."""
        return super(MinSegmentTree, self).operate(start, end)
class FastSumTree(object):
    # https://medium.com/free-code-camp/improvements-in-deep-q-learning-dueling-double-dqn-prioritized-experience-replay-and-fixed-58b130cc5682
    def __init__(self, capacity: int):
        self.capacity = (
            capacity  # number of leaf nodes (final nodes) that contains experiences
        )
        self.tree = np.zeros(2 * self.capacity - 1)  # sub tree
        # self.data = np.zeros(self.capacity, object)  # contains the experiences
    def add(self, idx: int, val: float):
        """Set value in tree."""
        tree_index = idx + self.capacity - 1
        # self.data[self.data_pointer] = data
        self.update(tree_index, val)
    def __getitem__(self, idx: int) -> float:
        """Get real value in leaf node of tree."""
        assert 0 <= idx < self.capacity
        return self.tree[self.capacity + idx]
    def update(self, tree_index: int, val: float):
        change = val - self.tree[tree_index]
        # print("change", change)
        self.tree[tree_index] = val
        while tree_index != 0:
            tree_index = (tree_index - 1) // 2
            self.tree[tree_index] += change
            # print("new value", self.tree[tree_index])
    def retrieve(self, v: float):
        parent_index = 0
        while True:
            left_child_index = 2 * parent_index + 1
            right_child_index = left_child_index + 1
            if left_child_index >= len(self.tree):
                leaf_index = parent_index
                break
            else:
                if v <= self.tree[left_child_index]:
                    parent_index = left_child_index
                else:
                    v -= self.tree[left_child_index]
                    parent_index = right_child_index
        return leaf_index, self.tree[leaf_index]
    @property
    def total_priority(self):
        return self.tree[0]
class BaseReplayBuffer:
    def __init__(
        self,
        max_size: int,
        batch_size: int = None,
        compressed_observations: bool = False,
    ):
        self.max_size = max_size
        self.batch_size = batch_size if batch_size is not None else max_size
        self.compressed_observations = compressed_observations
        self.clear()
        assert self.size == 0, "Replay buffer should be empty at initialization"
        assert self.max_size > 0, "Replay buffer should have a maximum size"
        assert self.batch_size > 0, "Replay buffer batch size should be greater than 0"
    def store(self, *args, **kwargs):
        raise NotImplementedError
    def sample(self, *args, **kwargs):
        raise NotImplementedError
    def sample_from_indices(self, indices: list[int]):
        raise NotImplementedError
    def clear(self):
        raise NotImplementedError
    def load(self, path):
        raise NotImplementedError
    def __len__(self):
        return self.size
class Game:
    def __init__(
        self, num_players: int
    ):  # num_actions, discount=1.0, n_step=1, gamma=0.99
        self.length = 0
        self.observation_history = []
        self.rewards = []
        self.policy_history = []
        self.value_history = []
        self.action_history = []
        self.info_history = []
        self.num_players = num_players
    def append(
        self,
        observation,
        reward: int,
        policy,
        value=None,
        action=None,
        info=None,
    ):
        self.observation_history.append(copy.deepcopy(observation))
        self.rewards.append(reward)
        self.policy_history.append(policy)
        self.value_history.append(value)
        self.action_history.append(action)
        self.info_history.append(info)
        self.length += 1
    def set_rewards(self):
        print("Initial Rewards", self.rewards)
        final_reward = self.rewards[-1]
        for i in reversed(range(self.length)):
            self.rewards[i] = (
                final_reward[i % self.num_players]
                # if i % self.num_players == (self.length - 1) % self.num_players
                # else -final_reward
            )
        print("Updated Rewards", self.rewards)
    def __len__(self):
        return self.length
class BaseGameReplayBuffer(BaseReplayBuffer):
    def __init__(
        self,
        max_size: int,
        batch_size: int,
    ):
        super().__init__(max_size=max_size, batch_size=batch_size)
    def store(self, game: Game):
        if len(self.buffer) >= self.max_size:
            self.buffer.pop(0)
        self.buffer.append(game)
        self.size += 1
    def sample(self):
        move_sum = float(sum([len(game) for game in self.buffer]))
        games: list[Game] = np.random.choice(
            self.buffer,
            self.batch_size,
            p=[len(game) / move_sum for game in self.buffer],
        )
        return [(game, np.random.randint(len(game))) for game in games]
    def clear(self):
        self.buffer: list[Game] = []
        self.size = 0
class BaseDQNReplayBuffer(BaseReplayBuffer):
    def __init__(
        self,
        observation_dimensions: tuple,
        observation_dtype: np.dtype,
        max_size: int,
        batch_size: int = 32,
        compressed_observations: bool = False,
    ):
        self.observation_dimensions = observation_dimensions
        self.observation_dtype = observation_dtype
        print(observation_dtype)
        super().__init__(
            max_size=max_size,
            batch_size=batch_size,
            compressed_observations=compressed_observations,
        )
    def store(
        self,
        observation,
        info: dict,
        action,
        reward: float,
        next_observation,
        next_info: dict,
        done: bool,
        id=None,
    ):
        # compute n-step return and store
        self.id_buffer[self.pointer] = id
        self.observation_buffer[self.pointer] = observation
        self.action_buffer[self.pointer] = action
        self.reward_buffer[self.pointer] = reward
        self.next_observation_buffer[self.pointer] = next_observation
        self.done_buffer[self.pointer] = done
        self.info_buffer[self.pointer] = info
        self.next_info_buffer[self.pointer] = next_info
        self.pointer = (self.pointer + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    def clear(self):
        if self.compressed_observations:
            self.observation_buffer = np.zeros(self.max_size, dtype=np.object_)
            self.next_observation_buffer = np.zeros(self.max_size, dtype=np.object_)
        else:
            observation_buffer_shape = (self.max_size,) + self.observation_dimensions
            self.observation_buffer = np.zeros(
                observation_buffer_shape, self.observation_dtype
            )
            self.next_observation_buffer = np.zeros(
                observation_buffer_shape, dtype=self.observation_dtype
            )
        self.id_buffer = np.zeros(self.max_size, dtype=np.object_)
        self.action_buffer = np.zeros(self.max_size, dtype=np.uint8)
        self.reward_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.done_buffer = np.zeros(self.max_size, dtype=np.bool_)
        self.info_buffer = np.zeros(self.max_size, dtype=np.object_)
        self.next_info_buffer = np.zeros(self.max_size, dtype=np.object_)
        self.pointer = 0
        self.size = 0
    def sample(self):
        indices = np.random.choice(self.size, self.batch_size, replace=False)
        return dict(
            observations=self.observation_buffer[indices],
            next_observations=self.next_observation_buffer[indices],
            actions=self.action_buffer[indices],
            rewards=self.reward_buffer[indices],
            dones=self.done_buffer[indices],
            ids=self.id_buffer[indices],
            info=self.info_buffer[indices],
            next_info=self.next_info_buffer[indices],
        )
    def sample_from_indices(self, indices: list[int]):
        return dict(
            observations=self.observation_buffer[indices],
            next_observations=self.next_observation_buffer[indices],
            actions=self.action_buffer[indices],
            rewards=self.reward_buffer[indices],
            dones=self.done_buffer[indices],
            ids=self.id_buffer[indices],
            infos=self.info_buffer[indices],
            next_infos=self.next_info_buffer[indices],
        )
    def __check_id__(self, index: int, id: str) -> bool:
        return self.id_buffer[index] == id
class BasePPOReplayBuffer(BaseReplayBuffer):
    def __init__(
        self,
        observation_dimensions,
        observation_dtype: np.dtype,
        max_size: int,
        gamma: float = 0.99,
        gae_lambda: float = 0.95,
        compressed_observations: bool = False,
    ):
        self.observation_dimensions = observation_dimensions
        self.observation_dtype = observation_dtype
        self.gamma = gamma
        self.gae_lambda = gae_lambda
        super().__init__(
            max_size=max_size, compressed_observations=compressed_observations
        )
    def store(
        self,
        observation,
        info: dict,
        action,
        value: float,
        log_probability: float,
        reward: float,
        id=None,
    ):
        self.observation_buffer[self.pointer] = observation
        self.action_buffer[self.pointer] = action
        self.reward_buffer[self.pointer] = reward
        self.value_buffer[self.pointer] = value
        self.log_probability_buffer[self.pointer] = log_probability
        self.info_buffer[self.pointer] = info
        self.pointer = (self.pointer + 1) % self.max_size
        self.size = min(self.size + 1, self.max_size)
    def sample(self):
        self.pointer, self.trajectory_start_index = 0, 0
        advantage_mean = np.mean(self.advantage_buffer)
        advantage_std = np.std(self.advantage_buffer)
        self.advantage_buffer = (self.advantage_buffer - advantage_mean) / (
            advantage_std + 1e-10
        )  # avoid division by zero
        return dict(
            observations=self.observation_buffer,
            actions=self.action_buffer,
            advantages=self.advantage_buffer,
            returns=self.return_buffer,
            log_probabilities=self.log_probability_buffer,
            infos=self.info_buffer,
        )
    def clear(self):
        if self.compressed_observations:
            self.observation_buffer = np.zeros(self.max_size, dtype=np.object_)
            self.next_observation_buffer = np.zeros(self.max_size, dtype=np.object_)
        else:
            observation_buffer_shape = (self.max_size,) + self.observation_dimensions
            self.observation_buffer = np.zeros(
                observation_buffer_shape, self.observation_dtype
            )
            self.next_observation_buffer = np.zeros(
                observation_buffer_shape, dtype=self.observation_dtype
            )
        self.action_buffer = np.zeros(self.max_size, dtype=np.int8)
        self.reward_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.advantage_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.return_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.value_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.log_probability_buffer = np.zeros(self.max_size, dtype=np.float16)
        self.info_buffer = np.zeros(self.max_size, dtype=np.object_)
        self.pointer = 0
        self.trajectory_start_index = 0
        self.size = 0
    def finish_trajectory(self, last_value: float = 0):
        path_slice = slice(self.trajectory_start_index, self.pointer)
        rewards = np.append(self.reward_buffer[path_slice], last_value)
        values = np.append(self.value_buffer[path_slice], last_value)
        deltas = rewards[:-1] + self.gamma * values[1:] - values[:-1]
        self.advantage_buffer[path_slice] = discounted_cumulative_sums(
            deltas, self.gamma * self.gae_lambda
        )
        self.return_buffer[path_slice] = discounted_cumulative_sums(rewards, self.gamma)[
            :-1
        ]
        # print(discounted_cumulative_sums(deltas, self.gamma * self.gae_lambda))
        # print(discounted_cumulative_sums(deltas, self.gamma * self.gae_lambda)[:-1])
        # print(self.advantage_buffer)
        self.trajectory_start_index = self.pointer
class NStepReplayBuffer(BaseDQNReplayBuffer):
    def __init__(
        self,
        observation_dimensions: tuple,
        observation_dtype: np.dtype,
        max_size: int,
        batch_size: int = 32,
        n_step: int = 1,
        gamma: float = 0.99,
        compressed_observations: bool = False,
        num_players: int = 1,
    ):
        self.n_step = n_step
        self.gamma = gamma
        self.num_players = num_players
        super().__init__(
            observation_dimensions=observation_dimensions,
            observation_dtype=observation_dtype,
            max_size=max_size,
            batch_size=batch_size,
            compressed_observations=compressed_observations,
        )
    def store(
        self,
        observation,
        info: dict,
        action,
        reward: float,
        next_observation,
        next_info: dict,
        done: bool,
        id=None,
        player: int = 0,
    ):
        """Store a (s_t, a, r, s_t+1) transtion to the replay buffer.
           Returns a valid generated n-step transition (s_t-n, a, r, s_t) with the
           inputted observation as the next_observation (s_t)
        Returns:
            (s_t-n, a, r, s_t): where r is the n-step return calculated with the replay buffer's gamma
        """
        transition = (
            observation,
            info,
            action,
            reward,
            next_observation,
            next_info,
            done,
        )
        # print("store t:", transition)
        self.n_step_buffers[player].append(transition)
        if len(self.n_step_buffers[player]) < self.n_step:
            return None
        # compute n-step return and store
        reward, next_observation, next_info, done = self._get_n_step_info(player)
        observation, info, action = self.n_step_buffers[player][0][:3]
        n_step_transition = (
            observation,
            info,
            action,
            reward,
            next_observation,
            next_info,
            done,
        )
        super().store(*n_step_transition, id=id)
        return n_step_transition
    def clear(self):
        super().clear()
        self.n_step_buffers = [deque(maxlen=self.n_step) for q in range(self.num_players)]
    def _get_n_step_info(self, player: int = 0):
        reward, next_observation, next_info, done = self.n_step_buffers[player][-1][-4:]
        for transition in reversed(list(self.n_step_buffers[player])[:-1]):
            r, n_o, n_i, d = transition[-4:]
            reward = r + self.gamma * reward * (1 - d)
            next_observation, next_info, done = (
                (n_o, n_i, d) if d else (next_observation, next_info, done)
            )
        return reward, next_observation, next_info, done
class PrioritizedNStepReplayBuffer(NStepReplayBuffer):
    def __init__(
        self,
        observation_dimensions,
        observation_dtype: np.dtype,
        max_size: int,
        batch_size: int = 32,
        max_priority: float = 1.0,
        alpha: float = 0.6,
        beta: float = 0.4,
        # epsilon=0.01,
        n_step: float = 1,
        gamma: float = 0.99,
        compressed_observations: bool = False,
        num_players: int = 1,
    ):
        assert alpha >= 0 and alpha <= 1
        assert beta >= 0 and beta <= 1
        assert n_step >= 1
        assert gamma > 0 and gamma <= 1
        self.initial_max_priority = max_priority
        super(PrioritizedNStepReplayBuffer, self).__init__(
            observation_dimensions,
            observation_dtype,
            max_size,
            batch_size,
            n_step=n_step,
            gamma=gamma,
            compressed_observations=compressed_observations,
            num_players=num_players,
        )
        self.alpha = alpha  # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
        self.beta = beta
        # self.epsilon = epsilon
    def store(
        self,
        observation,
        info: dict,
        action,
        reward: float,
        next_observation,
        next_info: dict,
        done: bool,
        id=None,
        priority: float = None,
        player: int = 0,
    ):
        transition = super().store(
            observation,
            info,
            action,
            reward,
            next_observation,
            next_info,
            done,
            id,
            player=player,
        )
        if priority is None:
            priority = self.max_priority**self.alpha
            self.max_priority = max(
                self.max_priority, priority
            )  # could remove and clip priorities in experience replay isntead
        if transition:
            self.sum_tree[self.tree_pointer] = priority**self.alpha
            self.min_tree[self.tree_pointer] = priority**self.alpha
            self.tree_pointer = (self.tree_pointer + 1) % self.max_size
        return transition
    def set_beta(self, beta: float):
        self.beta = beta
    def store_batch(self, batch):
        (
            observations,
            infos,
            actions,
            rewards,
            next_observations,
            next_infos,
            dones,
            ids,
            priorities,
        ) = batch
        for i in range(len(observations)):
            self.store(
                observations[i],
                infos[i],
                actions[i],
                rewards[i],
                next_observations[i],
                next_infos[i],
                dones[i],
                ids[i],
                priorities[i],
            )
    def sample(self, throw_exception=True) -> dict:
        if len(self) < self.batch_size:
            if throw_exception:
                raise "Only {} elements in buffer expected at least {}".format(
                    len(self), self.batch_size
                )
            else:
                return None
        if self.alpha != 0.0:
            indices = self._sample_proportional()
        else:
            indices = np.random.choice(self.size, size=self.batch_size, replace=False)
            # print(indices)
        weights = np.array([self._calculate_weight(i) for i in indices])
        n_step_samples = self.sample_from_indices(indices)
        # print(n_step_samples)
        n_step_samples.update(dict(weights=weights, indices=indices))
        # print(n_step_samples)
        return n_step_samples
    def clear(self):
        super().clear()
        self.max_priority = self.initial_max_priority  # (initial) priority
        self.tree_pointer = 0
        tree_capacity = 1
        while tree_capacity < self.max_size:
            tree_capacity *= 2
        self.sum_tree = SumSegmentTree(tree_capacity)
        self.min_tree = MinSegmentTree(tree_capacity)
    def update_priorities(self, indices: list[int], priorities: list[float], ids=None):
        # necessary for shared replay buffer
        if ids is not None:
            assert len(priorities) == len(ids) == len(indices)
            assert priorities.shape == ids.shape == indices.shape
            for index, id, priority in zip(indices, ids, priorities):
                assert priority > 0, "Negative priority: {} \n All priorities {}".format(
                    priority, priorities
                )
                assert 0 <= index < len(self)
                if self.id_buffer[index] != id:
                    continue
                self.sum_tree[index] = priority**self.alpha
                self.min_tree[index] = priority**self.alpha
                self.max_priority = max(self.max_priority, priority)
        else:
            assert len(indices) == len(priorities)
            for index, priority in zip(indices, priorities):
                assert priority > 0, "Negative priority: {}".format(priority)
                assert 0 <= index < len(self)
                self.sum_tree[index] = priority**self.alpha
                self.min_tree[index] = priority**self.alpha
                self.max_priority = max(
                    self.max_priority, priority
                )  # could remove and clip priorities in experience replay isntead
        return priorities**self.alpha
    def _sample_proportional(self):
        indices = []
        total_priority = self.sum_tree.sum(0, len(self) - 1)
        priority_segment = total_priority / self.batch_size
        for i in range(self.batch_size):
            a = priority_segment * i
            b = priority_segment * (i + 1)
            upperbound = np.random.uniform(a, b)
            index = self.sum_tree.retrieve(upperbound)
            indices.append(index)
        return indices
    def _calculate_weight(self, index: int):
        min_priority = self.min_tree.min() / self.sum_tree.sum()
        max_weight = (min_priority * len(self)) ** (-self.beta)
        priority_sample = self.sum_tree[index] / self.sum_tree.sum()
        weight = (priority_sample * len(self)) ** (-self.beta)
        weight = weight / max_weight
        return weight
class FastPrioritizedReplayBuffer(NStepReplayBuffer):
    def __init__(
        self,
        observation_dimensions,
        max_size: int,
        batch_size: int = 32,
        max_priority: float = 1.0,
        alpha: float = 0.6,
        beta: float = 0.4,
        # epsilon=0.01,
        n_step: int = 1,
        gamma: float = 0.99,
    ):
        assert alpha >= 0 and alpha <= 1
        assert beta >= 0 and beta <= 1
        assert n_step >= 1
        assert gamma > 0 and gamma <= 1
        super(FastPrioritizedReplayBuffer, self).__init__(
            observation_dimensions, max_size, batch_size, n_step=n_step, gamma=gamma
        )
        self.max_priority = max_priority  # (initial) priority
        self.min_priority = max_priority
        self.tree_pointer = 0
        self.alpha = alpha  # Hyperparameter that we use to make a tradeoff between taking only exp with high priority and sampling randomly
        self.beta = beta
        # self.epsilon = epsilon
        self.tree = FastSumTree(self.max_size)
    def store(
        self,
        observation,
        action,
        reward: float,
        next_observation,
        done: bool,
    ):
        transition = super().store(observation, action, reward, next_observation, done)
        # max_priority = np.max(self.tree.tree[-self.tree.capacity :])
        # if max_priority == 0:
        #     max_priority = self.max_priority
        if transition:
            self.tree.add(self.tree_pointer, self.max_priority)
            self.tree_pointer = (self.tree_pointer + 1) % self.max_size
        return transition
    def sample(self):
        assert len(self) >= self.batch_size
        priority_segment = self.tree.total_priority / self.batch_size
        indices, weights = np.empty((self.batch_size,), dtype=np.int32), np.empty(
            (self.batch_size, 1), dtype=np.float32
        )
        for i in range(self.batch_size):
            a, b = priority_segment * i, priority_segment * (i + 1)
            value = np.random.uniform(a, b)
            index, priority = self.tree.retrieve(value)
            sampling_probabilities = priority / self.tree.total_priority
            # weights[i, 0] = (self.batch_size * sampling_probabilities) ** -beta
            weights[i, 0] = (len(self) * sampling_probabilities) ** -self.beta
            indices[i] = index - self.tree.capacity + 1
            indices[i] = index - self.tree.capacity + 1
        # max_weight = max(weights)
        max_weight = (
            len(self) * self.min_priority / self.tree.total_priority
        ) ** -self.beta
        weights = weights / max_weight
        # print(weights)
        # print("Getting Indices from PrioritizedReplayBuffer Sum Tree Time ", time() - time1)
        # print("Retrieving Data from PrioritizedReplayBuffer Data Arrays")
        # time2 = 0
        # time2 = time()
        observations = self.observation_buffer[indices]
        next_observations = self.next_observation_buffer[indices]
        actions = self.action_buffer[indices]
        rewards = self.reward_buffer[indices]
        dones = self.done_buffer[indices]
        # weights = np.array([self._calculate_weight(i, beta) for i in indices])
        # print("Retrieving Data from PrioritizedReplayBuffer Data Arrays Time ", time() - time2)
        # print("Sampling from PrioritizedReplayBuffer Time ", time() - time1)
        return dict(
            observations=observations,
            next_observations=next_observations,
            actions=actions,
            rewards=rewards,
            dones=dones,
            weights=weights,
            indices=indices,
        )
    def update_priorities(self, indices: list[int], priorities: list[float]):
        assert len(indices) == len(priorities)
        # priorities += self.epsilon
        for index, priority in zip(indices, priorities):
            assert priority > 0, "Negative priority: {}".format(priority)
            # assert 0 <= index < len(self)
            # self.tree[index] = priority ** self.alpha
            self.max_priority = max(self.max_priority, priority**self.alpha)
            self.min_priority = min(self.min_priority, priority**self.alpha)
            # priority = np.clip(priority, self.epsilon, self.max_priority)
            self.tree.update(index + self.tree.capacity - 1, priority**self.alpha)
class RainbowAgent(BaseAgent):
    def __init__(
        self,
        env,
        config: RainbowConfig,
        name=f"rainbow_{current_timestamp():.1f}",
        device: torch.device = (
            torch.device("cuda")
            if torch.cuda.is_available()
            # MPS is sometimes useful for M2 instances, but only for large models/matrix multiplications otherwise CPU is faster
            else (
                torch.device("mps")
                if torch.backends.mps.is_available() and torch.backends.mps.is_built()
                else torch.device("cpu")
            )
        ),
        from_checkpoint=False,
    ):
        super(RainbowAgent, self).__init__(env, config, name, device=device)
        self.model = RainbowNetwork(
            config=config,
            output_size=self.num_actions,
            input_shape=(self.config.minibatch_size,) + self.observation_dimensions,
        )
        self.target_model = RainbowNetwork(
            config=config,
            output_size=self.num_actions,
            input_shape=(self.config.minibatch_size,) + self.observation_dimensions,
        )
        if not self.config.kernel_initializer == None:
            self.model.initialize(self.config.kernel_initializer)
        self.model.to(device)
        self.target_model.to(device)
        self.target_model.load_state_dict(self.model.state_dict())
        self.target_model.eval()
        if self.config.optimizer == Adam:
            self.optimizer: torch.optim.Optimizer = self.config.optimizer(
                params=self.model.parameters(),
                lr=self.config.learning_rate,
                eps=self.config.adam_epsilon,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == SGD:
            print("Warning: SGD does not use adam_epsilon param")
            self.optimizer: torch.optim.Optimizer = self.config.optimizer(
                params=self.model.parameters(),
                lr=self.config.learning_rate,
                momentum=self.config.momentum,
                weight_decay=self.config.weight_decay,
            )
        self.replay_buffer = PrioritizedNStepReplayBuffer(
            observation_dimensions=self.observation_dimensions,
            observation_dtype=self.env.observation_space.dtype,
            max_size=self.config.replay_buffer_size,
            batch_size=self.config.minibatch_size,
            max_priority=1.0,
            alpha=self.config.per_alpha,
            beta=self.config.per_beta,
            # epsilon=config["per_epsilon"],
            n_step=self.config.n_step,
            gamma=self.config.discount_factor,
            compressed_observations=(
                self.env.lz4_compress if hasattr(self.env, "lz4_compress") else False
            ),
            num_players=self.config.game.num_players,
        )
        # could use a MuZero min-max config and just constantly update the suport size (would this break the model?)
        # self.v_min = self.config.v_min
        # self.v_max = self.config.v_max
        self.support = torch.linspace(
            self.config.v_min,
            self.config.v_max,
            self.config.atom_size,
            device=device,
        ).to(device)
        """row vector Tensor(atom_size)
        """
        self.eg_epsilon = self.config.eg_epsilon
        self.stats = {
            "score": [],
            "loss": [],
            "test_score": [],
        }
        self.targets = {
            "score": self.env.spec.reward_threshold,
            "test_score": self.env.spec.reward_threshold,
        }
    def checkpoint_model_weights(self, checkpoint):
        checkpoint = super().checkpoint_model_weights(checkpoint)
        checkpoint["target_model"] = self.target_model.state_dict()
    def load_model_weights(self, checkpoint):
        self.model.load_state_dict(checkpoint["model"])
        self.target_model.load_state_dict(checkpoint["target_model"])
        self.target_model.eval()
    def predict(self, states, *args, **kwargs) -> torch.Tensor:
        # could change type later
        state_input = self.preprocess(states)
        q_distribution: torch.Tensor = self.model(state_input)
        return q_distribution
    def predict_target(self, states) -> torch.Tensor:
        # could change type later
        state_input = self.preprocess(states)
        q_distribution: torch.Tensor = self.target_model(state_input)
        return q_distribution
    def select_actions(self, distribution, info: dict = None, mask_actions: bool = True):
        assert info is not None if mask_actions else True, "Need info to mask actions"
        # print(info)
        if self.config.atom_size > 1:
            q_values = distribution * self.support
            q_values = q_values.sum(2, keepdim=False)
        else:
            q_values = distribution
        if mask_actions:
            legal_moves = get_legal_moves(info)
            q_values = action_mask(
                q_values, legal_moves, mask_value=-float("inf"), device=self.device
            )
        # print("Q Values", q_values)
        # q_values with argmax ties
        # selected_actions = torch.stack(
        #     [
        #         torch.tensor(np.random.choice(np.where(x.cpu() == x.cpu().max())[0]))
        #         for x in q_values
        #     ]
        # )
        # print(selected_actions)
        selected_actions = q_values.argmax(1, keepdim=False)
        return selected_actions
    def learn(self) -> np.ndarray:
        losses = np.zeros(self.config.training_iterations)
        for i in range(self.config.training_iterations):
            samples = self.replay_buffer.sample()
            loss = self.learn_from_sample(samples)
            losses[i] = loss
        return losses
    def learn_from_sample(self, samples: dict):
        observations, weights, actions = (
            samples["observations"],
            samples["weights"],
            torch.from_numpy(samples["actions"]).to(self.device).long(),
        )
        # print("actions", actions)
        # print("Observations", observations)
        # (B, outputs, atom_size) -[index action dimension by actions]> (B, atom_size)
        online_predictions = self.predict(observations)[
            range(self.config.minibatch_size), actions
        ]
        # for param in self.model.parameters():
        #     print(param)
        # print(self.predict(observations))
        # print(online_predictions)
        # (B, atom_size)
        if self.config.atom_size > 1:
            assert isinstance(self.config.loss_function, KLDivergenceLoss) or isinstance(
                self.config.loss_function, CategoricalCrossentropyLoss
            ), "Only KLDivergenceLoss and CategoricalCrossentropyLoss are supported for atom_size > 1, recieved {}".format(
                self.config.loss_function
            )
            target_predictions = self.compute_target_distributions(samples)
        else:
            # print("using default dqn loss")
            assert isinstance(self.config.loss_function, HuberLoss) or isinstance(
                self.config.loss_function, MSELoss
            ), "Only HuberLoss or MSELoss are supported for atom_size = 1, recieved {}".format(
                self.config.loss_function
            )
            next_observations, rewards, dones = (
                torch.from_numpy(samples["next_observations"]).to(self.device),
                torch.from_numpy(samples["rewards"]).to(self.device),
                torch.from_numpy(samples["dones"]).to(self.device),
            )
            next_infos = samples["next_infos"]
            target_predictions = self.predict_target(next_observations)  # next q values
            # print("Next q values", target_predictions)
            # print("Current q values", online_predictions)
            # print(self.predict(next_observations))
            next_actions = self.select_actions(
                self.predict(next_observations),  # current q values
                info=next_infos,
                mask_actions=self.config.game.has_legal_moves,
            )
            # print("Next actions", next_actions)
            target_predictions = target_predictions[
                range(self.config.minibatch_size), next_actions
            ]  # this might not work
            # print(target_predictions)
            target_predictions = (
                rewards + self.config.discount_factor * (~dones) * target_predictions
            )
            # print(target_predictions)
        # print("predicted", online_distributions)
        # print("target", target_distributions)
        weights_cuda = torch.from_numpy(weights).to(torch.float32).to(self.device)
        # (B)
        elementwise_loss = self.config.loss_function(
            online_predictions, target_predictions
        )
        # print("Loss", elementwise_loss.mean())
        assert torch.all(elementwise_loss) >= 0, "Elementwise Loss: {}".format(
            elementwise_loss
        )
        assert (
            elementwise_loss.shape == weights_cuda.shape
        ), "Loss Shape: {}, Weights Shape: {}".format(
            elementwise_loss.shape, weights_cuda.shape
        )
        loss = elementwise_loss * weights_cuda
        self.optimizer.zero_grad()
        loss.mean().backward()
        if self.config.clipnorm > 0:
            # print("clipnorm", self.config.clipnorm)
            clip_grad_norm_(self.model.parameters(), self.config.clipnorm)
        self.optimizer.step()
        self.update_replay_priorities(
            samples=samples,
            priorities=elementwise_loss.detach().to("cpu").numpy()
            + self.config.per_epsilon,
        )
        self.model.reset_noise()
        self.target_model.reset_noise()
        return loss.detach().to("cpu").mean().item()
    def update_replay_priorities(self, samples, priorities):
        self.replay_buffer.update_priorities(samples["indices"], priorities)
    def compute_target_distributions(self, samples):
        # print("computing target distributions")
        with torch.no_grad():
            discount_factor = self.config.discount_factor**self.config.n_step
            delta_z = (self.config.v_max - self.config.v_min) / (
                self.config.atom_size - 1
            )
            next_observations, rewards, dones = (
                samples["next_observations"],
                torch.from_numpy(samples["rewards"]).to(self.device).view(-1, 1),
                torch.from_numpy(samples["dones"]).to(self.device).view(-1, 1),
            )
            online_distributions = self.predict(next_observations)
            target_distributions = self.predict_target(next_observations)
            # print(samples["next_infos"])
            next_actions = self.select_actions(
                online_distributions,
                info=samples["next_infos"],
                mask_actions=self.config.game.has_legal_moves,
            )  # {} is the info but we are not doing action masking yet
            # (B, outputs, atom_size) -[index by [0..B-1, a_0..a_B-1]]> (B, atom_size)
            probabilities = target_distributions[
                range(self.config.minibatch_size), next_actions
            ]
            # print(probabilities)
            # (B, 1) + k(B, atom_size) * (B, atom_size) -> (B, atom_size)
            Tz = (rewards + discount_factor * (~dones) * self.support).clamp(
                self.config.v_min, self.config.v_max
            )
            # print("Tz", Tz)
            # all elementwise
            b: torch.Tensor = (Tz - self.config.v_min) / delta_z
            l, u = (
                torch.clamp(b.floor().long(), 0, self.config.atom_size - 1),
                torch.clamp(b.ceil().long(), 0, self.config.atom_size - 1),
            )
            # print("b", b)
            # print("l", l)
            # print("u", u)
            # Fix disappearing probability mass when l = b = u (b is int)
            l[(u > 0) * (l == u)] -= 1
            u[(l < (self.config.atom_size - 1)) * (l == u)] += 1
            # print("fixed l", l)
            # print("fixed u", u)
            # dones = dones.squeeze()
            # masked_probs = torch.ones_like(probabilities) / self.config.atom_size
            # masked_probs[~dones] = probabilities[~dones]
            m = torch.zeros_like(probabilities)
            m.scatter_add_(dim=1, index=l, src=probabilities * ((u.float()) - b))
            m.scatter_add_(dim=1, index=u, src=probabilities * ((b - l.float())))
            # print("old_m", (m * self.support).sum(-1))
            # projected_distribution = torch.zeros_like(probabilities)
            # projected_distribution.scatter_add_(
            #     dim=1, index=l, src=masked_probs * (u.float() - b)
            # )
            # projected_distribution.scatter_add_(
            #     dim=1, index=u, src=masked_probs * (b - l.float())
            # )
            # print("m", (projected_distribution * self.support).sum(-1))
            return m
    def fill_replay_buffer(self):
        print("replay buffer size:", self.replay_buffer.size)
        with torch.no_grad():
            state, info = self.env.reset()
            target_size = self.config.min_replay_buffer_size
            while self.replay_buffer.size < target_size:
                if (self.replay_buffer.size % (target_size // 100)) == 0:
                    print(
                        f"filling replay buffer: {self.replay_buffer.size} / ({target_size})"
                    )
                # dist = self.predict(state)
                # action = self.select_actions(dist).item()
                action = self.env.action_space.sample()
                next_state, reward, terminated, truncated, next_info = self.env.step(
                    action
                )
                done = terminated or truncated
                # print(state)
                self.replay_buffer.store(
                    state, info, action, reward, next_state, next_info, done
                )
                # print(self.replay_buffer.observation_buffer[0])
                state = next_state
                info = next_info
                if done:
                    state, info = self.env.reset()
                # gc.collect()
    def update_target_model(self):
        if self.config.soft_update:
            for wt, wp in zip(self.target_model.parameters(), self.model.parameters()):
                wt.copy_(self.config.ema_beta * wt + (1 - self.config.ema_beta) * wp)
        else:
            self.target_model.load_state_dict(self.model.state_dict())
    def update_eg_epsilon(self, training_step):
        if self.config.eg_epsilon_decay_type == "linear":
            # print("decaying eg epsilon linearly")
            self.eg_epsilon = update_linear_schedule(
                self.config.eg_epsilon_final,
                self.config.eg_epsilon_final_step,
                self.config.eg_epsilon,
                training_step,
            )
        elif self.config.eg_epsilon_decay_type == "inverse_sqrt":
            self.eg_epsilon = update_inverse_sqrt_schedule(
                self.config.eg_epsilon,
                training_step,
            )
        else:
            raise ValueError(
                "Invalid epsilon decay type: {}".format(self.config.eg_epsilon_decay_type)
            )
    def train(self):
        super().train()
        start_time = time() - self.training_time
        score = 0
        target_model_updated = (False, False)  # (score, loss)
        self.fill_replay_buffer()
        state, info = self.env.reset()
        while self.training_step < self.config.training_steps:
            if self.training_step % self.config.print_interval == 0:
                self.print_training_progress()
            with torch.no_grad():
                for _ in range(self.config.replay_interval):
                    values = self.predict(state)
                    # print(values)
                    action = epsilon_greedy_policy(
                        values,
                        info,
                        self.eg_epsilon,
                        wrapper=lambda values, info: self.select_actions(
                            values, info
                        ).item(),
                    )
                    # print("Action", action)
                    # print("Epislon Greedy Epsilon", self.eg_epsilon)
                    next_state, reward, terminated, truncated, next_info = self.env.step(
                        action
                    )
                    done = terminated or truncated
                    # print("State", state)
                    self.replay_buffer.store(
                        state, info, action, reward, next_state, next_info, done
                    )
                    state = next_state
                    info = next_info
                    score += reward
                    self.replay_buffer.set_beta(
                        update_per_beta(
                            self.replay_buffer.beta,
                            self.config.per_beta_final,
                            self.training_steps,
                            self.config.per_beta,
                        )
                    )
                    if done:
                        state, info = self.env.reset()
                        score_dict = {
                            "score": score,
                            "target_model_updated": target_model_updated[0],
                        }
                        self.stats["score"].append(score_dict)
                        target_model_updated = (False, target_model_updated[1])
                        score = 0
            self.update_eg_epsilon(self.training_step + 1)
            # print("replay buffer size", len(self.replay_buffer))
            for minibatch in range(self.config.num_minibatches):
                if len(self.replay_buffer) < self.config.min_replay_buffer_size:
                    break
                losses = self.learn()
                # print(losses)
                loss_mean = losses.mean()
                # could do things other than taking the mean here
                self.stats["loss"].append(
                    {"loss": loss_mean, "target_model_updated": target_model_updated[1]}
                )
                target_model_updated = (target_model_updated[0], False)
            if self.training_step % self.config.transfer_interval == 0:
                target_model_updated = (True, True)
                # stats["test_score"].append(
                #     {"target_model_weight_update": training_step}
                # )
                self.update_target_model()
            if self.training_step % self.checkpoint_interval == 0:
                # print(self.stats["score"])
                # print(len(self.replay_buffer))
                self.training_time = time() - start_time
                self.total_environment_steps = (
                    self.training_step * self.config.replay_interval
                )
                self.save_checkpoint()
            # gc.collect()
            self.training_step += 1
        self.training_time = time() - start_time
        self.total_environment_steps = self.training_step * self.config.replay_interval
        self.save_checkpoint()
        self.env.close()
from gymnasium.wrappers import AtariPreprocessing, FrameStack
import numpy as np
config_dict = {
    "conv_layers": [
        (32, 8, 4),
        (64, 4, 2),
        (64, 3, 1),
    ],
    "dense_layers_widths": [512],
    "value_hidden_layers_widths": [],  #
    "advatage_hidden_layers_widths": [],  #
    "adam_epsilon": 1.5e-4,
    "learning_rate": 0.00025 / 4,
    "training_steps": 50000000,  # Agent saw 200,000,000 frames
    "per_epsilon": 1e-6,  #
    "per_alpha": 0.5,
    "per_beta": 0.4,
    "minibatch_size": 32,
    "replay_buffer_size": 1000000,
    "min_replay_buffer_size": 80000,  # 80000
    "transfer_interval": 32000,
    "n_step": 3,
    "kernel_initializer": "orthogonal",  #
    "loss_function": KLDivergenceLoss(),
    "clipnorm": 0.0,  #
    "discount_factor": 0.99,
    "atom_size": 51,
    "replay_interval": 4,
}
game_config = AtariConfig()
config = RainbowConfig(config_dict, game_config)
class ClipReward(gym.RewardWrapper):
    def __init__(self, env, min_reward, max_reward):
        super().__init__(env)
        self.min_reward = min_reward
        self.max_reward = max_reward
        self.reward_range = (min_reward, max_reward)
    def reward(self, reward):
        return np.clip(reward, self.min_reward, self.max_reward)
env = gym.make(
    "MsPacmanNoFrameskip-v4", render_mode="rgb_array", max_episode_steps=108000
)
env = AtariPreprocessing(env, terminal_on_life_loss=True)
env = FrameStack(env, 4, lz4_compress=True)
agent = RainbowAgent(env, config, name="Rainbow_Atari_MsPacmanNoFrameskip-v4")
agent.checkpoint_interval = 1000
agent.train()