Soft Actor Critic (SAC)
API for the SAC algorithm implemented in SAC stable_baselines3.
Original paper: https://arxiv.org/abs/1801.01290 OpenAI blog post: https://spinningup.openai.com/en/latest/algorithms/sac.html
YAML parameters template
model_params:
# Training
training_steps: 5000 # The number of training steps to perform
# Save params
save_freq: 1000
save_prefix: sac_model
trained_model_name: trained_model
save_replay_buffer: False
# Load model params
load_model: False
model_name: sac_model_5000_steps
# Logging parameters
log_folder: SAC_1
log_interval: 4 # The number of episodes between logs
reset_num_timesteps: False # If true, will reset the number of timesteps to 0 every training
# Use custom policy - Only MlpPolicy is supported (Only used when new model is created)
use_custom_policy: False
policy_params:
net_arch: [400, 300] # List of hidden layer sizes
activation_fn: relu # relu, tanh, elu or selu
features_extractor_class: FlattenExtractor # FlattenExtractor, BaseFeaturesExtractor or CombinedExtractor
optimizer_class: Adam # Adam, Adadelta, Adagrad, RMSprop or SGD
# UseAction noise
use_action_noise: True # For now only Gaussian noise is supported
action_noise:
mean: 0.0
sigma: 0.01
# Use SDE
use_sde: True
sde_params:
sde_sample_freq: -1
use_sde_at_warmup: False
# SAC parameters
sac_params:
learning_rate: 0.0003
buffer_size: 1000000
learning_starts: 100
batch_size: 256
tau: 0.005
gamma: 0.99
gradient_steps: 1
ent_coef: auto
target_update_interval: 1
target_entropy: auto
train_freq:
freq: 20
unit: step # episode or step
Class docs
- class sac.SAC(env, save_model_path, log_path, load_trained=False, config_file_pkg='frobs_rl', config_filename='sac_config.yaml', ns='/')[source]
Soft Actor-Critic (SAC) algorithm.
Paper: https://arxiv.org/abs/1801.01290
- Parameters
env – The environment to be used.
save_model_path – The path to save the model.
log_path – The path to save the log.
load_trained – Whether to load a trained model or not.
config_file_pkg – The package where the config file is located. Default: frobs_rl.
config_filename – The name of the config file. Default: sac_config.yaml.
ns – The namespace of the ROS parameters. Default: “/”.
- check_env() bool
Use the stable-baselines check_env method to check the environment.
- Returns
True if the environment is correct, False otherwise.
- Return type
bool
- close_env() bool
Use the env close method to close the environment.
- Returns
True if the environment was closed, False otherwise.
- Return type
bool
- predict(observation, state=None, mask=None, deterministic=False)
Get the current action based on the observation, state or mask
- Parameters
observation (ndarray) – The enviroment observation
state (ndarray) – The previous states of the enviroment, used in recurrent policies.
mask (ndarray) – The mask of the last states, used in recurrent policies.
deterministic (bool) – Whether or not to return deterministic actions.
- Returns
The action to be taken and the next state(for recurrent policies)
- Return type
ndarray, ndarray
- save_model() bool
Function to save the model.
- Returns
True if the model was saved, False otherwise.
- Return type
bool
- save_replay_buffer() bool
Funtion to save the replay buffer, to be used the training must be finished or an error will be raised.
- Returns
True if the replay buffer was saved, False otherwise.
- Return type
bool
- set_model_logger() bool
Function to set the logger of the model.
- Returns
True if the logger was set, False otherwise.
- Return type
bool
- train() bool
Function to train the model the number of steps specified in the ROS parameter server. The function will automatically save the model after training.
- Returns
True if the training was successful, False otherwise.
- Return type
bool