Using trained models

After the user has trained the model they can use it without the need of an algorithm YAML parameter file and without the processes related to training like the saving of the replay buffer of the calculation of the loss function.

To use a trained model the user just only needs to import the type of the model algorithm from the FRobs_RL library and use the load_trained function. After the trained model is loaded the user can use the predict function to obtain the action based on the observation.

An example where a TD3 trained model is loaded and used in two episodes is shown below.

from kobuki_maze_rl.task_env import kobuki_maze
from frobs_rl.common import ros_gazebo, ros_node
import gym
import rospy

# Import TD3 algorithm
from frobs_rl.models.td3 import TD3

if __name__ == '__main__':
    # Kill all processes related to previous runs
    ros_node.ros_kill_all_processes()

    # Launch Gazebo
    ros_gazebo.launch_Gazebo(paused=True, gui=False)

    # Start node
    rospy.logwarn("Start")
    rospy.init_node('kobuki_maze_train')

    # Launch the task environment
    env = gym.make('KobukiMazeEnv-v0')

     #--- Normalize action space
    env = NormalizeActionWrapper(env)

    #--- Normalize observation space
    env = NormalizeObservWrapper(env)

    #--- Set max steps
    env = TimeLimitWrapper(env, max_steps=15000)
    env.reset()

     #--- Set the save and log path
    rospack = rospkg.RosPack()
    pkg_path = rospack.get_path("kobuki_maze_rl")
    save_path = pkg_path + "/models/dynamic/td3/"

    #-- TD3 trained
    model = TD3.load_trained(save_path + "trained_model")


    obs = env.reset()
    episodes = 2
    epi_count = 0
    while epi_count < episodes:
        action, _states = model.predict(obs, deterministic=True)
        obs, _, dones, info = env.step(action)
        if dones:
            epi_count += 1
            rospy.logwarn("Episode: " + str(epi_count))
            obs = env.reset()

    env.close()
    sys.exit()
basic_model.BasicModel.predict(self, observation, state=None, mask=None, deterministic=False)

Get the current action based on the observation, state or mask

Parameters
  • observation (ndarray) – The enviroment observation

  • state (ndarray) – The previous states of the enviroment, used in recurrent policies.

  • mask (ndarray) – The mask of the last states, used in recurrent policies.

  • deterministic (bool) – Whether or not to return deterministic actions.

Returns

The action to be taken and the next state(for recurrent policies)

Return type

ndarray, ndarray