From 37d0c2c8e615cbfb01c853ac7e039507dededa59 Mon Sep 17 00:00:00 2001 From: cwz Date: Fri, 12 Jun 2026 22:17:03 +0800 Subject: [PATCH 1/2] hostman: probe http get in pod netns --- pkg/hostman/container/prober/prober.go | 96 ++++++ .../container/prober/prober_manager.go | 34 ++- pkg/hostman/container/prober/worker.go | 15 +- pkg/hostman/guestman/container/runtime.go | 5 + pkg/hostman/guestman/pod.go | 14 +- pkg/hostman/guestman/pod_netns.go | 283 ++++++++++++++++++ pkg/llm/drivers/llm_container/ollama.go | 1 + pkg/llm/drivers/llm_container/sglang.go | 6 +- .../drivers/llm_container/startup_probe.go | 26 ++ pkg/llm/drivers/llm_container/vllm.go | 6 +- pkg/util/probe/http/doc.go | 16 + pkg/util/probe/http/http.go | 121 ++++++++ 12 files changed, 616 insertions(+), 7 deletions(-) create mode 100644 pkg/hostman/guestman/pod_netns.go create mode 100644 pkg/llm/drivers/llm_container/startup_probe.go create mode 100644 pkg/util/probe/http/doc.go create mode 100644 pkg/util/probe/http/http.go diff --git a/pkg/hostman/container/prober/prober.go b/pkg/hostman/container/prober/prober.go index 57326435ef2..783419491c9 100644 --- a/pkg/hostman/container/prober/prober.go +++ b/pkg/hostman/container/prober/prober.go @@ -31,8 +31,11 @@ limitations under the License. package prober import ( + "context" "fmt" "io" + "net" + nethttp "net/http" "strings" "time" @@ -46,6 +49,7 @@ import ( "yunion.io/x/onecloud/pkg/util/exec" "yunion.io/x/onecloud/pkg/util/probe" execprobe "yunion.io/x/onecloud/pkg/util/probe/exec" + httpprobe "yunion.io/x/onecloud/pkg/util/probe/http" tcpprobe "yunion.io/x/onecloud/pkg/util/probe/tcp" ) @@ -53,6 +57,7 @@ const maxProbeRetries = 3 // Prober helps to check the liveness of a container. type prober struct { + http httpprobe.Prober exec execprobe.Prober tcp tcpprobe.Prober runner container.CommandRunner @@ -60,6 +65,7 @@ type prober struct { func newProber(runner container.CommandRunner) *prober { return &prober{ + http: httpprobe.New(), exec: execprobe.New(), tcp: tcpprobe.New(), runner: runner, @@ -124,6 +130,78 @@ func (pb *prober) runProbeWithRetries(probeType apis.ContainerProbeType, p *apis return result, output, err } +func (pb *prober) runProbeInPodNetNS(pod IPod, run func() (probe.Result, string, error)) (probe.Result, string, error) { + netNSRunner, ok := pb.runner.(container.PodNetNSRunner) + if !ok { + log.Infof("[startup-probe-trace] run probe without pod netns pod=%s", pod.GetId()) + return run() + } + var result probe.Result + var output string + log.Infof("[startup-probe-trace] enter pod netns pod=%s", pod.GetId()) + err := netNSRunner.RunInPodNetNS(pod.GetId(), func() error { + var runErr error + result, output, runErr = run() + return runErr + }) + if err != nil { + log.Errorf("[startup-probe-trace] pod netns probe error pod=%s error=%v", pod.GetId(), err) + return probe.Unknown, "", err + } + log.Infof("[startup-probe-trace] leave pod netns pod=%s result=%s output=%q", pod.GetId(), result, output) + return result, output, nil +} + +func (pb *prober) shouldRunProbeInPodNetNS() bool { + _, ok := pb.runner.(container.PodNetNSRunner) + return ok +} + +func (pb *prober) newPodNetNSDialContext(pod IPod) httpprobe.DialContextFunc { + netNSRunner, ok := pb.runner.(container.PodNetNSRunner) + if !ok { + return nil + } + return func(ctx context.Context, network string, address string) (net.Conn, error) { + var conn net.Conn + dialer := &net.Dialer{} + log.Infof("[startup-probe-trace] enter pod netns dial pod=%s network=%s address=%s", pod.GetId(), network, address) + err := netNSRunner.RunInPodNetNS(pod.GetId(), func() error { + var dialErr error + conn, dialErr = dialer.DialContext(ctx, network, address) + return dialErr + }) + if err != nil { + log.Errorf("[startup-probe-trace] pod netns dial error pod=%s network=%s address=%s error=%v", pod.GetId(), network, address, err) + return nil, err + } + log.Infof("[startup-probe-trace] leave pod netns dial pod=%s network=%s address=%s", pod.GetId(), network, address) + return conn, nil + } +} + +func getProbeHost(explicitHost string, pod IPod) (string, error) { + if explicitHost != "" { + return explicitHost, nil + } + for _, nic := range pod.GetDesc().Nics { + if nic.Ip != "" { + return nic.Ip, nil + } + } + return "", errors.Errorf("not found guest ip") +} + +func (pb *prober) getProbeHost(explicitHost string, pod IPod) (string, error) { + if explicitHost != "" { + return explicitHost, nil + } + if pb.shouldRunProbeInPodNetNS() { + return "127.0.0.1", nil + } + return getProbeHost(explicitHost, pod) +} + func (pb *prober) runProbe(probeType apis.ContainerProbeType, p *apis.ContainerProbe, pod IPod, container *hostapi.ContainerDesc) (probe.Result, string, error) { timeout := time.Duration(p.TimeoutSeconds) * time.Second if p.Exec != nil { @@ -147,6 +225,24 @@ func (pb *prober) runProbe(probeType apis.ContainerProbeType, p *apis.ContainerP // log.Debugf("TCP-Probe Host: %v, Port: %v, Timeout: %v", host, port, timeout) return pb.tcp.Probe(host, port, timeout) } + if p.HTTPGet != nil { + host, err := pb.getProbeHost(p.HTTPGet.Host, pod) + if err != nil { + return probe.Unknown, "", err + } + headers := nethttp.Header{} + for _, header := range p.HTTPGet.HTTPHeaders { + headers.Add(header.Name, header.Value) + } + log.Infof("[startup-probe-trace] http probe pod=%s container=%s scheme=%s host=%s port=%d path=%s timeout=%s in_pod_netns=%v", pod.GetId(), container.Id, p.HTTPGet.Scheme, host, p.HTTPGet.Port, p.HTTPGet.Path, timeout, pb.shouldRunProbeInPodNetNS()) + httpProber := pb.http + if dialContext := pb.newPodNetNSDialContext(pod); dialContext != nil { + httpProber = httpprobe.NewWithDialContext(dialContext) + } + result, output, err := httpProber.Probe(string(p.HTTPGet.Scheme), host, p.HTTPGet.Port, p.HTTPGet.Path, headers, timeout) + log.Infof("[startup-probe-trace] http probe result pod=%s container=%s result=%s output=%q error=%v", pod.GetId(), container.Id, result, output, err) + return result, output, err + } errMsg := fmt.Sprintf("Failed to find probe builder for pod %v, container: %v", pod.GetName(), container.Name) log.Warningf("%s", errMsg) return probe.Unknown, "", errors.Error(errMsg) diff --git a/pkg/hostman/container/prober/prober_manager.go b/pkg/hostman/container/prober/prober_manager.go index 556bc32f39c..0ab2363d967 100644 --- a/pkg/hostman/container/prober/prober_manager.go +++ b/pkg/hostman/container/prober/prober_manager.go @@ -39,6 +39,7 @@ import ( "yunion.io/x/pkg/util/wait" "yunion.io/x/onecloud/pkg/apis" + computeapi "yunion.io/x/onecloud/pkg/apis/compute" "yunion.io/x/onecloud/pkg/apis/host" "yunion.io/x/onecloud/pkg/hostman/container/prober/results" "yunion.io/x/onecloud/pkg/hostman/container/status" @@ -87,6 +88,8 @@ type Manager interface { Start() SetDirtyContainer(ctrId string, reason string) + + GetContainerStartupStatus(ctrId string) (string, bool) } type manager struct { @@ -138,11 +141,33 @@ func (m *manager) cleanDirtyContainer(ctrId string) { m.dirtyContainers.Delete(ctrId) } +func (m *manager) GetContainerStartupStatus(ctrId string) (string, bool) { + if _, ok := m.dirtyContainers.Load(ctrId); ok { + return computeapi.CONTAINER_STATUS_PROBING, true + } + result, ok := m.startupManager.Get(ctrId) + if !ok { + return "", false + } + switch result.Result { + case results.Success: + return computeapi.CONTAINER_STATUS_RUNNING, true + case results.Failure: + if result.IsNetFailedError() { + return computeapi.CONTAINER_STATUS_NET_FAILED, true + } + return computeapi.CONTAINER_STATUS_PROBE_FAILED, true + default: + return computeapi.CONTAINER_STATUS_PROBING, true + } +} + // Start syncing probe status. This should only be called once. func (m *manager) Start() { // start syncing readiness. //go wait.Forever(m.updateReadiness, 0) // start syncing startup. + log.Infof("[startup-probe-trace] manager start") go wait.Forever(m.updateStartup, 0) } @@ -151,14 +176,17 @@ func (m *manager) AddPod(pod IPod) { defer m.workerLock.Unlock() key := probeKey{podUid: pod.GetId()} - for _, c := range pod.GetContainers() { + containers := pod.GetContainers() + log.Infof("[startup-probe-trace] AddPod pod=%s name=%s containers=%d", pod.GetId(), pod.GetName(), len(containers)) + for _, c := range containers { key.containerName = c.Name if c.Spec.StartupProbe != nil { key.probeType = apis.ContainerProbeTypeStartup if _, ok := m.workers[key]; ok { - log.Errorf("Startup probe already exists: %s:%s", pod.GetName(), c.Name) - return + log.Infof("[startup-probe-trace] startup worker exists pod=%s container=%s", pod.GetId(), c.Id) + continue } + log.Infof("[startup-probe-trace] create startup worker pod=%s container=%s name=%s period=%d timeout=%d failure=%d success=%d", pod.GetId(), c.Id, c.Name, c.Spec.StartupProbe.PeriodSeconds, c.Spec.StartupProbe.TimeoutSeconds, c.Spec.StartupProbe.FailureThreshold, c.Spec.StartupProbe.SuccessThreshold) w := newWorker(m, key.probeType, pod, c) m.workers[key] = w go w.run() diff --git a/pkg/hostman/container/prober/worker.go b/pkg/hostman/container/prober/worker.go index 53af67743b5..a9df2b194b7 100644 --- a/pkg/hostman/container/prober/worker.go +++ b/pkg/hostman/container/prober/worker.go @@ -75,6 +75,8 @@ type worker struct { lastResult results.Result // How many times in a row the probe has returned the same result. resultRun int + // Total probe attempts made by this worker. + probeAttempts int // If set, skip probing onHold bool @@ -112,14 +114,17 @@ func newWorker( // run periodically probes the container. func (w *worker) run() { probeTickerPeriod := time.Duration(w.spec.PeriodSeconds) * time.Second + jitter := time.Duration(rand.Float64() * float64(probeTickerPeriod)) + log.Infof("[startup-probe-trace] worker start pod=%s container=%s name=%s type=%s period=%s jitter=%s", w.pod.GetId(), w.container.Id, w.container.Name, w.probeType, probeTickerPeriod, jitter) // If host restarted the probes could be started in rapid succession. // Let the worker wait for a random portion of tickerPeriod before probing. - time.Sleep(time.Duration(rand.Float64() * float64(probeTickerPeriod))) + time.Sleep(jitter) probeTicker := time.NewTicker(probeTickerPeriod) defer func() { + log.Infof("[startup-probe-trace] worker stop pod=%s container=%s name=%s type=%s attempts=%d", w.pod.GetId(), w.container.Id, w.container.Name, w.probeType, w.probeAttempts) // Clean up. probeTicker.Stop() if len(w.containerId) != 0 { @@ -159,6 +164,11 @@ func (w *worker) doProbe() (keepGoing bool) { keepGoing = true }) + w.probeAttempts++ + if w.probeAttempts <= 3 { + log.Infof("[startup-probe-trace] probe attempt=%d pod=%s container=%s name=%s type=%s", w.probeAttempts, w.pod.GetId(), w.container.Id, w.container.Name, w.probeType) + } + prevResult := w.lastResult result, err := w.probeManager.prober.probe(w.probeType, w.pod, w.container) if err != nil { log.Errorf("probe: %s, pod: %s, container: %s, error: %v", w.probeType, w.pod.GetId(), w.container.Id, err) @@ -172,6 +182,9 @@ func (w *worker) doProbe() (keepGoing bool) { w.lastResult = result.Result w.resultRun = 1 } + if w.probeAttempts <= 3 || prevResult != result.Result { + log.Infof("[startup-probe-trace] probe result pod=%s container=%s name=%s type=%s result=%s run=%d reason=%q", w.pod.GetId(), w.container.Id, w.container.Name, w.probeType, result.Result, w.resultRun, result.Reason) + } _, isContainerDirty := w.probeManager.dirtyContainers.Load(w.container.Id) if (result.Result == results.Failure && w.resultRun < int(w.spec.FailureThreshold)) || diff --git a/pkg/hostman/guestman/container/runtime.go b/pkg/hostman/guestman/container/runtime.go index f668db1cc9b..1248273b885 100644 --- a/pkg/hostman/guestman/container/runtime.go +++ b/pkg/hostman/guestman/container/runtime.go @@ -23,3 +23,8 @@ type CommandRunner interface { // RunInContainer synchronously executes the command in the container, and returns the output. RunInContainer(podId string, containerId string, cmd []string, timeout time.Duration) ([]byte, error) } + +type PodNetNSRunner interface { + // RunInPodNetNS synchronously executes run in the pod network namespace. + RunInPodNetNS(podId string, run func() error) error +} diff --git a/pkg/hostman/guestman/pod.go b/pkg/hostman/guestman/pod.go index 1a148ef3010..d76c4e4830a 100644 --- a/pkg/hostman/guestman/pod.go +++ b/pkg/hostman/guestman/pod.go @@ -1362,6 +1362,13 @@ func (s *sPodGuestInstance) StartLocalContainer(ctx context.Context, userCred mc return ret, nil } +func (s *sPodGuestInstance) ensureContainerProbeStarted(ctrId string, spec *hostapi.ContainerSpec, reason string) { + if spec != nil && spec.NeedProbe() { + s.getProbeManager().SetDirtyContainer(ctrId, reason) + s.getProbeManager().AddPod(s) + } +} + func (s *sPodGuestInstance) StartContainer(ctx context.Context, userCred mcclient.TokenCredential, ctrId string, input *hostapi.ContainerCreateInput) (jsonutils.JSONObject, error) { _, hasCtr := s.containers[ctrId] needRecreate := false @@ -1409,6 +1416,7 @@ func (s *sPodGuestInstance) StartContainer(ctx context.Context, userCred mcclien if err := s.getCRI().StartContainer(ctx, criId); err != nil { return nil, errors.Wrap(err, "CRI.StartContainer") } + s.ensureContainerProbeStarted(ctrId, input.Spec, "container started") // 如果是 cgroup v2,设备规则已经通过 containerd API 在 CreateContainer 时设置,跳过 eBPF 方式 // 如果是 cgroup v1,继续使用原有的 eBPF 方式 @@ -2669,7 +2677,11 @@ func (s *sPodGuestInstance) getContainerStatus(ctx context.Context, ctrId string return "", cs, errors.Wrapf(httperrors.ErrNotFound, "not found container by id %s", ctrId) } if ctr.Spec.NeedProbe() { - status = computeapi.CONTAINER_STATUS_PROBING + if probeStatus, ok := s.getProbeManager().GetContainerStartupStatus(ctrId); ok { + status = probeStatus + } else { + status = computeapi.CONTAINER_STATUS_PROBING + } } } if status == computeapi.CONTAINER_STATUS_EXITED && resp.Status.ExitCode != 0 { diff --git a/pkg/hostman/guestman/pod_netns.go b/pkg/hostman/guestman/pod_netns.go new file mode 100644 index 00000000000..462dc9d0b8c --- /dev/null +++ b/pkg/hostman/guestman/pod_netns.go @@ -0,0 +1,283 @@ +package guestman + +import ( + "context" + "encoding/json" + "fmt" + goruntime "runtime" + "strconv" + "strings" + + "github.com/vishvananda/netns" + runtimeapi "k8s.io/cri-api/pkg/apis/runtime/v1" + + "yunion.io/x/log" + "yunion.io/x/pkg/errors" +) + +func (cr *containerRunner) RunInPodNetNS(podId string, run func() error) error { + srv, ok := cr.manager.GetServer(podId) + if !ok { + return errors.Errorf("server %s not found", podId) + } + pod, ok := srv.(*sPodGuestInstance) + if !ok { + return errors.Errorf("server %s is not a pod instance", podId) + } + netNSPath, err := pod.getPodNetNSPath(context.Background()) + if err != nil { + return errors.Wrap(err, "get pod netns path") + } + log.Infof("[startup-probe-trace] pod netns path pod=%s path=%s", podId, netNSPath) + return runInNetNSPath(netNSPath, run) +} + +func (s *sPodGuestInstance) getPodNetNSPath(ctx context.Context) (string, error) { + criId := s.GetCRIId() + if criId == "" { + return "", errors.Errorf("pod %s missing cri id", s.GetId()) + } + status, err := s.getCRI().GetRuntimeClient().PodSandboxStatus(ctx, &runtimeapi.PodSandboxStatusRequest{ + PodSandboxId: criId, + Verbose: true, + }) + if err != nil { + return "", errors.Wrapf(err, "PodSandboxStatus %s", criId) + } + return sandboxNetNSPathFromStatus(status) +} + +func runInNetNSPath(netNSPath string, run func() error) (retErr error) { + netNSPath = strings.TrimSpace(netNSPath) + if netNSPath == "" { + return errors.Errorf("netns path is empty") + } + + goruntime.LockOSThread() + defer goruntime.UnlockOSThread() + + origin, err := netns.Get() + if err != nil { + return errors.Wrap(err, "get current netns") + } + defer origin.Close() + + target, err := netns.GetFromPath(netNSPath) + if err != nil { + return errors.Wrapf(err, "get target netns %s", netNSPath) + } + defer target.Close() + + if !origin.Equal(target) { + if err := netns.Set(target); err != nil { + return errors.Wrapf(err, "set netns %s", netNSPath) + } + defer func() { + if err := netns.Set(origin); err != nil && retErr == nil { + retErr = errors.Wrap(err, "restore original netns") + } + }() + } + + return run() +} + +func sandboxNetNSPathFromStatus(status *runtimeapi.PodSandboxStatusResponse) (string, error) { + if status == nil { + return "", errors.Errorf("pod sandbox status is nil") + } + finder := &sandboxNetNSFinder{} + for key, value := range status.GetInfo() { + if decoded, ok := decodeSandboxInfoJSON(value); ok { + finder.inspectTopLevel(key, decoded) + } else { + finder.inspectTopLevel(key, value) + } + } + if finder.pid > 0 { + return fmt.Sprintf("/proc/%d/ns/net", finder.pid), nil + } + if finder.path != "" { + return finder.path, nil + } + return "", errors.Errorf("pod sandbox status does not include netns path or sandbox pid") +} + +type sandboxNetNSFinder struct { + path string + pid int +} + +func (f *sandboxNetNSFinder) inspectTopLevel(key string, value interface{}) { + f.inspectScalar(key, value, true) + switch typedValue := value.(type) { + case map[string]interface{}: + if isNetworkNamespaceEntry(typedValue) { + if pathValue, ok := namespaceEntryPath(typedValue); ok { + f.path = pathValue + return + } + } + for nestedKey, nestedValue := range typedValue { + f.inspect(nestedKey, nestedValue, true) + } + case []interface{}: + for _, nestedValue := range typedValue { + f.inspect("", nestedValue, false) + } + } +} + +func (f *sandboxNetNSFinder) inspect(key string, value interface{}, allowPID bool) { + if f.path != "" && f.pid > 0 { + return + } + + f.inspectScalar(key, value, allowPID) + + switch typedValue := value.(type) { + case map[string]interface{}: + if isNetworkNamespaceEntry(typedValue) { + if pathValue, ok := namespaceEntryPath(typedValue); ok { + f.path = pathValue + return + } + } + for nestedKey, nestedValue := range typedValue { + f.inspect(nestedKey, nestedValue, false) + } + case []interface{}: + for _, nestedValue := range typedValue { + f.inspect("", nestedValue, false) + } + } +} + +func (f *sandboxNetNSFinder) inspectScalar(key string, value interface{}, allowPID bool) { + normalizedKey := normalizeSandboxInfoKey(key) + if strValue, ok := sandboxInfoString(value); ok { + if isSandboxNetNSPathKey(normalizedKey) || (normalizedKey == "path" && strings.Contains(strValue, "/ns/net")) { + f.path = strValue + return + } + if allowPID && isSandboxPIDKey(normalizedKey) { + if pid, ok := sandboxInfoPID(strValue); ok { + f.pid = pid + } + } + } + if allowPID && isSandboxPIDKey(normalizedKey) { + if pid, ok := sandboxInfoPID(value); ok { + f.pid = pid + } + } +} + +func decodeSandboxInfoJSON(value string) (interface{}, bool) { + decoder := json.NewDecoder(strings.NewReader(value)) + decoder.UseNumber() + var decoded interface{} + if err := decoder.Decode(&decoded); err != nil { + return nil, false + } + return decoded, true +} + +func normalizeSandboxInfoKey(key string) string { + key = strings.ToLower(key) + key = strings.ReplaceAll(key, "_", "") + key = strings.ReplaceAll(key, "-", "") + key = strings.ReplaceAll(key, ".", "") + return key +} + +func isSandboxNetNSPathKey(key string) bool { + switch key { + case "netnspath", + "netnamespacepath", + "networknamespacepath", + "sandboxnetnspath", + "sandboxnetnamespacepath": + return true + } + return false +} + +func isSandboxPIDKey(key string) bool { + switch key { + case "pid", "sandboxpid", "processpid": + return true + } + return false +} + +func sandboxInfoString(value interface{}) (string, bool) { + switch typedValue := value.(type) { + case string: + if typedValue == "" { + return "", false + } + return typedValue, true + case json.Number: + return typedValue.String(), true + } + return "", false +} + +func sandboxInfoPID(value interface{}) (int, bool) { + switch typedValue := value.(type) { + case int: + return positivePID(typedValue) + case int64: + return positivePID(int(typedValue)) + case float64: + return positivePID(int(typedValue)) + case json.Number: + pid, err := typedValue.Int64() + if err != nil { + return 0, false + } + return positivePID(int(pid)) + case string: + pid, err := strconv.Atoi(strings.TrimSpace(typedValue)) + if err != nil { + return 0, false + } + return positivePID(pid) + } + return 0, false +} + +func positivePID(pid int) (int, bool) { + if pid <= 0 { + return 0, false + } + return pid, true +} + +func isNetworkNamespaceEntry(value map[string]interface{}) bool { + typeValue, ok := mapStringValue(value, "type") + if !ok { + return false + } + typeValue = strings.ToLower(typeValue) + return typeValue == "network" || typeValue == "net" +} + +func namespaceEntryPath(value map[string]interface{}) (string, bool) { + pathValue, ok := mapStringValue(value, "path") + if !ok || pathValue == "" { + return "", false + } + return pathValue, true +} + +func mapStringValue(value map[string]interface{}, key string) (string, bool) { + for candidateKey, candidateValue := range value { + if normalizeSandboxInfoKey(candidateKey) != normalizeSandboxInfoKey(key) { + continue + } + return sandboxInfoString(candidateValue) + } + return "", false +} diff --git a/pkg/llm/drivers/llm_container/ollama.go b/pkg/llm/drivers/llm_container/ollama.go index b807376aee7..ff17aa2ef6b 100644 --- a/pkg/llm/drivers/llm_container/ollama.go +++ b/pkg/llm/drivers/llm_container/ollama.go @@ -78,6 +78,7 @@ func (o *ollama) GetContainerSpec(ctx context.Context, llm *models.SLLM, image * ImageCredentialId: image.CredentialId, EnableLxcfs: true, AlwaysRestart: true, + StartupProbe: newLLMHTTPStartupProbe(api.LLM_OLLAMA_DEFAULT_PORT, "/api/tags"), }, } diff --git a/pkg/llm/drivers/llm_container/sglang.go b/pkg/llm/drivers/llm_container/sglang.go index 94154c87eae..f0f2f62e13d 100644 --- a/pkg/llm/drivers/llm_container/sglang.go +++ b/pkg/llm/drivers/llm_container/sglang.go @@ -223,7 +223,8 @@ func (s *sglang) GetContainerSpec(ctx context.Context, llm *models.SLLM, image * if sku != nil { backendParameters = sku.BackendParameters } - startScript := buildSGLangEntrypointScript(len(postOverlays) > 0, tensorParallelSize, backendParameters, effSpec) + hasMountedModels := len(postOverlays) > 0 + startScript := buildSGLangEntrypointScript(hasMountedModels, tensorParallelSize, backendParameters, effSpec) envs := []*commonapi.ContainerKeyValue{ { Key: "HUGGING_FACE_HUB_CACHE", @@ -245,6 +246,9 @@ func (s *sglang) GetContainerSpec(ctx context.Context, llm *models.SLLM, image * Envs: envs, }, } + if hasMountedModels { + spec.StartupProbe = newLLMHTTPStartupProbe(api.LLM_SGLANG_DEFAULT_PORT, "/v1/models") + } effDevs := models.GetEffectiveDevices(llm, sku) if len(devices) == 0 && effDevs != nil && len(*effDevs) > 0 { diff --git a/pkg/llm/drivers/llm_container/startup_probe.go b/pkg/llm/drivers/llm_container/startup_probe.go new file mode 100644 index 00000000000..98e1bc1d395 --- /dev/null +++ b/pkg/llm/drivers/llm_container/startup_probe.go @@ -0,0 +1,26 @@ +package llm_container + +import commonapi "yunion.io/x/onecloud/pkg/apis" + +const ( + llmStartupProbeTimeoutSeconds int32 = 5 + llmStartupProbePeriodSeconds int32 = 10 + llmStartupProbeSuccessThreshold int32 = 1 + llmStartupProbeFailureThreshold int32 = 360 +) + +func newLLMHTTPStartupProbe(port int, probePath string) *commonapi.ContainerProbe { + return &commonapi.ContainerProbe{ + ContainerProbeHandler: commonapi.ContainerProbeHandler{ + HTTPGet: &commonapi.ContainerProbeHTTPGetAction{ + Path: probePath, + Port: port, + Scheme: commonapi.URISchemeHTTP, + }, + }, + TimeoutSeconds: llmStartupProbeTimeoutSeconds, + PeriodSeconds: llmStartupProbePeriodSeconds, + SuccessThreshold: llmStartupProbeSuccessThreshold, + FailureThreshold: llmStartupProbeFailureThreshold, + } +} diff --git a/pkg/llm/drivers/llm_container/vllm.go b/pkg/llm/drivers/llm_container/vllm.go index 2ecefc010af..53349e1a9b8 100644 --- a/pkg/llm/drivers/llm_container/vllm.go +++ b/pkg/llm/drivers/llm_container/vllm.go @@ -431,7 +431,8 @@ func (v *vllm) GetContainerSpec(ctx context.Context, llm *models.SLLM, image *mo if sku != nil { backendParameters = sku.BackendParameters } - startScript := buildVLLMEntrypointScript(len(postOverlays) > 0, tensorParallelSize, backendParameters, effSpec) + hasMountedModels := len(postOverlays) > 0 + startScript := buildVLLMEntrypointScript(hasMountedModels, tensorParallelSize, backendParameters, effSpec) envs := []*commonapi.ContainerKeyValue{ { Key: "HUGGING_FACE_HUB_CACHE", @@ -463,6 +464,9 @@ func (v *vllm) GetContainerSpec(ctx context.Context, llm *models.SLLM, image *mo Envs: envs, }, } + if hasMountedModels { + spec.StartupProbe = newLLMHTTPStartupProbe(api.LLM_VLLM_DEFAULT_PORT, "/v1/models") + } // GPU Devices effDevs := models.GetEffectiveDevices(llm, sku) diff --git a/pkg/util/probe/http/doc.go b/pkg/util/probe/http/doc.go new file mode 100644 index 00000000000..27982258f52 --- /dev/null +++ b/pkg/util/probe/http/doc.go @@ -0,0 +1,16 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package http contains HTTP probe helpers. +package http diff --git a/pkg/util/probe/http/http.go b/pkg/util/probe/http/http.go new file mode 100644 index 00000000000..9b0b0475fb2 --- /dev/null +++ b/pkg/util/probe/http/http.go @@ -0,0 +1,121 @@ +// Copyright 2019 Yunion +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package http + +import ( + "context" + "fmt" + "net" + nethttp "net/http" + "net/url" + "strconv" + "strings" + "time" + + "yunion.io/x/onecloud/pkg/util/probe" +) + +func New() Prober { + return httpProber{} +} + +func NewWithDialContext(dialContext DialContextFunc) Prober { + return httpProber{dialContext: dialContext} +} + +type DialContextFunc func(context.Context, string, string) (net.Conn, error) + +type Prober interface { + Probe(scheme string, host string, port int, reqPath string, headers nethttp.Header, timeout time.Duration) (probe.Result, string, error) +} + +type httpProber struct { + dialContext DialContextFunc +} + +func (pr httpProber) Probe(scheme string, host string, port int, reqPath string, headers nethttp.Header, timeout time.Duration) (probe.Result, string, error) { + reqURL, err := buildProbeURL(scheme, host, port, reqPath) + if err != nil { + return probe.Failure, err.Error(), nil + } + return DoHTTPProbeWithDialContext(reqURL, headers, timeout, pr.dialContext) +} + +func buildProbeURL(scheme string, host string, port int, reqPath string) (string, error) { + normalizedScheme := strings.ToLower(strings.TrimSpace(scheme)) + if normalizedScheme == "" { + normalizedScheme = "http" + } + switch normalizedScheme { + case "http", "https": + default: + return "", fmt.Errorf("unsupported HTTP probe scheme %q", scheme) + } + + host = strings.TrimSpace(host) + if host == "" { + return "", fmt.Errorf("HTTP probe host is empty") + } + + if reqPath == "" { + reqPath = "/" + } else if !strings.HasPrefix(reqPath, "/") { + reqPath = "/" + reqPath + } + + u := url.URL{ + Scheme: normalizedScheme, + Host: net.JoinHostPort(host, strconv.Itoa(port)), + Path: reqPath, + } + return u.String(), nil +} + +func DoHTTPProbe(reqURL string, headers nethttp.Header, timeout time.Duration) (probe.Result, string, error) { + return DoHTTPProbeWithDialContext(reqURL, headers, timeout, nil) +} + +func DoHTTPProbeWithDialContext(reqURL string, headers nethttp.Header, timeout time.Duration, dialContext DialContextFunc) (probe.Result, string, error) { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + req, err := nethttp.NewRequestWithContext(ctx, nethttp.MethodGet, reqURL, nil) + if err != nil { + return probe.Failure, err.Error(), nil + } + for key, values := range headers { + for _, value := range values { + req.Header.Add(key, value) + } + } + + transport := nethttp.DefaultTransport.(*nethttp.Transport).Clone() + if dialContext != nil { + transport.DialContext = dialContext + } + defer transport.CloseIdleConnections() + + client := &nethttp.Client{Timeout: timeout, Transport: transport} + resp, err := client.Do(req) + if err != nil { + return probe.Failure, err.Error(), nil + } + defer resp.Body.Close() + + if resp.StatusCode >= nethttp.StatusOK && resp.StatusCode < nethttp.StatusBadRequest { + return probe.Success, resp.Status, nil + } + return probe.Failure, resp.Status, nil +} From e53b31e2a0abc5f145d6178c7f565c0acdfd2f1e Mon Sep 17 00:00:00 2001 From: cwz Date: Fri, 12 Jun 2026 22:16:42 +0800 Subject: [PATCH 2/2] llm: watch pod status for service readiness --- pkg/apis/llm/llm_const.go | 2 + pkg/llm/models/llm_pod_status_watcher.go | 176 ++++++++++++++++++++++ pkg/llm/models/llm_service_ready.go | 102 +++++++++++++ pkg/llm/models/llm_status_resolver.go | 105 +++++++++++++ pkg/llm/service/background_workers.go | 17 +++ pkg/llm/service/service.go | 1 + pkg/llm/tasks/llm/llm_create_task.go | 4 +- pkg/llm/tasks/llm/llm_restart_task.go | 4 + pkg/llm/tasks/llm/llm_start_task.go | 4 +- pkg/llm/tasks/llm/llm_sync_status_task.go | 8 +- 10 files changed, 419 insertions(+), 4 deletions(-) create mode 100644 pkg/llm/models/llm_pod_status_watcher.go create mode 100644 pkg/llm/models/llm_service_ready.go create mode 100644 pkg/llm/models/llm_status_resolver.go create mode 100644 pkg/llm/service/background_workers.go diff --git a/pkg/apis/llm/llm_const.go b/pkg/apis/llm/llm_const.go index 2e18e751e94..df93dd41a2a 100644 --- a/pkg/apis/llm/llm_const.go +++ b/pkg/apis/llm/llm_const.go @@ -13,6 +13,8 @@ const ( /* 启动失败 */ LLM_STATUS_START_FAIL = "start_fail" + /* 探测中 */ + LLM_STATUS_PROBING = "probing" /* 停机失败 */ LLM_STATUS_STOP_FAILED = "stop_fail" diff --git a/pkg/llm/models/llm_pod_status_watcher.go b/pkg/llm/models/llm_pod_status_watcher.go new file mode 100644 index 00000000000..d312ca91e06 --- /dev/null +++ b/pkg/llm/models/llm_pod_status_watcher.go @@ -0,0 +1,176 @@ +package models + +import ( + "context" + "database/sql" + "fmt" + "sort" + "strings" + + "yunion.io/x/jsonutils" + "yunion.io/x/log" + "yunion.io/x/pkg/errors" + + computeapi "yunion.io/x/onecloud/pkg/apis/compute" + "yunion.io/x/onecloud/pkg/mcclient" + "yunion.io/x/onecloud/pkg/mcclient/auth" + "yunion.io/x/onecloud/pkg/mcclient/informer" + "yunion.io/x/onecloud/pkg/mcclient/modules/compute" +) + +func StartLLMPodStatusWatcher(ctx context.Context, region string) { + session := auth.GetAdminSession(ctx, region) + informer.NewWatchManagerBySessionBg(session, func(watchMan *informer.SWatchManager) error { + handler := &llmPodStatusEventHandler{ + userCred: session.GetToken(), + } + if err := watchMan.For(compute.Servers).AddEventHandler(ctx, handler); err != nil { + return errors.Wrap(err, "watch compute servers for llm pod status") + } + return nil + }) +} + +type llmPodStatusEventHandler struct { + userCred mcclient.TokenCredential +} + +func (h *llmPodStatusEventHandler) OnAdd(obj *jsonutils.JSONDict) { + if !serverWatchStatusChanged(nil, obj) { + return + } + h.handleServerStatus(context.Background(), obj) +} + +func (h *llmPodStatusEventHandler) OnUpdate(oldObj, newObj *jsonutils.JSONDict) { + if !serverWatchStatusChanged(oldObj, newObj) { + return + } + h.handleServerStatus(context.Background(), newObj) +} + +func (h *llmPodStatusEventHandler) OnDelete(obj *jsonutils.JSONDict) { +} + +func (h *llmPodStatusEventHandler) handleServerStatus(ctx context.Context, obj *jsonutils.JSONDict) { + serverId := watchEventStringField(obj, "id") + if serverId == "" { + log.Warningf("LLM pod status watcher: server event missing id: %s", obj.String()) + return + } + + llm, err := fetchLLMByCmpId(serverId) + if err != nil { + if errors.Cause(err) != sql.ErrNoRows { + log.Warningf("LLM pod status watcher: fetch llm by cmp_id %s: %s", serverId, err) + } + return + } + + server, err := llm.GetServer(ctx) + if err != nil { + log.Warningf("LLM pod status watcher: fetch server %s for llm %s: %s", serverId, llm.Name, err) + return + } + if server.Hypervisor != computeapi.HYPERVISOR_POD { + return + } + + resolved, err := ResolveLLMStatusFromServerDetails(ctx, llm, server) + if err != nil { + log.Warningf("LLM pod status watcher: resolve status for llm %s: %s", llm.Name, err) + return + } + if !resolved.Update { + return + } + if err := llm.SetStatus(ctx, h.userCred, resolved.Status, resolved.Reason); err != nil { + log.Warningf("LLM pod status watcher: set llm %s status %s: %s", llm.Name, resolved.Status, err) + } +} + +func fetchLLMByCmpId(cmpId string) (*SLLM, error) { + llm := &SLLM{} + if err := GetLLMManager().Query().Equals("cmp_id", cmpId).First(llm); err != nil { + return nil, err + } + llm.SetModelManager(GetLLMManager(), llm) + return llm, nil +} + +func getLLMPrimaryContainerStatus(ctx context.Context, llm *SLLM, containers []*computeapi.PodContainerDesc) (string, error) { + if len(containers) == 0 { + return "", nil + } + ctr, err := llm.GetLLMContainerDriver().GetPrimaryContainer(ctx, llm, containers) + if err != nil { + return "", err + } + if ctr == nil { + return "", nil + } + return ctr.Status, nil +} + +func serverWatchStatusChanged(oldObj *jsonutils.JSONDict, newObj *jsonutils.JSONDict) bool { + if newObj == nil { + return false + } + newStatus := watchEventStringField(newObj, "status") + if newStatus == "" { + return false + } + if oldObj == nil { + return true + } + if watchEventStringField(oldObj, "status") != newStatus { + return true + } + return watchEventContainerStatusSignature(oldObj) != watchEventContainerStatusSignature(newObj) +} + +func watchEventStringField(obj *jsonutils.JSONDict, key string) string { + if obj == nil { + return "" + } + value, err := obj.GetString(key) + if err != nil { + return "" + } + return value +} + +func watchEventContainerStatusSignature(obj *jsonutils.JSONDict) string { + if obj == nil { + return "" + } + containers, err := obj.GetArray("containers") + if err != nil || len(containers) == 0 { + return "" + } + parts := make([]string, 0, len(containers)) + for idx, container := range containers { + key := watchEventJSONObjectStringField(container, "id") + if key == "" { + key = watchEventJSONObjectStringField(container, "name") + } + if key == "" { + key = fmt.Sprintf("%d", idx) + } + status := watchEventJSONObjectStringField(container, "status") + parts = append(parts, fmt.Sprintf("%s=%s", key, status)) + } + sort.Strings(parts) + return strings.Join(parts, ";") +} + +func watchEventJSONObjectStringField(obj jsonutils.JSONObject, key string) string { + if obj == nil { + return "" + } + value, err := obj.GetString(key) + if err != nil { + return "" + } + return value +} diff --git a/pkg/llm/models/llm_service_ready.go b/pkg/llm/models/llm_service_ready.go new file mode 100644 index 00000000000..816f7c2d208 --- /dev/null +++ b/pkg/llm/models/llm_service_ready.go @@ -0,0 +1,102 @@ +package models + +import ( + "context" + "time" + + "yunion.io/x/pkg/errors" + + computeapi "yunion.io/x/onecloud/pkg/apis/compute" + "yunion.io/x/onecloud/pkg/httperrors" + llmutils "yunion.io/x/onecloud/pkg/llm/utils" + "yunion.io/x/onecloud/pkg/mcclient" +) + +const LLMServiceReadyTimeoutSeconds = 3600 + +var errLLMServiceProbing = errors.Error("llm service probing") + +var llmServiceReadyServerStatuses = []string{ + computeapi.VM_RUNNING, + computeapi.POD_STATUS_CRASH_LOOP_BACK_OFF, + computeapi.POD_STATUS_CONTAINER_EXITED, + computeapi.POD_STATUS_UPLOADING_STATUS_FAILED, +} + +var llmServiceReadyContainerStatuses = []string{ + computeapi.CONTAINER_STATUS_RUNNING, + computeapi.CONTAINER_STATUS_PROBING, + computeapi.CONTAINER_STATUS_PROBE_FAILED, + computeapi.CONTAINER_STATUS_NET_FAILED, + computeapi.CONTAINER_STATUS_CRASH_LOOP_BACK_OFF, + computeapi.CONTAINER_STATUS_EXITED, +} + +func isLLMServiceReadyContainerStatus(status string) bool { + return status == computeapi.CONTAINER_STATUS_RUNNING +} + +func newLLMServiceProbingError(status string) error { + return errors.Wrapf(errLLMServiceProbing, "container status %s", status) +} + +func IsLLMServiceProbingError(err error) bool { + return errors.Cause(err) == errLLMServiceProbing +} + +func isLLMServiceFailedContainerStatus(status string) bool { + return status == computeapi.CONTAINER_STATUS_PROBE_FAILED || + status == computeapi.CONTAINER_STATUS_NET_FAILED || + status == computeapi.CONTAINER_STATUS_CRASH_LOOP_BACK_OFF || + status == computeapi.CONTAINER_STATUS_EXITED +} + +func (llm *SLLM) WaitServiceReady(ctx context.Context, userCred mcclient.TokenCredential, timeoutSecs int) (*computeapi.SContainer, error) { + return llm.WaitServiceReadyWithProbingCallback(ctx, userCred, timeoutSecs, nil) +} + +func (llm *SLLM) WaitServiceReadyWithProbingCallback(ctx context.Context, userCred mcclient.TokenCredential, timeoutSecs int, onProbing func() error) (*computeapi.SContainer, error) { + if timeoutSecs <= 0 { + timeoutSecs = LLMServiceReadyTimeoutSeconds + } + + server, err := llm.WaitServerStatus(ctx, userCred, llmServiceReadyServerStatuses, timeoutSecs) + if err != nil { + return nil, errors.Wrap(err, "WaitServerStatus") + } + if server.Status != computeapi.VM_RUNNING { + return nil, errors.Wrapf(errors.ErrInvalidStatus, "server status %s", server.Status) + } + + llmCtr, err := llm.GetLLMContainer() + if err != nil { + return nil, errors.Wrap(err, "GetLLMContainer") + } + + expire := time.Now().Add(time.Second * time.Duration(timeoutSecs)) + probingNotified := false + for time.Now().Before(expire) { + ctr, err := llmutils.GetContainer(ctx, llmCtr.CmpId) + if err != nil { + return nil, errors.Wrap(err, "GetContainer") + } + if isLLMServiceReadyContainerStatus(ctr.Status) { + return ctr, nil + } + if ctr.Status == computeapi.CONTAINER_STATUS_PROBING { + if onProbing != nil && !probingNotified { + if err := onProbing(); err != nil { + return nil, errors.Wrap(err, "on probing") + } + probingNotified = true + } + time.Sleep(time.Second) + continue + } + if isLLMServiceFailedContainerStatus(ctr.Status) { + return nil, errors.Wrapf(errors.ErrInvalidStatus, "container status %s", ctr.Status) + } + time.Sleep(time.Second) + } + return nil, errors.Wrapf(httperrors.ErrTimeout, "wait llm service ready timeout") +} diff --git a/pkg/llm/models/llm_status_resolver.go b/pkg/llm/models/llm_status_resolver.go new file mode 100644 index 00000000000..60ebea50b71 --- /dev/null +++ b/pkg/llm/models/llm_status_resolver.go @@ -0,0 +1,105 @@ +package models + +import ( + "context" + "fmt" + + commonapi "yunion.io/x/onecloud/pkg/apis" + computeapi "yunion.io/x/onecloud/pkg/apis/compute" + api "yunion.io/x/onecloud/pkg/apis/llm" +) + +type LLMStatusResolution struct { + Status string + Reason string + Update bool +} + +type llmStatusResolution = LLMStatusResolution + +func ResolveLLMStatusFromServerDetails(ctx context.Context, llm *SLLM, server *computeapi.ServerDetails) (LLMStatusResolution, error) { + if llm == nil || server == nil { + return LLMStatusResolution{}, nil + } + primaryStatus, err := getLLMPrimaryContainerStatus(ctx, llm, server.Containers) + if err != nil { + return LLMStatusResolution{}, err + } + return resolveLLMStatusFromPod(llm.Status, server.Status, primaryStatus), nil +} + +func resolveLLMStatusFromPod(currentStatus string, serverStatus string, primaryContainerStatus string) llmStatusResolution { + targetStatus := "" + reason := fmt.Sprintf("pod status=%s primary_container_status=%s", serverStatus, primaryContainerStatus) + + switch { + case isLLMPodCrashLoopStatus(serverStatus, primaryContainerStatus) || isLLMStartupProbeFailedStatus(primaryContainerStatus): + if currentStatus == commonapi.STATUS_CREATING { + targetStatus = api.LLM_STATUS_CREATE_FAIL + } else { + targetStatus = api.LLM_STATUS_START_FAIL + } + case serverStatus == computeapi.POD_STATUS_CONTAINER_EXITED || primaryContainerStatus == computeapi.CONTAINER_STATUS_EXITED: + targetStatus = api.LLM_STATUS_START_FAIL + case serverStatus == computeapi.VM_RUNNING && primaryContainerStatus == computeapi.CONTAINER_STATUS_PROBING: + targetStatus = api.LLM_STATUS_PROBING + case serverStatus == computeapi.VM_RUNNING && isPrimaryContainerRunning(primaryContainerStatus): + targetStatus = api.LLM_STATUS_RUNNING + case serverStatus == computeapi.VM_READY: + targetStatus = api.LLM_STATUS_READY + default: + return llmStatusResolution{} + } + + if targetStatus == currentStatus { + return llmStatusResolution{} + } + if !canWatchUpdateLLMStatus(currentStatus, targetStatus) { + return llmStatusResolution{} + } + return llmStatusResolution{ + Status: targetStatus, + Reason: reason, + Update: true, + } +} + +func isLLMPodCrashLoopStatus(serverStatus string, primaryContainerStatus string) bool { + return serverStatus == computeapi.POD_STATUS_CRASH_LOOP_BACK_OFF || + primaryContainerStatus == computeapi.CONTAINER_STATUS_CRASH_LOOP_BACK_OFF +} + +func isPrimaryContainerRunning(status string) bool { + return status == computeapi.CONTAINER_STATUS_RUNNING +} + +func isLLMStartupProbeFailedStatus(status string) bool { + return status == computeapi.CONTAINER_STATUS_PROBE_FAILED || + status == computeapi.CONTAINER_STATUS_NET_FAILED +} + +func canWatchUpdateLLMStatus(currentStatus string, targetStatus string) bool { + if currentStatus == commonapi.STATUS_CREATING && targetStatus == api.LLM_STATUS_CREATE_FAIL { + return true + } + if currentStatus == commonapi.STATUS_CREATING && targetStatus == api.LLM_STATUS_PROBING { + return true + } + if currentStatus == api.LLM_STATUS_START_RESTART && targetStatus == api.LLM_STATUS_PROBING { + return true + } + + switch currentStatus { + case api.LLM_STATUS_READY, + api.LLM_STATUS_RUNNING, + api.LLM_STATUS_PROBING, + api.LLM_STATUS_UNKNOWN, + api.LLM_STATUS_START_SYNCSTATUS, + api.LLM_STATUS_SYNCSTATUS, + api.LLM_STATUS_CREATE_FAIL, + api.LLM_STATUS_START_FAIL: + return true + } + + return false +} diff --git a/pkg/llm/service/background_workers.go b/pkg/llm/service/background_workers.go new file mode 100644 index 00000000000..1778b630d16 --- /dev/null +++ b/pkg/llm/service/background_workers.go @@ -0,0 +1,17 @@ +package service + +import ( + "context" + + "yunion.io/x/onecloud/pkg/llm/models" + "yunion.io/x/onecloud/pkg/llm/options" +) + +var startLLMPodStatusWatcher = models.StartLLMPodStatusWatcher + +func startBackgroundWorkers(ctx context.Context, opts *options.LLMOptions) { + if opts == nil || opts.IsSlaveNode { + return + } + startLLMPodStatusWatcher(ctx, opts.Region) +} diff --git a/pkg/llm/service/service.go b/pkg/llm/service/service.go index b4975bc3d79..f5a3946961e 100644 --- a/pkg/llm/service/service.go +++ b/pkg/llm/service/service.go @@ -51,6 +51,7 @@ func StartService() { opts.ModelCatalogURL, time.Duration(opts.LLMCatalogRefreshIntervalMinutes)*time.Minute, ) + startBackgroundWorkers(context.Background(), opts) // if !opts.IsSlaveNode { // models.InitializeCronjobs(app.GetContext()) diff --git a/pkg/llm/tasks/llm/llm_create_task.go b/pkg/llm/tasks/llm/llm_create_task.go index ebbe10414e1..79e2e4685e3 100644 --- a/pkg/llm/tasks/llm/llm_create_task.go +++ b/pkg/llm/tasks/llm/llm_create_task.go @@ -162,9 +162,9 @@ func (task *LLMCreateTask) OnLLMRefreshStatusComplete(ctx context.Context, llm * task.taskFailed(ctx, llm, errors.Wrap(err, "WaitServerStatus VM_RUNNING")) return } - _, err = llm.WaitContainerStatus(ctx, task.GetUserCred(), []string{computeapi.CONTAINER_STATUS_RUNNING}, 120) + _, err = llm.WaitServiceReady(ctx, task.GetUserCred(), 0) if err != nil { - task.taskFailed(ctx, llm, errors.Wrap(err, "WaitContainerStatus")) + task.taskFailed(ctx, llm, errors.Wrap(err, "WaitServiceReady")) return } err = llm.GetLLMContainerDriver().StartLLM(ctx, task.GetUserCred(), llm) diff --git a/pkg/llm/tasks/llm/llm_restart_task.go b/pkg/llm/tasks/llm/llm_restart_task.go index 66722fbe4ca..cf86dae6b22 100644 --- a/pkg/llm/tasks/llm/llm_restart_task.go +++ b/pkg/llm/tasks/llm/llm_restart_task.go @@ -451,6 +451,10 @@ func (task *LLMRestartTask) OnResetDiskComplete(ctx context.Context, obj db.ISta func (task *LLMRestartTask) OnStartComplete(ctx context.Context, obj db.IStandaloneModel, body jsonutils.JSONObject) { llm := obj.(*models.SLLM) + if _, err := llm.WaitServiceReady(ctx, task.GetUserCred(), 0); err != nil { + task.taskFailed(ctx, llm, errors.Wrap(err, "WaitServiceReady").Error()) + return + } task.taskComplete(ctx, llm) } diff --git a/pkg/llm/tasks/llm/llm_start_task.go b/pkg/llm/tasks/llm/llm_start_task.go index 2c9b02baa70..96625463554 100644 --- a/pkg/llm/tasks/llm/llm_start_task.go +++ b/pkg/llm/tasks/llm/llm_start_task.go @@ -32,7 +32,9 @@ func (task *LLMStartTask) taskFailed(ctx context.Context, llm *models.SLLM, err } func (task *LLMStartTask) taskComplete(ctx context.Context, llm *models.SLLM) { - llm.SetStatus(ctx, task.GetUserCred(), api.LLM_STATUS_RUNNING, "start complete") + if !task.HasParentTask() { + llm.SetStatus(ctx, task.GetUserCred(), api.LLM_STATUS_RUNNING, "start complete") + } // llm.NotifyRequest(ctx, task.GetUserCred(), notify.ActionStart, nil, true) task.SetStageComplete(ctx, nil) } diff --git a/pkg/llm/tasks/llm/llm_sync_status_task.go b/pkg/llm/tasks/llm/llm_sync_status_task.go index 33d63514eaa..5f11f43843f 100644 --- a/pkg/llm/tasks/llm/llm_sync_status_task.go +++ b/pkg/llm/tasks/llm/llm_sync_status_task.go @@ -109,7 +109,13 @@ func (task *LLMSyncStatusTask) OnInit(ctx context.Context, obj db.IStandaloneMod } task.setLLMStatus(ctx, llm, srv.Status, "stop server") } else { - task.setLLMStatus(ctx, llm, srv.Status, "WaitServerStatus") + resolved, err := models.ResolveLLMStatusFromServerDetails(ctx, llm, srv) + if err != nil { + return nil, errors.Wrap(err, "ResolveLLMStatusFromServerDetails") + } + if resolved.Update { + task.setLLMStatus(ctx, llm, resolved.Status, resolved.Reason) + } } volume, _ := llm.GetVolume()