Deep Q Network (DQN)

API for the DQN algorithm implemented in DQN stable_baselines3.

Original paper: https://arxiv.org/abs/1312.5602 Nature paper: https://www.nature.com/articles/nature14236

YAML parameters template

model_params:

  # Training 
  training_steps: 5000      # The number of training steps to perform

  # Save params
  save_freq: 1000
  save_prefix: dqn_model
  trained_model_name: trained_model
  save_replay_buffer: False

  # Load model params
  load_model: False
  model_name: dqn_model_5000_steps

  # Logging parameters
  log_folder: DQN_1
  log_interval: 4 # The number of episodes between logs
  reset_num_timesteps: False # If true, will reset the number of timesteps to 0 every training 

  # Use custom policy - Only MlpPolicy is supported (Only used when new model is created)
  use_custom_policy: False
  policy_params:
    net_arch: [400, 300] # List of hidden layer sizes
    activation_fn: relu  # relu, tanh, elu or selu
    features_extractor_class: FlattenExtractor # FlattenExtractor, BaseFeaturesExtractor or CombinedExtractor
    optimizer_class: Adam # Adam, Adadelta, Adagrad, RMSprop or SGD

  # DQN parameters
  dqn_params:
    learning_rate: 0.0001
    buffer_size: 1000000
    learning_starts: 50000 
    batch_size: 32
    tau: 1.0
    gamma: 0.99
    gradient_steps: 1
    target_update_interval: 10000 
    exploration_fraction: 0.1
    exploration_initial_eps: 1.0
    exploration_final_eps: 0.05 
    max_grad_norm: 10
    train_freq:
      freq: 20
      unit: step  # episode or step

Class docs

class dqn.DQN(env, save_model_path, log_path, load_trained=False, config_file_pkg='frobs_rl', config_filename='dqn_config.yaml', ns='/')[source]

Deep Q Network (DQN) algorithm.

Paper: https://arxiv.org/abs/1312.5602

Parameters
  • env – The environment to be used.

  • save_model_path – The path to save the model.

  • log_path – The path to save the log.

  • load_trained – Whether to load a trained model.

  • config_file_pkg – The package where the config file is located. Default: frobs_rl.

  • config_filename – The name of the config file. Default: dqn_config.yaml.

  • ns – The namespace of the ROS parameters. Default: “/”.

check_env() bool

Use the stable-baselines check_env method to check the environment.

Returns

True if the environment is correct, False otherwise.

Return type

bool

close_env() bool

Use the env close method to close the environment.

Returns

True if the environment was closed, False otherwise.

Return type

bool

predict(observation, state=None, mask=None, deterministic=False)

Get the current action based on the observation, state or mask

Parameters
  • observation (ndarray) – The enviroment observation

  • state (ndarray) – The previous states of the enviroment, used in recurrent policies.

  • mask (ndarray) – The mask of the last states, used in recurrent policies.

  • deterministic (bool) – Whether or not to return deterministic actions.

Returns

The action to be taken and the next state(for recurrent policies)

Return type

ndarray, ndarray

save_model() bool

Function to save the model.

Returns

True if the model was saved, False otherwise.

Return type

bool

save_replay_buffer() bool

Funtion to save the replay buffer, to be used the training must be finished or an error will be raised.

Returns

True if the replay buffer was saved, False otherwise.

Return type

bool

set_model_logger() bool

Function to set the logger of the model.

Returns

True if the logger was set, False otherwise.

Return type

bool

train() bool

Function to train the model the number of steps specified in the ROS parameter server. The function will automatically save the model after training.

Returns

True if the training was successful, False otherwise.

Return type

bool

load_trained(env=None)[source]

Load a trained model. Use only with predict function, as the logs will not be saved.

Parameters
  • model_path (str) – The path to the trained model.

  • env (gym.Env) – The environment to be used.

Returns

The trained model.

Return type

frobs_rl.DQN