#!/bin/python3
import os
from datetime import datetime
from frobs_rl.models.utils import get_policy_kwargs, get_action_noise
# ROS packages required
import rospy
# SB3 Callbacks
from stable_baselines3.common.callbacks import CheckpointCallback
# Logger
from stable_baselines3.common.logger import configure
[docs]class BasicModel:
"""
Base class for all the algorithms supported by the frobs_rl library.
:param env: The environment to be used.
:param save_model_path: The path to save the model.
:param log_path: The path to save the log.
:param ns: The namespace of the parameters.
:param load_trained: Whether or not to load a trained model.
"""
def __init__(self, env, save_model_path, log_path, ns="/", load_trained=False) -> None:
"""
BasicModel constructor.
"""
self.env = env
self.ns = ns
self.save_model_path = save_model_path
self.log_path = log_path
self.save_trained_model_path = None
self.model = None
if load_trained is False:
#--- Policy kwargs
self.policy_kwargs = get_policy_kwargs(ns=ns)
#--- Noise kwargs
self.action_noise = get_action_noise(self.env.action_space.shape[-1], ns=ns)
#--- Callback
save_freq = rospy.get_param(ns + "/model_params/save_freq")
save_prefix = rospy.get_param(ns + "/model_params/save_prefix")
self.checkpoint_callback = CheckpointCallback( save_freq=save_freq, save_path=save_model_path,
name_prefix=save_prefix)
[docs] def train(self) -> bool:
"""
Function to train the model the number of steps specified in the ROS parameter server.
The function will automatically save the model after training.
:return: True if the training was successful, False otherwise.
:rtype: bool
"""
training_steps = rospy.get_param(self.ns + "/model_params/training_steps")
learn_log_int = rospy.get_param(self.ns + "/model_params/log_interval")
learn_reset_num_tm = rospy.get_param(self.ns + "/model_params/reset_num_timesteps")
if learn_reset_num_tm is False:
self.env = self.model.get_env()
self.env.reset()
self.model.learn(total_timesteps=int(training_steps), callback=self.checkpoint_callback,
log_interval=learn_log_int, reset_num_timesteps=learn_reset_num_tm)
self.save_model()
return True
[docs] def save_model(self) -> bool:
"""
Function to save the model.
:return: True if the model was saved, False otherwise.
:rtype: bool
"""
#--- Model name
trained_model_name = rospy.get_param(self.ns + "/model_params/trained_model_name")
# If file exists, name the new model with a suffix
self.save_trained_model_path = self.save_model_path + trained_model_name
if os.path.isfile(self.save_model_path + trained_model_name + ".zip"):
now = datetime.now()
dt_string = now.strftime("%d_%m_%Y_%H_%M_%S")
self.save_trained_model_path = self.save_trained_model_path +"_" + dt_string
rospy.logwarn("Trained model name already exists, saving as: " + trained_model_name + "_" + dt_string)
self.model.save(self.save_trained_model_path)
self.save_replay_buffer()
return True
[docs] def save_replay_buffer(self) -> bool:
"""
Funtion to save the replay buffer, to be used the training must be finished or an error will be raised.
:return: True if the replay buffer was saved, False otherwise.
:rtype: bool
"""
if self.save_trained_model_path is None:
raise ValueError("Model not trained yet, cannot save replay buffer")
if rospy.get_param(self.ns + "/model_params/save_replay_buffer"):
rospy.logwarn("Saving replay buffer")
self.model.save_replay_buffer(self.save_trained_model_path+'_replay_buffer')
[docs] def set_model_logger(self) -> bool:
"""
Function to set the logger of the model.
:return: True if the logger was set, False otherwise.
:rtype: bool
"""
log_folder = rospy.get_param(self.ns + "/model_params/log_folder")
log_path = self.log_path + log_folder
assert not os.path.exists(log_path), "Log folder already exists, to log into that folder first delete it."
new_logger = configure(log_path+'/', ["stdout", "csv", "tensorboard"])
self.model.set_logger(new_logger)
return True
[docs] def close_env(self) -> bool:
"""
Use the env close method to close the environment.
:return: True if the environment was closed, False otherwise.
:rtype: bool
"""
self.env.close()
return True
[docs] def check_env(self) -> bool:
"""
Use the stable-baselines check_env method to check the environment.
:return: True if the environment is correct, False otherwise.
:rtype: bool
"""
self.env.check_env()
return True
[docs] def predict(self, observation, state=None, mask=None, deterministic=False):
"""
Get the current action based on the observation, state or mask
:param observation: The enviroment observation
:type observation: ndarray
:param state: The previous states of the enviroment, used in recurrent policies.
:type state: ndarray
:param mask: The mask of the last states, used in recurrent policies.
:type mask: ndarray
:param deterministic: Whether or not to return deterministic actions.
:type deterministic: bool
:return: The action to be taken and the next state(for recurrent policies)
:rtype: ndarray, ndarray
"""
return self.model.predict(observation, state=state, mask=mask, deterministic=deterministic)