|
25 | 25 |
|
26 | 26 | import platform |
27 | 27 | import shutil |
28 | | -fromsubprocessimport Popen, PIPE |
| 28 | +import subprocess |
29 | 29 | import os |
30 | 30 |
|
31 | 31 |
|
32 | 32 | def gpu_count(): |
| 33 | + nvidia_smi = shutil.which('nvidia-smi') |
| 34 | + if nvidia_smi is None and platform.system() == "Windows": |
| 35 | + nvidia_smi = f'{os.environ["systemdrive"]}\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe' |
| 36 | + if nvidia_smi is None: |
| 37 | + return 0 |
33 | 38 | try: |
34 | | - if platform.system() == "Windows": |
35 | | - nvidia_smi = shutil.which('nvidia-smi') |
36 | | - if nvidia_smi is None: |
37 | | - nvidia_smi = ( |
38 | | - "%s\\Program Files\\NVIDIA Corporation\\NVSMI\\nvidia-smi.exe" |
39 | | - % os.environ['systemdrive'] |
40 | | - ) |
41 | | - else: |
42 | | - nvidia_smi = "nvidia-smi" |
43 | | - |
44 | | - p = Popen( |
| 39 | + p = subprocess.run( |
45 | 40 | [nvidia_smi, "--query-gpu=name", "--format=csv,noheader,nounits"], |
46 | | - stdout=PIPE, |
| 41 | + stdout=subprocess.PIPE, |
| 42 | + text=True, |
47 | 43 | ) |
48 | | - stdout, stderror = p.communicate() |
49 | | - |
50 | | - output = stdout.decode('UTF-8') |
51 | | - lines = output.split(os.linesep) |
52 | | - num_devices = len(lines) - 1 |
53 | | - return num_devices |
54 | | - except: |
| 44 | + except (OSError, UnicodeDecodeError): |
55 | 45 | return 0 |
| 46 | + return len(p.stdout.splitlines()) |
0 commit comments