Source code for ddql_optimal_execution.state._state
import torch
[docs]class State(dict):
"""The class `State` is a subclass of the built-in `dict` class.
It is used to store the state of the environment.
"""
[docs] def update_state(self, **kwargs):
"""This function updates the state of an object with the key-value pairs passed as keyword arguments."""
for k, v in kwargs.items():
if k not in self.keys():
raise KeyError(f"Key {k} not initialized state")
self[k] = v
@property
def astensor(self):
"""This function converts a dictionary of values into a PyTorch tensor.
Returns
-------
The function `astensor` is returning a PyTorch tensor that is created from the values of the
dictionary object that the function is called on. The values are first converted to a list using the
`values()` method, and then the list is converted to a PyTorch tensor using the `torch.Tensor()`
function. Finally, the tensor is cast to a float data type using the `.float
"""
return torch.Tensor(list(self.values())).float()
[docs] def copy(self) -> "State":
"""The function returns a new State object that is a copy of the current State object.
Returns
-------
The `copy` method is returning a new instance of the `State` class, which is a copy of the current
instance.
"""
return State(self)