From 226d41ae1833c2630ddd82d4ba7b5d9dc639cd0c Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 17:36:11 +0000 Subject: [PATCH 1/6] experimental/air: add run config schema and structural validation Port the run YAML schema and its structural validation from the Python CLI's sdk/config.py: the top-level runConfig plus the environment, docker_image, code_source/snapshot/git, and permission blocks. loadRunConfig decodes a YAML file with KnownFields (mirroring pydantic extra="forbid") and runs the validation pass. "Structural" covers types, required fields, and format/cross-field rules that need no workspace access. Online checks (compute pool resolution, GPU availability), git/filesystem checks, _bases_ composition, and CLI --override handling are deferred to later milestones. Two deliberate divergences from the Python schema, both following from the training-service-only port: the compute pool fields were already dropped, and the top-level priority field is dropped here since it is a node-pool queue-ordering knob with no meaning for serverless workloads. Co-authored-by: Isaac --- experimental/air/cmd/runconfig.go | 445 +++++++++++++++++++++++++ experimental/air/cmd/runconfig_load.go | 40 +++ experimental/air/cmd/runconfig_test.go | 407 ++++++++++++++++++++++ 3 files changed, 892 insertions(+) create mode 100644 experimental/air/cmd/runconfig.go create mode 100644 experimental/air/cmd/runconfig_load.go create mode 100644 experimental/air/cmd/runconfig_test.go diff --git a/experimental/air/cmd/runconfig.go b/experimental/air/cmd/runconfig.go new file mode 100644 index 0000000000..cf29ea50aa --- /dev/null +++ b/experimental/air/cmd/runconfig.go @@ -0,0 +1,445 @@ +package aircmd + +import ( + "errors" + "fmt" + "regexp" + "slices" + "strings" + + "go.yaml.in/yaml/v3" +) + +// This file ports the run YAML schema and its structural validation from the +// Python CLI's sdk/config.py. "Structural" means types, required fields, and +// format/cross-field rules that need no workspace access. Online checks (e.g. +// GPU availability) and git/filesystem checks run at launch time and are +// intentionally not ported here. +// +// Divergences from the Python schema: compute.node_pool_id / compute.pool_name +// (see compute.go) and the top-level `priority` field are dropped because AIR +// does not support node-pool placement. priority is a pool-queue-ordering knob, +// so it goes with the pool fields. + +// REGEX_TASK_KEY_CHARS: ASCII alphanumeric, hyphen, underscore only (no periods). +// Explicit ASCII class, not \w: \w matches Unicode letters that the ASCII-only +// Jobs API task_key rejects. +var taskKeyRe = regexp.MustCompile(`^[A-Za-z0-9_-]+$`) + +// gitRefRe guards branch/remote names against command injection. Only safe ref +// characters are allowed. +var gitRefRe = regexp.MustCompile(`^[\w./-]+$`) + +// runConfig is the top-level run YAML schema: experiment_name + compute / +// environment / code_source plus the command and run options. +type runConfig struct { + ExperimentName string `yaml:"experiment_name"` + Compute *computeConfig `yaml:"compute"` + Environment *environmentConfig `yaml:"environment"` + Command *string `yaml:"command"` + EnvVariables map[string]string `yaml:"env_variables"` + Secrets map[string]string `yaml:"secrets"` + CodeSource *codeSourceConfig `yaml:"code_source"` + // MaxRetries defaults to 3 when unset; default-filling is a normalization + // concern handled at launch, so a nil pointer is left as-is here. + MaxRetries *int `yaml:"max_retries"` + TimeoutMinutes *int `yaml:"timeout_minutes"` + IdempotencyToken *string `yaml:"idempotency_token"` + Parameters map[string]any `yaml:"parameters"` + MLflowRunName *string `yaml:"mlflow_run_name"` + MLflowExperimentDirectory *string `yaml:"mlflow_experiment_directory"` + Permissions []permission `yaml:"permissions"` + UsagePolicyName *string `yaml:"usage_policy_name"` +} + +// validate runs structural validation over the whole config, returning the first +// failure. Fields are checked in declaration order to keep error output stable. +func (c *runConfig) validate() error { + if err := validateExperimentName(c.ExperimentName); err != nil { + return err + } + + if c.Compute == nil { + return errors.New("compute: section is required") + } + if err := c.Compute.validate(); err != nil { + return err + } + + if c.Environment != nil { + if err := c.Environment.validate(); err != nil { + return err + } + } + + // command is optional in the type system but required in practice, matching + // the Python validate_script_fields model validator. + if c.Command == nil { + return errors.New("command is required") + } + if err := validateCommand(*c.Command); err != nil { + return err + } + + if err := validateSecretRefs(c.Secrets); err != nil { + return err + } + + if c.CodeSource != nil { + if err := c.CodeSource.validate(); err != nil { + return err + } + } + + if c.MaxRetries != nil && *c.MaxRetries < 0 { + return fmt.Errorf("max_retries must be >= 0, got %d", *c.MaxRetries) + } + + if c.TimeoutMinutes != nil && *c.TimeoutMinutes < 1 { + return fmt.Errorf("timeout_minutes must be >= 1, got %d", *c.TimeoutMinutes) + } + + if c.IdempotencyToken != nil { + v := strings.TrimSpace(*c.IdempotencyToken) + if v == "" { + return errors.New("idempotency_token cannot be empty") + } + if len(v) > 64 { + return errors.New("idempotency_token must be 64 characters or less") + } + } + + if c.MLflowRunName != nil { + v := strings.TrimSpace(*c.MLflowRunName) + if v == "" { + return errors.New("mlflow_run_name cannot be empty") + } + if !taskKeyRe.MatchString(v) { + return fmt.Errorf("invalid mlflow_run_name %q: only alphanumeric characters, hyphens, and underscores are allowed", v) + } + } + + if c.MLflowExperimentDirectory != nil { + v := strings.TrimSpace(*c.MLflowExperimentDirectory) + if v == "" { + return errors.New("mlflow_experiment_directory cannot be empty") + } + // MLflow experiments live under the workspace tree. + if !strings.HasPrefix(v, "/Workspace") { + return fmt.Errorf("mlflow_experiment_directory must start with '/Workspace', got: %s", v) + } + } + + for i := range c.Permissions { + if err := c.Permissions[i].validate(); err != nil { + return err + } + } + + if c.UsagePolicyName != nil { + v := strings.TrimSpace(*c.UsagePolicyName) + if v == "" { + return errors.New("usage_policy_name must not be empty") + } + // 127 matches the server-side max_length on the policy name filter. + if len(v) > 127 { + return fmt.Errorf("usage_policy_name must be at most 127 characters, got %d", len(v)) + } + } + + return nil +} + +// validateExperimentName enforces the Databricks Jobs API task_key constraints: +// the experiment_name becomes a task key, which caps at 100 characters and allows +// only alphanumerics, hyphens, and underscores. +func validateExperimentName(v string) error { + if v == "" { + return errors.New("experiment_name cannot be empty") + } + if len(v) > 100 { + return fmt.Errorf("experiment_name must be 100 characters or less (got %d); this is the Jobs API task_key length limit", len(v)) + } + if !taskKeyRe.MatchString(v) { + return fmt.Errorf("invalid experiment_name %q: only alphanumeric characters, hyphens (-), and underscores (_) are allowed", v) + } + return nil +} + +// validateCommand enforces command is non-empty and within the line-count cap. +func validateCommand(v string) error { + if strings.TrimSpace(v) == "" { + return errors.New("command cannot be empty") + } + lineCount := strings.Count(v, "\n") + 1 + if lineCount > 1000 { + return fmt.Errorf("command is too long (%d lines); maximum is 1000 lines — move complex logic into a script in your code_source", lineCount) + } + return nil +} + +// validateSecretRefs checks that secret references use the "scope/key" format. +func validateSecretRefs(secrets map[string]string) error { + for varName, ref := range secrets { + parts := strings.Split(ref, "/") + if len(parts) != 2 { + return fmt.Errorf("invalid secret reference %q for variable %q: expected format 'scope/key' (e.g., my_scope/hf_token)", ref, varName) + } + if parts[0] == "" || parts[1] == "" { + return fmt.Errorf("invalid secret reference %q for variable %q: scope and key cannot be empty", ref, varName) + } + } + return nil +} + +// environmentConfig is the `environment` block: dependencies and/or a custom +// docker image. +type environmentConfig struct { + Dependencies dependencies `yaml:"dependencies"` + Version stringOrInt `yaml:"version"` + DockerImage *dockerImageConfig `yaml:"docker_image"` +} + +func (e *environmentConfig) validate() error { + // docker_image is exclusive with dependencies/version: the image already pins + // the full runtime. + if e.DockerImage != nil { + var conflicting []string + if e.Dependencies.set { + conflicting = append(conflicting, "dependencies") + } + if e.Version.set { + conflicting = append(conflicting, "version") + } + if len(conflicting) > 0 { + return fmt.Errorf("when 'docker_image' is specified under 'environment', these fields are not allowed: %s", strings.Join(conflicting, ", ")) + } + return e.DockerImage.validate() + } + + // version pins the client image version, which is only meaningful for an + // inline (list) dependency set — a requirements.yaml file carries its own. + if e.Version.set { + if e.Dependencies.set && !e.Dependencies.isList { + return errors.New("'environment.version' is only valid with inline dependencies (a list); when 'dependencies' points to a requirements.yaml file, set the version inside that file") + } + if !e.Dependencies.set { + return errors.New("'environment.version' requires inline 'dependencies' (a list of packages)") + } + } + + return nil +} + +// dependencies is environment.dependencies, which is polymorphic: a string is a +// path to a requirements.yaml file; a list is an inline package list. +type dependencies struct { + set bool + isList bool + path string + list []string +} + +func (d *dependencies) UnmarshalYAML(node *yaml.Node) error { + switch node.Kind { + case yaml.ScalarNode: + d.set, d.isList = true, false + return node.Decode(&d.path) + case yaml.SequenceNode: + d.set, d.isList = true, true + return node.Decode(&d.list) + default: + return errors.New("environment.dependencies must be a string path or a list of packages") + } +} + +// stringOrInt holds a scalar that may be a string or an integer in YAML +// (environment.version). The raw text is kept; integer-format validation is a +// launch-time concern. +type stringOrInt struct { + set bool + raw string +} + +func (s *stringOrInt) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.ScalarNode { + return errors.New("environment.version must be a string or integer") + } + s.set = true + s.raw = node.Value + return nil +} + +// dockerImageConfig is environment.docker_image. +type dockerImageConfig struct { + URL string `yaml:"url"` +} + +func (d *dockerImageConfig) validate() error { + if strings.TrimSpace(d.URL) == "" { + return errors.New("docker_image.url cannot be empty") + } + return nil +} + +// codeSourceConfig is the `code_source` block. Only the "snapshot" type exists. +type codeSourceConfig struct { + Type string `yaml:"type"` + Snapshot *snapshotSourceConfig `yaml:"snapshot"` +} + +func (c *codeSourceConfig) validate() error { + if c.Type != "snapshot" { + return fmt.Errorf("code_source.type must be 'snapshot', got %q", c.Type) + } + if c.Snapshot == nil { + return errors.New("code_source.type='snapshot' requires a snapshot configuration") + } + return c.Snapshot.validate() +} + +// snapshotSourceConfig describes a local directory to tar and upload. +type snapshotSourceConfig struct { + RootPath string `yaml:"root_path"` + RemoteVolume *string `yaml:"remote_volume"` + Git *gitRef `yaml:"git"` + IncludePaths []string `yaml:"include_paths"` +} + +func (s *snapshotSourceConfig) validate() error { + if strings.TrimSpace(s.RootPath) == "" { + return errors.New("code_source.snapshot.root_path cannot be empty") + } + + if s.RemoteVolume != nil && !strings.HasPrefix(*s.RemoteVolume, "/Volumes/") { + return errors.New("code_source.snapshot.remote_volume must start with '/Volumes/'") + } + + // A non-nil but empty include_paths is an explicit mistake (omit it instead). + if s.IncludePaths != nil && len(s.IncludePaths) == 0 { + return errors.New("code_source.snapshot.include_paths cannot be an empty list; either omit it or provide paths") + } + for _, p := range s.IncludePaths { + p = strings.TrimSpace(p) + if p == "" { + return errors.New("code_source.snapshot.include_paths entry cannot be empty") + } + if strings.HasPrefix(p, "/") { + return fmt.Errorf("code_source.snapshot.include_paths must be relative paths, got: %s", p) + } + // No parent traversal: snapshots must stay within root_path. + if slices.Contains(strings.Split(p, "/"), "..") { + return fmt.Errorf("code_source.snapshot.include_paths cannot contain '..' traversal, got: %s", p) + } + } + + if s.Git != nil { + return s.Git.validate() + } + return nil +} + +// gitRef pins a snapshot to a specific git ref. branch and commit are mutually +// exclusive; remote is only meaningful with branch. +type gitRef struct { + Branch *string `yaml:"branch"` + Commit *string `yaml:"commit"` + Remote gitRemote `yaml:"remote"` +} + +func (g *gitRef) validate() error { + if g.Branch != nil && !gitRefRe.MatchString(*g.Branch) { + return fmt.Errorf("invalid git.branch format %q: only alphanumeric characters, hyphens, dots, slashes, and underscores are allowed", *g.Branch) + } + if g.Remote.isString { + if g.Remote.name == "" { + return errors.New("git.remote string cannot be empty; use 'true' to auto-detect") + } + if !gitRefRe.MatchString(g.Remote.name) { + return fmt.Errorf("invalid git.remote name %q: only alphanumeric characters, hyphens, dots, slashes, and underscores are allowed", g.Remote.name) + } + } + + if g.Branch == nil && g.Commit == nil { + return errors.New("git: must specify either 'branch' or 'commit'") + } + if g.Branch != nil && g.Commit != nil { + return errors.New("git: 'branch' and 'commit' are mutually exclusive — specify only one") + } + if g.Remote.truthy() && g.Branch == nil { + return errors.New("git.remote requires git.branch (only valid with branch refs)") + } + return nil +} + +// gitRemote is git.remote: false (default, use local HEAD), true (auto-detect the +// remote), or a remote name string. +type gitRemote struct { + set bool + isString bool + name string + enabled bool +} + +func (r *gitRemote) UnmarshalYAML(node *yaml.Node) error { + if node.Kind != yaml.ScalarNode { + return errors.New("git.remote must be a boolean or a remote name string") + } + r.set = true + if node.Tag == "!!bool" { + return node.Decode(&r.enabled) + } + r.isString = true + r.name = node.Value + return nil +} + +// truthy reports whether remote requests a remote fetch (mirrors Python's +// truthiness of the bool|str union). +func (r *gitRemote) truthy() bool { + if r.isString { + return r.name != "" + } + return r.enabled +} + +// permission is a DABs-compatible permission grant: exactly one principal plus a +// level. +type permission struct { + UserName *string `yaml:"user_name"` + GroupName *string `yaml:"group_name"` + ServicePrincipalName *string `yaml:"service_principal_name"` + // Level is a databricks PermissionLevel (e.g. CAN_VIEW, CAN_MANAGE). Enum + // membership is validated server-side; here we only require it to be set. + Level string `yaml:"level"` +} + +func (p *permission) validate() error { + principals := map[string]*string{ + "user_name": p.UserName, + "group_name": p.GroupName, + "service_principal_name": p.ServicePrincipalName, + } + var set []string + for name, val := range principals { + if val != nil { + set = append(set, name) + } + } + switch len(set) { + case 0: + return errors.New("permissions: one of 'user_name', 'group_name', or 'service_principal_name' must be specified") + case 1: + name := set[0] + if strings.TrimSpace(*principals[name]) == "" { + return fmt.Errorf("permissions: '%s' cannot be empty", name) + } + default: + return errors.New("permissions: only one of 'user_name', 'group_name', or 'service_principal_name' can be specified") + } + + if strings.TrimSpace(p.Level) == "" { + return errors.New("permissions: 'level' is required") + } + return nil +} diff --git a/experimental/air/cmd/runconfig_load.go b/experimental/air/cmd/runconfig_load.go new file mode 100644 index 0000000000..4cdbd28308 --- /dev/null +++ b/experimental/air/cmd/runconfig_load.go @@ -0,0 +1,40 @@ +package aircmd + +import ( + "errors" + "fmt" + "io" + "os" + + "go.yaml.in/yaml/v3" +) + +// loadRunConfig reads a run YAML config file, decodes it into the schema, and +// runs structural validation. Unknown keys are rejected (KnownFields), mirroring +// the Python schema's extra="forbid". +// +// The `_bases_` composition feature and CLI `--override` handling are not yet +// ported; a config using `_bases_` is currently rejected as an unknown field. +func loadRunConfig(path string) (*runConfig, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + dec := yaml.NewDecoder(f) + dec.KnownFields(true) + + var cfg runConfig + if err := dec.Decode(&cfg); err != nil { + if errors.Is(err, io.EOF) { + return nil, fmt.Errorf("config %s is empty", path) + } + return nil, fmt.Errorf("invalid config %s: %w", path, err) + } + + if err := cfg.validate(); err != nil { + return nil, err + } + return &cfg, nil +} diff --git a/experimental/air/cmd/runconfig_test.go b/experimental/air/cmd/runconfig_test.go new file mode 100644 index 0000000000..06501e6ea6 --- /dev/null +++ b/experimental/air/cmd/runconfig_test.go @@ -0,0 +1,407 @@ +package aircmd + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// writeConfig writes content to a temp YAML file and returns its path. +func writeConfig(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "config.yaml") + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} + +// minimalConfig is the smallest valid config: the three required pieces. +const minimalConfig = ` +experiment_name: my-run +command: python train.py +compute: + accelerator_type: GPU_1xH100 + num_accelerators: 1 +` + +func TestLoadRunConfig_Minimal(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig)) + require.NoError(t, err) + assert.Equal(t, "my-run", cfg.ExperimentName) + require.NotNil(t, cfg.Command) + assert.Equal(t, "python train.py", *cfg.Command) + require.NotNil(t, cfg.Compute) + assert.Equal(t, "GPU_1xH100", cfg.Compute.AcceleratorType) + assert.Equal(t, 1, cfg.Compute.NumAccelerators) +} + +func TestLoadRunConfig_FullFeatured(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, ` +experiment_name: full_run +command: | + python train.py + echo done +compute: + accelerator_type: GPU_8xH100 + num_accelerators: 16 +environment: + dependencies: + - torch==2.3.0 + - numpy + version: 5 +env_variables: + FOO: bar +secrets: + HF_TOKEN: my_scope/hf_token +code_source: + type: snapshot + snapshot: + root_path: project_root/src + remote_volume: /Volumes/main/default/code + git: + branch: main + remote: origin + include_paths: + - src + - configs/train.yaml +max_retries: 5 +timeout_minutes: 120 +idempotency_token: abc-123 +mlflow_run_name: full_run_v2 +mlflow_experiment_directory: /Workspace/Users/me/exp +usage_policy_name: my-policy +permissions: + - group_name: users + level: CAN_VIEW + - user_name: alice@example.com + level: CAN_MANAGE +`)) + require.NoError(t, err) + assert.Equal(t, gpuType8xH100, gpuType(cfg.Compute.AcceleratorType)) + require.NotNil(t, cfg.Environment) + assert.True(t, cfg.Environment.Dependencies.isList) + assert.Equal(t, []string{"torch==2.3.0", "numpy"}, cfg.Environment.Dependencies.list) + assert.True(t, cfg.Environment.Version.set) + assert.Equal(t, "5", cfg.Environment.Version.raw) + require.NotNil(t, cfg.CodeSource) + require.NotNil(t, cfg.CodeSource.Snapshot) + require.NotNil(t, cfg.CodeSource.Snapshot.Git) + require.NotNil(t, cfg.CodeSource.Snapshot.Git.Branch) + assert.Equal(t, "main", *cfg.CodeSource.Snapshot.Git.Branch) + assert.True(t, cfg.CodeSource.Snapshot.Git.Remote.isString) + assert.Equal(t, "origin", cfg.CodeSource.Snapshot.Git.Remote.name) + assert.Len(t, cfg.Permissions, 2) +} + +// TestLoadRunConfig_PolymorphicFields exercises the str|list, str|int, and +// bool|str unions decoded by custom UnmarshalYAML. +func TestLoadRunConfig_PolymorphicFields(t *testing.T) { + t.Run("dependencies as string path", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +environment: + dependencies: requirements.yaml +`)) + require.NoError(t, err) + assert.True(t, cfg.Environment.Dependencies.set) + assert.False(t, cfg.Environment.Dependencies.isList) + assert.Equal(t, "requirements.yaml", cfg.Environment.Dependencies.path) + }) + + t.Run("git remote as bool true", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +code_source: + type: snapshot + snapshot: + root_path: . + git: + branch: main + remote: true +`)) + require.NoError(t, err) + r := cfg.CodeSource.Snapshot.Git.Remote + assert.False(t, r.isString) + assert.True(t, r.enabled) + assert.True(t, r.truthy()) + }) + + t.Run("git remote defaults to false when unset", func(t *testing.T) { + cfg, err := loadRunConfig(writeConfig(t, minimalConfig+` +code_source: + type: snapshot + snapshot: + root_path: . + git: + commit: deadbeef +`)) + require.NoError(t, err) + assert.False(t, cfg.CodeSource.Snapshot.Git.Remote.truthy()) + }) +} + +func TestLoadRunConfig_UnknownFieldRejected(t *testing.T) { + tests := []struct { + name string + extra string + errFrag string + }{ + {"top-level typo", "extra_field: nope\n", "extra_field"}, + // priority was intentionally dropped from the schema (pool-only concept). + {"dropped priority field", "priority: 100\n", "priority"}, + // _bases_ composition is not yet ported, so it surfaces as unknown. + {"unported _bases_", "_bases_: [base.yaml]\n", "_bases_"}, + {"nested typo", "environment:\n bogus: 1\n", "bogus"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, minimalConfig+tt.extra)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestLoadRunConfig_Errors(t *testing.T) { + tests := []struct { + name string + yaml string + errFrag string + }{ + { + "missing experiment_name", + "command: x\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "experiment_name cannot be empty", + }, + { + "experiment_name bad chars", + "experiment_name: my.run\ncommand: x\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "invalid experiment_name", + }, + { + "missing compute", + "experiment_name: r\ncommand: x\n", + "compute: section is required", + }, + { + "missing command", + "experiment_name: r\ncompute:\n accelerator_type: GPU_1xH100\n num_accelerators: 1\n", + "command is required", + }, + { + "bad gpu type", + "experiment_name: r\ncommand: x\ncompute:\n accelerator_type: a100\n num_accelerators: 1\n", + "invalid GPU type", + }, + { + "num_accelerators not a multiple", + "experiment_name: r\ncommand: x\ncompute:\n accelerator_type: GPU_8xH100\n num_accelerators: 3\n", + "must be a multiple of 8", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, tt.yaml)) + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +// TestRunConfigValidate_FieldRules unit-tests validation rules directly, away +// from YAML decoding, to keep each rule's failure mode explicit. +func TestRunConfigValidate_FieldRules(t *testing.T) { + str := func(s string) *string { return &s } + intp := func(i int) *int { return &i } + base := func() *runConfig { + return &runConfig{ + ExperimentName: "r", + Command: str("x"), + Compute: &computeConfig{AcceleratorType: "GPU_1xH100", NumAccelerators: 1}, + } + } + + tests := []struct { + name string + mutate func(c *runConfig) + errFrag string + }{ + {"ok baseline", func(c *runConfig) {}, ""}, + {"empty command", func(c *runConfig) { c.Command = str(" ") }, "command cannot be empty"}, + {"negative max_retries", func(c *runConfig) { c.MaxRetries = intp(-1) }, "max_retries must be >= 0"}, + {"zero timeout", func(c *runConfig) { c.TimeoutMinutes = intp(0) }, "timeout_minutes must be >= 1"}, + {"empty idempotency", func(c *runConfig) { c.IdempotencyToken = str(" ") }, "idempotency_token cannot be empty"}, + {"long idempotency", func(c *runConfig) { c.IdempotencyToken = str(string(make([]byte, 65))) }, "64 characters or less"}, + {"bad mlflow_run_name", func(c *runConfig) { c.MLflowRunName = str("bad name") }, "invalid mlflow_run_name"}, + {"bad experiment dir", func(c *runConfig) { c.MLflowExperimentDirectory = str("/Users/me") }, "must start with '/Workspace'"}, + {"empty usage policy", func(c *runConfig) { c.UsagePolicyName = str(" ") }, "usage_policy_name must not be empty"}, + {"bad secret ref", func(c *runConfig) { c.Secrets = map[string]string{"T": "noslash"} }, "expected format 'scope/key'"}, + {"empty secret scope", func(c *runConfig) { c.Secrets = map[string]string{"T": "/key"} }, "scope and key cannot be empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := base() + tt.mutate(c) + err := c.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestEnvironmentConfigValidate(t *testing.T) { + tests := []struct { + name string + env environmentConfig + errFrag string + }{ + { + "docker image alone ok", + environmentConfig{DockerImage: &dockerImageConfig{URL: "org/repo:tag"}}, + "", + }, + { + "docker image with deps conflicts", + environmentConfig{ + DockerImage: &dockerImageConfig{URL: "org/repo:tag"}, + Dependencies: dependencies{set: true, isList: true, list: []string{"torch"}}, + }, + "not allowed: dependencies", + }, + { + "empty docker url", + environmentConfig{DockerImage: &dockerImageConfig{URL: " "}}, + "docker_image.url cannot be empty", + }, + { + "version with file deps", + environmentConfig{ + Version: stringOrInt{set: true, raw: "5"}, + Dependencies: dependencies{set: true, isList: false, path: "req.yaml"}, + }, + "only valid with inline dependencies", + }, + { + "version without deps", + environmentConfig{Version: stringOrInt{set: true, raw: "5"}}, + "requires inline 'dependencies'", + }, + { + "version with inline deps ok", + environmentConfig{ + Version: stringOrInt{set: true, raw: "5"}, + Dependencies: dependencies{set: true, isList: true, list: []string{"torch"}}, + }, + "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.env.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestGitRefValidate(t *testing.T) { + str := func(s string) *string { return &s } + tests := []struct { + name string + ref gitRef + errFrag string + }{ + {"branch only ok", gitRef{Branch: str("main")}, ""}, + {"commit only ok", gitRef{Commit: str("abc123")}, ""}, + {"branch with remote ok", gitRef{Branch: str("main"), Remote: gitRemote{set: true, enabled: true}}, ""}, + {"neither branch nor commit", gitRef{}, "must specify either 'branch' or 'commit'"}, + {"both branch and commit", gitRef{Branch: str("main"), Commit: str("abc")}, "mutually exclusive"}, + {"remote without branch", gitRef{Commit: str("abc"), Remote: gitRemote{set: true, isString: true, name: "origin"}}, "requires git.branch"}, + {"bad branch chars", gitRef{Branch: str("bad branch")}, "invalid git.branch"}, + {"empty remote string", gitRef{Branch: str("main"), Remote: gitRemote{set: true, isString: true, name: ""}}, "cannot be empty"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.ref.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestSnapshotSourceConfigValidate(t *testing.T) { + tests := []struct { + name string + snap snapshotSourceConfig + errFrag string + }{ + {"ok", snapshotSourceConfig{RootPath: "src"}, ""}, + {"empty root_path", snapshotSourceConfig{RootPath: " "}, "root_path cannot be empty"}, + {"bad volume", snapshotSourceConfig{RootPath: "src", RemoteVolume: new("/mnt/x")}, "must start with '/Volumes/'"}, + {"empty include list", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{}}, "cannot be an empty list"}, + {"absolute include", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{"/etc"}}, "must be relative"}, + {"traversal include", snapshotSourceConfig{RootPath: "src", IncludePaths: []string{"../x"}}, "'..' traversal"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.snap.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestPermissionValidate(t *testing.T) { + str := func(s string) *string { return &s } + tests := []struct { + name string + perm permission + errFrag string + }{ + {"ok user", permission{UserName: str("alice@example.com"), Level: "CAN_VIEW"}, ""}, + {"no principal", permission{Level: "CAN_VIEW"}, "must be specified"}, + {"two principals", permission{UserName: str("a"), GroupName: str("g"), Level: "CAN_VIEW"}, "only one of"}, + {"empty principal", permission{UserName: str(" "), Level: "CAN_VIEW"}, "cannot be empty"}, + {"missing level", permission{GroupName: str("users")}, "'level' is required"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.perm.validate() + if tt.errFrag == "" { + assert.NoError(t, err) + return + } + require.Error(t, err) + assert.Contains(t, err.Error(), tt.errFrag) + }) + } +} + +func TestLoadRunConfig_FileErrors(t *testing.T) { + t.Run("missing file", func(t *testing.T) { + _, err := loadRunConfig(filepath.Join(t.TempDir(), "nope.yaml")) + assert.Error(t, err) + }) + t.Run("empty file", func(t *testing.T) { + _, err := loadRunConfig(writeConfig(t, "")) + require.Error(t, err) + assert.Contains(t, err.Error(), "is empty") + }) +} From f757f6b01bb898eaea8079a32daef83b59e0451b Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 20:21:45 +0000 Subject: [PATCH 2/6] experimental/air: add run config launch accessors Flatten the validated runConfig schema into the derived values the launch path consumes (timeout seconds, retry default, docker image URL, requirements file vs inline dependencies, runtime version), replacing the Python CLI's _convert_to_run_config step. handle_run reads runConfig directly, so these are accessors rather than a separate internal config type. Co-authored-by: Isaac --- experimental/air/cmd/runconfig_launch.go | 62 ++++++++++++++ experimental/air/cmd/runconfig_launch_test.go | 80 +++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 experimental/air/cmd/runconfig_launch.go create mode 100644 experimental/air/cmd/runconfig_launch_test.go diff --git a/experimental/air/cmd/runconfig_launch.go b/experimental/air/cmd/runconfig_launch.go new file mode 100644 index 0000000000..6864e4c454 --- /dev/null +++ b/experimental/air/cmd/runconfig_launch.go @@ -0,0 +1,62 @@ +package aircmd + +// This file flattens the validated runConfig schema into the derived values the +// launch path consumes, replacing the Python CLI's _convert_to_run_config step. +// There is no separate internal config type: handle_run reads runConfig directly, +// using these accessors for the values that need computing rather than a plain +// field read. + +const defaultMaxRetries = 3 + +// timeoutSeconds converts timeout_minutes to seconds. Zero means the user set no +// timeout and the backend default applies. +func (c *runConfig) timeoutSeconds() int { + if c.TimeoutMinutes == nil { + return 0 + } + return *c.TimeoutMinutes * 60 +} + +// maxRetries returns the retry count, applying the schema default when unset. +func (c *runConfig) maxRetries() int { + if c.MaxRetries == nil { + return defaultMaxRetries + } + return *c.MaxRetries +} + +// dockerImageURL returns the custom docker image URL, or "" when none is set. +func (c *runConfig) dockerImageURL() string { + if c.Environment != nil && c.Environment.DockerImage != nil { + return c.Environment.DockerImage.URL + } + return "" +} + +// requirementsFile returns the path to a requirements file when +// environment.dependencies is a string, and whether it was set. +func (c *runConfig) requirementsFile() (string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || c.Environment.Dependencies.isList { + return "", false + } + return c.Environment.Dependencies.path, true +} + +// inlineDependencies returns the inline package list when +// environment.dependencies is a list, and whether it was set. +func (c *runConfig) inlineDependencies() ([]string, bool) { + if c.Environment == nil || !c.Environment.Dependencies.set || !c.Environment.Dependencies.isList { + return nil, false + } + return c.Environment.Dependencies.list, true +} + +// runtimeVersion returns the client image version from environment.version when +// set. For a requirements-file dependency set, the version lives in that file and +// is resolved at launch, not here. +func (c *runConfig) runtimeVersion() (string, bool) { + if c.Environment == nil || !c.Environment.Version.set { + return "", false + } + return c.Environment.Version.raw, true +} diff --git a/experimental/air/cmd/runconfig_launch_test.go b/experimental/air/cmd/runconfig_launch_test.go new file mode 100644 index 0000000000..289db91c7d --- /dev/null +++ b/experimental/air/cmd/runconfig_launch_test.go @@ -0,0 +1,80 @@ +package aircmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRunConfigTimeoutSeconds(t *testing.T) { + c := &runConfig{} + assert.Equal(t, 0, c.timeoutSeconds()) + + c.TimeoutMinutes = new(2) + assert.Equal(t, 120, c.timeoutSeconds()) +} + +func TestRunConfigMaxRetries(t *testing.T) { + c := &runConfig{} + assert.Equal(t, defaultMaxRetries, c.maxRetries()) + + c.MaxRetries = new(0) + assert.Equal(t, 0, c.maxRetries()) + + c.MaxRetries = new(7) + assert.Equal(t, 7, c.maxRetries()) +} + +func TestRunConfigDockerImageURL(t *testing.T) { + c := &runConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment = &environmentConfig{} + assert.Empty(t, c.dockerImageURL()) + + c.Environment.DockerImage = &dockerImageConfig{URL: "org/repo:tag"} + assert.Equal(t, "org/repo:tag", c.dockerImageURL()) +} + +func TestRunConfigDependencies(t *testing.T) { + t.Run("unset", func(t *testing.T) { + c := &runConfig{} + _, ok := c.requirementsFile() + assert.False(t, ok) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("file path", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: false, path: "req.yaml"}, + }} + path, ok := c.requirementsFile() + assert.True(t, ok) + assert.Equal(t, "req.yaml", path) + _, ok = c.inlineDependencies() + assert.False(t, ok) + }) + + t.Run("inline list", func(t *testing.T) { + c := &runConfig{Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + }} + list, ok := c.inlineDependencies() + assert.True(t, ok) + assert.Equal(t, []string{"torch", "numpy"}, list) + _, ok = c.requirementsFile() + assert.False(t, ok) + }) +} + +func TestRunConfigRuntimeVersion(t *testing.T) { + c := &runConfig{} + _, ok := c.runtimeVersion() + assert.False(t, ok) + + c.Environment = &environmentConfig{Version: stringOrInt{set: true, raw: "5"}} + v, ok := c.runtimeVersion() + assert.True(t, ok) + assert.Equal(t, "5", v) +} From 2e7cf854c8bc9164a7aff4f25d906ec84b74026e Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Thu, 18 Jun 2026 20:37:18 +0000 Subject: [PATCH 3/6] experimental/air: wire run command for load, validate, dry-run Wire `air run`'s RunE to load and structurally validate the YAML config, and implement --dry-run (validate without submitting). The non-dry-run submission path returns "not implemented" until the submit phase lands; --override is rejected with a clear error since the override pipeline is not ported yet. Drop `run` from the not-implemented stub test now that it does real work. Co-authored-by: Isaac --- acceptance/experimental/air/run/invalid.yaml | 5 ++ acceptance/experimental/air/run/out.test.toml | 3 ++ acceptance/experimental/air/run/output.txt | 39 ++++++++++++++ acceptance/experimental/air/run/script | 17 +++++++ acceptance/experimental/air/run/test.toml | 4 ++ acceptance/experimental/air/run/valid.yaml | 5 ++ .../experimental/air/unimplemented/output.txt | 6 --- .../experimental/air/unimplemented/script | 3 -- experimental/air/cmd/run.go | 51 +++++++++++++++++-- experimental/air/cmd/stubs_test.go | 1 - 10 files changed, 121 insertions(+), 13 deletions(-) create mode 100644 acceptance/experimental/air/run/invalid.yaml create mode 100644 acceptance/experimental/air/run/out.test.toml create mode 100644 acceptance/experimental/air/run/output.txt create mode 100644 acceptance/experimental/air/run/script create mode 100644 acceptance/experimental/air/run/test.toml create mode 100644 acceptance/experimental/air/run/valid.yaml diff --git a/acceptance/experimental/air/run/invalid.yaml b/acceptance/experimental/air/run/invalid.yaml new file mode 100644 index 0000000000..c011fc81b3 --- /dev/null +++ b/acceptance/experimental/air/run/invalid.yaml @@ -0,0 +1,5 @@ +experiment_name: bad.name +command: x +compute: + accelerator_type: GPU_8xH100 + num_accelerators: 3 diff --git a/acceptance/experimental/air/run/out.test.toml b/acceptance/experimental/air/run/out.test.toml new file mode 100644 index 0000000000..d6187dcb04 --- /dev/null +++ b/acceptance/experimental/air/run/out.test.toml @@ -0,0 +1,3 @@ +Local = true +Cloud = false +EnvMatrix.DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/output.txt b/acceptance/experimental/air/run/output.txt new file mode 100644 index 0000000000..180886290b --- /dev/null +++ b/acceptance/experimental/air/run/output.txt @@ -0,0 +1,39 @@ + +=== dry-run (text) +>>> [CLI] experimental air run -f valid.yaml --dry-run +Dry run: configuration for "smoke-test" is valid; not submitting. + +=== dry-run (json) +>>> [CLI] experimental air run -f valid.yaml --dry-run -o json +{ + "v": 1, + "ts": "[TIMESTAMP]", + "data": { + "status": "DRY_RUN_OK", + "dry_run": true + } +} + +=== override not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --override a=b +Error: --override is not yet supported + +Exit code: 1 + +=== watch not yet supported +>>> [CLI] experimental air run -f valid.yaml --dry-run --watch +Error: --watch is not yet supported + +Exit code: 1 + +=== invalid config is rejected +>>> [CLI] experimental air run -f invalid.yaml --dry-run +Error: invalid experiment_name "bad.name": only alphanumeric characters, hyphens (-), and underscores (_) are allowed + +Exit code: 1 + +=== missing --file +>>> [CLI] experimental air run --dry-run +Error: required flag(s) "file" not set + +Exit code: 1 diff --git a/acceptance/experimental/air/run/script b/acceptance/experimental/air/run/script new file mode 100644 index 0000000000..806bd321e6 --- /dev/null +++ b/acceptance/experimental/air/run/script @@ -0,0 +1,17 @@ +title "dry-run (text)" +trace $CLI experimental air run -f valid.yaml --dry-run + +title "dry-run (json)" +trace $CLI experimental air run -f valid.yaml --dry-run -o json + +title "override not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --override a=b + +title "watch not yet supported" +errcode trace $CLI experimental air run -f valid.yaml --dry-run --watch + +title "invalid config is rejected" +errcode trace $CLI experimental air run -f invalid.yaml --dry-run + +title "missing --file" +errcode trace $CLI experimental air run --dry-run diff --git a/acceptance/experimental/air/run/test.toml b/acceptance/experimental/air/run/test.toml new file mode 100644 index 0000000000..2f971c3ed2 --- /dev/null +++ b/acceptance/experimental/air/run/test.toml @@ -0,0 +1,4 @@ +# `air run --dry-run` validates the config locally and makes no workspace calls, +# so no engine matrix or server stubs are needed. +[EnvMatrix] +DATABRICKS_BUNDLE_ENGINE = [] diff --git a/acceptance/experimental/air/run/valid.yaml b/acceptance/experimental/air/run/valid.yaml new file mode 100644 index 0000000000..b82a321b05 --- /dev/null +++ b/acceptance/experimental/air/run/valid.yaml @@ -0,0 +1,5 @@ +experiment_name: smoke-test +command: python train.py +compute: + accelerator_type: GPU_1xH100 + num_accelerators: 1 diff --git a/acceptance/experimental/air/unimplemented/output.txt b/acceptance/experimental/air/unimplemented/output.txt index 66ddc34d58..21c3c891af 100644 --- a/acceptance/experimental/air/unimplemented/output.txt +++ b/acceptance/experimental/air/unimplemented/output.txt @@ -1,10 +1,4 @@ -=== run ->>> [CLI] experimental air run -Error: `air run` is not implemented yet - -Exit code: 1 - === logs >>> [CLI] experimental air logs 123 Error: `air logs` is not implemented yet diff --git a/acceptance/experimental/air/unimplemented/script b/acceptance/experimental/air/unimplemented/script index d00d045368..4c53586b16 100644 --- a/acceptance/experimental/air/unimplemented/script +++ b/acceptance/experimental/air/unimplemented/script @@ -1,8 +1,5 @@ # Each stub must fail with "not implemented"; errcode records the exit code. -title "run" -errcode trace $CLI experimental air run - title "logs" errcode trace $CLI experimental air logs 123 diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go index 0bc3d1fd94..95bf360b83 100644 --- a/experimental/air/cmd/run.go +++ b/experimental/air/cmd/run.go @@ -1,10 +1,23 @@ package aircmd import ( + "errors" + "fmt" + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" ) +// runResult is the JSON payload for `air run`. +type runResult struct { + Status string `json:"status"` + DryRun bool `json:"dry_run,omitempty"` + RunID string `json:"run_id,omitempty"` + DashboardURL string `json:"dashboard_url,omitempty"` +} + func newRunCommand() *cobra.Command { var ( file string @@ -21,9 +34,6 @@ func newRunCommand() *cobra.Command { Long: `Submit a training workload to Databricks serverless GPU compute. The workload is described by a YAML config file (see --file).`, - RunE: func(cmd *cobra.Command, args []string) error { - return notImplemented("run") - }, } cmd.Flags().StringVarP(&file, "file", "f", "", "Path to the workload YAML config") @@ -31,6 +41,41 @@ The workload is described by a YAML config file (see --file).`, cmd.Flags().StringArrayVar(&overrides, "override", nil, "Override a YAML field, e.g. compute.num_accelerators=8 (repeatable)") cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Validate the config without submitting") cmd.Flags().StringVar(&idempotencyKey, "idempotency-key", "", "Return the existing run if this key was already used") + _ = cmd.MarkFlagRequired("file") + + // --dry-run only validates the config locally, so it needs no workspace. + // Submission requires an authenticated client. + cmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if dryRun { + return nil + } + return root.MustWorkspaceClient(cmd, args) + } + + cmd.RunE = func(cmd *cobra.Command, args []string) error { + ctx := cmd.Context() + + // --override is parsed and applied before validation; that pipeline is + // not ported yet, so reject it rather than silently ignore the flag. + if len(overrides) > 0 { + return errors.New("--override is not yet supported") + } + + cfg, err := loadRunConfig(file) + if err != nil { + return err + } + + if dryRun { + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, fmt.Sprintf("Dry run: configuration for %q is valid; not submitting.", cfg.ExperimentName)) + return nil + } + return renderEnvelope(ctx, runResult{Status: "DRY_RUN_OK", DryRun: true}) + } + + return notImplemented("run submission") + } return cmd } diff --git a/experimental/air/cmd/stubs_test.go b/experimental/air/cmd/stubs_test.go index b9f5c330f0..4607d7d9ea 100644 --- a/experimental/air/cmd/stubs_test.go +++ b/experimental/air/cmd/stubs_test.go @@ -13,7 +13,6 @@ import ( // fails with a "not implemented" error. Drop a command here once it lands. func TestStubCommandsReturnNotImplemented(t *testing.T) { stubs := map[string]*cobra.Command{ - "run": newRunCommand(), "logs": newLogsCommand(), "cancel": newCancelCommand(), "register-image": newRegisterImageCommand(), From f46d5e0dfd5382d632eb050ce5e5b07513caabb1 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:15:39 +0000 Subject: [PATCH 4/6] experimental/air: add run pre-submit resolution helpers Resolve the workspace context air run needs before uploading and submitting: the current user, the per-user workspace home (with env override), a unique cli_launch directory for a run's artifacts, the MLflow experiment path, and ensuring a custom experiment_directory exists (created if missing, matching the CLI's convention for its other artifact directories). Co-authored-by: Isaac --- experimental/air/cmd/runlaunch.go | 82 ++++++++++++++++++++++++++ experimental/air/cmd/runlaunch_test.go | 71 ++++++++++++++++++++++ 2 files changed, 153 insertions(+) create mode 100644 experimental/air/cmd/runlaunch.go create mode 100644 experimental/air/cmd/runlaunch_test.go diff --git a/experimental/air/cmd/runlaunch.go b/experimental/air/cmd/runlaunch.go new file mode 100644 index 0000000000..163655d9dd --- /dev/null +++ b/experimental/air/cmd/runlaunch.go @@ -0,0 +1,82 @@ +package aircmd + +import ( + "context" + "errors" + "fmt" + "path" + "strings" + + "github.com/databricks/cli/libs/env" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/apierr" + "github.com/databricks/databricks-sdk-go/service/iam" + "github.com/databricks/databricks-sdk-go/service/workspace" + "github.com/google/uuid" +) + +// userWorkspaceDirEnv overrides the per-user workspace directory; mirrors the +// Python CLI's DATABRICKS_INTERNAL_USER_WORKSPACE_DIR escape hatch. +const userWorkspaceDirEnv = "DATABRICKS_INTERNAL_USER_WORKSPACE_DIR" + +// currentUserEmail returns the authenticated user's email (works for any domain). +func currentUserEmail(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + me, err := w.CurrentUser.Me(ctx, iam.MeRequest{}) + if err != nil { + return "", fmt.Errorf("failed to resolve current user: %w", err) + } + return me.UserName, nil +} + +// userWorkspaceDir returns the user's workspace home, honoring the env override. +func userWorkspaceDir(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + if override := env.Get(ctx, userWorkspaceDirEnv); override != "" { + return override, nil + } + email, err := currentUserEmail(ctx, w) + if err != nil { + return "", err + } + return "/Workspace/Users/" + email, nil +} + +// cliLaunchDir returns a unique workspace directory for a run's launch artifacts: +// /.air/cli_launch//_. run defaults to experiment. +func cliLaunchDir(base, experiment, run string) string { + if run == "" { + run = experiment + } + unique := strings.ReplaceAll(uuid.NewString(), "-", "")[:16] + return path.Join(base, ".air", "cli_launch", experiment, run+"_"+unique) +} + +// mlflowExperimentName builds the full MLflow experiment path. A custom directory +// is used as-is; otherwise it defaults under the user's home. +func mlflowExperimentName(experiment, experimentDir, userEmail string) string { + if experimentDir != "" { + return strings.TrimRight(experimentDir, "/") + "/" + experiment + } + return "/Users/" + userEmail + "/" + experiment +} + +// ensureExperimentDirectory creates experimentDir if it is missing, matching the +// CLI's convention for its other artifact directories. Without this, a missing +// parent surfaces only as a server-side INTERNAL_ERROR after the run is wasted. +// An empty dir means the default (/Users//...), which always exists. +func ensureExperimentDirectory(ctx context.Context, w *databricks.WorkspaceClient, experimentDir string) error { + if experimentDir == "" { + return nil + } + + info, err := w.Workspace.GetStatusByPath(ctx, experimentDir) + if errors.Is(err, apierr.ErrNotFound) { + return w.Workspace.MkdirsByPath(ctx, experimentDir) + } + if err != nil { + return fmt.Errorf("failed to check experiment_directory %q: %w", experimentDir, err) + } + if info.ObjectType != workspace.ObjectTypeDirectory { + return fmt.Errorf("experiment_directory %q is not a directory (object_type=%s)", experimentDir, info.ObjectType) + } + return nil +} diff --git a/experimental/air/cmd/runlaunch_test.go b/experimental/air/cmd/runlaunch_test.go new file mode 100644 index 0000000000..df8c3a087a --- /dev/null +++ b/experimental/air/cmd/runlaunch_test.go @@ -0,0 +1,71 @@ +package aircmd + +import ( + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestMlflowExperimentName(t *testing.T) { + assert.Equal(t, "/Users/me@example.com/exp", mlflowExperimentName("exp", "", "me@example.com")) + assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared", "me@example.com")) + assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared/", "me@example.com")) +} + +func TestCliLaunchDir(t *testing.T) { + dir := cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "") + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/me@example.com/.air/cli_launch/my-exp/my-exp_"), dir) + // run name overrides the leaf; the unique suffix keeps successive dirs distinct. + withRun := cliLaunchDir("/base", "exp", "run1") + assert.True(t, strings.HasPrefix(withRun, "/base/.air/cli_launch/exp/run1_"), withRun) + assert.NotEqual(t, dir, cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "")) +} + +func newFakeWorkspaceClient(t *testing.T) *databricks.WorkspaceClient { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + return w +} + +func TestUserWorkspaceDir(t *testing.T) { + w := newFakeWorkspaceClient(t) + dir, err := userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/"), dir) + + // The env override wins without an API call. + t.Setenv(userWorkspaceDirEnv, "/Workspace/custom") + dir, err = userWorkspaceDir(t.Context(), w) + require.NoError(t, err) + assert.Equal(t, "/Workspace/custom", dir) +} + +func TestEnsureExperimentDirectory(t *testing.T) { + ctx := t.Context() + w := newFakeWorkspaceClient(t) + + // Empty means default (always exists) — no API call, no error. + require.NoError(t, ensureExperimentDirectory(ctx, w, "")) + + // A missing path is created. + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/exp")) + + // An existing directory is accepted as-is. + require.NoError(t, w.Workspace.MkdirsByPath(ctx, "/Workspace/Users/me/existing")) + require.NoError(t, ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/existing")) + + // A path that exists but is a file is rejected. + fc, err := filer.NewWorkspaceFilesClient(w, "/Workspace/Users/me") + require.NoError(t, err) + require.NoError(t, fc.Write(ctx, "afile", strings.NewReader("x"))) + err = ensureExperimentDirectory(ctx, w, "/Workspace/Users/me/afile") + require.ErrorContains(t, err, "is not a directory") +} From 185f533a05bdfca7dafc6dfd1561a65209d30c46 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:18:02 +0000 Subject: [PATCH 5/6] experimental/air: upload run launch artifacts Assemble and upload the launch artifacts for a run into its cli_launch directory: the merged config (training_config.yaml, 1 MB cap), the inline command as command.sh, requirements.yaml (from a file or synthesized from inline dependencies), and hyperparameters.yaml. buildArtifacts is pure; the upload writes through a narrow fileWriter (a workspace filer in production). A TODO(DABs) marks the client-side upload path as a future candidate for reuse of DABs' file-staging (libs/sync / bundle deploy). Co-authored-by: Isaac --- experimental/air/cmd/runupload.go | 114 +++++++++++++++++++++ experimental/air/cmd/runupload_test.go | 135 +++++++++++++++++++++++++ 2 files changed, 249 insertions(+) create mode 100644 experimental/air/cmd/runupload.go create mode 100644 experimental/air/cmd/runupload_test.go diff --git a/experimental/air/cmd/runupload.go b/experimental/air/cmd/runupload.go new file mode 100644 index 0000000000..c9c4047381 --- /dev/null +++ b/experimental/air/cmd/runupload.go @@ -0,0 +1,114 @@ +package aircmd + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + "path/filepath" + + "github.com/databricks/cli/libs/filer" + "go.yaml.in/yaml/v3" +) + +// Launch artifact basenames, uploaded into the run's cli_launch directory. The +// server-side launcher derives requirements.yaml / hyperparameters.yaml from the +// same directory, so these names are part of the contract. +const ( + trainingConfigName = "training_config.yaml" + commandScriptName = "command.sh" + requirementsName = "requirements.yaml" + hyperparametersName = "hyperparameters.yaml" +) + +// maxConfigYAMLBytes caps training_config.yaml. It is referenced by the Jobs +// payload and rendered on the run page, so an oversized parameters/command block +// is rejected here; full parameters still ship in hyperparameters.yaml. +const maxConfigYAMLBytes = 1024 * 1024 + +// uploadItem is a single artifact to write into the launch directory. +type uploadItem struct { + name string + data []byte +} + +// fileWriter is the subset of filer.Filer the upload path needs; a narrow +// interface keeps buildArtifacts/upload testable without a live workspace. +type fileWriter interface { + Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error +} + +// requirementsDoc mirrors the on-disk requirements.yaml format so the worker +// parses synthesized inline dependencies identically to a user-provided file. +type requirementsDoc struct { + Version string `yaml:"version,omitempty"` + Dependencies []string `yaml:"dependencies"` +} + +// buildArtifacts assembles the files to upload for a run: the merged config, the +// inline command as a script, requirements (from a file or synthesized from +// inline dependencies), and hyperparameters. configPath is the local YAML path. +func buildArtifacts(cfg *runConfig, configPath string) ([]uploadItem, error) { + // TODO(DABs): with no _bases_/overrides ported yet, the merged config is the + // file as-is; once those land, upload the re-serialized merged YAML instead. + configData, err := os.ReadFile(configPath) + if err != nil { + return nil, fmt.Errorf("failed to read config %s: %w", configPath, err) + } + if len(configData) > maxConfigYAMLBytes { + return nil, fmt.Errorf("config YAML is %.2f MB, over the %d MB limit; reduce 'parameters' or 'command'", + float64(len(configData))/(1024*1024), maxConfigYAMLBytes/(1024*1024)) + } + + items := []uploadItem{ + {trainingConfigName, configData}, + {commandScriptName, []byte(*cfg.Command)}, + } + + switch reqPath, ok := cfg.requirementsFile(); { + case ok: + // Resolve a relative requirements path against the config's directory. + if !filepath.IsAbs(reqPath) { + reqPath = filepath.Join(filepath.Dir(configPath), reqPath) + } + data, err := os.ReadFile(reqPath) + if err != nil { + return nil, fmt.Errorf("failed to read requirements file %s: %w", reqPath, err) + } + items = append(items, uploadItem{requirementsName, data}) + default: + if deps, ok := cfg.inlineDependencies(); ok { + version, _ := cfg.runtimeVersion() + data, err := yaml.Marshal(requirementsDoc{Version: version, Dependencies: deps}) + if err != nil { + return nil, fmt.Errorf("failed to synthesize requirements.yaml: %w", err) + } + items = append(items, uploadItem{requirementsName, data}) + } + } + + if len(cfg.Parameters) > 0 { + data, err := yaml.Marshal(cfg.Parameters) + if err != nil { + return nil, fmt.Errorf("failed to serialize parameters: %w", err) + } + items = append(items, uploadItem{hyperparametersName, data}) + } + + return items, nil +} + +// uploadArtifacts writes each artifact into the launch directory, overwriting and +// creating parents as needed. +// +// TODO(DABs): this client-side upload could move onto libs/sync / a bundle deploy +// so the CLI reuses DABs' file-staging machinery instead of writing files itself. +func uploadArtifacts(ctx context.Context, w fileWriter, items []uploadItem) error { + for _, it := range items { + if err := w.Write(ctx, it.name, bytes.NewReader(it.data), filer.OverwriteIfExists, filer.CreateParentDirectories); err != nil { + return fmt.Errorf("failed to upload %s: %w", it.name, err) + } + } + return nil +} diff --git a/experimental/air/cmd/runupload_test.go b/experimental/air/cmd/runupload_test.go new file mode 100644 index 0000000000..ec700a9229 --- /dev/null +++ b/experimental/air/cmd/runupload_test.go @@ -0,0 +1,135 @@ +package aircmd + +import ( + "context" + "errors" + "io" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/databricks/cli/libs/filer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// fakeWriter records artifact writes in place of a workspace filer. +type fakeWriter struct { + written map[string]string +} + +func (f *fakeWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + if f.written == nil { + f.written = map[string]string{} + } + data, err := io.ReadAll(reader) + if err != nil { + return err + } + f.written[name] = string(data) + return nil +} + +func writeConfigFile(t *testing.T, name, content string) string { + t.Helper() + path := filepath.Join(t.TempDir(), name) + require.NoError(t, os.WriteFile(path, []byte(content), 0o600)) + return path +} + +func itemNames(items []uploadItem) []string { + names := make([]string, len(items)) + for i, it := range items { + names[i] = it.name + } + return names +} + +func TestBuildArtifacts_CommandAndConfig(t *testing.T) { + path := writeConfigFile(t, "run.yaml", minimalConfig) + cfg := &runConfig{Command: new("python train.py")} + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName}, itemNames(items)) + assert.Equal(t, minimalConfig, string(items[0].data)) + assert.Equal(t, "python train.py", string(items[1].data)) +} + +func TestBuildArtifacts_InlineRequirementsAndParameters(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{ + Dependencies: dependencies{set: true, isList: true, list: []string{"torch", "numpy"}}, + Version: stringOrInt{set: true, raw: "5"}, + }, + Parameters: map[string]any{"lr": 0.1}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Equal(t, []string{trainingConfigName, commandScriptName, requirementsName, hyperparametersName}, itemNames(items)) + + var reqIdx int + for i, it := range items { + if it.name == requirementsName { + reqIdx = i + } + } + req := string(items[reqIdx].data) + assert.Contains(t, req, "version: \"5\"") + assert.Contains(t, req, "- torch") +} + +func TestBuildArtifacts_RequirementsFile(t *testing.T) { + dir := t.TempDir() + require.NoError(t, os.WriteFile(filepath.Join(dir, "run.yaml"), []byte("x: y\n"), 0o600)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "reqs.yaml"), []byte("version: 4\n"), 0o600)) + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "reqs.yaml"}}, + } + + items, err := buildArtifacts(cfg, filepath.Join(dir, "run.yaml")) + require.NoError(t, err) + assert.Contains(t, itemNames(items), requirementsName) +} + +func TestBuildArtifacts_OversizeConfigRejected(t *testing.T) { + path := writeConfigFile(t, "run.yaml", strings.Repeat("a", maxConfigYAMLBytes+1)) + _, err := buildArtifacts(&runConfig{Command: new("x")}, path) + require.Error(t, err) + assert.Contains(t, err.Error(), "over the 1 MB limit") +} + +func TestUploadArtifacts(t *testing.T) { + w := &fakeWriter{} + items := []uploadItem{{trainingConfigName, []byte("cfg")}, {commandScriptName, []byte("cmd")}} + require.NoError(t, uploadArtifacts(t.Context(), w, items)) + assert.Equal(t, "cfg", w.written[trainingConfigName]) + assert.Equal(t, "cmd", w.written[commandScriptName]) +} + +// errWriter fails every Write, exercising the upload error path. +type errWriter struct{} + +func (errWriter) Write(ctx context.Context, name string, reader io.Reader, mode ...filer.WriteMode) error { + return errors.New("boom") +} + +func TestUploadArtifacts_WriteError(t *testing.T) { + err := uploadArtifacts(t.Context(), errWriter{}, []uploadItem{{trainingConfigName, []byte("x")}}) + require.ErrorContains(t, err, "failed to upload "+trainingConfigName) +} + +func TestBuildArtifacts_MissingRequirementsFile(t *testing.T) { + cfgPath := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + Environment: &environmentConfig{Dependencies: dependencies{set: true, isList: false, path: "nope.yaml"}}, + } + _, err := buildArtifacts(cfg, cfgPath) + require.ErrorContains(t, err, "failed to read requirements file") +} From a5d851b7080fba99ef9374317e06be2add30bd29 Mon Sep 17 00:00:00 2001 From: riddhibhagwat-db Date: Mon, 22 Jun 2026 23:30:39 +0000 Subject: [PATCH 6/6] experimental/air: assemble and submit a training run MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Wire `air run` end to end: ensure the experiment directory, upload launch artifacts, build the native ai_runtime_task payload, and submit it via a direct POST to /api/2.2/jobs/runs/submit. The ai_runtime_task routes straight to the training service with no genai-mapi forwarding — the MAPI path is deprecated. The proto is lean: env vars and secrets are staged as co-located env_vars.json / secret_env_vars.json workspace files rather than inline, and requirements / hyperparameters are derived server-side from the command directory. The non-dry-run path resolves the workspace context, uploads, submits, and prints the run id + dashboard URL. usage_policy_name, code_source snapshots, and --watch are rejected with clear errors until their phases land. environment.docker_image is accepted by the schema as scaffolding but not conveyed (the native path has no docker field). Co-authored-by: Isaac --- experimental/air/cmd/run.go | 23 ++- experimental/air/cmd/runlaunch.go | 9 - experimental/air/cmd/runlaunch_test.go | 6 - experimental/air/cmd/runsubmit.go | 244 +++++++++++++++++++++++++ experimental/air/cmd/runsubmit_test.go | 145 +++++++++++++++ experimental/air/cmd/runupload.go | 56 ++++++ experimental/air/cmd/runupload_test.go | 20 ++ 7 files changed, 485 insertions(+), 18 deletions(-) create mode 100644 experimental/air/cmd/runsubmit.go create mode 100644 experimental/air/cmd/runsubmit_test.go diff --git a/experimental/air/cmd/run.go b/experimental/air/cmd/run.go index 95bf360b83..bd32810e9b 100644 --- a/experimental/air/cmd/run.go +++ b/experimental/air/cmd/run.go @@ -3,8 +3,10 @@ package aircmd import ( "errors" "fmt" + "strconv" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/flags" "github.com/spf13/cobra" @@ -55,11 +57,14 @@ The workload is described by a YAML config file (see --file).`, cmd.RunE = func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // --override is parsed and applied before validation; that pipeline is - // not ported yet, so reject it rather than silently ignore the flag. + // These flags' pipelines are not ported yet; reject rather than silently + // ignore them. if len(overrides) > 0 { return errors.New("--override is not yet supported") } + if watch { + return errors.New("--watch is not yet supported") + } cfg, err := loadRunConfig(file) if err != nil { @@ -74,7 +79,19 @@ The workload is described by a YAML config file (see --file).`, return renderEnvelope(ctx, runResult{Status: "DRY_RUN_OK", DryRun: true}) } - return notImplemented("run submission") + w := cmdctx.WorkspaceClient(ctx) + runID, dashboardURL, err := submitWorkload(ctx, w, cfg, file, idempotencyKey) + if err != nil { + return err + } + + runIDStr := strconv.FormatInt(runID, 10) + if root.OutputType(cmd) == flags.OutputText { + cmdio.LogString(ctx, "Submitted run "+runIDStr) + cmdio.LogString(ctx, "View at: "+dashboardURL) + return nil + } + return renderEnvelope(ctx, runResult{Status: "SUBMITTED", RunID: runIDStr, DashboardURL: dashboardURL}) } return cmd diff --git a/experimental/air/cmd/runlaunch.go b/experimental/air/cmd/runlaunch.go index 163655d9dd..b2a7215e66 100644 --- a/experimental/air/cmd/runlaunch.go +++ b/experimental/air/cmd/runlaunch.go @@ -50,15 +50,6 @@ func cliLaunchDir(base, experiment, run string) string { return path.Join(base, ".air", "cli_launch", experiment, run+"_"+unique) } -// mlflowExperimentName builds the full MLflow experiment path. A custom directory -// is used as-is; otherwise it defaults under the user's home. -func mlflowExperimentName(experiment, experimentDir, userEmail string) string { - if experimentDir != "" { - return strings.TrimRight(experimentDir, "/") + "/" + experiment - } - return "/Users/" + userEmail + "/" + experiment -} - // ensureExperimentDirectory creates experimentDir if it is missing, matching the // CLI's convention for its other artifact directories. Without this, a missing // parent surfaces only as a server-side INTERNAL_ERROR after the run is wasted. diff --git a/experimental/air/cmd/runlaunch_test.go b/experimental/air/cmd/runlaunch_test.go index df8c3a087a..af6f0f70d3 100644 --- a/experimental/air/cmd/runlaunch_test.go +++ b/experimental/air/cmd/runlaunch_test.go @@ -11,12 +11,6 @@ import ( "github.com/stretchr/testify/require" ) -func TestMlflowExperimentName(t *testing.T) { - assert.Equal(t, "/Users/me@example.com/exp", mlflowExperimentName("exp", "", "me@example.com")) - assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared", "me@example.com")) - assert.Equal(t, "/Workspace/shared/exp", mlflowExperimentName("exp", "/Workspace/shared/", "me@example.com")) -} - func TestCliLaunchDir(t *testing.T) { dir := cliLaunchDir("/Workspace/Users/me@example.com", "my-exp", "") assert.True(t, strings.HasPrefix(dir, "/Workspace/Users/me@example.com/.air/cli_launch/my-exp/my-exp_"), dir) diff --git a/experimental/air/cmd/runsubmit.go b/experimental/air/cmd/runsubmit.go new file mode 100644 index 0000000000..08f7c99399 --- /dev/null +++ b/experimental/air/cmd/runsubmit.go @@ -0,0 +1,244 @@ +package aircmd + +import ( + "context" + "errors" + "net/http" + "path" + "strconv" + "strings" + + "github.com/databricks/cli/libs/auth" + "github.com/databricks/cli/libs/env" + "github.com/databricks/cli/libs/filer" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/client" + "github.com/google/uuid" +) + +// jobsRunsSubmitPath is the Jobs one-time-run endpoint. air builds the full +// payload and POSTs it here directly — the native ai_runtime_task is not modeled +// by the typed SDK, and we want no genai-mapi forwarding. +const jobsRunsSubmitPath = "/api/2.2/jobs/runs/submit" + +// dlRuntimeImageEnv overrides the default deep-learning runtime image. +const dlRuntimeImageEnv = "DATABRICKS_DL_RUNTIME_IMAGE" + +const defaultDlRuntimeImage = "CLIENT-GPU-4" + +// aiRuntimeEnvironmentKey ties the task to the serverless environment that +// carries the runtime channel. +const aiRuntimeEnvironmentKey = "default" + +// aiRuntimeCompute is a deployment's accelerator request. +type aiRuntimeCompute struct { + AcceleratorType string `json:"accelerator_type"` + AcceleratorCount int `json:"accelerator_count"` +} + +// aiRuntimeDeployment is one worker deployment of the run. +type aiRuntimeDeployment struct { + CommandPath string `json:"command_path"` + Compute aiRuntimeCompute `json:"compute"` +} + +// aiRuntimeTask is the native AI Runtime task. It routes straight to the training +// service — no genai-mapi forwarding. The proto is lean: env vars, secrets, +// requirements, and hyperparameters are staged as workspace files co-located with +// command.sh (see runupload.go), not carried inline. +type aiRuntimeTask struct { + Experiment string `json:"experiment"` + Deployments []aiRuntimeDeployment `json:"deployments"` + MlflowRun string `json:"mlflow_run,omitempty"` + MlflowExperimentDirectory string `json:"mlflow_experiment_directory,omitempty"` +} + +// environmentSpec carries the bare runtime channel ("4", "5", ...). +type environmentSpec struct { + EnvironmentVersion string `json:"environment_version"` +} + +// jobEnvironment is the serverless environment a task references for its runtime. +type jobEnvironment struct { + EnvironmentKey string `json:"environment_key"` + Spec environmentSpec `json:"spec"` +} + +// submitTask is the single task air submits: a native ai_runtime_task. +type submitTask struct { + TaskKey string `json:"task_key"` + RunIf string `json:"run_if"` + AiRuntimeTask aiRuntimeTask `json:"ai_runtime_task"` + EnvironmentKey string `json:"environment_key"` + MaxRetries int `json:"max_retries,omitempty"` + RetryOnTimeout bool `json:"retry_on_timeout,omitempty"` +} + +// jobsSubmitRun is the Jobs runs/submit payload. +type jobsSubmitRun struct { + RunName string `json:"run_name"` + TimeoutSeconds int `json:"timeout_seconds,omitempty"` + Tasks []submitTask `json:"tasks"` + Environments []jobEnvironment `json:"environments"` + BudgetPolicyID string `json:"budget_policy_id,omitempty"` + IdempotencyToken string `json:"idempotency_token,omitempty"` +} + +// dlRuntimeImage resolves the runtime channel. A config version wins; otherwise +// the env override or default applies. The CLIENT-GPU- prefix is stripped because +// the native path wants the bare channel. +func dlRuntimeImage(ctx context.Context, runtimeVersion string) string { + if runtimeVersion != "" { + return runtimeVersion + } + img := env.Get(ctx, dlRuntimeImageEnv) + if img == "" { + img = defaultDlRuntimeImage + } + return strings.TrimPrefix(img, "CLIENT-GPU-") +} + +// buildSubmitPayload assembles the runs/submit payload. commandPath is the +// workspace path of the uploaded command.sh; dlImage is the runtime channel. +func buildSubmitPayload(cfg *runConfig, commandPath, dlImage string) jobsSubmitRun { + task := aiRuntimeTask{ + Experiment: cfg.ExperimentName, + Deployments: []aiRuntimeDeployment{{ + CommandPath: commandPath, + Compute: aiRuntimeCompute{ + AcceleratorType: cfg.Compute.AcceleratorType, + AcceleratorCount: cfg.Compute.NumAccelerators, + }, + }}, + } + if cfg.MLflowRunName != nil { + task.MlflowRun = *cfg.MLflowRunName + } + if cfg.MLflowExperimentDirectory != nil { + task.MlflowExperimentDirectory = *cfg.MLflowExperimentDirectory + } + + st := submitTask{ + TaskKey: cfg.ExperimentName, + RunIf: "ALL_SUCCESS", + AiRuntimeTask: task, + EnvironmentKey: aiRuntimeEnvironmentKey, + } + // retry_on_timeout pairs with max_retries, matching the Python payload. + if r := cfg.maxRetries(); r > 0 { + st.MaxRetries = r + st.RetryOnTimeout = true + } + + return jobsSubmitRun{ + RunName: cfg.ExperimentName, + TimeoutSeconds: cfg.timeoutSeconds(), + Tasks: []submitTask{st}, + Environments: []jobEnvironment{{ + EnvironmentKey: aiRuntimeEnvironmentKey, + Spec: environmentSpec{EnvironmentVersion: dlImage}, + }}, + } +} + +// jobsSubmitClient submits one-time runs through the Jobs API. +type jobsSubmitClient struct { + c *client.DatabricksClient +} + +func newJobsSubmitClient(w *databricks.WorkspaceClient) (*jobsSubmitClient, error) { + c, err := client.New(w.Config) + if err != nil { + return nil, err + } + return &jobsSubmitClient{c: c}, nil +} + +type submitRunResponse struct { + RunID int64 `json:"run_id,omitempty"` +} + +// submit POSTs the payload to runs/submit and returns the new run_id. +func (j *jobsSubmitClient) submit(ctx context.Context, payload jobsSubmitRun) (int64, error) { + var resp submitRunResponse + if err := j.c.Do(ctx, http.MethodPost, jobsRunsSubmitPath, auth.WorkspaceIDHeaders(j.c.Config), nil, payload, &resp); err != nil { + return 0, err + } + return resp.RunID, nil +} + +// submitToken resolves the idempotency token: the --idempotency-key flag wins, +// then the config's token, else a generated one. Capped at the Jobs API's 64. +func submitToken(flag string, cfg *runConfig) string { + token := flag + if token == "" && cfg.IdempotencyToken != nil { + token = *cfg.IdempotencyToken + } + if token == "" { + token = uuid.NewString() + } + if len(token) > 64 { + token = token[:64] + } + return token +} + +// submitWorkload runs the submit happy path: ensure the experiment directory, +// upload the launch artifacts, assemble the Jobs payload, and submit it. It +// returns the new run_id and its dashboard URL. +func submitWorkload(ctx context.Context, w *databricks.WorkspaceClient, cfg *runConfig, configPath, idempotencyKey string) (int64, string, error) { + // Resolving usage_policy_name to a budget policy id and packaging a + // code_source snapshot are not ported yet; reject rather than silently drop. + if cfg.UsagePolicyName != nil { + return 0, "", errors.New("usage_policy_name is not yet supported") + } + if cfg.CodeSource != nil { + return 0, "", errors.New("code_source is not yet supported") + } + + experimentDir := "" + if cfg.MLflowExperimentDirectory != nil { + experimentDir = *cfg.MLflowExperimentDirectory + } + if err := ensureExperimentDirectory(ctx, w, experimentDir); err != nil { + return 0, "", err + } + + base, err := userWorkspaceDir(ctx, w) + if err != nil { + return 0, "", err + } + runName := "" + if cfg.MLflowRunName != nil { + runName = *cfg.MLflowRunName + } + funcDir := cliLaunchDir(base, cfg.ExperimentName, runName) + + fc, err := filer.NewWorkspaceFilesClient(w, funcDir) + if err != nil { + return 0, "", err + } + items, err := buildArtifacts(cfg, configPath) + if err != nil { + return 0, "", err + } + if err := uploadArtifacts(ctx, fc, items); err != nil { + return 0, "", err + } + + runtimeVersion, _ := cfg.runtimeVersion() + payload := buildSubmitPayload(cfg, path.Join(funcDir, commandScriptName), dlRuntimeImage(ctx, runtimeVersion)) + payload.IdempotencyToken = submitToken(idempotencyKey, cfg) + + jc, err := newJobsSubmitClient(w) + if err != nil { + return 0, "", err + } + runID, err := jc.submit(ctx, payload) + if err != nil { + return 0, "", err + } + + dashboardURL := strings.TrimRight(w.Config.Host, "/") + "/jobs/runs/" + strconv.FormatInt(runID, 10) + return runID, dashboardURL, nil +} diff --git a/experimental/air/cmd/runsubmit_test.go b/experimental/air/cmd/runsubmit_test.go new file mode 100644 index 0000000000..c43dda0fcd --- /dev/null +++ b/experimental/air/cmd/runsubmit_test.go @@ -0,0 +1,145 @@ +package aircmd + +import ( + "encoding/json" + "strings" + "testing" + + "github.com/databricks/cli/libs/testserver" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/config" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDlRuntimeImage(t *testing.T) { + ctx := t.Context() + // A config runtime version wins and is used bare. + assert.Equal(t, "5", dlRuntimeImage(ctx, "5")) + // Default, with the CLIENT-GPU- prefix stripped for the GPU_* path. + assert.Equal(t, "4", dlRuntimeImage(ctx, "")) + // Env override, prefix stripped. + t.Setenv(dlRuntimeImageEnv, "CLIENT-GPU-7") + assert.Equal(t, "7", dlRuntimeImage(ctx, "")) +} + +func TestBuildSubmitPayload(t *testing.T) { + cfg := &runConfig{ + ExperimentName: "exp", + Command: new("python train.py"), + Compute: &computeConfig{AcceleratorType: "GPU_8xH100", NumAccelerators: 16}, + MaxRetries: new(2), + TimeoutMinutes: new(30), + MLflowRunName: new("run-v2"), + MLflowExperimentDirectory: new("/Workspace/Users/me/exp"), + } + + p := buildSubmitPayload(cfg, "/d/command.sh", "5") + + assert.Equal(t, "exp", p.RunName) + assert.Equal(t, 1800, p.TimeoutSeconds) + require.Len(t, p.Environments, 1) + assert.Equal(t, aiRuntimeEnvironmentKey, p.Environments[0].EnvironmentKey) + assert.Equal(t, "5", p.Environments[0].Spec.EnvironmentVersion) + + require.Len(t, p.Tasks, 1) + task := p.Tasks[0] + assert.Equal(t, "exp", task.TaskKey) + assert.Equal(t, "ALL_SUCCESS", task.RunIf) + assert.Equal(t, aiRuntimeEnvironmentKey, task.EnvironmentKey) + assert.Equal(t, 2, task.MaxRetries) + assert.True(t, task.RetryOnTimeout) + + at := task.AiRuntimeTask + assert.Equal(t, "exp", at.Experiment) + assert.Equal(t, "run-v2", at.MlflowRun) + assert.Equal(t, "/Workspace/Users/me/exp", at.MlflowExperimentDirectory) + require.Len(t, at.Deployments, 1) + assert.Equal(t, "/d/command.sh", at.Deployments[0].CommandPath) + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_8xH100", AcceleratorCount: 16}, at.Deployments[0].Compute) +} + +func TestSubmitToken(t *testing.T) { + cfg := &runConfig{IdempotencyToken: new("from-config")} + assert.Equal(t, "from-flag", submitToken("from-flag", cfg)) // flag wins + assert.Equal(t, "from-config", submitToken("", cfg)) // then config + assert.NotEmpty(t, submitToken("", &runConfig{})) // else generated + assert.Len(t, submitToken(string(make([]byte, 80)), cfg), 64) // capped +} + +func TestJobsSubmitClient(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 999} + }) + + w := &databricks.WorkspaceClient{Config: &config.Config{Host: server.URL, Token: "token"}} + jc, err := newJobsSubmitClient(w) + require.NoError(t, err) + + runID, err := jc.submit(t.Context(), jobsSubmitRun{RunName: "exp", Tasks: []submitTask{{TaskKey: "exp"}}}) + require.NoError(t, err) + assert.Equal(t, int64(999), runID) + assert.Equal(t, "exp", got.RunName) +} + +func TestSubmitWorkload(t *testing.T) { + server := testserver.New(t) + t.Cleanup(server.Close) + testserver.AddDefaultHandlers(server) + + var got jobsSubmitRun + server.Handle("POST", "/api/2.2/jobs/runs/submit", func(req testserver.Request) any { + require.NoError(t, json.Unmarshal(req.Body, &got)) + return submitRunResponse{RunID: 777} + }) + w, err := databricks.NewWorkspaceClient(&databricks.Config{Host: server.URL, Token: "token"}) + require.NoError(t, err) + + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + cfg, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + runID, dashboardURL, err := submitWorkload(t.Context(), w, cfg, cfgPath, "idem-key") + require.NoError(t, err) + assert.Equal(t, int64(777), runID) + assert.Contains(t, dashboardURL, "/jobs/runs/777") + + // The submitted payload is a native ai_runtime_task pointing at the uploaded + // command.sh under the run's launch directory. + assert.Equal(t, "my-run", got.RunName) + assert.Equal(t, "idem-key", got.IdempotencyToken) + require.Len(t, got.Environments, 1) + require.Len(t, got.Tasks, 1) + at := got.Tasks[0].AiRuntimeTask + require.Len(t, at.Deployments, 1) + d := at.Deployments[0] + assert.True(t, strings.HasSuffix(d.CommandPath, "/"+commandScriptName), d.CommandPath) + assert.Contains(t, d.CommandPath, "/.air/cli_launch/") + assert.Equal(t, aiRuntimeCompute{AcceleratorType: "GPU_1xH100", AcceleratorCount: 1}, d.Compute) +} + +func TestSubmitWorkloadGuards(t *testing.T) { + w := newFakeWorkspaceClient(t) + cfgPath := writeConfigFile(t, "run.yaml", minimalConfig) + base, err := loadRunConfig(cfgPath) + require.NoError(t, err) + + t.Run("usage_policy_name rejected", func(t *testing.T) { + cfg := *base + cfg.UsagePolicyName = new("p") + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "usage_policy_name is not yet supported") + }) + + t.Run("code_source rejected", func(t *testing.T) { + cfg := *base + cfg.CodeSource = &codeSourceConfig{Type: "snapshot"} + _, _, err := submitWorkload(t.Context(), w, &cfg, cfgPath, "") + require.ErrorContains(t, err, "code_source is not yet supported") + }) +} diff --git a/experimental/air/cmd/runupload.go b/experimental/air/cmd/runupload.go index c9c4047381..fb9ca00b98 100644 --- a/experimental/air/cmd/runupload.go +++ b/experimental/air/cmd/runupload.go @@ -3,10 +3,14 @@ package aircmd import ( "bytes" "context" + "encoding/json" "fmt" "io" + "maps" "os" "path/filepath" + "slices" + "strings" "github.com/databricks/cli/libs/filer" "go.yaml.in/yaml/v3" @@ -20,6 +24,8 @@ const ( commandScriptName = "command.sh" requirementsName = "requirements.yaml" hyperparametersName = "hyperparameters.yaml" + envVarsName = "env_vars.json" + secretEnvVarsName = "secret_env_vars.json" ) // maxConfigYAMLBytes caps training_config.yaml. It is referenced by the Jobs @@ -96,9 +102,59 @@ func buildArtifacts(cfg *runConfig, configPath string) ([]uploadItem, error) { items = append(items, uploadItem{hyperparametersName, data}) } + // The ai_runtime_task proto carries no inline env vars or secrets; stage them + // as JSON files co-located with command.sh for the server-side launcher. + if len(cfg.EnvVariables) > 0 { + data, err := json.Marshal(envVarEntries(cfg.EnvVariables)) + if err != nil { + return nil, fmt.Errorf("failed to serialize env_variables: %w", err) + } + items = append(items, uploadItem{envVarsName, data}) + } + if len(cfg.Secrets) > 0 { + data, err := json.Marshal(secretEnvVarEntries(cfg.Secrets)) + if err != nil { + return nil, fmt.Errorf("failed to serialize secrets: %w", err) + } + items = append(items, uploadItem{secretEnvVarsName, data}) + } + return items, nil } +// envVarEntry is one entry in env_vars.json. +type envVarEntry struct { + Name string `json:"name"` + Value string `json:"value"` +} + +// secretEnvVarEntry is one entry in secret_env_vars.json. The YAML side is +// {ENV_VAR: "scope/key"}; the launcher wants the split form. +type secretEnvVarEntry struct { + Name string `json:"name"` + SecretScope string `json:"secret_scope"` + SecretKey string `json:"secret_key"` +} + +// envVarEntries renders env_variables sorted by name for deterministic output. +func envVarEntries(vars map[string]string) []envVarEntry { + out := make([]envVarEntry, 0, len(vars)) + for _, name := range slices.Sorted(maps.Keys(vars)) { + out = append(out, envVarEntry{Name: name, Value: vars[name]}) + } + return out +} + +// secretEnvVarEntries renders secrets sorted by name for deterministic output. +func secretEnvVarEntries(secrets map[string]string) []secretEnvVarEntry { + out := make([]secretEnvVarEntry, 0, len(secrets)) + for _, name := range slices.Sorted(maps.Keys(secrets)) { + scope, key, _ := strings.Cut(secrets[name], "/") + out = append(out, secretEnvVarEntry{Name: name, SecretScope: scope, SecretKey: key}) + } + return out +} + // uploadArtifacts writes each artifact into the launch directory, overwriting and // creating parents as needed. // diff --git a/experimental/air/cmd/runupload_test.go b/experimental/air/cmd/runupload_test.go index ec700a9229..0c87524735 100644 --- a/experimental/air/cmd/runupload_test.go +++ b/experimental/air/cmd/runupload_test.go @@ -83,6 +83,26 @@ func TestBuildArtifacts_InlineRequirementsAndParameters(t *testing.T) { assert.Contains(t, req, "- torch") } +func TestBuildArtifacts_EnvVarsAndSecrets(t *testing.T) { + path := writeConfigFile(t, "run.yaml", "x: y\n") + cfg := &runConfig{ + Command: new("echo hi"), + EnvVariables: map[string]string{"WANDB": "demo"}, + Secrets: map[string]string{"HF_TOKEN": "myscope/hf"}, + } + + items, err := buildArtifacts(cfg, path) + require.NoError(t, err) + assert.Subset(t, itemNames(items), []string{envVarsName, secretEnvVarsName}) + + byName := map[string][]byte{} + for _, it := range items { + byName[it.name] = it.data + } + assert.JSONEq(t, `[{"name":"WANDB","value":"demo"}]`, string(byName[envVarsName])) + assert.JSONEq(t, `[{"name":"HF_TOKEN","secret_scope":"myscope","secret_key":"hf"}]`, string(byName[secretEnvVarsName])) +} + func TestBuildArtifacts_RequirementsFile(t *testing.T) { dir := t.TempDir() require.NoError(t, os.WriteFile(filepath.Join(dir, "run.yaml"), []byte("x: y\n"), 0o600))