Skip to content

Navigation Menu

Sign in
Appearance settings

Search code, repositories, users, issues, pull requests...

Provide feedback

We read every piece of feedback, and take your input very seriously.

Saved searches

Use saved searches to filter your results more quickly

Sign up
Appearance settings

Commit a83b3b8

Browse files
bmind7maryamziaa
andauthored
[Bugfix] Fix CUDA/CPU mismatch in threaded training (#6245)
* Ensure tensors use default device in torch policy and utils --------- Co-authored-by: maryam-zia <maryam.zia@unity3d.com>
1 parent a277771 commit a83b3b8

File tree

6 files changed

+34
-22
lines changed

6 files changed

+34
-22
lines changed

‎ml-agents/mlagents/trainers/optimizer/torch_optimizer.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Dict, Optional, Tuple, List
2-
from mlagents.torch_utils import torch
2+
from mlagents.torch_utils import torch, default_device
33
import numpy as np
44
from collections import defaultdict
55

@@ -162,7 +162,7 @@ def get_trajectory_value_estimates(
162162
memory = self.critic_memory_dict[agent_id]
163163
else:
164164
memory = (
165-
torch.zeros((1, 1, self.critic.memory_size))
165+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
166166
if self.policy.use_recurrent
167167
else None
168168
)

‎ml-agents/mlagents/trainers/poca/optimizer_torch.py‎

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -608,12 +608,12 @@ def get_trajectory_and_baseline_value_estimates(
608608
_init_baseline_mem = self.baseline_memory_dict[agent_id]
609609
else:
610610
_init_value_mem = (
611-
torch.zeros((1, 1, self.critic.memory_size))
611+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
612612
if self.policy.use_recurrent
613613
else None
614614
)
615615
_init_baseline_mem = (
616-
torch.zeros((1, 1, self.critic.memory_size))
616+
torch.zeros((1, 1, self.critic.memory_size), device=default_device())
617617
if self.policy.use_recurrent
618618
else None
619619
)

‎ml-agents/mlagents/trainers/policy/torch_policy.py‎

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ def export_memory_size(self) -> int:
6969
return self._export_m_size
7070

7171
def _extract_masks(self, decision_requests: DecisionSteps) -> np.ndarray:
72+
device = default_device()
7273
mask = None
7374
if self.behavior_spec.action_spec.discrete_size > 0:
7475
num_discrete_flat = np.sum(self.behavior_spec.action_spec.discrete_branches)
75-
mask = torch.ones([len(decision_requests), num_discrete_flat])
76+
mask = torch.ones(
77+
[len(decision_requests), num_discrete_flat], device=device
78+
)
7679
if decision_requests.action_mask is not None:
7780
mask = torch.as_tensor(
78-
1 - np.concatenate(decision_requests.action_mask, axis=1)
81+
1 - np.concatenate(decision_requests.action_mask, axis=1),
82+
device=device,
7983
)
8084
return mask
8185

@@ -91,11 +95,12 @@ def evaluate(
9195
"""
9296
obs = decision_requests.obs
9397
masks = self._extract_masks(decision_requests)
94-
tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs]
98+
device = default_device()
99+
tensor_obs = [torch.as_tensor(np_ob, device=device) for np_ob in obs]
95100

96-
memories = torch.as_tensor(self.retrieve_memories(global_agent_ids)).unsqueeze(
97-
0
98-
)
101+
memories = torch.as_tensor(
102+
self.retrieve_memories(global_agent_ids), device=device
103+
).unsqueeze(0)
99104
with torch.no_grad():
100105
action, run_out, memories = self.actor.get_action_and_stats(
101106
tensor_obs, masks=masks, memories=memories

‎ml-agents/mlagents/trainers/torch_entities/components/reward_providers/gail_reward_provider.py‎

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def compute_estimate(
143143
if self._settings.use_actions:
144144
actions = self.get_action_input(mini_batch)
145145
dones = torch.as_tensor(
146-
mini_batch[BufferKey.DONE], dtype=torch.float
146+
mini_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
147147
).unsqueeze(1)
148148
action_inputs = torch.cat([actions, dones], dim=1)
149149
hidden, _ = self.encoder(inputs, action_inputs)
@@ -162,7 +162,7 @@ def compute_loss(
162162
"""
163163
Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator.
164164
"""
165-
total_loss = torch.zeros(1)
165+
total_loss = torch.zeros(1, device=default_device())
166166
stats_dict: Dict[str, np.ndarray] = {}
167167
policy_estimate, policy_mu = self.compute_estimate(
168168
policy_batch, use_vail_noise=True
@@ -219,21 +219,23 @@ def compute_gradient_magnitude(
219219
expert_inputs = self.get_state_inputs(expert_batch)
220220
interp_inputs = []
221221
for policy_input, expert_input in zip(policy_inputs, expert_inputs):
222-
obs_epsilon = torch.rand(policy_input.shape)
222+
obs_epsilon = torch.rand(policy_input.shape, device=policy_input.device)
223223
interp_input = obs_epsilon * policy_input + (1 - obs_epsilon) * expert_input
224224
interp_input.requires_grad = True # For gradient calculation
225225
interp_inputs.append(interp_input)
226226
if self._settings.use_actions:
227227
policy_action = self.get_action_input(policy_batch)
228228
expert_action = self.get_action_input(expert_batch)
229-
action_epsilon = torch.rand(policy_action.shape)
229+
action_epsilon = torch.rand(
230+
policy_action.shape, device=policy_action.device
231+
)
230232
policy_dones = torch.as_tensor(
231-
policy_batch[BufferKey.DONE], dtype=torch.float
233+
policy_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
232234
).unsqueeze(1)
233235
expert_dones = torch.as_tensor(
234-
expert_batch[BufferKey.DONE], dtype=torch.float
236+
expert_batch[BufferKey.DONE], dtype=torch.float, device=default_device()
235237
).unsqueeze(1)
236-
dones_epsilon = torch.rand(policy_dones.shape)
238+
dones_epsilon = torch.rand(policy_dones.shape, device=policy_dones.device)
237239
action_inputs = torch.cat(
238240
[
239241
action_epsilon * policy_action

‎ml-agents/mlagents/trainers/torch_entities/networks.py‎

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from typing import Callable, List, Dict, Tuple, Optional, Union, Any
22
import abc
33

4-
from mlagents.torch_utils import torch, nn
4+
from mlagents.torch_utils import torch, nn, default_device
55

66
from mlagents_envs.base_env import ActionSpec, ObservationSpec, ObservationType
77
from mlagents.trainers.torch_entities.action_model import ActionModel
@@ -87,7 +87,9 @@ def update_normalization(self, buffer: AgentBuffer) -> None:
8787
obs = ObsUtil.from_buffer(buffer, len(self.processors))
8888
for vec_input, enc in zip(obs, self.processors):
8989
if isinstance(enc, VectorInput):
90-
enc.update_normalization(torch.as_tensor(vec_input.to_ndarray()))
90+
enc.update_normalization(
91+
torch.as_tensor(vec_input.to_ndarray(), device=default_device())
92+
)
9193

9294
def copy_normalization(self, other_encoder: "ObservationEncoder") -> None:
9395
if self.normalize:

‎ml-agents/mlagents/trainers/torch_entities/utils.py‎

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import List, Optional, Tuple, Dict
2-
from mlagents.torch_utils import torch, nn
2+
from mlagents.torch_utils import torch, nn, default_device
33
from mlagents.trainers.torch_entities.layers import LinearEncoder, Initialization
44
import numpy as np
55

@@ -233,7 +233,8 @@ def list_to_tensor(
233233
Converts a list of numpy arrays into a tensor. MUCH faster than
234234
calling as_tensor on the list directly.
235235
"""
236-
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype)
236+
device = default_device()
237+
return torch.as_tensor(np.asanyarray(ndarray_list), dtype=dtype, device=device)
237238

238239
@staticmethod
239240
def list_to_tensor_list(
@@ -243,8 +244,10 @@ def list_to_tensor_list(
243244
Converts a list of numpy arrays into a list of tensors. MUCH faster than
244245
calling as_tensor on the list directly.
245246
"""
247+
device = default_device()
246248
return [
247-
torch.as_tensor(np.asanyarray(_arr), dtype=dtype) for _arr in ndarray_list
249+
torch.as_tensor(np.asanyarray(_arr), dtype=dtype, device=device)
250+
for _arr in ndarray_list
248251
]
249252

250253
@staticmethod

0 commit comments

Comments
(0)

AltStyle によって変換されたページ (->オリジナル) /