diff --git a/acceptance/experimental/air/run/invalid.yaml b/acceptance/experimental/air/run/invalid.yaml new file mode 100644 index 0000000000..c011fc81b3 --- /dev/null +++ b/acceptance/experimental/air/run/invalid.yaml @@ -0,0 +1,5 @@ +experiment_name: bad.name +command: x +compute: + accelerator_type: GPU_8xH100 + num_accelerators: 3 diff --git a/acceptance/experimental/air/run/out.test.toml b/acceptance/experimental/air/run/out.test.toml new file mode 100644 index 0000000000..d6187dcb04 --- /dev/null +++ b/acceptance/experimental/air/run/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/output.txt b/acceptance/experimental/air/run/output.txt new file mode 100644 index 0000000000..180886290b --- /dev/null +++ b/acceptance/experimental/air/run/output.txt @@ -0,0 +1,39 @@ + +=== dry-run (text) +>>> [CLI] experimental air run -f valid.yaml --dry-run +Dry run: configuration for "smoke-test" is valid; not submitting. + +=== dry-run (json) +>>> [CLI] experimental air run -f valid.yaml --dry-run -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "status": "DRY_RUN_OK", + "dry_run": true + } +} + +=== override not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --override a=b +Error: --override is not yet supported + +Exit code: 1 + +=== watch not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --watch +Error: --watch is not yet supported + +Exit code: 1 + +=== invalid config is rejected +>>> [CLI] experimental air run -f invalid.yaml --dry-run +Error: invalid experiment_name "bad.name": only alphanumeric characters, hyphens (-), and underscores (_) are allowed + +Exit code: 1 + +=== missing --file +>>> [CLI] experimental air run --dry-run +Error: required flag(s) "file" not set + +Exit code: 1 diff --git a/acceptance/experimental/air/run/script b/acceptance/experimental/air/run/script new file mode 100644 index 0000000000..806bd321e6 --- /dev/null +++ b/acceptance/experimental/air/run/script @@ -0,0 +1,17 @@ +title "dry-run (text)" +trace $CLI experimental air run -f valid.yaml --dry-run + +title "dry-run (json)" +trace $CLI experimental air run -f valid.yaml --dry-run -o json + +title "override not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --override a=b + +title "watch not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --watch + +title "invalid config is rejected" +errcode trace $CLI experimental air run -f invalid.yaml --dry-run + +title "missing --file" +errcode trace $CLI experimental air run --dry-run diff --git a/acceptance/experimental/air/run/test.toml b/acceptance/experimental/air/run/test.toml new file mode 100644 index 0000000000..2f971c3ed2 --- /dev/null +++ b/acceptance/experimental/air/run/test.toml @@ -0,0 +1,4 @@ +# `air run --dry-run` validates the config locally and makes no workspace calls, +# so no engine matrix or server stubs are needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/valid.yaml b/acceptance/experimental/air/run/valid.yaml new file mode 100644 index 0000000000..b82a321b05 --- /dev/null +++ b/acceptance/experimental/air/run/valid.yaml @@ -0,0 +1,5 @@ +experiment_name: smoke-test +command: python train.py +compute: + accelerator_type: GPU_1xH100 + num_accelerators: 1 diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 66ddc34d58..21c3c891af 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -1,10 +1,4 @@ -=== run ->>> [CLI] experimental air run -Error: `air run` is not implemented yet - -Exit code: 1 - === logs >>> [CLI] experimental air logs 123 Error: `air logs` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index d00d045368..4c53586b16 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -1,8 +1,5 @@ # Each stub must fail with "not implemented"; errcode records the exit code. -title "run" -errcode trace $CLI experimental air run - title "logs" errcode trace $CLI experimental air logs 123 diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go index 0bc3d1fd94..bd32810e9b 100644 --- a/experimental/air/cmd/run.go +++ b/experimental/air/cmd/run.go @@ -1,10 +1,25 @@ package aircmd import ( + "errors" + "fmt" + "strconv" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) +// runResult is the JSON payload for `air run`. +type runResult struct { + Status string `json:"status"` + DryRun bool `json:"dry_run,omitempty"` + RunID string `json:"run_id,omitempty"` + DashboardURL string `json:"dashboard_url,omitempty"` +} + func newRunCommand() *cobra.Command { var ( file string @@ -21,9 +36,6 @@ func newRunCommand() *cobra.Command { Long: `Submit a training workload to Databricks serverless GPU compute. The workload is described by a YAML config file (see --file).`, - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("run") - }, } cmd.Flags().StringVarP(&file, "file", "f", "", "Path to the workload YAML config") @@ -31,6 +43,56 @@ The workload is described by a YAML config file (see --file).`, cmd.Flags().StringArrayVar(&overrides, "override", nil, "Override a YAML field, e.g. compute.num_accelerators=8 (repeatable)") cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Validate the config without submitting") cmd.Flags().StringVar(&idempotencyKey, "idempotency-key", "", "Return the existing run if this key was already used") + _ = cmd.MarkFlagRequired("file") + + // --dry-run only validates the config locally, so it needs no workspace. + // Submission requires an authenticated client. + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if dryRun { + return nil + } + return root.MustWorkspaceClient(cmd, args) + } + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + // These flags' pipelines are not ported yet; reject rather than silently + // ignore them. + if len(overrides) > 0 { + return errors.New("--override is not yet supported") + } + if watch { + return errors.New("--watch is not yet supported") + } + + cfg, err := loadRunConfig(file) + if err != nil { + return err + } + + if dryRun { + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, fmt.Sprintf("Dry run: configuration for %q is valid; not submitting.", cfg.ExperimentName)) + return nil + } + return renderEnvelope(ctx, runResult{Status: "DRY_RUN_OK", DryRun: true}) + } + + w := cmdctx.WorkspaceClient(ctx) + runID, dashboardURL, err := submitWorkload(ctx, w, cfg, file, idempotencyKey) + if err != nil { + return err + } + + runIDStr := strconv.FormatInt(runID, 10) + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, "Submitted run "+runIDStr) + cmdio.LogString(ctx, "View at: "+dashboardURL) + return nil + } + return renderEnvelope(ctx, runResult{Status: "SUBMITTED", RunID: runIDStr, DashboardURL: dashboardURL}) + } return cmd } diff --git a/experimental/air/cmd/runconfig_launch.go b/experimental/air/cmd/runconfig_launch.go new file mode 100644 index 0000000000..6864e4c454 --- /dev/null +++ b/experimental/air/cmd/runconfig_launch.go @@ -0,0 +1,62 @@ +package aircmd + +// This file flattens the validated runConfig schema into the derived values the +// launch path consumes, replacing the Python CLI's _convert_to_run_config step. +// There is no separate internal config type: handle_run reads runConfig directly, +// using these accessors for the values that need computing rather than a plain +// field read. + +const defaultMaxRetries = 3 + +// timeoutSeconds converts timeout_minutes to seconds. Zero means the user set no +// timeout and the backend default applies. +func (c *runConfig) timeoutSeconds() int { + if c.TimeoutMinutes == nil { + return 0 + } + return *c.TimeoutMinutes * 60 +} + +// maxRetries returns the retry count, applying the schema default when unset. +func (c *runConfig) maxRetries() int { + if c.MaxRetries == nil { + return defaultMaxRetries + } + return *c.MaxRetries +} + +// dockerImageURL returns the custom docker image URL, or "" when none is set. +func (c *runConfig) dockerImageURL() string { + if c.Environment != nil && c.Environment.DockerImage != nil { + return c.Environment.DockerImage.URL + } + return "" +} + +// requirementsFile returns the path to a requirements file when +// environment.dependencies is a string, and whether it was set. +func (c *runConfig) requirementsFile() (string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || c.Environment.Dependencies.isList { + return "", false + } + return c.Environment.Dependencies.path, true +} + +// inlineDependencies returns the inline package list when +// environment.dependencies is a list, and whether it was set. +func (c *runConfig) inlineDependencies() ([]string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || !c.Environment.Dependencies.isList { + return nil, false + } + return c.Environment.Dependencies.list, true +} + +// runtimeVersion returns the client image version from environment.version when +// set. For a requirements-file dependency set, the version lives in that file and +// is resolved at launch, not here. +func (c *runConfig) runtimeVersion() (string, bool) { + if c.Environment == nil || !c.Environment.Version.set { + return "", false + } + return c.Environment.Version.raw, true +} diff --git a/experimental/air/cmd/runconfig_launch_test.go b/experimental/air/cmd/runconfig_launch_test.go new file mode 100644 index 0000000000..289db91c7d --- /dev/null +++ b/experimental/air/cmd/runconfig_launch_test.go @@ -0,0 +1,80 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunConfigTimeoutSeconds(t *testing.T) { + c := &runConfig{} + assert.Equal(t, 0, c.timeoutSeconds()) + + c.TimeoutMinutes = new(2) + assert.Equal(t, 120, c.timeoutSeconds()) +} + +func TestRunConfigMaxRetries(t *testing.T) { + c := &runConfig{} + assert.Equal(t, defaultMaxRetries, c.maxRetries()) + + c.MaxRetries = new(0) + assert.Equal(t, 0, c.maxRetries()) + + c.MaxRetries = new(7) + assert.Equal(t, 7, c.maxRetries()) +} + +func TestRunConfigDockerImageURL(t *testing.T) { + c := &runConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment = &environmentConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment.DockerImage = &dockerImageConfig{URL: "org/repo:tag"} + assert.Equal(t, "org/repo:tag", c.dockerImageURL()) +} + +func TestRunConfigDependencies(t *testing.T) { + t.Run("unset", func(t *testing.T) { + c := &runConfig{} + _, ok := c.requirementsFile() + assert.False(t, ok) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("file path", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: false, path: "req.yaml"}, + }} + path, ok := c.requirementsFile() + assert.True(t, ok) + assert.Equal(t, "req.yaml", path) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("inline list", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + }} + list, ok := c.inlineDependencies() + assert.True(t, ok) + assert.Equal(t, []string{"torch", "numpy"}, list) + _, ok = c.requirementsFile() + assert.False(t, ok) + }) +} + +func TestRunConfigRuntimeVersion(t *testing.T) { + c := &runConfig{} + _, ok := c.runtimeVersion() + assert.False(t, ok) + + c.Environment = &environmentConfig{Version: stringOrInt{set: true, raw: "5"}} + v, ok := c.runtimeVersion() + assert.True(t, ok) + assert.Equal(t, "5", v) +} diff --git a/experimental/air/cmd/runlaunch.go b/experimental/air/cmd/runlaunch.go new file mode 100644 index 0000000000..b2a7215e66 --- /dev/null +++ b/experimental/air/cmd/runlaunch.go @@ -0,0 +1,73 @@ +package aircmd + +import ( + "context" + "errors" + "fmt" + "path" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/google/uuid" +) + +// userWorkspaceDirEnv overrides the per-user workspace directory; mirrors the +// Python CLI's DATABRICKS_INTERNAL_USER_WORKSPACE_DIR escape hatch. +const userWorkspaceDirEnv = "DATABRICKS_INTERNAL_USER_WORKSPACE_DIR" + +// currentUserEmail returns the authenticated user's email (works for any domain). +func currentUserEmail(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + me, err := w.CurrentUser.Me(ctx, iam.MeRequest{}) + if err != nil { + return "", fmt.Errorf("failed to resolve current user: %w", err) + } + return me.UserName, nil +} + +// userWorkspaceDir returns the user's workspace home, honoring the env override. +func userWorkspaceDir(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + if override := env.Get(ctx, userWorkspaceDirEnv); override != "" { + return override, nil + } + email, err := currentUserEmail(ctx, w) + if err != nil { + return "", err + } + return "/Workspace/Users/" + email, nil +} + +// cliLaunchDir returns a unique workspace directory for a run's launch artifacts: +// /.air/cli_launch//_. run defaults to experiment. +func cliLaunchDir(base, experiment, run string) string { + if run == "" { + run = experiment + } + unique := strings.ReplaceAll(uuid.NewString(), "-", "")[:16] + return path.Join(base, ".air", "cli_launch", experiment, run+"_"+unique) +} + +// ensureExperimentDirectory creates experimentDir if it is missing, matching the +// CLI's convention for its other artifact directories. Without this, a missing +// parent surfaces only as a server-side INTERNAL_ERROR after the run is wasted. +// An empty dir means the default (/Users//...), which always exists. +func ensureExperimentDirectory(ctx context.Context, w *databricks.WorkspaceClient, experimentDir string) error { + if experimentDir == "" { + return nil + } + + info, err := w.Workspace.GetStatusByPath(ctx, experimentDir) + if errors.Is(err, apierr.ErrNotFound) { + return w.Workspace.MkdirsByPath(ctx, experimentDir) + } + if err != nil { + return fmt.Errorf("failed to check experiment_directory %q: %w", experimentDir, err) + } + if info.ObjectType != workspace.ObjectTypeDirectory { + return fmt.Errorf("experiment_directory %q is not a directory (object_type=%s)", experimentDir, info.ObjectType) + } + return nil +} diff --git a/experimental/air/cmd/runlaunch_test.go b/experimental/air/cmd/runlaunch_test.go new file mode 100644 index 0000000000..af6f0f70d3 --- /dev/null +++ b/experimental/air/cmd/runlaunch_test.go @@ -0,0 +1,65 @@ +package aircmd + +import ( + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCliLaunchDir(t *testing.T) { + dir := cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "") + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/me@example.com/.air/cli_launch/my-exp/my-exp_"), dir) + // run name overrides the leaf; the unique suffix keeps successive dirs distinct. + withRun := cliLaunchDir("/base", "exp", "run1") + assert.True(t, strings.HasPrefix(withRun, "/base/.air/cli_launch/exp/run1_"), withRun) + assert.NotEqual(t, dir, cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "")) +} + +func newFakeWorkspaceClient(t *testing.T) *databricks.WorkspaceClient { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + return w +} + +func TestUserWorkspaceDir(t *testing.T) { + w := newFakeWorkspaceClient(t) + dir, err := userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/"), dir) + + // The env override wins without an API call. + t.Setenv(userWorkspaceDirEnv, "/Workspace/custom") + dir, err = userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.Equal(t, "/Workspace/custom", dir) +} + +func TestEnsureExperimentDirectory(t *testing.T) { + ctx := t.Context() + w := newFakeWorkspaceClient(t) + + // Empty means default (always exists) — no API call, no error. + require.NoError(t, ensureExperimentDirectory(ctx, w, "")) + + // A missing path is created. + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/exp")) + + // An existing directory is accepted as-is. + require.NoError(t, w.Workspace.MkdirsByPath(ctx, "/Workspace/Users/me/existing")) + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/existing")) + + // A path that exists but is a file is rejected. + fc, err := filer.NewWorkspaceFilesClient(w, "/Workspace/Users/me") + require.NoError(t, err) + require.NoError(t, fc.Write(ctx, "afile", strings.NewReader("x"))) + err = ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/afile") + require.ErrorContains(t, err, "is not a directory") +} diff --git a/experimental/air/cmd/runsubmit.go b/experimental/air/cmd/runsubmit.go new file mode 100644 index 0000000000..08f7c99399 --- /dev/null +++ b/experimental/air/cmd/runsubmit.go @@ -0,0 +1,244 @@ +package aircmd + +import ( + "context" + "errors" + "net/http" + "path" + "strconv" + "strings" + + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/google/uuid" +) + +// jobsRunsSubmitPath is the Jobs one-time-run endpoint. air builds the full +// payload and POSTs it here directly — the native ai_runtime_task is not modeled +// by the typed SDK, and we want no genai-mapi forwarding. +const jobsRunsSubmitPath = "/api/2.2/jobs/runs/submit" + +// dlRuntimeImageEnv overrides the default deep-learning runtime image. +const dlRuntimeImageEnv = "DATABRICKS_DL_RUNTIME_IMAGE" + +const defaultDlRuntimeImage = "CLIENT-GPU-4" + +// aiRuntimeEnvironmentKey ties the task to the serverless environment that +// carries the runtime channel. +const aiRuntimeEnvironmentKey = "default" + +// aiRuntimeCompute is a deployment's accelerator request. +type aiRuntimeCompute struct { + AcceleratorType string `json:"accelerator_type"` + AcceleratorCount int `json:"accelerator_count"` +} + +// aiRuntimeDeployment is one worker deployment of the run. +type aiRuntimeDeployment struct { + CommandPath string `json:"command_path"` + Compute aiRuntimeCompute `json:"compute"` +} + +// aiRuntimeTask is the native AI Runtime task. It routes straight to the training +// service — no genai-mapi forwarding. The proto is lean: env vars, secrets, +// requirements, and hyperparameters are staged as workspace files co-located with +// command.sh (see runupload.go), not carried inline. +type aiRuntimeTask struct { + Experiment string `json:"experiment"` + Deployments []aiRuntimeDeployment `json:"deployments"` + MlflowRun string `json:"mlflow_run,omitempty"` + MlflowExperimentDirectory string `json:"mlflow_experiment_directory,omitempty"` +} + +// environmentSpec carries the bare runtime channel ("4", "5", ...). +type environmentSpec struct { + EnvironmentVersion string `json:"environment_version"` +} + +// jobEnvironment is the serverless environment a task references for its runtime. +type jobEnvironment struct { + EnvironmentKey string `json:"environment_key"` + Spec environmentSpec `json:"spec"` +} + +// submitTask is the single task air submits: a native ai_runtime_task. +type submitTask struct { + TaskKey string `json:"task_key"` + RunIf string `json:"run_if"` + AiRuntimeTask aiRuntimeTask `json:"ai_runtime_task"` + EnvironmentKey string `json:"environment_key"` + MaxRetries int `json:"max_retries,omitempty"` + RetryOnTimeout bool `json:"retry_on_timeout,omitempty"` +} + +// jobsSubmitRun is the Jobs runs/submit payload. +type jobsSubmitRun struct { + RunName string `json:"run_name"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` + Tasks []submitTask `json:"tasks"` + Environments []jobEnvironment `json:"environments"` + BudgetPolicyID string `json:"budget_policy_id,omitempty"` + IdempotencyToken string `json:"idempotency_token,omitempty"` +} + +// dlRuntimeImage resolves the runtime channel. A config version wins; otherwise +// the env override or default applies. The CLIENT-GPU- prefix is stripped because +// the native path wants the bare channel. +func dlRuntimeImage(ctx context.Context, runtimeVersion string) string { + if runtimeVersion != "" { + return runtimeVersion + } + img := env.Get(ctx, dlRuntimeImageEnv) + if img == "" { + img = defaultDlRuntimeImage + } + return strings.TrimPrefix(img, "CLIENT-GPU-") +} + +// buildSubmitPayload assembles the runs/submit payload. commandPath is the +// workspace path of the uploaded command.sh; dlImage is the runtime channel. +func buildSubmitPayload(cfg *runConfig, commandPath, dlImage string) jobsSubmitRun { + task := aiRuntimeTask{ + Experiment: cfg.ExperimentName, + Deployments: []aiRuntimeDeployment{{ + CommandPath: commandPath, + Compute: aiRuntimeCompute{ + AcceleratorType: cfg.Compute.AcceleratorType, + AcceleratorCount: cfg.Compute.NumAccelerators, + }, + }}, + } + if cfg.MLflowRunName != nil { + task.MlflowRun = *cfg.MLflowRunName + } + if cfg.MLflowExperimentDirectory != nil { + task.MlflowExperimentDirectory = *cfg.MLflowExperimentDirectory + } + + st := submitTask{ + TaskKey: cfg.ExperimentName, + RunIf: "ALL_SUCCESS", + AiRuntimeTask: task, + EnvironmentKey: aiRuntimeEnvironmentKey, + } + // retry_on_timeout pairs with max_retries, matching the Python payload. + if r := cfg.maxRetries(); r > 0 { + st.MaxRetries = r + st.RetryOnTimeout = true + } + + return jobsSubmitRun{ + RunName: cfg.ExperimentName, + TimeoutSeconds: cfg.timeoutSeconds(), + Tasks: []submitTask{st}, + Environments: []jobEnvironment{{ + EnvironmentKey: aiRuntimeEnvironmentKey, + Spec: environmentSpec{EnvironmentVersion: dlImage}, + }}, + } +} + +// jobsSubmitClient submits one-time runs through the Jobs API. +type jobsSubmitClient struct { + c *client.DatabricksClient +} + +func newJobsSubmitClient(w *databricks.WorkspaceClient) (*jobsSubmitClient, error) { + c, err := client.New(w.Config) + if err != nil { + return nil, err + } + return &jobsSubmitClient{c: c}, nil +} + +type submitRunResponse struct { + RunID int64 `json:"run_id,omitempty"` +} + +// submit POSTs the payload to runs/submit and returns the new run_id. +func (j *jobsSubmitClient) submit(ctx context.Context, payload jobsSubmitRun) (int64, error) { + var resp submitRunResponse + if err := j.c.Do(ctx, http.MethodPost, jobsRunsSubmitPath, auth.WorkspaceIDHeaders(j.c.Config), nil, payload, &resp); err != nil { + return 0, err + } + return resp.RunID, nil +} + +// submitToken resolves the idempotency token: the --idempotency-key flag wins, +// then the config's token, else a generated one. Capped at the Jobs API's 64. +func submitToken(flag string, cfg *runConfig) string { + token := flag + if token == "" && cfg.IdempotencyToken != nil { + token = *cfg.IdempotencyToken + } + if token == "" { + token = uuid.NewString() + } + if len(token) > 64 { + token = token[:64] + } + return token +} + +// submitWorkload runs the submit happy path: ensure the experiment directory, +// upload the launch artifacts, assemble the Jobs payload, and submit it. It +// returns the new run_id and its dashboard URL. +func submitWorkload(ctx context.Context, w *databricks.WorkspaceClient, cfg *runConfig, configPath, idempotencyKey string) (int64, string, error) { + // Resolving usage_policy_name to a budget policy id and packaging a + // code_source snapshot are not ported yet; reject rather than silently drop. + if cfg.UsagePolicyName != nil { + return 0, "", errors.New("usage_policy_name is not yet supported") + } + if cfg.CodeSource != nil { + return 0, "", errors.New("code_source is not yet supported") + } + + experimentDir := "" + if cfg.MLflowExperimentDirectory != nil { + experimentDir = *cfg.MLflowExperimentDirectory + } + if err := ensureExperimentDirectory(ctx, w, experimentDir); err != nil { + return 0, "", err + } + + base, err := userWorkspaceDir(ctx, w) + if err != nil { + return 0, "", err + } + runName := "" + if cfg.MLflowRunName != nil { + runName = *cfg.MLflowRunName + } + funcDir := cliLaunchDir(base, cfg.ExperimentName, runName) + + fc, err := filer.NewWorkspaceFilesClient(w, funcDir) + if err != nil { + return 0, "", err + } + items, err := buildArtifacts(cfg, configPath) + if err != nil { + return 0, "", err + } + if err := uploadArtifacts(ctx, fc, items); err != nil { + return 0, "", err + } + + runtimeVersion, _ := cfg.runtimeVersion() + payload := buildSubmitPayload(cfg, path.Join(funcDir, commandScriptName), dlRuntimeImage(ctx, runtimeVersion)) + payload.IdempotencyToken = submitToken(idempotencyKey, cfg) + + jc, err := newJobsSubmitClient(w) + if err != nil { + return 0, "", err + } + runID, err := jc.submit(ctx, payload) + if err != nil { + return 0, "", err + } + + dashboardURL := strings.TrimRight(w.Config.Host, "/") + "/jobs/runs/" + strconv.FormatInt(runID, 10) + return runID, dashboardURL, nil +} diff --git a/experimental/air/cmd/runsubmit_test.go b/experimental/air/cmd/runsubmit_test.go new file mode 100644 index 0000000000..c43dda0fcd --- /dev/null +++ b/experimental/air/cmd/runsubmit_test.go @@ -0,0 +1,145 @@ +package aircmd + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDlRuntimeImage(t *testing.T) { + ctx := t.Context() + // A config runtime version wins and is used bare. + assert.Equal(t, "5", dlRuntimeImage(ctx, "5")) + // Default, with the CLIENT-GPU- prefix stripped for the GPU_* path. + assert.Equal(t, "4", dlRuntimeImage(ctx, "")) + // Env override, prefix stripped. + t.Setenv(dlRuntimeImageEnv, "CLIENT-GPU-7") + assert.Equal(t, "7", dlRuntimeImage(ctx, "")) +} + +func TestBuildSubmitPayload(t *testing.T) { + cfg := &runConfig{ + ExperimentName: "exp", + Command: new("python train.py"), + Compute: &computeConfig{AcceleratorType: "GPU_8xH100", NumAccelerators: 16}, + MaxRetries: new(2), + TimeoutMinutes: new(30), + MLflowRunName: new("run-v2"), + MLflowExperimentDirectory: new("/Workspace/Users/me/exp"), + } + + p := buildSubmitPayload(cfg, "/d/command.sh", "5") + + assert.Equal(t, "exp", p.RunName) + assert.Equal(t, 1800, p.TimeoutSeconds) + require.Len(t, p.Environments, 1) + assert.Equal(t, aiRuntimeEnvironmentKey, p.Environments[0].EnvironmentKey) + assert.Equal(t, "5", p.Environments[0].Spec.EnvironmentVersion) + + require.Len(t, p.Tasks, 1) + task := p.Tasks[0] + assert.Equal(t, "exp", task.TaskKey) + assert.Equal(t, "ALL_SUCCESS", task.RunIf) + assert.Equal(t, aiRuntimeEnvironmentKey, task.EnvironmentKey) + assert.Equal(t, 2, task.MaxRetries) + assert.True(t, task.RetryOnTimeout) + + at := task.AiRuntimeTask + assert.Equal(t, "exp", at.Experiment) + assert.Equal(t, "run-v2", at.MlflowRun) + assert.Equal(t, "/Workspace/Users/me/exp", at.MlflowExperimentDirectory) + require.Len(t, at.Deployments, 1) + assert.Equal(t, "/d/command.sh", at.Deployments[0].CommandPath) + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_8xH100", AcceleratorCount: 16}, at.Deployments[0].Compute) +} + +func TestSubmitToken(t *testing.T) { + cfg := &runConfig{IdempotencyToken: new("from-config")} + assert.Equal(t, "from-flag", submitToken("from-flag", cfg)) // flag wins + assert.Equal(t, "from-config", submitToken("", cfg)) // then config + assert.NotEmpty(t, submitToken("", &runConfig{})) // else generated + assert.Len(t, submitToken(string(make([]byte, 80)), cfg), 64) // capped +} + +func TestJobsSubmitClient(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 999} + }) + + w := &databricks.WorkspaceClient{Config: &config.Config{Host: server.URL, Token: "token"}} + jc, err := newJobsSubmitClient(w) + require.NoError(t, err) + + runID, err := jc.submit(t.Context(), jobsSubmitRun{RunName: "exp", Tasks: []submitTask{{TaskKey: "exp"}}}) + require.NoError(t, err) + assert.Equal(t, int64(999), runID) + assert.Equal(t, "exp", got.RunName) +} + +func TestSubmitWorkload(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 777} + }) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + cfg, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + runID, dashboardURL, err := submitWorkload(t.Context(), w, cfg, cfgPath, "idem-key") + require.NoError(t, err) + assert.Equal(t, int64(777), runID) + assert.Contains(t, dashboardURL, "/jobs/runs/777") + + // The submitted payload is a native ai_runtime_task pointing at the uploaded + // command.sh under the run's launch directory. + assert.Equal(t, "my-run", got.RunName) + assert.Equal(t, "idem-key", got.IdempotencyToken) + require.Len(t, got.Environments, 1) + require.Len(t, got.Tasks, 1) + at := got.Tasks[0].AiRuntimeTask + require.Len(t, at.Deployments, 1) + d := at.Deployments[0] + assert.True(t, strings.HasSuffix(d.CommandPath, "/"+commandScriptName), d.CommandPath) + assert.Contains(t, d.CommandPath, "/.air/cli_launch/") + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_1xH100", AcceleratorCount: 1}, d.Compute) +} + +func TestSubmitWorkloadGuards(t *testing.T) { + w := newFakeWorkspaceClient(t) + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + base, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + t.Run("usage_policy_name rejected", func(t *testing.T) { + cfg := *base + cfg.UsagePolicyName = new("p") + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "usage_policy_name is not yet supported") + }) + + t.Run("code_source rejected", func(t *testing.T) { + cfg := *base + cfg.CodeSource = &codeSourceConfig{Type: "snapshot"} + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "code_source is not yet supported") + }) +} diff --git a/experimental/air/cmd/runupload.go b/experimental/air/cmd/runupload.go new file mode 100644 index 0000000000..fb9ca00b98 --- /dev/null +++ b/experimental/air/cmd/runupload.go @@ -0,0 +1,170 @@ +package aircmd + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "maps" + "os" + "path/filepath" + "slices" + "strings" + + "github.com/databricks/cli/libs/filer" + "go.yaml.in/yaml/v3" +) + +// Launch artifact basenames, uploaded into the run's cli_launch directory. The +// server-side launcher derives requirements.yaml / hyperparameters.yaml from the +// same directory, so these names are part of the contract. +const ( + trainingConfigName = "training_config.yaml" + commandScriptName = "command.sh" + requirementsName = "requirements.yaml" + hyperparametersName = "hyperparameters.yaml" + envVarsName = "env_vars.json" + secretEnvVarsName = "secret_env_vars.json" +) + +// maxConfigYAMLBytes caps training_config.yaml. It is referenced by the Jobs +// payload and rendered on the run page, so an oversized parameters/command block +// is rejected here; full parameters still ship in hyperparameters.yaml. +const maxConfigYAMLBytes = 1024 * 1024 + +// uploadItem is a single artifact to write into the launch directory. +type uploadItem struct { + name string + data []byte +} + +// fileWriter is the subset of filer.Filer the upload path needs; a narrow +// interface keeps buildArtifacts/upload testable without a live workspace. +type fileWriter interface { + Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error +} + +// requirementsDoc mirrors the on-disk requirements.yaml format so the worker +// parses synthesized inline dependencies identically to a user-provided file. +type requirementsDoc struct { + Version string `yaml:"version,omitempty"` + Dependencies []string `yaml:"dependencies"` +} + +// buildArtifacts assembles the files to upload for a run: the merged config, the +// inline command as a script, requirements (from a file or synthesized from +// inline dependencies), and hyperparameters. configPath is the local YAML path. +func buildArtifacts(cfg *runConfig, configPath string) ([]uploadItem, error) { + // TODO(DABs): with no _bases_/overrides ported yet, the merged config is the + // file as-is; once those land, upload the re-serialized merged YAML instead. + configData, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("failed to read config %s: %w", configPath, err) + } + if len(configData) > maxConfigYAMLBytes { + return nil, fmt.Errorf("config YAML is %.2f MB, over the %d MB limit; reduce 'parameters' or 'command'", + float64(len(configData))/(1024*1024), maxConfigYAMLBytes/(1024*1024)) + } + + items := []uploadItem{ + {trainingConfigName, configData}, + {commandScriptName, []byte(*cfg.Command)}, + } + + switch reqPath, ok := cfg.requirementsFile(); { + case ok: + // Resolve a relative requirements path against the config's directory. + if !filepath.IsAbs(reqPath) { + reqPath = filepath.Join(filepath.Dir(configPath), reqPath) + } + data, err := os.ReadFile(reqPath) + if err != nil { + return nil, fmt.Errorf("failed to read requirements file %s: %w", reqPath, err) + } + items = append(items, uploadItem{requirementsName, data}) + default: + if deps, ok := cfg.inlineDependencies(); ok { + version, _ := cfg.runtimeVersion() + data, err := yaml.Marshal(requirementsDoc{Version: version, Dependencies: deps}) + if err != nil { + return nil, fmt.Errorf("failed to synthesize requirements.yaml: %w", err) + } + items = append(items, uploadItem{requirementsName, data}) + } + } + + if len(cfg.Parameters) > 0 { + data, err := yaml.Marshal(cfg.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to serialize parameters: %w", err) + } + items = append(items, uploadItem{hyperparametersName, data}) + } + + // The ai_runtime_task proto carries no inline env vars or secrets; stage them + // as JSON files co-located with command.sh for the server-side launcher. + if len(cfg.EnvVariables) > 0 { + data, err := json.Marshal(envVarEntries(cfg.EnvVariables)) + if err != nil { + return nil, fmt.Errorf("failed to serialize env_variables: %w", err) + } + items = append(items, uploadItem{envVarsName, data}) + } + if len(cfg.Secrets) > 0 { + data, err := json.Marshal(secretEnvVarEntries(cfg.Secrets)) + if err != nil { + return nil, fmt.Errorf("failed to serialize secrets: %w", err) + } + items = append(items, uploadItem{secretEnvVarsName, data}) + } + + return items, nil +} + +// envVarEntry is one entry in env_vars.json. +type envVarEntry struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// secretEnvVarEntry is one entry in secret_env_vars.json. The YAML side is +// {ENV_VAR: "scope/key"}; the launcher wants the split form. +type secretEnvVarEntry struct { + Name string `json:"name"` + SecretScope string `json:"secret_scope"` + SecretKey string `json:"secret_key"` +} + +// envVarEntries renders env_variables sorted by name for deterministic output. +func envVarEntries(vars map[string]string) []envVarEntry { + out := make([]envVarEntry, 0, len(vars)) + for _, name := range slices.Sorted(maps.Keys(vars)) { + out = append(out, envVarEntry{Name: name, Value: vars[name]}) + } + return out +} + +// secretEnvVarEntries renders secrets sorted by name for deterministic output. +func secretEnvVarEntries(secrets map[string]string) []secretEnvVarEntry { + out := make([]secretEnvVarEntry, 0, len(secrets)) + for _, name := range slices.Sorted(maps.Keys(secrets)) { + scope, key, _ := strings.Cut(secrets[name], "/") + out = append(out, secretEnvVarEntry{Name: name, SecretScope: scope, SecretKey: key}) + } + return out +} + +// uploadArtifacts writes each artifact into the launch directory, overwriting and +// creating parents as needed. +// +// TODO(DABs): this client-side upload could move onto libs/sync / a bundle deploy +// so the CLI reuses DABs' file-staging machinery instead of writing files itself. +func uploadArtifacts(ctx context.Context, w fileWriter, items []uploadItem) error { + for _, it := range items { + if err := w.Write(ctx, it.name, bytes.NewReader(it.data), filer.OverwriteIfExists, filer.CreateParentDirectories); err != nil { + return fmt.Errorf("failed to upload %s: %w", it.name, err) + } + } + return nil +} diff --git a/experimental/air/cmd/runupload_test.go b/experimental/air/cmd/runupload_test.go new file mode 100644 index 0000000000..0c87524735 --- /dev/null +++ b/experimental/air/cmd/runupload_test.go @@ -0,0 +1,155 @@ +package aircmd + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeWriter records artifact writes in place of a workspace filer. +type fakeWriter struct { + written map[string]string +} + +func (f *fakeWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + if f.written == nil { + f.written = map[string]string{} + } + data, err := io.ReadAll(reader) + if err != nil { + return err + } + f.written[name] = string(data) + return nil +} + +func writeConfigFile(t *testing.T, name, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), name) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} + +func itemNames(items []uploadItem) []string { + names := make([]string, len(items)) + for i, it := range items { + names[i] = it.name + } + return names +} + +func TestBuildArtifacts_CommandAndConfig(t *testing.T) { + path := writeConfigFile(t, "run.yaml", minimalConfig) + cfg := &runConfig{Command: new("python train.py")} + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName}, itemNames(items)) + assert.Equal(t, minimalConfig, string(items[0].data)) + assert.Equal(t, "python train.py", string(items[1].data)) +} + +func TestBuildArtifacts_InlineRequirementsAndParameters(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + Version: stringOrInt{set: true, raw: "5"}, + }, + Parameters: map[string]any{"lr": 0.1}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName, requirementsName, hyperparametersName}, itemNames(items)) + + var reqIdx int + for i, it := range items { + if it.name == requirementsName { + reqIdx = i + } + } + req := string(items[reqIdx].data) + assert.Contains(t, req, "version: \"5\"") + assert.Contains(t, req, "- torch") +} + +func TestBuildArtifacts_EnvVarsAndSecrets(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + EnvVariables: map[string]string{"WANDB": "demo"}, + Secrets: map[string]string{"HF_TOKEN": "myscope/hf"}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Subset(t, itemNames(items), []string{envVarsName, secretEnvVarsName}) + + byName := map[string][]byte{} + for _, it := range items { + byName[it.name] = it.data + } + assert.JSONEq(t, `[{"name":"WANDB","value":"demo"}]`, string(byName[envVarsName])) + assert.JSONEq(t, `[{"name":"HF_TOKEN","secret_scope":"myscope","secret_key":"hf"}]`, string(byName[secretEnvVarsName])) +} + +func TestBuildArtifacts_RequirementsFile(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "run.yaml"), []byte("x: y\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "reqs.yaml"), []byte("version: 4\n"), 0o600)) + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "reqs.yaml"}}, + } + + items, err := buildArtifacts(cfg, filepath.Join(dir, "run.yaml")) + require.NoError(t, err) + assert.Contains(t, itemNames(items), requirementsName) +} + +func TestBuildArtifacts_OversizeConfigRejected(t *testing.T) { + path := writeConfigFile(t, "run.yaml", strings.Repeat("a", maxConfigYAMLBytes+1)) + _, err := buildArtifacts(&runConfig{Command: new("x")}, path) + require.Error(t, err) + assert.Contains(t, err.Error(), "over the 1 MB limit") +} + +func TestUploadArtifacts(t *testing.T) { + w := &fakeWriter{} + items := []uploadItem{{trainingConfigName, []byte("cfg")}, {commandScriptName, []byte("cmd")}} + require.NoError(t, uploadArtifacts(t.Context(), w, items)) + assert.Equal(t, "cfg", w.written[trainingConfigName]) + assert.Equal(t, "cmd", w.written[commandScriptName]) +} + +// errWriter fails every Write, exercising the upload error path. +type errWriter struct{} + +func (errWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + return errors.New("boom") +} + +func TestUploadArtifacts_WriteError(t *testing.T) { + err := uploadArtifacts(t.Context(), errWriter{}, []uploadItem{{trainingConfigName, []byte("x")}}) + require.ErrorContains(t, err, "failed to upload "+trainingConfigName) +} + +func TestBuildArtifacts_MissingRequirementsFile(t *testing.T) { + cfgPath := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "nope.yaml"}}, + } + _, err := buildArtifacts(cfg, cfgPath) + require.ErrorContains(t, err, "failed to read requirements file") +} diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index b9f5c330f0..4607d7d9ea 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -13,7 +13,6 @@ import ( // fails with a "not implemented" error. Drop a command here once it lands. func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ - "run": newRunCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), "register-image": newRegisterImageCommand(),