Train a TensorFlow model with Keras on Google Kubernetes Engine
Stay organized with collections
Save and categorize content based on your preferences.
The following section provides an example of fine-tuning a BERT model for sequence classification using the Hugging Face transformers library with TensorFlow. The dataset is downloaded into a mounted Parallelstore-backed volume, allowing the model training to directly read data from the volume.
Prerequisites
- Ensure your node has at least 8 GiB of memory available.
- Create a PersistentVolumeClaim requesting for a Parallelstore-backed volume.
Save the following YAML manifest (parallelstore-csi-job-example.yaml) for your model training Job.
apiVersion:batch/v1
kind:Job
metadata:
name:parallelstore-csi-job-example
spec:
template:
metadata:
annotations:
gke-parallelstore/cpu-limit:"0"
gke-parallelstore/memory-limit:"0"
spec:
securityContext:
runAsUser:1000
runAsGroup:100
fsGroup:100
containers:
-name:tensorflow
image:jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d
command:["bash","-c"]
args:
-|
pip install transformers datasets
python - <<EOF
from datasets import load_dataset
dataset = load_dataset("glue", "cola", cache_dir='/data')
dataset = dataset["train"]
from transformers import AutoTokenizer
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenized_data = tokenizer(dataset["sentence"], return_tensors="np", padding=True)
tokenized_data = dict(tokenized_data)
labels = np.array(dataset["label"])
from transformers import TFAutoModelForSequenceClassification
from tensorflow.keras.optimizers import Adam
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
model.compile(optimizer=Adam(3e-5))
model.fit(tokenized_data, labels)
EOF
volumeMounts:
-name:parallelstore-volume
mountPath:/data
volumes:
-name:parallelstore-volume
persistentVolumeClaim:
claimName:parallelstore-pvc
restartPolicy:Never
backoffLimit:1
Apply the YAML manifest to the cluster.
kubectl apply -f parallelstore-csi-job-example.yaml
Check your data loading and model training progress with the following command:
POD_NAME=$(kubectlgetpod|grep'parallelstore-csi-job-example'|awk'{print 1ドル}')
kubectllogs-f$POD_NAME-ctensorflow