Train a TensorFlow model with Keras on Google Kubernetes Engine

The following section provides an example of fine-tuning a BERT model for sequence classification using the Hugging Face transformers library with TensorFlow. The dataset is downloaded into a mounted Parallelstore-backed volume, allowing the model training to directly read data from the volume.

Prerequisites

Save the following YAML manifest (parallelstore-csi-job-example.yaml) for your model training Job.

apiVersion:batch/v1
kind:Job
metadata:
name:parallelstore-csi-job-example
spec:
template:
metadata:
annotations:
gke-parallelstore/cpu-limit:"0"
gke-parallelstore/memory-limit:"0"
spec:
securityContext:
runAsUser:1000
runAsGroup:100
fsGroup:100
containers:
-name:tensorflow
image:jupyter/tensorflow-notebook@sha256:173f124f638efe870bb2b535e01a76a80a95217e66ed00751058c51c09d6d85d
command:["bash","-c"]
args:
-|
pip install transformers datasets
python - <<EOF
from datasets import load_dataset
dataset = load_dataset("glue", "cola", cache_dir='/data')
dataset = dataset["train"]
from transformers import AutoTokenizer
import numpy as np
tokenizer = AutoTokenizer.from_pretrained("bert-base-cased")
tokenized_data = tokenizer(dataset["sentence"], return_tensors="np", padding=True)
tokenized_data = dict(tokenized_data)
labels = np.array(dataset["label"])
from transformers import TFAutoModelForSequenceClassification
from tensorflow.keras.optimizers import Adam
model = TFAutoModelForSequenceClassification.from_pretrained("bert-base-cased")
model.compile(optimizer=Adam(3e-5))
model.fit(tokenized_data, labels)
EOF
volumeMounts:
-name:parallelstore-volume
mountPath:/data
volumes:
-name:parallelstore-volume
persistentVolumeClaim:
claimName:parallelstore-pvc
restartPolicy:Never
backoffLimit:1

Apply the YAML manifest to the cluster.

kubectl apply -f parallelstore-csi-job-example.yaml

Check your data loading and model training progress with the following command:

POD_NAME=$(kubectlgetpod|grep'parallelstore-csi-job-example'|awk'{print 1ドル}')
kubectllogs-f$POD_NAME-ctensorflow

Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.

Last updated 2026年01月02日 UTC.