Client-efficient large-model federated learning via federated_select and sparse aggregation
Stay organized with collections
Save and categorize content based on your preferences.
This tutorial shows how TFF can be used to train a very large model where each client device only downloads and updates a small part of the model, using
tff.federated_select and sparse aggregation. While this tutorial is fairly self-contained, the tff.federated_select tutorial and custom FL algorithms tutorial provide good introductions to some of the techniques used here.
Concretely, in this tutorial we consider logistic regression for multi-label classification, predicting which "tags" are associated with a text string based on a bag-of-words feature representation. Importantly, communication and client-side computation costs are controlled by a fixed constant (MAX_TOKENS_SELECTED_PER_CLIENT), and do not scale with the overall vocabulary size, which could be extremely large in practical settings.
pipinstall--quiet--upgradetensorflow-federatedimportcollections
fromcollections.abcimport Callable
importitertools
importnumpyasnp
importtensorflowastf
importtensorflow_federatedastff
Each client will federated_select the rows of the model weights for at most this many unique tokens. This upper-bounds the size of the client's local model and the amount of server -> client (federated_select) and client - > server (federated_aggregate) communication performed.
This tutorial should still run correctly even if you set this as small as 1 (ensuring not all tokens from each client are selected) or to a large value, though model convergence may be effected.
MAX_TOKENS_SELECTED_PER_CLIENT = 6
We also define a few constants for various types. For this colab, a token is an integer identifier for a particular word after parsing the dataset.
#Therearesomeconstraintsontypes
#herethatwillrequiresomeexplicittypeconversions:
#-`tff.federated_select`requiresint32
#-`tf.SparseTensor`requiresint64indices.
TOKEN_DTYPE=np.int64
SELECT_KEY_DTYPE=np.int32
#Typeforcountsoftokenoccurences.
TOKEN_COUNT_DTYPE=np.int32
#Asparsefeaturevectorcanbethoughtofasamap
#fromTOKEN_DTYPEtoFEATURE_DTYPE.
#Ourfeaturesare{0,1}indicators,sowecouldpotentially
#usenp.int8asanoptimization.
FEATURE_DTYPE=np.int32
Setting up the problem: Dataset and Model
We construct a tiny toy dataset for easy experimentation in this tutorial. However, the format of the dataset is compatible with Federated StackOverflow, and the pre-processing and model architecture are adopted from the StackOverflow tag prediction problem of Adaptive Federated Optimization.
Dataset parsing and pre-processing
NUM_OOV_BUCKETS=1
BatchType=collections.namedtuple('BatchType',['tokens', 'tags'])
defbuild_to_ids_fn(word_vocab:list[str],
tag_vocab:list[str])->Callable[[tf.Tensor],tf.Tensor]:
"""Constructs a function mapping examples to sequences of token indices."""
word_table_values=np.arange(len(word_vocab),dtype=np.int64)
word_table=tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(word_vocab,word_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
tag_table_values=np.arange(len(tag_vocab),dtype=np.int64)
tag_table=tf.lookup.StaticVocabularyTable(
tf.lookup.KeyValueTensorInitializer(tag_vocab,tag_table_values),
num_oov_buckets=NUM_OOV_BUCKETS)
defto_ids(example):
"""Converts a Stack Overflow example to a bag-of-words/tags format."""
sentence=tf.strings.join([example['tokens'],example['title']],
separator=' ')
#Werepresentthatlabel(outputtags)densely.
raw_tags=example['tags']
tags=tf.strings.split(raw_tags,sep='|')
tags=tag_table.lookup(tags)
tags,_=tf.unique(tags)
tags=tf.one_hot(tags,len(tag_vocab)+NUM_OOV_BUCKETS)
tags=tf.reduce_max(tags,axis=0)
#WerepresentthefeaturesasaSparseTensorof{0,1}s.
words=tf.strings.split(sentence)
tokens=word_table.lookup(words)
tokens,_=tf.unique(tokens)
#Note:Wecouldchoosetousethewordcountsasthefeaturevector
#insteadofjust{0,1}values(seetf.unique_with_counts).
tokens=tf.reshape(tokens,shape=(tf.size(tokens),1))
tokens_st=tf.SparseTensor(
tokens,
tf.ones(tf.size(tokens),dtype=FEATURE_DTYPE),
dense_shape=(len(word_vocab)+NUM_OOV_BUCKETS,))
tokens_st=tf.sparse.reorder(tokens_st)
returnBatchType(tokens_st,tags)
returnto_ids
defbuild_preprocess_fn(word_vocab,tag_vocab):
@tf.function
defpreprocess_fn(dataset):
to_ids=build_to_ids_fn(word_vocab,tag_vocab)
#We*don't*shuffleinordertomakethiscolabdeterministicfor
#easiertestingandreproducibility.
#Butreal-worldtrainingshoulduse`.shuffle()`.
returndataset.map(to_ids,num_parallel_calls=tf.data.experimental.AUTOTUNE)
returnpreprocess_fn
A tiny toy dataset
We construct a tiny toy dataset with a global vocabulary of 12 words and 3 clients. This tiny example is useful for testing edge cases (for example,
we have two clients with less than MAX_TOKENS_SELECTED_PER_CLIENT = 6 distinct tokens, and one with more) and developing the code.
However, the real-world use cases of this approach would be global vocabularies of 10s of millions or more, with perhaps 1000s of distinct tokens appearing on each client. Because the format of the data is the same, the extension to more realistic testbed problems, e.g. the tff.simulation.datasets.stackoverflow.load_data() dataset, should be straightforward.
First, we define our word and tag vocabularies.
# Features
FRUIT_WORDS = ['apple', 'orange', 'pear', 'kiwi']
VEGETABLE_WORDS = ['carrot', 'broccoli', 'arugula', 'peas']
FISH_WORDS = ['trout', 'tuna', 'cod', 'salmon']
WORD_VOCAB = FRUIT_WORDS + VEGETABLE_WORDS + FISH_WORDS
# Labels
TAG_VOCAB = ['FRUIT', 'VEGETABLE', 'FISH']
Now, we create 3 clients with small local datasets. If you are running this tutorial in colab, it may be useful to use the "mirror cell in tab" feature to pin this cell and its output in order to interpret/check the output of the functions developed below.
preprocess_fn=build_preprocess_fn(WORD_VOCAB,TAG_VOCAB)
defmake_dataset(raw):
d=tf.data.Dataset.from_tensor_slices(
# Matches the StackOverflow formatting
collections.OrderedDict(
tokens=tf.constant([t[0]fortinraw]),
tags=tf.constant([t[1]fortinraw]),
title=[''for_inraw]))
d=preprocess_fn(d)
returnd
# 4 distinct tokens
CLIENT1_DATASET=make_dataset([
('apple orange apple orange','FRUIT'),
('carrot trout','VEGETABLE|FISH'),
('orange apple','FRUIT'),
('orange','ORANGE|CITRUS')# 2 OOV tag
])
# 6 distinct tokens
CLIENT2_DATASET=make_dataset([
('pear cod','FRUIT|FISH'),
('arugula peas','VEGETABLE'),
('kiwi pear','FRUIT'),
('sturgeon','FISH'),# OOV word
('sturgeon bass','FISH')# 2 OOV words
])
# A client with all possible words & tags (13 distinct tokens).
# With MAX_TOKENS_SELECTED_PER_CLIENT = 6, we won't download the model
# slices for all tokens that occur on this client.
CLIENT3_DATASET=make_dataset([
(' '.join(WORD_VOCAB+['oovword']),'|'.join(TAG_VOCAB)),
# Mathe the OOV token and 'salmon' occur in the largest number
# of examples on this client:
('salmon oovword','FISH|OOVTAG')
])
print('Word vocab')
fori,wordinenumerate(WORD_VOCAB):
print(f'{i:2d} {word}')
print('\nTag vocab')
fori,taginenumerate(TAG_VOCAB):
print(f'{i:2d} {tag}')
Word vocab 0 apple 1 orange 2 pear 3 kiwi 4 carrot 5 broccoli 6 arugula 7 peas 8 trout 9 tuna 10 cod 11 salmon Tag vocab 0 FRUIT 1 VEGETABLE 2 FISH
Define constants for the raw numbers of input features (tokens/words) and labels (post tags). Our actual input/output spaces are NUM_OOV_BUCKETS = 1 larger because we add an OOV token / tag.
NUM_WORDS = len(WORD_VOCAB)
NUM_TAGS = len(TAG_VOCAB)
WORD_VOCAB_SIZE = NUM_WORDS + NUM_OOV_BUCKETS
TAG_VOCAB_SIZE = NUM_TAGS + NUM_OOV_BUCKETS
Create batched versions of the datasets, and individual batches, which will be useful in testing code as we go.
batched_dataset1 = CLIENT1_DATASET.batch(2)
batched_dataset2 = CLIENT2_DATASET.batch(3)
batched_dataset3 = CLIENT3_DATASET.batch(2)
batch1 = next(iter(batched_dataset1))
batch2 = next(iter(batched_dataset2))
batch3 = next(iter(batched_dataset3))
Define a model with sparse inputs
We use a simple independent logistic regression model for each tag.
defcreate_logistic_model(word_vocab_size:int,vocab_tags_size:int):
model=tf.keras.models.Sequential([
tf.keras.layers.InputLayer(input_shape=(word_vocab_size,),sparse=True),
tf.keras.layers.Dense(
vocab_tags_size,
activation='sigmoid',
kernel_initializer=tf.keras.initializers.zeros,
#Forsimplicity,don'tuseabiasvector;thismeansthemodel
#isasingletensor,andweonlyneedsparseaggregationof
#theper-tokenslicesofthemodel.Generalizingtoalsohandle
#othermodelweightsthatarefullyupdated
#(non-densebroadcastandaggregate)wouldbeagoodexercise.
use_bias=False),
])
returnmodel
Let's make sure it works, first by making predictions:
model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
p = model.predict(batch1.tokens)
print(p)
[[0.5 0.5 0.5 0.5] [0.5 0.5 0.5 0.5]]
And some simple centralized training:
model.compile(optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.001),
loss=tf.keras.losses.BinaryCrossentropy())
model.train_on_batch(batch1.tokens, batch1.tags)
Building blocks for the federated computation
We will implement a simple version of the Federated Averaging algorithm with the key difference that each device only downloads a relevant subset of the model, and only contributes updates to that subset.
We use M as shorthand for MAX_TOKENS_SELECTED_PER_CLIENT. At a high level, one round of training involves these steps:
Each participating client scans over its local dataset, parsing the input strings and mapping them to the correct tokens (int indexes). This requires access to the global (large) dictionary (this could potentially be avoided using feature hashing techniques). We then sparsely count how many times each token occurs. If
Uunique tokens occur on device, we choose thenum_actual_tokens = min(U, M)most frequent tokens to train.The clients use
federated_selectto retrieve the model coefficients for thenum_actual_tokensselected tokens from the server. Each model slice is a tensor of shape(TAG_VOCAB_SIZE, ), so the total data transmitted to the client is at most of sizeTAG_VOCAB_SIZE * M(see note below).The clients construct a mapping
global_token -> local_tokenwhere the local token (int index) is the index of the global token in the list of selected tokens.The clients use a "small" version of the global model that only has coefficients for at most
Mtokens, from the range[0, num_actual_tokens). Theglobal -> localmapping is used to initialize the dense parameters of this model from the selected model slices.Clients train their local model using SGD on data preprocessed with the
global -> localmapping.Clients turn the parameters of their local model into
IndexedSlicesupdates using thelocal -> globalmapping to index the rows. The server aggregates these updates using a sparse sum aggregation.The server takes the (dense) result of the above aggregation, divides it by the number of clients participating, and applies the resulting average update to the global model.
In this section we construct the building blocks for these steps, which will then be combined in a final federated_computation that captures the full logic of one training round.
Count client tokens and decide which model slices to federated_select
Each device needs to decide which "slices" of the model are relevant to its local training dataset. For our problem, we do this by (sparsely!) counting how many examples contain each token in the client training data set.
@tf.function
deftoken_count_fn(token_counts,batch):
"""Adds counts from `batch` to the running `token_counts` sum."""
# Sum across the batch dimension.
flat_tokens=tf.sparse.reduce_sum(
batch.tokens,axis=0,output_is_sparse=True)
flat_tokens=tf.cast(flat_tokens,dtype=TOKEN_COUNT_DTYPE)
returntf.sparse.add(token_counts,flat_tokens)
#Simpletests
#Createtheinitialzerotokencountsusingemptytensors.
initial_token_counts=tf.SparseTensor(
indices=tf.zeros(shape=(0,1),dtype=TOKEN_DTYPE),
values=tf.zeros(shape=(0,),dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts=batched_dataset1.reduce(initial_token_counts,
token_count_fn)
tokens=tf.reshape(client_token_counts.indices,(-1,)).numpy()
print('tokens:',tokens)
np.testing.assert_array_equal(tokens,[0,1,4,8])
#Thecountisthenumberof*examples*inwhichthetoken/word
#occurs,notthetotalnumberofoccurences,sincewestillfeaturize
#multipleoccurencesinthesameexampleasa"1".
counts=client_token_counts.values.numpy()
print('counts:',counts)
np.testing.assert_array_equal(counts,[2,3,1,1])
tokens: [0 1 4 8] counts: [2 3 1 1]
We will select the model parameters corresponding to the MAX_TOKENS_SELECTED_PER_CLIENT most frequently occuring tokens on device. If
fewer than this many tokens occur on device, we pad the list to enable the use
of federated_select.
Note that other strategies are possibly better, for example, randomly selecting tokens (perhaps based on their occurrence probability). This would ensure that all slices of the model (for which the client has data) have some chance of being updated.
@tf.function
defkeys_for_client(client_dataset,max_tokens_per_client):
"""Computes a set of max_tokens_per_client keys."""
initial_token_counts=tf.SparseTensor(
indices=tf.zeros((0,1),dtype=TOKEN_DTYPE),
values=tf.zeros((0,),dtype=TOKEN_COUNT_DTYPE),
dense_shape=(WORD_VOCAB_SIZE,))
client_token_counts=client_dataset.reduce(initial_token_counts,
token_count_fn)
# Find the most-frequently occuring tokens
tokens=tf.reshape(client_token_counts.indices,shape=(-1,))
counts=client_token_counts.values
perm=tf.argsort(counts,direction='DESCENDING')
tokens=tf.gather(tokens,perm)
counts=tf.gather(counts,perm)
num_raw_tokens=tf.shape(tokens)[0]
actual_num_tokens=tf.minimum(max_tokens_per_client,num_raw_tokens)
selected_tokens=tokens[:actual_num_tokens]
paddings=[[0,max_tokens_per_client-tf.shape(selected_tokens)[0]]]
padded_tokens=tf.pad(selected_tokens,paddings=paddings)
# Make sure the type is statically determined
padded_tokens=tf.reshape(padded_tokens,shape=(max_tokens_per_client,))
# We will pass these tokens as keys into `federated_select`, which
# requires SELECT_KEY_DTYPE=np.int32 keys.
padded_tokens=tf.cast(padded_tokens,dtype=SELECT_KEY_DTYPE)
returnpadded_tokens,actual_num_tokens
#Simpletest
#Case1:actual_num_tokens > max_tokens_per_client
selected_tokens,actual_num_tokens=keys_for_client(batched_dataset1,3)
asserttf.size(selected_tokens)==3
assertactual_num_tokens==3
#Case2:actual_num_tokens < max_tokens_per_client
selected_tokens,actual_num_tokens=keys_for_client(batched_dataset1,10)
asserttf.size(selected_tokens)==10
assertactual_num_tokens==4
Map global tokens to local tokens
The above selection gives us a dense set of tokens in the range [0, actual_num_tokens) which we will use for the on-device model. However, the dataset we read has tokens from the much larger global vocabulary range [0, WORD_VOCAB_SIZE).
Thus, we need to map the global tokens to their corresponding local tokens. The
local token ids are simply given by the indexes into the selected_tokens tensor computed in the previous step.
@tf.function
defmap_to_local_token_ids(client_data,client_keys):
global_to_local=tf.lookup.StaticHashTable(
#Noteint32->int64mapsarenotsupported
tf.lookup.KeyValueTensorInitializer(
keys=tf.cast(client_keys,dtype=TOKEN_DTYPE),
#Noteweneedtousetf.shape,notthestatic
#shapeclient_keys.shape[0]
values=tf.range(0,limit=tf.shape(client_keys)[0],
dtype=TOKEN_DTYPE)),
#Weuse-1fortokensthatwerenotselected,whichcanoccurforclients
#withmorethanMAX_TOKENS_SELECTED_PER_CLIENTdistincttokens.
#Wewillsimplyremovetheseinvalidindicesfromthebatchbelow.
default_value=-1)
defto_local_ids(sparse_tokens):
indices_t=tf.transpose(sparse_tokens.indices)
batch_indices=indices_t[0]#Firstcolumn
tokens=indices_t[1]#Secondcolumn
tokens=tf.map_fn(
lambdaglobal_token_id:global_to_local.lookup(global_token_id),tokens)
#Removetokensthataren'tactuallyavailable(lookedupas-1):
available_tokens=tokens>=0
tokens=tokens[available_tokens]
batch_indices=batch_indices[available_tokens]
updated_indices=tf.transpose(
tf.concat([[batch_indices],[tokens]],axis=0))
st=tf.sparse.SparseTensor(
updated_indices,
tf.ones(tf.size(tokens),dtype=FEATURE_DTYPE),
#EachclienthasatmostMAX_TOKENS_SELECTED_PER_CLIENTdistincttokens.
dense_shape=[sparse_tokens.dense_shape[0],MAX_TOKENS_SELECTED_PER_CLIENT])
st=tf.sparse.reorder(st)
returnst
returnclient_data.map(lambdab:BatchType(to_local_ids(b.tokens),b.tags))
#Simpletest
client_keys,actual_num_tokens=keys_for_client(
batched_dataset3,MAX_TOKENS_SELECTED_PER_CLIENT)
client_keys=client_keys[:actual_num_tokens]
d=map_to_local_token_ids(batched_dataset3,client_keys)
batch=next(iter(d))
all_tokens=tf.gather(batch.tokens.indices,indices=1,axis=1)
#Confirmwehavelocalindicesintherange[0,MAX):
asserttf.math.reduce_max(all_tokens) < MAX_TOKENS_SELECTED_PER_CLIENT
asserttf.math.reduce_max(all_tokens)>=0
Train the local (sub)model on each client
Note federated_select will return the selected slices as a tf.data.Dataset in the same order as the selection keys. So, we first define a utility function to take such a Dataset and convert it to a single dense tensor which can be used as the model weights of the client model.
@tf.function
defslices_dataset_to_tensor(slices_dataset):
"""Convert a dataset of slices to a tensor."""
#Usebatchingtogatheralloftheslicesintoasingletensor.
d=slices_dataset.batch(MAX_TOKENS_SELECTED_PER_CLIENT,
drop_remainder=False)
iter_d=iter(d)
tensor=next(iter_d)
#Makesurewehaveconsumedeverything
opt=iter_d.get_next_as_optional()
tf.Assert(tf.logical_not(opt.has_value()),data=[''],name='CHECK_EMPTY')
returntensor
#Simpletest
weights=np.random.random(
size=(MAX_TOKENS_SELECTED_PER_CLIENT,TAG_VOCAB_SIZE)).astype(np.float32)
model_slices_as_dataset=tf.data.Dataset.from_tensor_slices(weights)
weights2=slices_dataset_to_tensor(model_slices_as_dataset)
np.testing.assert_array_equal(weights,weights2)
We now have all the components we need to define a simple local training loop which will run on each client.
@tf.function
defclient_train_fn(model,client_optimizer,
model_slices_as_dataset,client_data,
client_keys,actual_num_tokens):
initial_model_weights=slices_dataset_to_tensor(model_slices_as_dataset)
assertlen(model.trainable_variables)==1
model.trainable_variables[0].assign(initial_model_weights)
# Only keep the "real" (unpadded) keys.
client_keys=client_keys[:actual_num_tokens]
client_data=map_to_local_token_ids(client_data,client_keys)
loss_fn=tf.keras.losses.BinaryCrossentropy()
forfeatures,labelsinclient_data:
withtf.GradientTape()astape:
predictions=model(features)
loss=loss_fn(labels,predictions)
grads=tape.gradient(loss,model.trainable_variables)
client_optimizer.apply_gradients(zip(grads,model.trainable_variables))
model_weights_delta=model.trainable_weights[0]-initial_model_weights
model_weights_delta=tf.slice(model_weights_delta,begin=[0,0],
size=[actual_num_tokens,-1])
returnclient_keys,model_weights_delta
#Simpletest
#Noteifyouexecutethiscellasecondtime,youneedtoalsore-execute
#thepreceedingcelltoavoid"tf.function-decorated function tried to
# create variables on non-first call"errors.
on_device_model=create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer=tf.keras.optimizers.SGD(learning_rate=0.001)
client_keys,actual_num_tokens=keys_for_client(
batched_dataset2,MAX_TOKENS_SELECTED_PER_CLIENT)
model_slices_as_dataset=tf.data.Dataset.from_tensor_slices(
np.zeros((MAX_TOKENS_SELECTED_PER_CLIENT,TAG_VOCAB_SIZE),
dtype=np.float32))
keys,delta=client_train_fn(
on_device_model,
client_optimizer,
model_slices_as_dataset,
client_data=batched_dataset3,
client_keys=client_keys,
actual_num_tokens=actual_num_tokens)
print(delta)
Aggregate IndexedSlices
We use tff.federated_aggregate to construct a federated sparse sum for IndexedSlices. This simple implementation has the constraint that the
dense_shape is known statically in advance. Note also that this sum is only semi-sparse, in the sense that the client -> server communication is sparse, but the server maintains a dense representation of the sum in accumulate and merge, and outputs this dense representation.
deffederated_indexed_slices_sum(slice_indices,slice_values,dense_shape):
"""
Sums IndexedSlices@CLIENTS to a dense @SERVER Tensor.
Intermediate aggregation is performed by converting to a dense representation,
which may not be suitable for all applications.
Args:
slice_indices: An IndexedSlices.indices tensor @CLIENTS.
slice_values: An IndexedSlices.values tensor @CLIENTS.
dense_shape: A statically known dense shape.
Returns:
A dense tensor placed @SERVER representing the sum of the client's
IndexedSclies.
"""
slices_dtype=slice_values.type_signature.member.dtype
zero=tff.tensorflow.computation(
lambda:tf.zeros(dense_shape,dtype=slices_dtype))()
@tf.function
defaccumulate_slices(dense,client_value):
indices,slices=client_value
# There is no built-in way to add `IndexedSlices`, but
# tf.convert_to_tensor is a quick way to convert to a dense representation
# so we can add them.
returndense+tf.convert_to_tensor(
tf.IndexedSlices(slices,indices,dense_shape))
returntff.federated_aggregate(
(slice_indices,slice_values),
zero=zero,
accumulate=tff.tensorflow.computation(accumulate_slices),
merge=tff.tensorflow.computation(lambdad1,d2:tf.add(d1,d2,name='merge')),
report=tff.tensorflow.computation(lambdad:d))
Construct a minimal federated_computation as a test
dense_shape=(6,2)
indices_type=tff.TensorType(np.int64,(None,))
values_type=tff.TensorType(np.float32,(None,2))
client_slice_type=tff.FederatedType(
(indices_type,values_type),tff.CLIENTS)
@tff.federated_computation(client_slice_type)
deftest_sum_indexed_slices(indices_values_at_client):
indices,values=indices_values_at_client
returnfederated_indexed_slices_sum(indices,values,dense_shape)
print(test_sum_indexed_slices.type_signature)
({<int64[?],float32[?,2]>}@CLIENTS -> float32[6,2]@SERVER)
x=tf.IndexedSlices(
values=np.array([[2.,2.1],[0.,0.1],[1.,1.1],[5.,5.1]],
dtype=np.float32),
indices=[2,0,1,5],
dense_shape=dense_shape)
y=tf.IndexedSlices(
values=np.array([[0.,0.3],[3.1,3.2]],dtype=np.float32),
indices=[1,3],
dense_shape=dense_shape)
#Sumone.
result=test_sum_indexed_slices([(x.indices,x.values)])
np.testing.assert_array_equal(tf.convert_to_tensor(x),result)
#Sumtwo.
expected=[[0.,0.1],[1.,1.4],[2.,2.1],[3.1,3.2],[0.,0.],[5.,5.1]]
result=test_sum_indexed_slices([(x.indices,x.values),(y.indices,y.values)])
np.testing.assert_array_almost_equal(expected,result)
Putting it all together in a federated_computation
We now use TFF to bind together the components into a tff.federated_computation.
DENSE_MODEL_SHAPE=(WORD_VOCAB_SIZE,TAG_VOCAB_SIZE)
client_data_type=tff.SequenceType(batched_dataset1.element_spec)
model_type=tff.TensorType(np.float32,shape=DENSE_MODEL_SHAPE)
We use a basic server training function based on Federated Averaging, applying the update with a server learning rate of 1.0. It is important that we apply an update (delta) to the model, rather than simply averaging client-supplied models, as otherwise if a given slice of the model wasn't trained on by any client on a given round its coefficients could be zeroed out.
@tff.tensorflow.computation
defserver_update(current_model_weights,update_sum,num_clients):
average_update=update_sum/num_clients
returncurrent_model_weights+average_update
We need a couple more tff.tensorflow.computation components:
# Function to select slices from the model weights in federated_select:
select_fn=tff.tensorflow.computation(
lambdamodel_weights,index:tf.gather(model_weights,index))
# We need to wrap `client_train_fn` as a `tff.tensorflow.computation`, making
# sure we do any operations that might construct `tf.Variable`s outside
# of the `tf.function` we are wrapping.
@tff.tensorflow.computation
defclient_train_fn_tff(model_slices_as_dataset,client_data,client_keys,
actual_num_tokens):
# Note this is amaller than the global model, using
# MAX_TOKENS_SELECTED_PER_CLIENT which is much smaller than WORD_VOCAB_SIZE.
# We would like a model of size `actual_num_tokens`, but we
# can't build the model dynamically, so we will slice off the padded
# weights at the end.
client_model=create_logistic_model(MAX_TOKENS_SELECTED_PER_CLIENT,
TAG_VOCAB_SIZE)
client_optimizer=tf.keras.optimizers.SGD(learning_rate=0.1)
returnclient_train_fn(client_model,client_optimizer,
model_slices_as_dataset,client_data,client_keys,
actual_num_tokens)
@tff.tensorflow.computation
defkeys_for_client_tff(client_data):
returnkeys_for_client(client_data,MAX_TOKENS_SELECTED_PER_CLIENT)
We're now ready to put all the pieces together!
@tff.federated_computation(
tff.FederatedType(model_type,tff.SERVER),tff.FederatedType(client_data_type,tff.CLIENTS))
defsparse_model_update(server_model,client_data):
max_tokens=tff.federated_value(MAX_TOKENS_SELECTED_PER_CLIENT,tff.SERVER)
keys_at_clients,actual_num_tokens=tff.federated_map(
keys_for_client_tff,client_data)
model_slices=tff.federated_select(keys_at_clients,max_tokens,server_model,
select_fn)
update_keys,update_slices=tff.federated_map(
client_train_fn_tff,
(model_slices,client_data,keys_at_clients,actual_num_tokens))
dense_update_sum=federated_indexed_slices_sum(update_keys,update_slices,
DENSE_MODEL_SHAPE)
num_clients=tff.federated_sum(tff.federated_value(1.0,tff.CLIENTS))
updated_server_model=tff.federated_map(
server_update,(server_model,dense_update_sum,num_clients))
returnupdated_server_model
print(sparse_model_update.type_signature)
(<server_model=float32[13,4]@SERVER,client_data={<tokens=<indices=int64[?,2],values=int32[?],dense_shape=int64[2]>,tags=float32[?,4]>*}@CLIENTS> -> float32[13,4]@SERVER)
Let's train a model!
Now that we have our training function, let's try it out.
server_model = create_logistic_model(WORD_VOCAB_SIZE, TAG_VOCAB_SIZE)
server_model.compile( # Compile to make evaluation easy.
optimizer=tf.keras.optimizers.Adagrad(learning_rate=0.0), # Unused
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=[
tf.keras.metrics.Precision(name='precision'),
tf.keras.metrics.AUC(name='auc'),
tf.keras.metrics.Recall(top_k=2, name='recall_at_2'),
])
def evaluate(model, dataset, name):
metrics = model.evaluate(dataset, verbose=0)
metrics_str = ', '.join([f'{k}={v:.2f}' for k, v in
(zip(server_model.metrics_names, metrics))])
print(f'{name}: {metrics_str}')
print('Before training')
evaluate(server_model,batched_dataset1,'Client 1')
evaluate(server_model,batched_dataset2,'Client 2')
evaluate(server_model,batched_dataset3,'Client 3')
model_weights=server_model.trainable_weights[0]
client_datasets=[batched_dataset1, batched_dataset2, batched_dataset3]
for_inrange(10):#Run10roundsofFedAvg
#Wetrainon1,2,or3clientsperround,selecting
#randomly.
cohort_size=np.random.randint(1,4)
clients=np.random.choice([0, 1, 2],cohort_size,replace=False)
print('Training on clients',clients)
model_weights=sparse_model_update(
model_weights,[client_datasets[i]foriinclients])
server_model.set_weights([model_weights])
print('After training')
evaluate(server_model,batched_dataset1,'Client 1')
evaluate(server_model,batched_dataset2,'Client 2')
evaluate(server_model,batched_dataset3,'Client 3')
Before training Client 1: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.60 Client 2: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.50 Client 3: loss=0.69, precision=0.00, auc=0.50, recall_at_2=0.40 Training on clients [0 1] Training on clients [0 2 1] Training on clients [2 0] Training on clients [1 0 2] Training on clients [2] Training on clients [2 0] Training on clients [1 2 0] Training on clients [0] Training on clients [2] Training on clients [1 2] After training Client 1: loss=0.67, precision=0.80, auc=0.91, recall_at_2=0.80 Client 2: loss=0.68, precision=0.67, auc=0.96, recall_at_2=1.00 Client 3: loss=0.65, precision=1.00, auc=0.93, recall_at_2=0.80