Proximal Policy Optimization (PPO)
API for the PPO algorithm implemented in PPO stable_baselines3.
Original paper: https://arxiv.org/abs/1707.06347 OpenAI blog post: https://blog.openai.com/openai-baselines-ppo/
YAML parameters template
model_params:
# Training
training_steps: 5000 # The number of training steps to perform
# Save params
save_freq: 1000
save_prefix: ppo_model
trained_model_name: trained_model
save_replay_buffer: False # PPO does not support saving replay buffer
# Load model params
load_model: False
model_name: ppo_model_5000_steps
# Logging parameters
log_folder: PPO_1
log_interval: 4 # The number of episodes between logs
reset_num_timesteps: False # If true, will reset the number of timesteps to 0 every training
# Use custom policy - Only MlpPolicy is supported (Only used when new model is created)
use_custom_policy: False
policy_params:
net_arch: [400, 300] # List of hidden layer sizes
activation_fn: relu # relu, tanh, elu or selu
features_extractor_class: FlattenExtractor # FlattenExtractor, BaseFeaturesExtractor or CombinedExtractor
optimizer_class: Adam # Adam, Adadelta, Adagrad, RMSprop or SGD
# Use SDE
use_sde: False
sde_params:
sde_sample_freq: -1
# PPO parameters
ppo_params:
learning_rate: 0.0003
n_steps: 100 # The number of steps to run for each environment per update (i.e. rollout buffer size is n_steps * n_envs where n_envs is number of environment copies running in parallel)
batch_size: 100 # Minibatch size
n_epochs: 5 # Number of epoch when optimizing the surrogate loss
gamma: 0.99
gae_lambda: 0.95
clip_range: 0.2
ent_coef: 0.0
vf_coef: 0.5
max_grad_norm: 0.5
Class docs
- class ppo.PPO(env, save_model_path, log_path, load_trained=False, config_file_pkg='frobs_rl', config_filename='ppo_config.yaml', ns='/')[source]
Proximal Policy Optimization (PPO) algorithm.
Paper: https://arxiv.org/abs/1707.06347
- Parameters
env – The environment to be used.
save_model_path – The path to save the model.
log_path – The path to save the log.
load_trained – Whether to load a trained model or not.
config_file_pkg – The package where the config file is located. Default: frobs_rl.
config_filename – The name of the config file. Default: ppo_config.yaml.
ns – The namespace of the ROS parameters. Default: “/”.
- check_env() bool
Use the stable-baselines check_env method to check the environment.
- Returns
True if the environment is correct, False otherwise.
- Return type
bool
- close_env() bool
Use the env close method to close the environment.
- Returns
True if the environment was closed, False otherwise.
- Return type
bool
- predict(observation, state=None, mask=None, deterministic=False)
Get the current action based on the observation, state or mask
- Parameters
observation (ndarray) – The enviroment observation
state (ndarray) – The previous states of the enviroment, used in recurrent policies.
mask (ndarray) – The mask of the last states, used in recurrent policies.
deterministic (bool) – Whether or not to return deterministic actions.
- Returns
The action to be taken and the next state(for recurrent policies)
- Return type
ndarray, ndarray
- save_model() bool
Function to save the model.
- Returns
True if the model was saved, False otherwise.
- Return type
bool
- save_replay_buffer() bool
Funtion to save the replay buffer, to be used the training must be finished or an error will be raised.
- Returns
True if the replay buffer was saved, False otherwise.
- Return type
bool
- set_model_logger() bool
Function to set the logger of the model.
- Returns
True if the logger was set, False otherwise.
- Return type
bool
- train() bool
Function to train the model the number of steps specified in the ROS parameter server. The function will automatically save the model after training.
- Returns
True if the training was successful, False otherwise.
- Return type
bool