Source code for ddql_optimal_execution.trainer._trainer

import random
import warnings
from typing import Optional

from tqdm import tqdm

from ddql_optimal_execution.agent import DDQL
from ddql_optimal_execution.environnement import MarketEnvironnement
from ddql_optimal_execution.experience_replay import ExperienceReplay
from ddql_optimal_execution.experience_replay._warnings import (
    ExperienceReplayEmptyWarning,
)

from ._warnings import MaxStepsTooLowWarning


[docs]class Trainer: """This class is used to train a DDQL agent in a given environment. Attributes ---------- agent : DDQL The agent attribute is an instance of the DDQL class, which is a reinforcement learning algorithm used for decision making in an environment. env : MarketEnvironnement The `env` attribute is an instance of the `MarketEnvironnement` class, which represents the environment in which the agent will interact and learn. It provides the agent with information about the current state of the market and allows it to take actions based on that information. exp_replay : ExperienceReplay The `exp_replay` attribute is an instance of the `ExperienceReplay` class, which is a memory buffer that stores and retrieves experiences for reinforcement learning agents. Methods ------- pretrain(max_steps: int = 1000, batch_size: int = 32) This function pretrains a DDQL agent by running random episodes, taking limit actions (sell all at the beginning or the end) and storing the experiences in an experience replay buffer. train(max_steps: int = 1000, batch_size: int = 32) This function trains a DDQL agent by running episodes, taking actions based on the current state of the environment, and storing the experiences in an experience replay buffer. """
[docs] def __init__(self, agent: DDQL, env: MarketEnvironnement, **kwargs): """This function initializes an object with an agent, environment, and experience replay capacity. Parameters ---------- agent : DDQL The agent parameter is an instance of the DDQL class, which is a reinforcement learning algorithm used for decision making in an environment. env : MarketEnvironnement The `env` parameter is an instance of the `MarketEnvironnement` class, which represents the environment in which the agent will interact and learn. It provides the agent with information about the current state of the market and allows it to take actions based on that information. """ self.agent = agent self.env = env self.exp_replay = ExperienceReplay(capacity=kwargs.get("capacity", 10000)) self.verbose = kwargs.get("verbose", True)
[docs] def fill_exp_replay(self, max_steps: int = 1000, verbose: bool = True): """This function fills an experience replay buffer with experiences from random episodes. Parameters ---------- max_steps : int The `max_steps` parameter is the maximum number of steps that the function will take before stopping. It is used to prevent the function from running indefinitely if the experience replay buffer is full. """ if max_steps < self.exp_replay.capacity: max_steps = self.exp_replay.capacity warnings.warn(MaxStepsTooLowWarning(max_steps)) p_bar = ( tqdm( total=min(max_steps, self.exp_replay.capacity), desc="Filling experience replay buffer", ) if verbose else None ) n_steps = 0 while (not self.exp_replay.is_full) and n_steps < max_steps: self.__random_border_actions(p_bar=p_bar) n_steps += 1
def __random_border_actions(self, p_bar: Optional[tqdm] = None): """This function runs a random episode, taking limit actions (sell all at the beginning or the end) and storing the experiences in an experience replay buffer.""" random_episode = random.randint(0, len(self.env.historical_data_series) - 1) sell_beginning = random.randint(0, 1) self.env.swap_episode(random_episode) distance2horizon = self.env.horizon while not self.env.done: current_state = self.env.state.copy() action = ( self.env.initial_inventory if ( (distance2horizon == 1 and not sell_beginning) or (distance2horizon == self.env.horizon and sell_beginning) ) else 0 ) reward = self.env.step(action) distance2horizon = self.env.horizon - self.env.state["period"] self.exp_replay.push( current_state, action, reward, self.env.state.copy(), distance2horizon, ) if self.verbose and p_bar is not None: p_bar.update(1)
[docs] def pretrain(self, max_steps: int = 1000, batch_size: int = 32): """This function pretrains a DDQL agent by running random episodes, taking limit actions (sell all at the beginning or the end) and storing the experiences in an experience replay buffer. Parameters ---------- max_steps : int, optional The maximum number of steps to pretrain the agent for. batch_size : int, optional The number of experiences to sample from the experience replay buffer at each training step. """ if not isinstance(self.agent, DDQL): raise TypeError("The agent must be an instance of the DDQL class.") p_bar = ( tqdm(total=max_steps, desc="Pretraining agent") if self.verbose else None ) for _ in range(max_steps): self.__random_border_actions() self.agent.learn(self.exp_replay.get_sample(batch_size)) if self.verbose and p_bar is not None: p_bar.update(1)
[docs] def train(self, max_steps: int = 1000, batch_size: int = 32): """This function trains an agent using the DDQL algorithm and an experience replay buffer. Parameters ---------- max_steps : int, optional max_steps is an optional integer parameter that specifies the maximum number of steps to train the agent for. If the number of steps taken during training exceeds this value, the training process will stop. batch_size : int, optional batch_size is an optional integer parameter that specifies the number of experiences to sample from the experience replay buffer at each training step. """ if not isinstance(self.agent, DDQL): raise TypeError("The agent must be an instance of the DDQL class.") if self.exp_replay.is_empty: warnings.warn(ExperienceReplayEmptyWarning()) self.fill_exp_replay(max_steps=max_steps) if self.verbose: p_bar = tqdm( total=max_steps, desc="Training agent", ) n_steps = 0 while n_steps < max_steps: episode = random.randint(0, len(self.env.historical_data_series) - 1) self.env.swap_episode(episode) while not self.env.done: current_state = self.env.state.copy() action = self.agent(current_state) reward = self.env.step(action) distance2horizon = self.env.horizon - self.env.state["period"] self.exp_replay.push( current_state, action, reward, self.env.state.copy(), distance2horizon, ) n_steps += 1 if self.verbose: p_bar.update(1) self.agent.learn(self.exp_replay.get_sample(batch_size))
[docs] def test(self, max_steps: int = 1000): ...