Build Pytorch Training Job
This section introduces how to customly build a Pytorch training job.
Path
pkg/apis/training.PytorchJobBuilder
Function
func NewPytorchJobBuilder() *PytorchJobBuilder
Parameters
PytorchJobBuilder has following functions to custom your Pytorch training job.
function | description | matches cli option |
---|---|---|
Name(name string) *PytorchJobBuilder | specify the job name | --name |
Command(args []string) *PytorchJobBuilder | specify the job command | - |
WorkingDir(dir string) *PytorchJobBuilder | specify the working dir | --working-dir |
Envs(envs map[string]string) *PytorchJobBuilder | specify the container env | --env |
GPUCount(count int) *PytorchJobBuilder | specify the gpu count of each worker | --gpus |
Image(image string) *PytorchJobBuilder | specify the image | --image |
Tolerations(tolerations []string) *PytorchJobBuilder | specify the k8s node taint tolerations | --toleration |
ConfigFiles(files map[string]string) *PytorchJobBuilder | specify the configuration files | --config-file |
NodeSelectors(selectors map[string]string) *PytorchJobBuilder | specify the node selectors | --selector |
Annotations(annotations map[string]string) *PytorchJobBuilder | specify the instance annotations | --annotation |
Datas(volumes map[string]string) *PytorchJobBuilder | specify the data pvc | --data |
DataDirs(volumes map[string]string) *PytorchJobBuilder | specify host path and its' mapping container path | --data-dir |
LogDir(dir string) *PytorchJobBuilder | specify the log dir | --logdir |
Priority(priority string) *PytorchJobBuilder | specify the priority | --priority |
EnableRDMA() *PytorchJobBuilder | enable rdma | --rdma |
SyncImage(image string) *PytorchJobBuilder | specify the sync image | --sync-image |
SyncMode(mode string) *PytorchJobBuilder | specify the sync mode(rsync,git) | --sync-mode |
SyncSource(source string) *PytorchJobBuilder | specify the code address(eg: git url or rsync url) | --sync-source |
EnableTensorboard() *PytorchJobBuilder | enable tensorboard | --tensorboard |
TensorboardImage(image string) *PytorchJobBuilder | specify the tensorboard image | --tensorboard-image |
ImagePullSecrets(secrets []string) *PytorchJobBuilder | specify the image pull secret | --image-pull-secret |
WorkerCount(count int) *PytorchJobBuilder | specify the worker count | --workers |
CPU(cpu string) *PytorchJobBuilder | specify the cpu limits | --cpu |
Memory(memory string) *PytorchJobBuilder | specify the memory limits | --memory |
CleanPodPolicy(policy string) *PytorchJobBuilder | specify the cleaning pod policy | --clean-task-policy |
Build() (*Job, error) | build the Pytorch training job | - |
Example
package main
import (
"fmt"
"time"
"github.com/kubeflow/arena/pkg/apis/arenaclient"
"github.com/kubeflow/arena/pkg/apis/training"
"github.com/kubeflow/arena/pkg/apis/types"
)
func main() {
jobName := "pytorch-test"
jobType := types.PytorchTrainingJob
// create arena client
client, err := arenaclient.NewArenaClient(types.ArenaClientArgs{
Kubeconfig: "",
LogLevel: "info",
Namespace: "default",
})
if err != nil {
fmt.Printf("failed to create arena client,reason: %v", err)
return
}
// create tfjob
/* command:
arena \
submit \
pytorchjob \
--name=pytorch-standalone-test \
--gpus=1 \
--sync-mode=git \
--tensorboard \
--sync-source=https://code.aliyun.com/370272561/mnist-pytorch.git \
--loglevel debug \
--image=registry.cn-shanghai.aliyuncs.com/ai-samples/pytorch-with-tensorboard:1.5.1-cuda10.1-cudnn7-runtime \
"python /root/code/mnist-pytorch/mnist.py --backend gloo"
*/
submitJob, err := training.NewPytorchJobBuilder().
Name(jobName).
GPUCount(1).
SyncMode("git").
SyncSource("https://code.aliyun.com/370272561/mnist-pytorch.git").
Image("registry.cn-shanghai.aliyuncs.com/ai-samples/pytorch-with-tensorboard:1.5.1-cuda10.1-cudnn7-runtime").
Command([]string{"python /root/code/mnist-pytorch/mnist.py --backend gloo"}).Build()
if err != nil {
fmt.Printf("failed to build pytorchjob,reason: %v\n", err)
return
}
// submit tfjob
if err := client.Training().Submit(submitJob); err != nil {
fmt.Printf("failed to submit job,reason: %v\n", err)
return
}
}