Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

How to add custom done flags to done_spec of auto-resetting environments? #2344

Unanswered
ErcBunny asked this question in Q&A
Discussion options

Thank you for viewing my question!

I am trying to write my own environment MyEnv by subclassing the EnvBase. My environment spawns an Isaac Gym environment which is auto-resetting, so I am trying the following code to wrap auto-reset transforms to my env.

my_env = AutoResetEnv(MyEnv(), AutoResetTransform())

And I am stuck with this error: "A 'done' key was a string but a tuple was expected." if I try to put other fields in the done_spec like the following.

binary_spec = BinaryDiscreteTensorSpec(
 n=1,
 shape=torch.Size([num_envs, 1]),
 device=self.device,
 dtype=torch.bool,
)
self.done_spec = CompositeSpec(
 {
 "truncated": binary_spec,
 "terminated": binary_spec,
 "done": binary_spec,
 "extra": binary_spec
 },
 device=self.device,
 shape=self.batch_size,
)

And the error is gone without "extra" in the spec because that is the case of _simple_done. But I would really like to have the extra field for my application. Any idea how to do it correctly? Thank you!

You must be logged in to vote

Replies: 2 comments 8 replies

Comment options

Why don't you add the "extra" key in the observation spec? The following code does that and using the AutoResetTransform seems to work fine. I hope it helps.

from typing import Optional
import torch
from tensordict.tensordict import TensorDict, TensorDictBase
from torchrl.data import (
 CompositeSpec,
 DiscreteTensorSpec,
 UnboundedContinuousTensorSpec,
)
from torchrl.data.utils import DEVICE_TYPING
from torchrl.envs import EnvBase
class ExampleEnv(EnvBase):
 def __init__(
 self,
 num_envs: int = 1,
 device: DEVICE_TYPING = None,
 ):
 super().__init__(
 device=device,
 batch_size=torch.Size([num_envs]),
 )
 self.shape_obs = (40, 40, 3)
 self.num_envs = num_envs
 # Define specs
 self._set_specs()
 def _reset(self, tensordict: TensorDictBase = None, **kwargs) -> TensorDictBase:
 # Write your logic here
 # Create reset tensordict
 reset_tensordict = TensorDict(
 {
 "pixels": torch.zeros(self.num_envs, *self.shape_obs),
 "done": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "truncated": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "terminated": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "extra": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 },
 device=self.device,
 batch_size=self.batch_size,
 )
 
 return reset_tensordict
 def _step(self, tensordict: TensorDictBase) -> TensorDictBase:
 # Get actions
 actions = tensordict.get("action")
 # Write your logic here
 # Create next_tensordict
 next_tensordict = TensorDict(
 {
 "pixels": torch.zeros(self.num_envs, *self.shape_obs),
 "done": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "truncated": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "terminated": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "extra": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.bool
 ),
 "reward": torch.zeros(
 self.num_envs, 1, device=self.device, dtype=torch.float32
 ),
 },
 device=self.device,
 batch_size=self.batch_size,
 )
 return next_tensordict
 def _set_seed(self, seed: Optional[int] = -1) -> None:
 torch.manual_seed(seed)
 def _set_specs(self) -> None:
 self.observation_spec = CompositeSpec(
 {
 "pixels": UnboundedContinuousTensorSpec(
 shape=torch.Size(self.shape_obs),
 dtype=torch.float32,
 device=self.device,
 ),
 "extra": DiscreteTensorSpec(
 n=2, dtype=torch.bool, device=self.device
 ).unsqueeze(-1),
 }
 ).expand(self.num_envs)
 self.action_spec = CompositeSpec(
 {
 "action": DiscreteTensorSpec(
 n=10, dtype=torch.uint8, device=self.device
 )
 }
 ).expand(self.num_envs)
 self.reward_spec = CompositeSpec(
 {
 "reward": UnboundedContinuousTensorSpec(
 shape=(1,),
 dtype=torch.float32,
 device=self.device,
 )
 }
 ).expand(self.num_envs)
 self.done_spec = (
 CompositeSpec(
 {
 "done": DiscreteTensorSpec(
 n=2, dtype=torch.bool, device=self.device
 ),
 "truncated": DiscreteTensorSpec(
 n=2, dtype=torch.bool, device=self.device
 ),
 "terminated": DiscreteTensorSpec(
 n=2, dtype=torch.bool, device=self.device
 ),
 }
 )
 .expand(self.num_envs)
 .unsqueeze(-1)
 )
 def __repr__(self) -> str:
 return f"{self.__class__.__name__}"
 
if __name__ == "__main__":
 from torchrl.envs.utils import check_env_specs
 from torchrl.envs.transforms import AutoResetEnv, TransformedEnv, AutoResetTransform
 env = ExampleEnv(num_envs=1)
 check_env_specs(env)
 rollout = env.rollout(max_steps=256)
 print(rollout)
 env = AutoResetEnv(env)
 env = env.append_transform(AutoResetTransform())
 rollout = env.rollout(max_steps=256)
 print(rollout)
You must be logged in to vote
7 replies
Comment options

I just find it cleaner to define the real shapes first and simply call expand at the end for all the specs, but should be the same.

Comment options

I also noticed that your BinaryDiscreteTensorSpec has n=1. That defines the variable as always False.

Comment options

I also noticed that your BinaryDiscreteTensorSpec has n=1. That defines the variable as always False.

Thank you for pointing it out. But I am a bit confused by the documentation here that says:

n (int) – length of the binary vector.

which is a bit different from what is said here:

n (int) – number of possible outcomes.

So I am wondering how the documentation translates to "That defines the variable as always False."?

Comment options

I agree it is confusing. An issue should be open to clarify that. Probably also one for not being able to add "extra" key to the done_spec

Comment options

if you create

spec = DiscreteTensorSpec(n=1, dtype=torch.bool)

and repeatedly call

spec.rand()

you will see it always samples False. With n=2 samples both True and False, and with n>2 it raises a RuntimeError

Comment options

Since my extra keys only provide extra information on top of conventional "done", "terminated", and "truncated", I think a workaround is to create a subclass deriving from AutoResetTransform that enforces using the simple_done logic to correct the auto reset values.

@albertbou92 Do you think it is okay to do so without introducing other problems? Thank you.

You must be logged in to vote
1 reply
Comment options

I wanted to look into it today, but did not have much time. I think your workaround should work.

However, to allow additional done-like keys, I guess the solution should be to allow for a broader definition of _simple_done in https://github.com/pytorch/rl/blob/main/torchrl/envs/common.py#L2796 (allowing extra flags there). @vmoens would that break anything?

A False _simple_done seems to be for multi-agent, so it is important that the extra keys do not conflict with these.

maybe

_simple_done = {'terminated', 'done', 'truncated'}.issubset(key_set) or {'terminated', 'done'}.issubset(key_set)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet

AltStyle によって変換されたページ (->オリジナル) /