Composing Learning Algorithms

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Before you start

Before you start, please run the following to make sure that your environment is correctly setup. If you don't see a greeting, please refer to the Installation guide for instructions.

# @test {"skip": true}
pipinstall--quite--upgradefederated_language
pipinstall--quiet--upgradetensorflow-federated
importcollections
importfederated_language
importnumpyasnp
importtensorflowastf
importtensorflow_federatedastff

Composing Learning Algorithms

The Building Your Own Federated Learning Algorithm Tutorial used TFF's federated core to directly implement a version of the Federated Averaging (FedAvg) algorithm.

In this tutorial, you will use federated learning components in TFF's API to build federated learning algorithms in a modular manner, without having to re-implement everything from scratch.

For the purposes of this tutorial, you will implement a variant of FedAvg that employs gradient clipping through local training.

Learning Algorithm Building Blocks

At a high level, many learning algorithms can be separated into 4 separate components, referred to as building blocks. These are as follows:

  1. Distributor (ie. server-to-client communication)
  2. Client work (ie. local client computation)
  3. Aggregator (ie. client-to-server communication)
  4. Finalizer (ie. server computation using aggregated client outputs)

While the Building Your Own Federated Learning Algorithm Tutorial implemented all of these building blocks from scratch, this is often unnecessary. Instead, you can re-use building blocks from similar algorithms.

In this case, to implement FedAvg with gradient clipping, you only need to modify the client work building block. The remaining blocks can be identical to what is used in "vanilla" FedAvg.

Implementing the Client Work

First, let's write TF logic that does local model training with gradient clipping. For simplicity, gradients will be clipped have norm at most 1.

TF Logic

@tf.function
defclient_update(
model:tff.learning.models.FunctionalModel,
dataset:tf.data.Dataset,
initial_weights:tff.learning.models.ModelWeights,
client_optimizer:tff.learning.optimizers.Optimizer,
):
"""Performs training (using the initial server model weights) on the client's dataset."""
#Keeptrackofthenumberofexamples.
num_examples=0.0
#Usetheclient_optimizertoupdatethelocalmodel.
trainable_weights,non_trainable_weights=(
initial_weights.trainable,
initial_weights.non_trainable,
)
optimizer_state=client_optimizer.initialize(
tf.nest.map_structure(lambdax:tf.TensorSpec,trainable_weights)
)
forbatchindataset:
x,y=batch
withtf.GradientTape()astape:
tape.watch(trainable_weights)
logits=model.predict_on_batch(
model_weights=(trainable_weights,non_trainable_weights),
x=x,
training=True,
)
num_examples+=tf.cast(tf.shape(y)[0],tf.float32)
loss=model.loss(output=logits,label=y)
#Computethecorrespondinggradient
grads=tape.gradient(loss,trainable_weights)
#Computethegradientnormandclip
gradient_norm=tf.linalg.global_norm(grads)
ifgradient_norm > 1:
grads=tf.nest.map_structure(lambdax:x/gradient_norm,grads)
#Applythegradientusingaclientoptimizer.
optimizer_state,trainable_weights=client_optimizer.next(
optimizer_state,trainable_weights,grads
)
#Computethedifferencebetweentheinitialweightsandtheclientweights
client_update=tf.nest.map_structure(
tf.subtract,trainable_weights,initial_weights[0]
)
returntff.learning.templates.ClientResult(
update=client_update,update_weight=num_examples
)

There are a few important points about the code above. First, it keeps track of the number of examples seen, as this will constitute the weight of the client update (when computing an average across clients).

Second, it uses tff.learning.templates.ClientResult to package the output. This return type is used to standardize client work building blocks in tff.learning.

Creating a ClientWorkProcess

While the TF logic above will do local training with clipping, it still needs to be wrapped in TFF code in order to create the necessary building block.

Specifically, the 4 building blocks are represented as a tff.templates.MeasuredProcess. This means that all 4 blocks have both an initialize and next function used to instantiate and run the computation.

This allows each building block to keep track of its own state (stored at the server) as needed to perform its operations. While it will not be used in this tutorial, it can be used for things like tracking how many iterations have occurred, or keeping track of optimizer states.

Client work TF logic should generally be wrapped as a tff.learning.templates.ClientWorkProcess, which codifies the expected types going into and out of the client's local training. It can be parameterized by a model and optimizer, as below.

defbuild_gradient_clipping_client_work(
model:tff.learning.models.FunctionalModel,
optimizer:tff.learning.optimizers.Optimizer,
)->tff.learning.templates.ClientWorkProcess:
"""Creates a client work process that uses gradient clipping."""
data_type=federated_language.SequenceType(
tff.tensorflow.to_type(model.input_spec)
)
model_weights_type=federated_language.to_type(
tf.nest.map_structure(
lambdaarr:federated_language.TensorType(
shape=arr.shape,dtype=arr.dtype
),
tff.learning.models.ModelWeights(*model.initial_weights),
)
)
@federated_language.federated_computation
definitialize_fn():
returnfederated_language.federated_value((),federated_language.SERVER)
@tff.tensorflow.computation(model_weights_type,data_type)
defclient_update_computation(model_weights,dataset):
returnclient_update(model,dataset,model_weights,optimizer)
@federated_language.federated_computation(
initialize_fn.type_signature.result,
federated_language.FederatedType(
model_weights_type,federated_language.CLIENTS
),
federated_language.FederatedType(data_type,federated_language.CLIENTS),
)
defnext_fn(state,model_weights,client_dataset):
client_result=federated_language.federated_map(
client_update_computation,(model_weights,client_dataset)
)
#Returnemptymeasurements,thoughamorecompletealgorithmmight
#measuresomethinghere.
measurements=federated_language.federated_value(
(),federated_language.SERVER
)
returntff.templates.MeasuredProcessOutput(
state,client_result,measurements
)
returntff.learning.templates.ClientWorkProcess(initialize_fn,next_fn)

Composing a Learning Algorithm

Let's put the client work above into a full-fledged algorithm. First, let's set up our data and model.

Preparing the input data

Load and preprocess the EMNIST dataset included in TFF. For more details, see the image classification tutorial.

emnist_train,emnist_test=tff.simulation.datasets.emnist.load_data()

In order to feed the dataset into our model, the data is flattened and converted into tuples of the form (flattened_image_vector, label).

Let's select a small number of clients, and apply the preprocessing above to their datasets.

NUM_CLIENTS = 10
BATCH_SIZE = 20
def preprocess(dataset):
 def batch_format_fn(element):
 """Flatten a batch of EMNIST data and return a (features, label) tuple."""
 return (
 tf.reshape(element['pixels'], [-1, 784]),
 tf.reshape(element['label'], [-1, 1]),
 )
 return dataset.batch(BATCH_SIZE).map(batch_format_fn)
client_ids = sorted(emnist_train.client_ids)[:NUM_CLIENTS]
federated_train_data = [
 preprocess(emnist_train.create_tf_dataset_for_client(x)) for x in client_ids
]

Preparing the model

This uses the same model as in the image classification tutorial. This model (implemented via tf.keras) has a single hidden layer, followed by a softmax layer. In order to use this model in TFF, Keras model is wrapped as a tff.learning.models.FunctionalModel. This allows us to perform the model's forward pass aggregator_factory = tff.aggregators.MeanFactory() aggregator = aggregator_factory.create( model_weights_type.trainable, federated_language.TensorType(np.float32) ) finalizer = tff.learning.templates.build_apply_optimizer_finalizer( server_optimizer, model_weights_type )

initializer=tf.keras.initializers.GlorotNormal(seed=0)
keras_model=tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(784,)),
tf.keras.layers.Dense(10,kernel_initializer=initializer),
tf.keras.layers.Softmax(),
])
tff_model=tff.learning.models.functional_model_from_keras(
keras_model,
loss_fn=tf.keras.losses.SparseCategoricalCrossentropy(),
input_spec=federated_train_data[0].element_spec,
metrics_constructor=collections.OrderedDict(
accuracy=tf.keras.metrics.SparseCategoricalAccuracy
),
)

Preparing the optimizers

Just as in tff.learning.algorithms.build_weighted_fed_avg, there are two optimizers here: A client optimizer, and a server optimizer. For simplicity, the optimizers will be SGD with different learning rates.

client_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=0.01)
server_optimizer = tff.learning.optimizers.build_sgdm(learning_rate=1.0)

Defining the building blocks

Now that the client work building block, data, model, and optimizers are set up, it remains to create building blocks for the distributor, the aggregator, and the finalizer. This can be done just by borrowing some defaults available in TFF and that are used by FedAvg.

@tff.tensorflow.computation
definitial_model_weights_fn():
returntff.learning.models.ModelWeights(*tff_model.initial_weights)
model_weights_type=initial_model_weights_fn.type_signature.result
distributor=tff.learning.templates.build_broadcast_process(model_weights_type)
client_work=build_gradient_clipping_client_work(tff_model,client_optimizer)
#TFFaggregatorsuseafactorypattern,whichcreateanaggregator
#basedontheoutputtypeoftheclientwork.Thisalsousesafloat(thenumber
#ofexamples)togoverntheweightintheaveragebeingcomputed.)
aggregator_factory=tff.aggregators.MeanFactory()
aggregator=aggregator_factory.create(
model_weights_type.trainable,federated_language.TensorType(np.float32)
)
finalizer=tff.learning.templates.build_apply_optimizer_finalizer(
server_optimizer,model_weights_type
)

Composing the building blocks

Finally, you can use a built-in composer in TFF for putting the building blocks together. This one is a relatively simple composer, which takes the 4 building blocks above and wires their types together.

fed_avg_with_clipping=tff.learning.templates.compose_learning_process(
initial_model_weights_fn,distributor,client_work,aggregator,finalizer
)

Running the algorithm

Now that the algorithm is done, let's run it. First, initialize the algorithm. The state of this algorithm has a component for each building block, along with one for the global model weights.

state = fed_avg_with_clipping.initialize()
state.client_work
()

As expected, the client work has an empty state (remember the client work code above!). However, other building blocks may have non-empty state. For example, the finalizer keeps track of how many iterations have occurred. Since next has not been run yet, it has a state of 0.

state.finalizer
OrderedDict([('learning_rate', 1.0)])

Now run a training round.

learning_process_output = fed_avg_with_clipping.next(
 state, federated_train_data
)

The output of this (tff.learning.templates.LearningProcessOutput) has both a .state and .metrics output. Let's look at both.

learning_process_output.state.finalizer
OrderedDict([('learning_rate', 1.0)])

Clearly, the finalizer state has incremented by one, as one round of .next has been run.

learning_process_output.metrics
OrderedDict([('distributor', ()), ('client_work', ()), ('aggregator', OrderedDict([('mean_value', ()), ('mean_weight', ())])), ('finalizer', OrderedDict([('update_non_finite', 0)]))])

While the metrics are empty, for more complex and practical algorithms they'll generally be full of useful information.

Conclusion

By using the building block/composers framework above, you can create entirely new learning algorithms, without having to re-do everything from scratch. However, this is only the starting point. This framework makes it much easier to express algorithms as simple modifications of FedAvg. For more algorithms, see tff.learning.algorithms, which contains algorithms such as FedProx and FedAvg with client learning rate scheduling. These APIs can even aid implementations of entirely new algorithms, such as federated k-means clustering.

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2025年01月30日 UTC.