GPU device plugins

TensorFlow's pluggable device architecture adds new device support as separate plug-in packages that are installed alongside the official TensorFlow package.

The mechanism requires no device-specific changes in the TensorFlow code. It relies on C APIs to communicate with the TensorFlow binary in a stable manner. Plug-in developers maintain separate code repositories and distribution packages for their plugins and are responsible for testing their devices.

Use device plugins

To use a particular device, like one would a native device in TensorFlow, users only have to install the device plug-in package for that device. The following code snippet shows how the plugin for a new demonstration device, Awesome Processing Unit (APU), is installed and used. For simplicity, this sample APU plug-in only has one custom kernel for ReLU:

# Install the APU example plug-in package
$pipinstalltensorflow-apu-0.0.1-cp36-cp36m-linux_x86_64.whl
...
Successfullyinstalledtensorflow-apu-0.0.1

With the plug-in installed, test that the device is visible and run an operation on the new APU device:

importtensorflowastf # TensorFlow registers PluggableDevices here.
tf.config.list_physical_devices() # APU device is visible to TensorFlow.
[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:APU:0', device_type='APU')]
a = tf.random.normal(shape=[5], dtype=tf.float32) # Runs on CPU.
b = tf.nn.relu(a) # Runs on APU.
with tf.device("/APU:0"): # Users can also use 'with tf.device' syntax.
 c = tf.nn.relu(a) # Runs on APU.
with tf.device("/CPU:0"):
 c = tf.nn.relu(a) # Runs on CPU.
@tf.function # Defining a tf.function
defrun():
 d = tf.random.uniform(shape=[100], dtype=tf.float32) # Runs on CPU.
 e = tf.nn.relu(d) # Runs on APU.
run() # PluggableDevices also work with tf.function and graph mode.

Available devices

Metal PluggableDevice for macOS GPUs:

DirectML PluggableDevice for Windows and WSL (preview):

Intel® Extension for TensorFlow PluggableDevice for Linux and WSL:

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024年07月25日 UTC.