Client-efficient large-model federated learning via federated_select and sparse aggregation

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This tutorial shows how TFF can be used to train a very large model where each client device only downloads and updates a small part of the model, using tff.federated_select and sparse aggregation. While this tutorial is fairly self-contained, the tff.federated_select tutorial and custom FL algorithms tutorial provide good introductions to some of the techniques used here.

Concretely, in this tutorial we consider logistic regression for multi-label classification, predicting which "tags" are associated with a text string based on a bag-of-words feature representation. Importantly, communication and client-side computation costs are controlled by a fixed constant (MAX_TOKENS_SELECTED_PER_CLIENT), and do not scale with the overall vocabulary size, which could be extremely large in practical settings.

pipinstall--quiet--upgradetensorflow-federated
importcollections
fromcollections.abcimport Callable
importitertools
importnumpyasnp
importtensorflowastf
importtensorflow_federatedastff

Each client will federated_select the rows of the model weights for at most this many unique tokens. This upper-bounds the size of the client's local model and the amount of server -> client (federated_select) and client - > server (federated_aggregate) communication performed.

This tutorial should still run correctly even if you set this as small as 1 (ensuring not all tokens from each client are selected) or to a large value, though model convergence may be effected.

MAX_TOKENS_SELECTED_PER_CLIENT = 6

We also define a few constants for various types. For this colab, a token is an integer identifier for a particular word after parsing the dataset.

#Therearesomeconstraintsontypes
#herethatwillrequiresomeexplicittypeconversions:
#-`tff.federated_select`requiresint32
#-`tf.SparseTensor`requiresint64indices.
TOKEN_DTYPE=np.int64
SELECT_KEY_DTYPE=np.int32
#Typeforcountsoftokenoccurences.
TOKEN_COUNT_DTYPE=np.int32
#Asparsefeaturevectorcanbethoughtofasamap
#fromTOKEN_DTYPEtoFEATURE_DTYPE.
#Ourfeaturesare{0,1}indicators,sowecouldpotentially
#usenp.int8asanoptimization.
FEATURE_DTYPE=np.int32

Setting up the problem: Dataset and Model

We construct a tiny toy dataset for easy experimentation in this tutorial. However, the format of the dataset is compatible with Federated StackOverflow, and the pre-processing and model architecture are adopted from the StackOverflow tag prediction problem of Adaptive Federated Optimization.

Dataset parsing and pre-processing

NUM_OOV_BUCKETS=1
BatchType=collections.namedtuple('BatchType',['tokens', 'tags'])
defbuild_to_ids_fn(word_vocab:list[str],
tag_vocab:list[str])->Callable[[tf.Tensor],tf.Tensor]:
"""Constructs a function mapping examples to sequences of token indices."""
word_table_values=np.arange(len(word_vocab),dtype=np.int64)
word_table=tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(word_vocab,word_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
tag_table_values=np.arange(len(tag_vocab),dtype=np.int64)
tag_table=tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(tag_vocab,tag_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
defto_ids(example):
"""Converts a Stack Overflow example to a bag-of-words/tags format."""
sentence=tf.strings.join([example['tokens'],example['title']],
separator=' ')
#Werepresentthatlabel(outputtags)densely.
raw_tags=example['tags']
tags=tf.strings.split(raw_tags,sep='|')
tags=tag_table.lookup(tags)
tags,_=tf.unique(tags)
tags=tf.one_hot(tags,len(tag_vocab)+NUM_OOV_BUCKETS)
tags=tf.reduce_max(tags,axis=0)
#WerepresentthefeaturesasaSparseTensorof{0,1}s.
words=tf.strings.split(sentence)
tokens=word_table.lookup(words)
tokens,_=tf.unique(tokens)
#Note:Wecouldchoosetousethewordcountsasthefeaturevector
#insteadofjust{0,1}values(seetf.unique_with_counts).
tokens=tf.reshape(tokens,shape=(tf.size(tokens),1))
tokens_st=tf.SparseTensor(
tokens,
tf.ones(tf.size(tokens),dtype=FEATURE_DTYPE),
dense_shape=(len(word_vocab)+NUM_OOV_BUCKETS,))
tokens_st=tf.sparse.reorder(tokens_st)
returnBatchType(tokens_st,tags)
returnto_ids
defbuild_preprocess_fn(word_vocab,tag_vocab):
@tf.function
defpreprocess_fn(dataset):
to_ids=build_to_ids_fn(word_vocab,tag_vocab)
#We*don't*shuffleinordertomakethiscolabdeterministicfor
#easiertestingandreproducibility.
#Butreal-worldtrainingshoulduse`.shuffle()`.
returndataset.map(to_ids,num_parallel_calls=tf.data.experimental.AUTOTUNE)
returnpreprocess_fn

A tiny toy dataset

We construct a tiny toy dataset with a global vocabulary of 12 words and 3 clients. This tiny example is useful for testing edge cases (for example, we have two clients with less than MAX_TOKENS_SELECTED_PER_CLIENT = 6 distinct tokens, and one with more) and developing the code.

However, the real-world use cases of this approach would be global vocabularies of 10s of millions or more, with perhaps 1000s of distinct tokens appearing on each client. Because the format of the data is the same, the extension to more realistic testbed problems, e.g. the tff.simulation.datasets.stackoverflow.load_data() dataset, should be straightforward.

First, we define our word and tag vocabularies.

# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS
# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']

Now, we create 3 clients with small local datasets. If you are running this tutorial in colab, it may be useful to use the "mirror cell in tab" feature to pin this cell and its output in order to interpret/check the output of the functions developed below.

preprocess_fn=build_preprocess_fn(WORD_VOCAB,TAG_VOCAB)
defmake_dataset(raw):
d=tf.data.Dataset.from_tensor_slices(
# Matches the StackOverflow formatting
collections.OrderedDict(
tokens=tf.constant([t[0]fortinraw]),
tags=tf.constant([t[1]fortinraw]),
title=[''for_inraw]))
d=preprocess_fn(d)
returnd
# 4 distinct tokens
CLIENT1_DATASET=make_dataset([
('apple orange apple orange','FRUIT'),
('carrot trout','VEGETABLE|FISH'),
('orange apple','FRUIT'),
('orange','ORANGE|CITRUS')# 2 OOV tag
])
# 6 distinct tokens
CLIENT2_DATASET=make_dataset([
('pear cod','FRUIT|FISH'),
('arugula peas','VEGETABLE'),
('kiwi pear','FRUIT'),
('sturgeon','FISH'),# OOV word
('sturgeon bass','FISH')# 2 OOV words
])
# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET=make_dataset([
(' '.join(WORD_VOCAB+['oovword']),'|'.join(TAG_VOCAB)),
# Mathe the OOV token and 'salmon' occur in the largest number
# of examples on this client:
('salmon oovword','FISH|OOVTAG')
])
print('Word vocab')
fori,wordinenumerate(WORD_VOCAB):
print(f'{i:2d} {word}')
print('\nTag vocab')
fori,taginenumerate(TAG_VOCAB):
print(f'{i:2d} {tag}')
Word vocab
 0 apple
 1 orange
 2 pear
 3 kiwi
 4 carrot
 5 broccoli
 6 arugula
 7 peas
 8 trout
 9 tuna
10 cod
11 salmon
Tag vocab
 0 FRUIT
 1 VEGETABLE
 2 FISH

Define constants for the raw numbers of input features (tokens/words) and labels (post tags). Our actual input/output spaces are NUM_OOV_BUCKETS = 1 larger because we add an OOV token / tag.

NUM_WORDS = len(WORD_VOCAB) 
NUM_TAGS = len(TAG_VOCAB)
WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS

Create batched versions of the datasets, and individual batches, which will be useful in testing code as we go.

batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)
batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))

Define a model with sparse inputs

We use a simple independent logistic regression model for each tag.

defcreate_logistic_model(word_vocab_size:int,vocab_tags_size:int):
model=tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(word_vocab_size,),sparse=True),
tf.keras.layers.Dense(
vocab_tags_size,
activation='sigmoid',
kernel_initializer=tf.keras.initializers.zeros,
#Forsimplicity,don'tuseabiasvector;thismeansthemodel
#isasingletensor,andweonlyneedsparseaggregationof
#theper-tokenslicesofthemodel.Generalizingtoalsohandle
#othermodelweightsthatarefullyupdated
#(non-densebroadcastandaggregate)wouldbeagoodexercise.
use_bias=False),
])
returnmodel

Let's make sure it works, first by making predictions:

model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5]
 [0.5 0.5 0.5 0.5]]

And some simple centralized training:

model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
 loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)

Building blocks for the federated computation

We will implement a simple version of the Federated Averaging algorithm with the key difference that each device only downloads a relevant subset of the model, and only contributes updates to that subset.

We use M as shorthand for MAX_TOKENS_SELECTED_PER_CLIENT. At a high level, one round of training involves these steps:

  1. Each participating client scans over its local dataset, parsing the input strings and mapping them to the correct tokens (int indexes). This requires access to the global (large) dictionary (this could potentially be avoided using feature hashing techniques). We then sparsely count how many times each token occurs. If U unique tokens occur on device, we choose the num_actual_tokens = min(U, M) most frequent tokens to train.

  2. The clients use federated_select to retrieve the model coefficients for the num_actual_tokens selected tokens from the server. Each model slice is a tensor of shape (TAG_VOCAB_SIZE, ), so the total data transmitted to the client is at most of size TAG_VOCAB_SIZE * M (see note below).

  3. The clients construct a mapping global_token -> local_token where the local token (int index) is the index of the global token in the list of selected tokens.

  4. The clients use a "small" version of the global model that only has coefficients for at most M tokens, from the range [0, num_actual_tokens). The global -> local mapping is used to initialize the dense parameters of this model from the selected model slices.

  5. Clients train their local model using SGD on data preprocessed with the global -> local mapping.

  6. Clients turn the parameters of their local model into IndexedSlices updates using the local -> global mapping to index the rows. The server aggregates these updates using a sparse sum aggregation.

  7. The server takes the (dense) result of the above aggregation, divides it by the number of clients participating, and applies the resulting average update to the global model.

In this section we construct the building blocks for these steps, which will then be combined in a final federated_computation that captures the full logic of one training round.

Count client tokens and decide which model slices to federated_select

Each device needs to decide which "slices" of the model are relevant to its local training dataset. For our problem, we do this by (sparsely!) counting how many examples contain each token in the client training data set.

@tf.function
deftoken_count_fn(token_counts,batch):
"""Adds counts from `batch` to the running `token_counts` sum."""
# Sum across the batch dimension.
flat_tokens=tf.sparse.reduce_sum(
batch.tokens,axis=0,output_is_sparse=True)
flat_tokens=tf.cast(flat_tokens,dtype=TOKEN_COUNT_DTYPE)
returntf.sparse.add(token_counts,flat_tokens)
#Simpletests
#Createtheinitialzerotokencountsusingemptytensors.
initial_token_counts=tf.SparseTensor(
indices=tf.zeros(shape=(0,1),dtype=TOKEN_DTYPE),
values=tf.zeros(shape=(0,),dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts=batched_dataset1.reduce(initial_token_counts,
token_count_fn)
tokens=tf.reshape(client_token_counts.indices,(-1,)).numpy()
print('tokens:',tokens)
np.testing.assert_array_equal(tokens,[0,1,4,8])
#Thecountisthenumberof*examples*inwhichthetoken/word
#occurs,notthetotalnumberofoccurences,sincewestillfeaturize
#multipleoccurencesinthesameexampleasa"1".
counts=client_token_counts.values.numpy()
print('counts:',counts)
np.testing.assert_array_equal(counts,[2,3,1,1])
tokens: [0 1 4 8]
counts: [2 3 1 1]

We will select the model parameters corresponding to the MAX_TOKENS_SELECTED_PER_CLIENT most frequently occuring tokens on device. If fewer than this many tokens occur on device, we pad the list to enable the use of federated_select.

Note that other strategies are possibly better, for example, randomly selecting tokens (perhaps based on their occurrence probability). This would ensure that all slices of the model (for which the client has data) have some chance of being updated.

@tf.function
defkeys_for_client(client_dataset,max_tokens_per_client):
"""Computes a set of max_tokens_per_client keys."""
initial_token_counts=tf.SparseTensor(
indices=tf.zeros((0,1),dtype=TOKEN_DTYPE),
values=tf.zeros((0,),dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts=client_dataset.reduce(initial_token_counts,
token_count_fn)
# Find the most-frequently occuring tokens
tokens=tf.reshape(client_token_counts.indices,shape=(-1,))
counts=client_token_counts.values
perm=tf.argsort(counts,direction='DESCENDING')
tokens=tf.gather(tokens,perm)
counts=tf.gather(counts,perm)
num_raw_tokens=tf.shape(tokens)[0]
actual_num_tokens=tf.minimum(max_tokens_per_client,num_raw_tokens)
selected_tokens=tokens[:actual_num_tokens]
paddings=[[0,max_tokens_per_client-tf.shape(selected_tokens)[0]]]
padded_tokens=tf.pad(selected_tokens,paddings=paddings)
# Make sure the type is statically determined
padded_tokens=tf.reshape(padded_tokens,shape=(max_tokens_per_client,))
# We will pass these tokens as keys into `federated_select`, which
# requires SELECT_KEY_DTYPE=np.int32 keys.
padded_tokens=tf.cast(padded_tokens,dtype=SELECT_KEY_DTYPE)
returnpadded_tokens,actual_num_tokens
#Simpletest
#Case1:actual_num_tokens > max_tokens_per_client
selected_tokens,actual_num_tokens=keys_for_client(batched_dataset1,3)
asserttf.size(selected_tokens)==3
assertactual_num_tokens==3
#Case2:actual_num_tokens < max_tokens_per_client
selected_tokens,actual_num_tokens=keys_for_client(batched_dataset1,10)
asserttf.size(selected_tokens)==10
assertactual_num_tokens==4

Map global tokens to local tokens

The above selection gives us a dense set of tokens in the range [0, actual_num_tokens) which we will use for the on-device model. However, the dataset we read has tokens from the much larger global vocabulary range [0, WORD_VOCAB_SIZE).

Thus, we need to map the global tokens to their corresponding local tokens. The local token ids are simply given by the indexes into the selected_tokens tensor computed in the previous step.

@tf.function
defmap_to_local_token_ids(client_data,client_keys):
global_to_local=tf.lookup.StaticHashTable(
#Noteint32->int64mapsarenotsupported
tf.lookup.KeyValueTensorInitializer(
keys=tf.cast(client_keys,dtype=TOKEN_DTYPE),
#Noteweneedtousetf.shape,notthestatic
#shapeclient_keys.shape[0]
values=tf.range(0,limit=tf.shape(client_keys)[0],
dtype=TOKEN_DTYPE)),
#Weuse-1fortokensthatwerenotselected,whichcanoccurforclients
#withmorethanMAX_TOKENS_SELECTED_PER_CLIENTdistincttokens.
#Wewillsimplyremovetheseinvalidindicesfromthebatchbelow.
default_value=-1)
defto_local_ids(sparse_tokens):
indices_t=tf.transpose(sparse_tokens.indices)
batch_indices=indices_t[0]#Firstcolumn
tokens=indices_t[1]#Secondcolumn
tokens=tf.map_fn(
lambdaglobal_token_id:global_to_local.lookup(global_token_id),tokens)
#Removetokensthataren'tactuallyavailable(lookedupas-1):
available_tokens=tokens>=0
tokens=tokens[available_tokens]
batch_indices=batch_indices[available_tokens]
updated_indices=tf.transpose(
tf.concat([[batch_indices],[tokens]],axis=0))
st=tf.sparse.SparseTensor(
updated_indices,
tf.ones(tf.size(tokens),dtype=FEATURE_DTYPE),
#EachclienthasatmostMAX_TOKENS_SELECTED_PER_CLIENTdistincttokens.
dense_shape=[sparse_tokens.dense_shape[0],MAX_TOKENS_SELECTED_PER_CLIENT])
st=tf.sparse.reorder(st)
returnst
returnclient_data.map(lambdab:BatchType(to_local_ids(b.tokens),b.tags))
#Simpletest
client_keys,actual_num_tokens=keys_for_client(
batched_dataset3,MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys=client_keys[:actual_num_tokens]
d=map_to_local_token_ids(batched_dataset3,client_keys)
batch=next(iter(d))
all_tokens=tf.gather(batch.tokens.indices,indices=1,axis=1)
#Confirmwehavelocalindicesintherange[0,MAX):
asserttf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
asserttf.math.reduce_max(all_tokens)>=0

Train the local (sub)model on each client

Note federated_select will return the selected slices as a tf.data.Dataset in the same order as the selection keys. So, we first define a utility function to take such a Dataset and convert it to a single dense tensor which can be used as the model weights of the client model.

@tf.function
defslices_dataset_to_tensor(slices_dataset):
"""Convert a dataset of slices to a tensor."""
#Usebatchingtogatheralloftheslicesintoasingletensor.
d=slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
drop_remainder=False)
iter_d=iter(d)
tensor=next(iter_d)
#Makesurewehaveconsumedeverything
opt=iter_d.get_next_as_optional()
tf.Assert(tf.logical_not(opt.has_value()),data=[''],name='CHECK_EMPTY')
returntensor
#Simpletest
weights=np.random.random(
size=(MAX_TOKENS_SELECTED_PER_CLIENT,TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset=tf.data.Dataset.from_tensor_slices(weights)
weights2=slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights,weights2)

We now have all the components we need to define a simple local training loop which will run on each client.

@tf.function
defclient_train_fn(model,client_optimizer,
model_slices_as_dataset,client_data,
client_keys,actual_num_tokens):
initial_model_weights=slices_dataset_to_tensor(model_slices_as_dataset)
assertlen(model.trainable_variables)==1
model.trainable_variables[0].assign(initial_model_weights)
# Only keep the "real" (unpadded) keys.
client_keys=client_keys[:actual_num_tokens]
client_data=map_to_local_token_ids(client_data,client_keys)
loss_fn=tf.keras.losses.BinaryCrossentropy()
forfeatures,labelsinclient_data:
withtf.GradientTape()astape:
predictions=model(features)
loss=loss_fn(labels,predictions)
grads=tape.gradient(loss,model.trainable_variables)
client_optimizer.apply_gradients(zip(grads,model.trainable_variables))
model_weights_delta=model.trainable_weights[0]-initial_model_weights
model_weights_delta=tf.slice(model_weights_delta,begin=[0,0],
size=[actual_num_tokens,-1])
returnclient_keys,model_weights_delta
#Simpletest
#Noteifyouexecutethiscellasecondtime,youneedtoalsore-execute
#thepreceedingcelltoavoid"tf.function-decorated function tried to 
# create variables on non-first call"errors.
on_device_model=create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer=tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys,actual_num_tokens=keys_for_client(
batched_dataset2,MAX_TOKENS_SELECTED_PER_CLIENT)
model_slices_as_dataset=tf.data.Dataset.from_tensor_slices(
np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT,TAG_VOCAB_SIZE),
dtype=np.float32))
keys,delta=client_train_fn(
on_device_model,
client_optimizer,
model_slices_as_dataset,
client_data=batched_dataset3,
client_keys=client_keys,
actual_num_tokens=actual_num_tokens)
print(delta)

Aggregate IndexedSlices

We use tff.federated_aggregate to construct a federated sparse sum for IndexedSlices. This simple implementation has the constraint that the dense_shape is known statically in advance. Note also that this sum is only semi-sparse, in the sense that the client -> server communication is sparse, but the server maintains a dense representation of the sum in accumulate and merge, and outputs this dense representation.

deffederated_indexed_slices_sum(slice_indices,slice_values,dense_shape):
"""
 Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor.
 Intermediate aggregation is performed by converting to a dense representation,
 which may not be suitable for all applications.
 Args:
 slice_indices: An IndexedSlices.indices tensor @CLIENTS.
 slice_values: An IndexedSlices.values tensor @CLIENTS.
 dense_shape: A statically known dense shape.
 Returns:
 A dense tensor placed @SERVER representing the sum of the client's
 IndexedSclies.
"""
slices_dtype=slice_values.type_signature.member.dtype
zero=tff.tensorflow.computation(
lambda:tf.zeros(dense_shape,dtype=slices_dtype))()
@tf.function
defaccumulate_slices(dense,client_value):
indices,slices=client_value
# There is no built-in way to add `IndexedSlices`, but 
# tf.convert_to_tensor is a quick way to convert to a dense representation
# so we can add them.
returndense+tf.convert_to_tensor(
tf.IndexedSlices(slices,indices,dense_shape))
returntff.federated_aggregate(
(slice_indices,slice_values),
zero=zero,
accumulate=tff.tensorflow.computation(accumulate_slices),
merge=tff.tensorflow.computation(lambdad1,d2:tf.add(d1,d2,name='merge')),
report=tff.tensorflow.computation(lambdad:d))

Construct a minimal federated_computation as a test

dense_shape=(6,2)
indices_type=tff.TensorType(np.int64,(None,))
values_type=tff.TensorType(np.float32,(None,2))
client_slice_type=tff.FederatedType(
(indices_type,values_type),tff.CLIENTS)
@tff.federated_computation(client_slice_type)
deftest_sum_indexed_slices(indices_values_at_client):
indices,values=indices_values_at_client
returnfederated_indexed_slices_sum(indices,values,dense_shape)
print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x=tf.IndexedSlices(
values=np.array([[2.,2.1],[0.,0.1],[1.,1.1],[5.,5.1]],
dtype=np.float32),
indices=[2,0,1,5],
dense_shape=dense_shape)
y=tf.IndexedSlices(
values=np.array([[0.,0.3],[3.1,3.2]],dtype=np.float32),
indices=[1,3],
dense_shape=dense_shape)
#Sumone.
result=test_sum_indexed_slices([(x.indices,x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x),result)
#Sumtwo.
expected=[[0.,0.1],[1.,1.4],[2.,2.1],[3.1,3.2],[0.,0.],[5.,5.1]]
result=test_sum_indexed_slices([(x.indices,x.values),(y.indices,y.values)])
np.testing.assert_array_almost_equal(expected,result)

Putting it all together in a federated_computation

We now use TFF to bind together the components into a tff.federated_computation.

DENSE_MODEL_SHAPE=(WORD_VOCAB_SIZE,TAG_VOCAB_SIZE)
client_data_type=tff.SequenceType(batched_dataset1.element_spec)
model_type=tff.TensorType(np.float32,shape=DENSE_MODEL_SHAPE)

We use a basic server training function based on Federated Averaging, applying the update with a server learning rate of 1.0. It is important that we apply an update (delta) to the model, rather than simply averaging client-supplied models, as otherwise if a given slice of the model wasn't trained on by any client on a given round its coefficients could be zeroed out.

@tff.tensorflow.computation
defserver_update(current_model_weights,update_sum,num_clients):
average_update=update_sum/num_clients
returncurrent_model_weights+average_update

We need a couple more tff.tensorflow.computation components:

# Function to select slices from the model weights in federated_select:
select_fn=tff.tensorflow.computation(
lambdamodel_weights,index:tf.gather(model_weights,index))
# We need to wrap `client_train_fn` as a `tff.tensorflow.computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tensorflow.computation
defclient_train_fn_tff(model_slices_as_dataset,client_data,client_keys,
actual_num_tokens):
# Note this is amaller than the global model, using
# MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
# We would like a model of size `actual_num_tokens`, but we
# can't build the model dynamically, so we will slice off the padded
# weights at the end.
client_model=create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)
returnclient_train_fn(client_model,client_optimizer,
model_slices_as_dataset,client_data,client_keys,
actual_num_tokens)
@tff.tensorflow.computation
defkeys_for_client_tff(client_data):
returnkeys_for_client(client_data,MAX_TOKENS_SELECTED_PER_CLIENT)

We're now ready to put all the pieces together!

@tff.federated_computation(
tff.FederatedType(model_type,tff.SERVER),tff.FederatedType(client_data_type,tff.CLIENTS))
defsparse_model_update(server_model,client_data):
max_tokens=tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT,tff.SERVER)
keys_at_clients,actual_num_tokens=tff.federated_map(
keys_for_client_tff,client_data)
model_slices=tff.federated_select(keys_at_clients,max_tokens,server_model,
select_fn)
update_keys,update_slices=tff.federated_map(
client_train_fn_tff,
(model_slices,client_data,keys_at_clients,actual_num_tokens))
dense_update_sum=federated_indexed_slices_sum(update_keys,update_slices,
DENSE_MODEL_SHAPE)
num_clients=tff.federated_sum(tff.federated_value(1.0,tff.CLIENTS))
updated_server_model=tff.federated_map(
server_update,(server_model,dense_update_sum,num_clients))
returnupdated_server_model
print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)

Let's train a model!

Now that we have our training function, let's try it out.

server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile( # Compile to make evaluation easy.
 optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused
 loss=tf.keras.losses.BinaryCrossentropy(),
 metrics=[ 
 tf.keras.metrics.Precision(name='precision'),
 tf.keras.metrics.AUC(name='auc'),
 tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
 ])
def evaluate(model, dataset, name):
 metrics = model.evaluate(dataset, verbose=0)
 metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in 
 (zip(server_model.metrics_names, metrics))])
 print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model,batched_dataset1,'Client 1')
evaluate(server_model,batched_dataset2,'Client 2')
evaluate(server_model,batched_dataset3,'Client 3')
model_weights=server_model.trainable_weights[0]
client_datasets=[batched_dataset1, batched_dataset2, batched_dataset3]
for_inrange(10):#Run10roundsofFedAvg
#Wetrainon1,2,or3clientsperround,selecting
#randomly.
cohort_size=np.random.randint(1,4)
clients=np.random.choice([0, 1, 2],cohort_size,replace=False)
print('Training on clients',clients)
model_weights=sparse_model_update(
model_weights,[client_datasets[i]foriinclients])
server_model.set_weights([model_weights])
print('After training')
evaluate(server_model,batched_dataset1,'Client 1')
evaluate(server_model,batched_dataset2,'Client 2')
evaluate(server_model,batched_dataset3,'Client 3')
Before training
Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60
Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50
Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40
Training on clients [0 1]
Training on clients [0 2 1]
Training on clients [2 0]
Training on clients [1 0 2]
Training on clients [2]
Training on clients [2 0]
Training on clients [1 2 0]
Training on clients [0]
Training on clients [2]
Training on clients [1 2]
After training
Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80
Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00
Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2024年08月21日 UTC.