-
Notifications
You must be signed in to change notification settings - Fork 2k
Use MPS device when available #951
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@qgallouedec could you test this PR (do make pytest
) on a MPS enabled machine? (best would be to test sb3 contrib too)
We should probably add a warning in the doc about the minimum pytorch version? (or in the code)
Not only the pytest failed, but it caused a Python Fatal Error:
(env) quentingallouedec@MacBook-Pro-de-Quentin stable-baselines3 % pytest tests/test_cnn.py
=========================================== test session starts ============================================
platform darwin -- Python 3.9.13, pytest-7.1.2, pluggy-1.0.0
rootdir: /Users/quentingallouedec/stable-baselines3, configfile: setup.cfg
plugins: xdist-2.5.0, forked-1.4.0, env-0.6.2, typeguard-2.13.3, cov-3.0.0
collected 14 items
tests/test_cnn.py Fatal Python error: Aborted
Current thread 0x0000000101308580 (most recent call first):
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 453 in _conv_forward
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 457 in forward
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139 in forward
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/torch_layers.py", line 93 in forward
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 129 in extract_features
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 588 in forward
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 167 in collect_rollouts
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248 in learn
File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 197 in learn
File "/Users/quentingallouedec/stable-baselines3/tests/test_cnn.py", line 33 in test_cnn
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 1761 in runtest
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 259 in <lambda>
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 338 in from_call
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 258 in call_runtest_hook
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 219 in call_and_report
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 130 in runtestprotocol
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 347 in pytest_runtestloop
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 322 in _main
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 268 in wrap_session
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 164 in main
File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 187 in console_main
File "/Users/quentingallouedec/stable-baselines3/env/bin/pytest", line 8 in <module>
zsh: abort pytest tests/test_cnn.py
Don't know what it is. I will investigate.
Well, I'm pretty sure the problem comes from the fact that the observation is transposed before passing into the CNN of the feature extractor, and this seems to cause some more bugs: pytorch/pytorch#81557
To reproduce:
from stable_baselines3 import A2C from stable_baselines3.common.envs import FakeImageEnv env = FakeImageEnv() model = A2C("CnnPolicy", env).learn(250)
It causes fatal error in this line:
without traceback, but with this error message:
Assertion failed: (mapIt != _jitValueTypes.end()), function getStaticType, file MPSRuntime_Project.h, line 435.
zsh: abort /Users/quentingallouedec/stable-baselines3/env/bin/python
But more generally, there are still some features missing, such as support for the multinomial distribution (pytorch/pytorch#80760) for SB3 to work fully on the mps device
So we still have to be a bit patient.
Thanks for testing =)
Pytorch 1.13 is out. MPS is still not fully supported and causes bugs in SB3.
To keep track of MPS op coverage, see pytorch/pytorch#77764
kulinseth
commented
Nov 17, 2022
@qgallouedec , can you please provide which Ops are missing ?
Also if there is any Functional issue , can you provide a repro case? We will take a look.
kulinseth
commented
Nov 17, 2022
Not only the pytest failed, but it caused a Python Fatal Error:
(env) quentingallouedec@MacBook-Pro-de-Quentin stable-baselines3 % pytest tests/test_cnn.py =========================================== test session starts ============================================ platform darwin -- Python 3.9.13, pytest-7.1.2, pluggy-1.0.0 rootdir: /Users/quentingallouedec/stable-baselines3, configfile: setup.cfg plugins: xdist-2.5.0, forked-1.4.0, env-0.6.2, typeguard-2.13.3, cov-3.0.0 collected 14 items tests/test_cnn.py Fatal Python error: Aborted Current thread 0x0000000101308580 (most recent call first): File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 453 in _conv_forward File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 457 in forward File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/container.py", line 139 in forward File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/torch_layers.py", line 93 in forward File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 129 in extract_features File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 588 in forward File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130 in _call_impl File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 167 in collect_rollouts File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248 in learn File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 197 in learn File "/Users/quentingallouedec/stable-baselines3/tests/test_cnn.py", line 33 in test_cnn File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 192 in pytest_pyfunc_call File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__ File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/python.py", line 1761 in runtest File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 166 in pytest_runtest_call File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__ File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 259 in <lambda> File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 338 in from_call File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 258 in call_runtest_hook File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 219 in call_and_report File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 130 in runtestprotocol File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/runner.py", line 111 in pytest_runtest_protocol File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__ File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 347 in pytest_runtestloop File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__ File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 322 in _main File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 268 in wrap_session File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/main.py", line 315 in pytest_cmdline_main File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_callers.py", line 39 in _multicall File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_manager.py", line 80 in _hookexec File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/pluggy/_hooks.py", line 265 in __call__ File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 164 in main File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/_pytest/config/__init__.py", line 187 in console_main File "/Users/quentingallouedec/stable-baselines3/env/bin/pytest", line 8 in <module> zsh: abort pytest tests/test_cnn.py
Don't know what it is. I will investigate.
Is this still happening in latest nightly cc @qgallouedec ?
With the latest nightly:
% /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py [W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware. Traceback (most recent call last): File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module> model = A2C("CnnPolicy", env).learn(250) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn return super().learn( File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts actions, values, log_probs = self.policy(obs_tensor) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl return forward_call(*input, **kwargs) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward log_prob = distribution.log_prob(actions) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob return self.distribution.log_prob(actions) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob self._validate_sample(value) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample valid = support.check(value) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
EDIT: tested with PyTorch 2.0.0.dev20221220
kulinseth
commented
Nov 17, 2022
With the latest nightly:
% /Users/quentingallouedec/stable-baselines3/env/bin/python /Users/quentingallouedec/stable-baselines3/test_mps.py [W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware. Traceback (most recent call last): File "/Users/quentingallouedec/stable-baselines3/test_mps.py", line 5, in <module> model = A2C("CnnPolicy", env).learn(250) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/a2c/a2c.py", line 193, in learn return super().learn( File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 248, in learn continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/on_policy_algorithm.py", line 166, in collect_rollouts actions, values, log_probs = self.policy(obs_tensor) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1427, in _call_impl return forward_call(*input, **kwargs) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/policies.py", line 576, in forward log_prob = distribution.log_prob(actions) File "/Users/quentingallouedec/stable-baselines3/stable_baselines3/common/distributions.py", line 279, in log_prob return self.distribution.log_prob(actions) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/categorical.py", line 123, in log_prob self._validate_sample(value) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/distribution.py", line 298, in _validate_sample valid = support.check(value) File "/Users/quentingallouedec/stable-baselines3/env/lib/python3.9/site-packages/torch/distributions/constraints.py", line 257, in check return (value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound) NotImplementedError: The operator 'aten::remainder.Tensor_out' is not currently implemented for the MPS device. If you want this op to be added in priority during the prototype phase of this feature, please comment on https://github.com/pytorch/pytorch/issues/77764. As a temporary fix, you can set the environment variable `PYTORCH_ENABLE_MPS_FALLBACK=1` to use the CPU as a fallback for this op. WARNING: this will be slower than running natively on MPS.
Its in PR. Will try to priortize the merge.
pytorch/pytorch#87582
Is there any progress on this? Is mps usable in any way already?
@BasLaa you can already give it a try by passing device="mps"
to the constructor and using latest pytorch version (pytorch nightly is probably even better).
It should work at least partially (but you might need to use the cpu fallback), please report any issue here.
@qgallouedec how is the support with PyTorch 2.1.0?
The number of errors decreases. Here's one a them:
TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
Is double precision a feature of sb3 or should single precision be forced systematically?
Is double precision a feature of sb3 or should single precision be forced systematically?
I think we don't really support float64... (mainly to avoid issues when using CUDA)
there are several places where we already force float32 anyway (#1572), including preprocessing if I recall.
tty666
commented
Oct 29, 2023
If you need someone to test something please tell me I could with my Mac because this PR is there for a while now and nobody comes with a solution or a review ...
Just tell me what to do and I will perform the testing for you to deliver this MPS support ...
@tty666 thank you for the proposal. Feel free to test and provide your feedback if any. As far as I remember, there are still some issues related to dtype (float64 instead of float32), see #951 (comment). As soon as all the CI passes, we can consider this PR as ready to be merged
ArthurMynl
commented
Dec 20, 2023
Any news regarding this PR? Is someone working on it?
Any news regarding this PR? Is someone working on it?
lsibilla
commented
Apr 17, 2024
Hello!
I just tried this out, out of curiosity and it seems to work. The small snippet above and another project I have been working on recently work very similarly with and without MPS.
I can see GPU going to 100% with asitop and no crashes.
Performance-wise it's not as good as we might expect but that might related to my particular use-case.
lsibilla
commented
Apr 18, 2024
Hi. I see the tests are still failing. I'll try to give a bit more details on my setup.
First, I'm running a MacBook Pro M1 Pro. The test from yesterday was running with Python 3.12.
This morning, I cloned the repo, switched to the feat/mps-support branch, created a Python 3.11 venv and ran test_cnn.py
:
(.venv-3.11) ➜ stable-baselines3 git:(feat/mps-support) ✗ pytest tests/test_cnn.py
======================================= test session starts =======================================
platform darwin -- Python 3.11.8, pytest-8.1.1, pluggy-1.4.0
rootdir: /Users/lsibilla/src/lab/stable-baselines3
configfile: pyproject.toml
plugins: cov-5.0.0, anyio-4.3.0, env-1.1.3, xdist-3.5.0
collected 29 items
tests/test_cnn.py .............s............... [100%]
======================================== warnings summary =========================================
tests/test_cnn.py: 25 warnings
/Users/lsibilla/src/lab/stable-baselines3/stable_baselines3/common/utils.py:524: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
if hasattr(th, "has_mps") and th.backends.mps.is_built():
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
====================== 28 passed, 1 skipped, 25 warnings in 76.15s (0:01:16) ======================
(.venv-3.11) ➜ stable-baselines3 git:(feat/mps-support) ✗
hi 👋 i would like to help move this pr forward, i see there hasnt been much progress in past few months, i have an m1 mac studio where i'm testing this branch with this setup:
- conda environment with:
python 3.11.9
- installed pytorch nightly with:
conda install pytorch torchvision torchaudio -c pytorch-nightly
- installed dependencies with:
pip install -e .[docs,tests,extra]
- tested test_cnn.py first with:
python3 -m pytest -v tests/test_cnn.py
which passed - then i run the full test suite with:
make pytest
and got 45 failing tests:
======================================================================================================================================================== short test summary info =========================================================================================================================================================
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[False-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DQN] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_dict_env.py::test_dict_vec_framestack[True-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_envs.py::test_bit_flipping[kwargs1] - OverflowError: Python integer 128 out of bounds for int8
FAILED tests/test_envs.py::test_bit_flipping[kwargs2] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_envs.py::test_bit_flipping[kwargs3] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-SAC] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-TD3] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-DDPG] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_her[True-DQN] - OverflowError: Python integer 255 out of bounds for int8
FAILED tests/test_her.py::test_multiprocessing[True-TD3] - EOFError
FAILED tests/test_her.py::test_multiprocessing[True-DQN] - EOFError
FAILED tests/test_her.py::test_save_load[True-SAC] - ValueError: Expected parameter scale (Tensor of shape (64, 4)) of distribution Normal(loc: torch.Size([64, 4]), scale: torch.Size([64, 4])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space1-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space0-obs_space3-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space1-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-SAC] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-TD3] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-PPO] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-DDPG] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_spaces.py::test_float64_action_space[action_space1-obs_space3-A2C] - TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.
FAILED tests/test_train_eval_mode.py::test_td3_train_with_batch_norm - AssertionError: assert ~tensor(True, device='mps:0')
FAILED tests/test_vec_normalize.py::test_get_original - AssertionError: assert dtype('float32') == dtype('float64')
FAILED tests/test_vec_normalize.py::test_get_original_dict - AssertionError: assert dtype('float32') == dtype('float64')
=========================================================================================================================== 45 failed, 713 passed, 28 skipped, 1 deselected, 6 warnings in 1381.13s (0:23:01) ============================================================================================================================
can someone point me in the right direction for the changes that i need to do to make the tests pass? i seen in this pr only 3 files have been changed but i didn't find examples fixes of these issues
edit: i tried my best to do things with common sense and fixed all tests, have a look at this pr #2005
Description
Add support for MPS device (uses it if available) and save cloudpickle version (important to debug saving/loading issues).
DO NOT MERGE: this PR must be tested on a MPS device first
closes #914
Motivation and Context
Types of changes
Checklist:
make format
(required)make check-codestyle
andmake lint
(required)make pytest
andmake type
both pass. (required)make doc
(required)Note: You can run most of the checks using
make commit-checks
.Note: we are using a maximum length of 127 characters per line