Drivers

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Introduction

A common pattern in reinforcement learning is to execute a policy in an environment for a specified number of steps or episodes. This happens, for example, during data collection, evaluation and generating a video of the agent.

While this is relatively straightforward to write in python, it is much more complex to write and debug in TensorFlow because it involves tf.while loops, tf.cond and tf.control_dependencies. Therefore we abstract this notion of a run loop into a class called driver, and provide well tested implementations both in Python and TensorFlow.

Additionally, the data encountered by the driver at each step is saved in a named tuple called Trajectory and broadcast to a set of observers such as replay buffers and metrics. This data includes the observation from the environment, the action recommended by the policy, the reward obtained, the type of the current and the next step, etc.

Setup

If you haven't installed tf-agents or gym yet, run:

pipinstalltf-agents
pipinstalltf-keras
importos
# Keep using keras-2 (tf-keras) rather than keras-3 (keras).
os.environ['TF_USE_LEGACY_KERAS'] = '1'
from__future__import absolute_import
from__future__import division
from__future__import print_function
importtensorflowastf
fromtf_agents.environmentsimport suite_gym
fromtf_agents.environmentsimport tf_py_environment
fromtf_agents.policiesimport random_py_policy
fromtf_agents.policiesimport random_tf_policy
fromtf_agents.metricsimport py_metrics
fromtf_agents.metricsimport tf_metrics
fromtf_agents.driversimport py_driver
fromtf_agents.driversimport dynamic_episode_driver

Python Drivers

The PyDriver class takes a python environment, a python policy and a list of observers to update at each step. The main method is run(), which steps the environment using actions from the policy until at least one of the following termination criteria is met: The number of steps reaches max_steps or the number of episodes reaches max_episodes.

The implementation is roughly as follows:

classPyDriver(object):
 def__init__(self, env, policy, observers, max_steps=1, max_episodes=1):
 self._env = env
 self._policy = policy
 self._observers = observers or []
 self._max_steps = max_steps or np.inf
 self._max_episodes = max_episodes or np.inf
 defrun(self, time_step, policy_state=()):
 num_steps = 0
 num_episodes = 0
 while num_steps < self._max_steps and num_episodes < self._max_episodes:
 # Compute an action using the policy for the given time_step
 action_step = self._policy.action(time_step, policy_state)
 # Apply the action to the environment and get the next step
 next_time_step = self._env.step(action_step.action)
 # Package information into a trajectory
 traj = trajectory.Trajectory(
 time_step.step_type,
 time_step.observation,
 action_step.action,
 action_step.info,
 next_time_step.step_type,
 next_time_step.reward,
 next_time_step.discount)
 for observer in self._observers:
 observer(traj)
 # Update statistics to check termination
 num_episodes += np.sum(traj.is_last())
 num_steps += np.sum(~traj.is_boundary())
 time_step = next_time_step
 policy_state = action_step.state
 return time_step, policy_state

Now, let us run through the example of running a random policy on the CartPole environment, saving the results to a replay buffer and computing some metrics.

env = suite_gym.load('CartPole-v0')
policy = random_py_policy.RandomPyPolicy(time_step_spec=env.time_step_spec(), 
 action_spec=env.action_spec())
replay_buffer = []
metric = py_metrics.AverageReturnMetric()
observers = [replay_buffer.append, metric]
driver = py_driver.PyDriver(
 env, policy, observers, max_steps=20, max_episodes=1)
initial_time_step = env.reset()
final_time_step, _ = driver.run(initial_time_step)
print('Replay Buffer:')
for traj in replay_buffer:
 print(traj)
print('Average Return: ', metric.result())
Replay Buffer:
Trajectory(
{'step_type': array(0, dtype=int32),
 'observation': array([ 0.00374074, -0.02818722, -0.02798625, -0.0196638 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00317699, 0.16732468, -0.02837953, -0.3210437 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00652349, -0.02738187, -0.0348004 , -0.03744393], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00597585, -0.22198795, -0.03554928, 0.24405919], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([ 0.00153609, -0.41658458, -0.0306681 , 0.5253204 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.0067956 , -0.61126184, -0.02016169, 0.80818397], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.01902084, -0.8061018 , -0.00399801, 1.0944574 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.03514287, -0.6109274 , 0.01789114, 0.8005227 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.04736142, -0.8062901 , 0.03390159, 1.0987796 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.06348722, -0.61163044, 0.05587719, 0.816923 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.07571983, -0.41731614, 0.07221565, 0.54232585], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.08406615, -0.61337477, 0.08306216, 0.8568603 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.09633365, -0.8095243 , 0.10019937, 1.1744623 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.11252414, -0.6158369 , 0.12368862, 0.91479784], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.12484087, -0.8123951 , 0.14198457, 1.2436544 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.14108877, -0.61935145, 0.16685766, 0.9986062 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1534758 , -0.42680538, 0.18682979, 0.7626272 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(1, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(1., dtype=float32)})
Trajectory(
{'step_type': array(1, dtype=int32),
 'observation': array([-0.1620119 , -0.23468053, 0.20208232, 0.5340639 ], dtype=float32),
 'action': array(0),
 'policy_info': (),
 'next_step_type': array(2, dtype=int32),
 'reward': array(1., dtype=float32),
 'discount': array(0., dtype=float32)})
Trajectory(
{'step_type': array(2, dtype=int32),
 'observation': array([-0.16670552, -0.43198496, 0.21276361, 0.8830067 ], dtype=float32),
 'action': array(1),
 'policy_info': (),
 'next_step_type': array(0, dtype=int32),
 'reward': array(0., dtype=float32),
 'discount': array(1., dtype=float32)})
Average Return: 18.0

TensorFlow Drivers

We also have drivers in TensorFlow which are functionally similar to Python drivers, but use TF environments, TF policies, TF observers etc. We currently have 2 TensorFlow drivers: DynamicStepDriver, which terminates after a given number of (valid) environment steps and DynamicEpisodeDriver, which terminates after a given number of episodes. Let us look at an example of the DynamicEpisode in action.

env = suite_gym.load('CartPole-v0')
tf_env = tf_py_environment.TFPyEnvironment(env)
tf_policy = random_tf_policy.RandomTFPolicy(action_spec=tf_env.action_spec(),
 time_step_spec=tf_env.time_step_spec())
num_episodes = tf_metrics.NumberOfEpisodes()
env_steps = tf_metrics.EnvironmentSteps()
observers = [num_episodes, env_steps]
driver = dynamic_episode_driver.DynamicEpisodeDriver(
 tf_env, tf_policy, observers, num_episodes=2)
# Initial driver.run will reset the environment and initialize the policy.
final_time_step, policy_state = driver.run()
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.0367443 , 0.00652178, 0.04001181, -0.00376746]],
 dtype=float32)>})
Number of Steps: 34
Number of Episodes: 2
# Continue running from previous state
final_time_step, _ = driver.run(final_time_step, policy_state)
print('final_time_step', final_time_step)
print('Number of Steps: ', env_steps.result().numpy())
print('Number of Episodes: ', num_episodes.result().numpy())
final_time_step TimeStep(
{'step_type': <tf.Tensor: shape=(1,), dtype=int32, numpy=array([0], dtype=int32)>,
 'reward': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([0.], dtype=float32)>,
 'discount': <tf.Tensor: shape=(1,), dtype=float32, numpy=array([1.], dtype=float32)>,
 'observation': <tf.Tensor: shape=(1, 4), dtype=float32, numpy=
array([[-0.04702466, -0.04836502, 0.01751254, -0.00393545]],
 dtype=float32)>})
Number of Steps: 63
Number of Episodes: 4

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024年03月09日 UTC.