diff --git a/controllers/hosts.go b/controllers/hosts.go index 6cace29cc..0f194bddb 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -70,6 +70,10 @@ func hostHandlers(r *mux.Router) { Methods(http.MethodPut) r.HandleFunc("/api/v1/host/{hostid}/peer_info", AuthorizeHost(http.HandlerFunc(getHostPeerInfo))). Methods(http.MethodGet) + r.HandleFunc("/api/v1/host/{hostid}/posture_status", AuthorizeHost(http.HandlerFunc(getHostPostureStatus))). + Methods(http.MethodGet) + r.HandleFunc("/api/v1/host/{hostid}/posture_status/ui", logic.SecurityCheck(true, http.HandlerFunc(getHostPostureStatus))). + Methods(http.MethodGet) r.HandleFunc("/api/v1/pending_hosts", logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts))). Methods(http.MethodGet) r.HandleFunc("/api/v1/pending_hosts/approve/{id}", logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost))). @@ -1643,6 +1647,92 @@ func getHostPeerInfo(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponseWithJson(w, r, peerInfo, "fetched host peer info") } +// @Summary Get the host's last-evaluated posture status +// @Router /api/v1/host/{hostid}/posture_status [get] +// @Tags Hosts +// @Security oauth +// @Produce json +// @Param hostid path string true "Host ID" +// @Success 200 {object} models.HostPostureStatus +// @Failure 400 {object} models.ErrorResponse +func getHostPostureStatus(w http.ResponseWriter, r *http.Request) { + hostIDStr := mux.Vars(r)["hostid"] + hostID, err := uuid.Parse(hostIDStr) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("failed to parse host id: %w", err), logic.BadReq)) + return + } + + host := &schema.Host{ID: hostID} + if err := host.Get(r.Context()); err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{Code: http.StatusBadRequest, Message: err.Error()}) + return + } + + if strings.Contains(r.URL.Path, "/posture_status/ui") { + if err := logic.CheckUIHostReadAccess(r, host); err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Forbidden)) + return + } + } + + resp := models.HostPostureStatus{ + HostID: hostIDStr, + Networks: []models.NetworkPostureStatus{}, + } + + // MDM block - best-effort: only populated if an integration is configured and + // a sync row exists for the host. + mdmIntg := &schema.Integration{Type: "mdm"} + if mdmIntegrations, err := mdmIntg.ListByType(r.Context()); err == nil && len(mdmIntegrations) > 0 { + state := &schema.DeviceMDMState{HostID: hostIDStr, Provider: mdmIntegrations[0].ID} + if err := state.Get(r.Context()); err == nil { + resp.MDM = &models.HostMDMStatus{ + Provider: state.Provider, + MatchedBy: state.MatchedBy, + Enrolled: state.Enrolled, + Compliant: state.Compliant, + LastSyncedAt: state.LastSyncedAt, + } + } + } + + // Per-network status - copy from already-evaluated nodes belonging to the + // host. No new posture computation happens on this read path (v1). + nodes, err := logic.GetAllNodes() + if err != nil { + logic.ReturnErrorResponse(w, r, models.ErrorResponse{Code: http.StatusInternalServerError, Message: err.Error()}) + return + } + var latest time.Time + for _, n := range nodes { + if n.HostID != hostID || n.IsStatic { + continue + } + entry := models.NetworkPostureStatus{ + NetworkID: n.Network, + NodeID: n.ID.String(), + Severity: n.PostureCheckViolationSeverityLevel, + Violations: append([]models.Violation{}, n.PostureChecksViolations...), + } + switch { + case len(entry.Violations) == 0: + entry.Status = models.PostureStatusPass + case entry.Severity >= schema.SeverityHigh: + entry.Status = models.PostureStatusFail + default: + entry.Status = models.PostureStatusWarn + } + if n.LastEvaluatedAt.After(latest) { + latest = n.LastEvaluatedAt + } + resp.Networks = append(resp.Networks, entry) + } + resp.EvaluatedAt = latest + + logic.ReturnSuccessResponseWithJson(w, r, resp, "fetched posture status") +} + // @Summary List pending hosts in a network // @Router /api/v1/pending_hosts [get] // @Tags Hosts diff --git a/controllers/server.go b/controllers/server.go index bbb432231..87e74feaf 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -232,7 +232,6 @@ func getSettings(w http.ResponseWriter, r *http.Request) { if scfg.OktaAPIToken != "" { scfg.OktaAPIToken = logic.Mask() } - logic.ReturnSuccessResponseWithJson(w, r, scfg, "fetched server settings successfully") } diff --git a/logic/hosts.go b/logic/hosts.go index e18d6f7a9..924a3c2a5 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "net" + "net/http" "sort" "strings" "sync" @@ -43,6 +44,17 @@ var GetPostureCheckDeviceInfoByNode = func(node *models.Node) (d models.PostureC return } +// SyncHostMDMState refreshes MDM posture state for a host (no-op in community). +var SyncHostMDMState = func(ctx context.Context, hostID string) error { + return nil +} + +// CheckUIHostReadAccess verifies a dashboard user may read the given host. +// Overridden in Pro to enforce network-scoped host access. +var CheckUIHostReadAccess = func(r *http.Request, host *schema.Host) error { + return nil +} + const ( maxPort = 1<<16 - 1 minPort = 1025 @@ -142,6 +154,9 @@ func UpdateHost(newHost, currentHost *schema.Host) { newHost.Nodes = currentHost.Nodes newHost.PublicKey = currentHost.PublicKey newHost.TrafficKeyPublic = currentHost.TrafficKeyPublic + newHost.EntraDeviceID = currentHost.EntraDeviceID + newHost.SerialNumber = currentHost.SerialNumber + newHost.HardwareUUID = currentHost.HardwareUUID // changeable fields if len(newHost.Version) == 0 { newHost.Version = currentHost.Version @@ -245,6 +260,18 @@ func UpdateHostFromClient(newHost, currHost *schema.Host) (isEndpointChanged, se if newHost.Interface != "" { currHost.Interface = newHost.Interface } + // MDM device-matching identifiers: only overwrite if the netclient reported + // a non-empty value, so we don't clobber a previously reported identifier + // when a later check-in omits the field. + if newHost.EntraDeviceID != "" { + currHost.EntraDeviceID = newHost.EntraDeviceID + } + if newHost.SerialNumber != "" { + currHost.SerialNumber = newHost.SerialNumber + } + if newHost.HardwareUUID != "" { + currHost.HardwareUUID = newHost.HardwareUUID + } if isEndpointChanged || currHost.Location == "" || currHost.CountryCode == "" { var nodeIP net.IP if currHost.EndpointIP != nil { @@ -358,6 +385,10 @@ func RemoveHost(h *schema.Host, forceDelete bool) error { slog.Error("failed to delete node", "node", node.ID, "host", h.ID, "error", err) } } + mdmState := &schema.DeviceMDMState{HostID: h.ID.String()} + if err := mdmState.DeleteByHostID(db.WithContext(context.TODO())); err != nil { + slog.Error("failed to delete mdm state for host", "host", h.ID, "error", err) + } return h.Delete(db.WithContext(context.TODO())) } diff --git a/models/api_host.go b/models/api_host.go index 15bd1a0c1..ba91b6d32 100644 --- a/models/api_host.go +++ b/models/api_host.go @@ -144,5 +144,8 @@ func (a *ApiHost) ConvertAPIHostToNMHost(currentHost *schema.Host) *schema.Host h.EnableFlowLogs = a.EnableFlowLogs h.Location = currentHost.Location h.CountryCode = currentHost.CountryCode + h.EntraDeviceID = currentHost.EntraDeviceID + h.SerialNumber = currentHost.SerialNumber + h.HardwareUUID = currentHost.HardwareUUID return &h } diff --git a/models/host.go b/models/host.go index 08b44327a..f9954aa82 100644 --- a/models/host.go +++ b/models/host.go @@ -80,6 +80,13 @@ type Host struct { Location string `json:"location"` // Format: "lat,lon" CountryCode string `json:"country_code"` EnableFlowLogs bool `json:"enable_flow_logs" yaml:"enable_flow_logs"` + + // MDM device-matching identifiers. Reported by netclient on host check-in + // and consumed by the MDM sync worker to match a Netmaker host to its + // upstream MDM-managed device record. + EntraDeviceID string `json:"entra_device_id" yaml:"entra_device_id"` + SerialNumber string `json:"serial_number" yaml:"serial_number"` + HardwareUUID string `json:"hardware_uuid" yaml:"hardware_uuid"` } // FormatBool converts a boolean to a [yes|no] string diff --git a/models/posture_status.go b/models/posture_status.go new file mode 100644 index 000000000..c70f5121d --- /dev/null +++ b/models/posture_status.go @@ -0,0 +1,42 @@ +package models + +import ( + "time" + + "github.com/gravitl/netmaker/schema" +) + +// HostPostureStatus is the netclient-facing summary of a host's last evaluated +// posture state. Returned by GET /api/v1/host/{hostid}/posture_status. +type HostPostureStatus struct { + HostID string `json:"host_id"` + EvaluatedAt time.Time `json:"evaluated_at"` + MDM *HostMDMStatus `json:"mdm,omitempty"` + Networks []NetworkPostureStatus `json:"networks"` +} + +// HostMDMStatus is the current MDM sync snapshot for the host's configured +// MDM provider (if any). +type HostMDMStatus struct { + Provider string `json:"provider"` + MatchedBy string `json:"matched_by"` + Enrolled bool `json:"enrolled"` + Compliant bool `json:"compliant"` + LastSyncedAt time.Time `json:"last_synced_at"` +} + +// NetworkPostureStatus describes posture state for a single (host, network). +type NetworkPostureStatus struct { + NetworkID string `json:"network_id"` + NodeID string `json:"node_id"` + Severity schema.Severity `json:"severity"` + Status string `json:"status"` // pass | warn | fail + Violations []Violation `json:"violations"` +} + +// Network posture status values. +const ( + PostureStatusPass = "pass" + PostureStatusWarn = "warn" + PostureStatusFail = "fail" +) diff --git a/models/structs.go b/models/structs.go index 541b6e58c..51f4a5408 100644 --- a/models/structs.go +++ b/models/structs.go @@ -476,6 +476,11 @@ type PostureCheckDeviceInfo struct { Tags map[TagID]struct{} IsUser bool UserGroups map[schema.UserGroupID]struct{} + // HostID is the Netmaker host's UUID; used to look up MDM state. + HostID string + // MDMState is the most recent sync snapshot for the configured MDM + // provider; nil if MDM is not configured or the host hasn't synced yet. + MDMState *schema.DeviceMDMState } type Violation struct { diff --git a/mq/handlers.go b/mq/handlers.go index 0d18045d6..a17ce4f1a 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -414,6 +414,29 @@ func HandleHostCheckin(h, currentHost *schema.Host) bool { slog.Info("updated host after check-in", "name", currentHost.Name, "id", currentHost.ID) } + // Persist MDM device-matching identifiers if the netclient reported any + // new values. These don't affect peers, so don't roll into ifaceDelta. + mdmChanged := false + if h.EntraDeviceID != "" && h.EntraDeviceID != currentHost.EntraDeviceID { + currentHost.EntraDeviceID = h.EntraDeviceID + mdmChanged = true + } + if h.SerialNumber != "" && h.SerialNumber != currentHost.SerialNumber { + currentHost.SerialNumber = h.SerialNumber + mdmChanged = true + } + if h.HardwareUUID != "" && h.HardwareUUID != currentHost.HardwareUUID { + currentHost.HardwareUUID = h.HardwareUUID + mdmChanged = true + } + if mdmChanged { + if err := logic.UpsertHost(currentHost); err != nil { + slog.Error("failed to update mdm identifiers after check-in", "name", h.Name, "id", h.ID, "error", err) + } else if currentHost.EntraDeviceID != "" { + go logic.SyncHostMDMState(context.Background(), currentHost.ID.String()) + } + } + slog.Info("check-in processed for host", "name", h.Name, "id", h.ID) return ifaceDelta } diff --git a/pro/controllers/integrations.go b/pro/controllers/integrations.go index 8ba3667a7..6d43dbea6 100644 --- a/pro/controllers/integrations.go +++ b/pro/controllers/integrations.go @@ -8,11 +8,15 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/gravitl/netmaker/grpc/siem" + "github.com/gravitl/netmaker/db" + grpcs "github.com/gravitl/netmaker/grpc/siem" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/pro/integration" + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" + siempkg "github.com/gravitl/netmaker/pro/integration/siem" logic2 "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" "google.golang.org/protobuf/types/known/structpb" @@ -21,14 +25,19 @@ import ( ) func IntegrationHandlers(r *mux.Router) { + r.HandleFunc("/api/v1/integrations/mdm/providers", + logic.SecurityCheck(true, http.HandlerFunc(listMDMProviders))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/integrations/mdm/sync", + logic.SecurityCheck(true, http.HandlerFunc(triggerMDMSync))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/integrations/mdm/device_state", + logic.SecurityCheck(true, http.HandlerFunc(listMDMDeviceState))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/integrations/{type}", logic.SecurityCheck(true, http.HandlerFunc(getIntegration))).Methods(http.MethodGet) r.HandleFunc("/api/v1/integrations/{type}/{id}", logic.SecurityCheck(true, http.HandlerFunc(upsertIntegration))).Methods(http.MethodPut) r.HandleFunc("/api/v1/integrations/{type}/{id}", logic.SecurityCheck(true, http.HandlerFunc(deleteIntegration))).Methods(http.MethodDelete) r.HandleFunc("/api/v1/integrations/{type}/{id}/test", logic.SecurityCheck(true, http.HandlerFunc(testIntegration))).Methods(http.MethodPost) } -// extractAndValidateIntegration pulls {type} and {id} from the URL -// and validates both against the provider registry. func extractAndValidateIntegration(w http.ResponseWriter, r *http.Request) (integration.Type, integration.ProviderID, bool) { vars := mux.Vars(r) intType := integration.Type(vars["type"]) @@ -43,29 +52,23 @@ func extractAndValidateIntegration(w http.ResponseWriter, r *http.Request) (inte } // @Summary Get an integration -// @Router /api/v1/integrations/{type}/{id} [get] +// @Router /api/v1/integrations/{type} [get] // @Tags Integrations // @Security oauth // @Produce json -// @Param type path string true "Integration type (e.g. siem)" +// @Param type path string true "Integration type (e.g. siem, mdm)" // @Success 200 {object} schema.Integration // @Failure 400 {object} models.ErrorResponse // @Failure 404 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func getIntegration(w http.ResponseWriter, r *http.Request) { intType := integration.Type(mux.Vars(r)["type"]) - - // hardcoding a correct provider id to do use the same function for validating integration type is siem. - // TODO: change provider when other integration types are introduced. - _, err := integration.Lookup(intType, integration.ProviderDatadog) - if err != nil { - logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + if !integration.TypeExists(intType) { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("unknown integration type '%s'", intType), logic.BadReq)) return } - intg := &schema.Integration{ - Type: string(intType), - } + intg := &schema.Integration{Type: string(intType)} integrations, err := intg.ListByType(r.Context()) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) @@ -98,9 +101,9 @@ func getIntegration(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Accept json // @Produce json -// @Param type path string true "Integration type (e.g. siem)" -// @Param id path string true "Provider ID (e.g. splunk)" -// @Param body body schema.Integration true "Integration config" +// @Param type path string true "Integration type (e.g. siem, mdm)" +// @Param id path string true "Provider ID (e.g. splunk, intune)" +// @Param body body object true "Integration config" // @Success 200 {object} schema.Integration // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse @@ -110,9 +113,7 @@ func upsertIntegration(w http.ResponseWriter, r *http.Request) { return } - intg := &schema.Integration{ - Type: string(intType), - } + intg := &schema.Integration{Type: string(intType)} integrations, err := intg.ListByType(r.Context()) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) @@ -122,8 +123,8 @@ func upsertIntegration(w http.ResponseWriter, r *http.Request) { if len(integrations) > 0 { var isUpsert bool if len(integrations) == 1 { - intg := integrations[0] - if intg.ID == string(id) && intg.Type == string(intType) { + existing := integrations[0] + if existing.ID == string(id) && existing.Type == string(intType) { isUpsert = true } } @@ -141,7 +142,15 @@ func upsertIntegration(w http.ResponseWriter, r *http.Request) { return } - provider, _ := integration.Lookup(intType, id) // already validated above + if intType == integration.TypeMDM { + config, err = mergeMDMConfig(r.Context(), string(id), config, len(integrations) == 1) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + return + } + } + + provider, _ := integration.Lookup(intType, id) err = provider.Validate(config) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) @@ -160,30 +169,9 @@ func upsertIntegration(w http.ResponseWriter, r *http.Request) { return } - go func(configBytes json.RawMessage) { - config := make(map[string]interface{}) - err = json.Unmarshal(configBytes, &config) - if err != nil { - logger.Log(0, fmt.Sprintf("error unmarshaling config: %s", err.Error())) - return - } - - configStruct, err := structpb.NewStruct(config) - if err != nil { - logger.Log(0, fmt.Sprintf("error constructing struct val: %s", err.Error())) - return - } - - err = siem.Client().Init(context.Background(), string(id), configStruct) - if err != nil { - logger.Log(0, fmt.Sprintf("error upserting siem integration %s on exporter: %v", id, err)) - - err = mq.PublishIntegrationUpsert(string(id)) - if err != nil { - logger.Log(0, fmt.Sprintf("error publishing siem integration upsert event %s on exporter: %v", id, err)) - } - } - }(config) + if intType == integration.TypeSIEM { + go initSIEMExporter(string(id), config) + } err = redactConfig(intg) if err != nil { @@ -195,19 +183,44 @@ func upsertIntegration(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponseWithJson(w, r, intg, "integration saved") } +func initSIEMExporter(id string, configBytes json.RawMessage) { + config := make(map[string]interface{}) + err := json.Unmarshal(configBytes, &config) + if err != nil { + logger.Log(0, fmt.Sprintf("error unmarshaling config: %s", err.Error())) + return + } + + configStruct, err := structpb.NewStruct(config) + if err != nil { + logger.Log(0, fmt.Sprintf("error constructing struct val: %s", err.Error())) + return + } + + err = grpcs.Client().Init(context.Background(), id, configStruct) + if err != nil { + logger.Log(0, fmt.Sprintf("error upserting siem integration %s on exporter: %v", id, err)) + + err = mq.PublishIntegrationUpsert(id) + if err != nil { + logger.Log(0, fmt.Sprintf("error publishing siem integration upsert event %s on exporter: %v", id, err)) + } + } +} + // @Summary Delete an integration // @Router /api/v1/integrations/{type}/{id} [delete] // @Tags Integrations // @Security oauth // @Produce json -// @Param type path string true "Integration type (e.g. siem)" -// @Param id path string true "Provider ID (e.g. splunk)" +// @Param type path string true "Integration type (e.g. siem, mdm)" +// @Param id path string true "Provider ID" // @Success 200 {object} models.SuccessResponse // @Failure 400 {object} models.ErrorResponse // @Failure 404 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func deleteIntegration(w http.ResponseWriter, r *http.Request) { - _, id, ok := extractAndValidateIntegration(w, r) + intType, id, ok := extractAndValidateIntegration(w, r) if !ok { return } @@ -229,17 +242,19 @@ func deleteIntegration(w http.ResponseWriter, r *http.Request) { return } - go func() { - err := siem.Client().Terminate(context.Background()) - if err != nil { - logger.Log(0, fmt.Sprintf("error terminating siem integration %s on exporter: %v", id, err)) - - err = mq.PublishIntegrationDelete(string(id)) + if intType == integration.TypeSIEM { + go func() { + err := grpcs.Client().Terminate(context.Background()) if err != nil { - logger.Log(0, fmt.Sprintf("error publishing siem integration delete event %s on exporter: %v", id, err)) + logger.Log(0, fmt.Sprintf("error terminating siem integration %s on exporter: %v", id, err)) + + err = mq.PublishIntegrationDelete(string(id)) + if err != nil { + logger.Log(0, fmt.Sprintf("error publishing siem integration delete event %s on exporter: %v", id, err)) + } } - } - }() + }() + } logic2.SkipPushToSiem() logic.ReturnSuccessResponse(w, r, "integration deleted") @@ -251,8 +266,8 @@ func deleteIntegration(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Accept json // @Produce json -// @Param type path string true "Integration type (e.g. siem)" -// @Param id path string true "Provider ID (e.g. splunk)" +// @Param type path string true "Integration type (e.g. siem, mdm)" +// @Param id path string true "Provider ID" // @Param body body object true "Provider config to test (not saved)" // @Success 200 {object} models.SuccessResponse // @Failure 400 {object} models.ErrorResponse @@ -270,45 +285,166 @@ func testIntegration(w http.ResponseWriter, r *http.Request) { return } - provider, _ := integration.Lookup(intType, id) // already validated above + if intType == integration.TypeMDM { + active, _ := mdmpkg.GetActive(r.Context()) + hasExisting := active != nil && active.ID == string(id) + config, err = mergeMDMConfig(r.Context(), string(id), config, hasExisting) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) + return + } + } + + provider, _ := integration.Lookup(intType, id) err = provider.Validate(config) if err != nil { + if intType == integration.TypeMDM { + logMDMVerifyEvent(r, string(id), false, err.Error()) + } logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq)) return } err = provider.Test(config) if err != nil { + if intType == integration.TypeMDM { + logMDMVerifyEvent(r, string(id), false, err.Error()) + } logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("integration test failed: %w", err), logic.BadReq)) return } + if intType == integration.TypeMDM { + logMDMVerifyEvent(r, string(id), true, "") + } + logic.ReturnSuccessResponse(w, r, "integration test passed") } +// @Summary List built-in MDM provider types +// @Router /api/v1/integrations/mdm/providers [get] +// @Tags Integrations +// @Security oauth +// @Produce json +// @Success 200 {array} mdmpkg.ProviderType +func listMDMProviders(w http.ResponseWriter, r *http.Request) { + logic.ReturnSuccessResponseWithJson(w, r, mdmpkg.ListProviderTypes(), "fetched mdm provider types") +} + +// @Summary Trigger an out-of-cycle MDM sync +// @Router /api/v1/integrations/mdm/sync [post] +// @Tags Integrations +// @Security oauth +// @Produce json +// @Success 202 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +func triggerMDMSync(w http.ResponseWriter, r *http.Request) { + active, err := mdmpkg.GetActive(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + if active == nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("no MDM integration configured"), logic.BadReq)) + return + } + + syncCtx := db.WithContext(context.Background()) + go func() { + if err := mdmpkg.RunMDMSyncForce(syncCtx); err != nil { + logger.Log(0, "mdm: manual sync failed:", err.Error()) + } + }() + + logic.LogEvent(&models.Event{ + Action: schema.MDMSync, + TriggeredBy: r.Header.Get("user"), + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: schema.UserSub, + }, + Target: models.Subject{ + ID: active.ID, + Name: active.ID, + Type: schema.MDMSub, + }, + Origin: schema.Dashboard, + Diff: models.Diff{ + New: map[string]interface{}{"status": "queued", "provider": active.ID}, + }, + }) + logic.ReturnSuccessResponseWithJson(w, r, map[string]any{"queued": true}, "mdm sync queued") +} + +// @Summary List synced MDM device states +// @Router /api/v1/integrations/mdm/device_state [get] +// @Tags Integrations +// @Security oauth +// @Produce json +// @Param host_id query string false "Filter by host UUID" +// @Param provider query string false "Filter by provider name" +// @Success 200 {array} schema.DeviceMDMState +func listMDMDeviceState(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + hostID := r.URL.Query().Get("host_id") + provider := r.URL.Query().Get("provider") + state := &schema.DeviceMDMState{HostID: hostID, Provider: provider} + var out []schema.DeviceMDMState + var err error + switch { + case hostID != "" && provider != "": + err = state.Get(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("mdm device state not found"), logic.NotFound)) + return + } + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + out = []schema.DeviceMDMState{*state} + case hostID != "": + out, err = state.ListByHost(ctx) + case provider != "": + out, err = state.ListByProvider(ctx) + default: + out, err = state.ListAll(ctx) + } + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal)) + return + } + logic.ReturnSuccessResponseWithJson(w, r, out, "fetched mdm device states") +} + func redactConfig(intg *schema.Integration) error { - switch integration.ProviderID(intg.ID) { - case integration.ProviderDatadog: - var config integration.DatadogConfig - err := json.Unmarshal(intg.Config, &config) + if intg.Type == string(integration.TypeMDM) { + redacted, err := mdmpkg.RedactConfig(intg.ID, json.RawMessage(intg.Config)) if err != nil { return err } + intg.Config = datatypes.JSON(redacted) + return nil + } + switch integration.ProviderID(intg.ID) { + case integration.ProviderDatadog: + var config siempkg.DatadogConfig + if err := json.Unmarshal(intg.Config, &config); err != nil { + return err + } config.APIKey = logic.Mask() configBytes, err := json.Marshal(config) if err != nil { return err } - intg.Config = configBytes case integration.ProviderElastic: - var config integration.ElasticConfig - err := json.Unmarshal(intg.Config, &config) - if err != nil { + var config siempkg.ElasticConfig + if err := json.Unmarshal(intg.Config, &config); err != nil { return err } - if config.APIKey != "" { config.APIKey = logic.Mask() } @@ -319,37 +455,114 @@ func redactConfig(intg *schema.Integration) error { if err != nil { return err } - intg.Config = configBytes case integration.ProviderSentinel: - var config integration.SentinelConfig - err := json.Unmarshal(intg.Config, &config) - if err != nil { + var config siempkg.SentinelConfig + if err := json.Unmarshal(intg.Config, &config); err != nil { return err } - config.SharedKey = logic.Mask() configBytes, err := json.Marshal(config) if err != nil { return err } - intg.Config = configBytes case integration.ProviderSplunk: - var config integration.SplunkConfig - err := json.Unmarshal(intg.Config, &config) - if err != nil { + var config siempkg.SplunkConfig + if err := json.Unmarshal(intg.Config, &config); err != nil { return err } - config.HECToken = logic.Mask() configBytes, err := json.Marshal(config) if err != nil { return err } - intg.Config = configBytes } - return nil } + +func mergeMDMConfig(ctx context.Context, providerID string, incoming json.RawMessage, hasExisting bool) (json.RawMessage, error) { + var patch map[string]json.RawMessage + if err := json.Unmarshal(incoming, &patch); err != nil { + return nil, fmt.Errorf("invalid request body: %w", err) + } + changed := false + for _, field := range []string{"client_secret", "api_token"} { + merged, ok, err := mergeMDMSecretField(ctx, providerID, patch, field, hasExisting) + if err != nil { + return nil, err + } + if ok { + patch = merged + changed = true + } + } + if !changed { + return incoming, nil + } + return json.Marshal(patch) +} + +func mergeMDMSecretField( + ctx context.Context, + providerID string, + patch map[string]json.RawMessage, + field string, + hasExisting bool, +) (map[string]json.RawMessage, bool, error) { + secret, ok := patch[field] + if !ok { + return patch, false, nil + } + var secretStr string + if err := json.Unmarshal(secret, &secretStr); err != nil { + return patch, false, nil + } + if !isMaskedSecret(secretStr) || !hasExisting { + return patch, false, nil + } + + existing := &schema.Integration{ID: providerID} + if err := existing.Get(ctx); err != nil { + return patch, false, nil + } + + var stored map[string]json.RawMessage + if err := json.Unmarshal(existing.Config, &stored); err != nil { + return nil, false, err + } + storedSecret, ok := stored[field] + if !ok { + return patch, false, nil + } + patch[field] = storedSecret + return patch, true, nil +} + +func isMaskedSecret(s string) bool { + return s == logic.Mask() || s == "********" +} + +func logMDMVerifyEvent(r *http.Request, providerID string, ok bool, errMsg string) { + diff := map[string]interface{}{"status": "ok", "provider": providerID} + if !ok { + diff = map[string]interface{}{"status": "failed", "error": errMsg} + } + logic.LogEvent(&models.Event{ + Action: schema.MDMVerify, + TriggeredBy: r.Header.Get("user"), + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: schema.UserSub, + }, + Target: models.Subject{ + ID: providerID, + Name: providerID, + Type: schema.MDMSub, + }, + Origin: schema.Dashboard, + Diff: models.Diff{New: diff}, + }) +} diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index bba0c3330..a1514b9c8 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -26,6 +26,7 @@ func PostureCheckHandlers(r *mux.Router) { r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks))).Methods(http.MethodGet) r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck))).Methods(http.MethodPut) r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/posture_check/run", logic.SecurityCheck(true, http.HandlerFunc(triggerPostureChecks))).Methods(http.MethodPost) r.HandleFunc("/api/v1/posture_check/attrs", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs))).Methods(http.MethodGet) r.HandleFunc("/api/v1/posture_check/violations", logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes))).Methods(http.MethodGet) } @@ -79,6 +80,7 @@ func createPostureCheck(w http.ResponseWriter, r *http.Request) { UserGroups: req.UserGroups, Attribute: req.Attribute, Values: req.Values, + Config: req.Config, Severity: req.Severity, Status: true, CreatedBy: r.Header.Get("user"), @@ -192,14 +194,15 @@ func updatePostureCheck(w http.ResponseWriter, r *http.Request) { return } - if err := proLogic.ValidatePostureCheck(&updatePc); err != nil { + pc := schema.PostureCheck{ID: updatePc.ID} + err = pc.Get(db.WithContext(r.Context())) + if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - pc := schema.PostureCheck{ID: updatePc.ID} - err = pc.Get(db.WithContext(r.Context())) - if err != nil { + proLogic.MergePostureCheckUpdate(&pc, &updatePc) + if err := proLogic.ValidatePostureCheck(&updatePc); err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } @@ -231,6 +234,7 @@ func updatePostureCheck(w http.ResponseWriter, r *http.Request) { pc.UserGroups = updatePc.UserGroups pc.Attribute = updatePc.Attribute pc.Values = updatePc.Values + pc.Config = updatePc.Config pc.Description = updatePc.Description pc.Name = updatePc.Name pc.Severity = updatePc.Severity @@ -310,6 +314,48 @@ func deletePostureCheck(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponseWithJson(w, r, pc, "deleted posture check") } +// @Summary Trigger an out-of-cycle posture check evaluation +// @Router /api/v1/posture_check/run [post] +// @Tags Posture Check +// @Security oauth +// @Produce json +// @Success 202 {object} models.SuccessResponse +// @Failure 400 {object} models.ErrorResponse +// @Failure 401 {object} models.ErrorResponse +func triggerPostureChecks(w http.ResponseWriter, r *http.Request) { + if !proLogic.GetFeatureFlags().EnablePostureChecks { + logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture checks are not enabled on your plan"), logic.BadReq)) + return + } + + go func() { + if err := proLogic.RunPostureChecks(); err != nil { + logger.Log(0, "posture check: manual run failed:", err.Error()) + } + mq.PublishPeerUpdate(false) + }() + + logic.LogEvent(&models.Event{ + Action: schema.Sync, + TriggeredBy: r.Header.Get("user"), + Source: models.Subject{ + ID: r.Header.Get("user"), + Name: r.Header.Get("user"), + Type: schema.UserSub, + }, + Target: models.Subject{ + ID: string(schema.AllPostureCheckRsrcID), + Name: "all", + Type: schema.PostureCheckSub, + }, + Origin: schema.Dashboard, + Diff: models.Diff{ + New: map[string]interface{}{"status": "queued"}, + }, + }) + logic.ReturnSuccessResponseWithJson(w, r, map[string]any{"queued": true}, "posture checks queued") +} + // @Summary List Posture Check violated Nodes // @Router /api/v1/posture_check/violations [get] // @Tags Posture Check diff --git a/pro/initialize.go b/pro/initialize.go index 26bbea938..8c0c13918 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -21,6 +21,13 @@ import ( "github.com/gravitl/netmaker/pro/email" "github.com/gravitl/netmaker/pro/license" proLogic "github.com/gravitl/netmaker/pro/logic" + // Blank-import MDM provider packages so their init() registers with + // the integration/mdm registry. Add new providers by appending another import. + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" + _ "github.com/gravitl/netmaker/pro/integration/mdm/intune" + _ "github.com/gravitl/netmaker/pro/integration/mdm/iru" + _ "github.com/gravitl/netmaker/pro/integration/mdm/jamf" + _ "github.com/gravitl/netmaker/pro/integration/mdm/jumpcloud" "github.com/gravitl/netmaker/pro/orchestrator/extensions" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" @@ -212,6 +219,8 @@ func InitPro() { logic.ValidateEgressReq = proLogic.ValidateEgressReq logic.CheckPostureViolations = proLogic.CheckPostureViolations logic.GetPostureCheckDeviceInfoByNode = proLogic.GetPostureCheckDeviceInfoByNode + logic.SyncHostMDMState = mdmpkg.SyncHostMDMState + logic.CheckUIHostReadAccess = proLogic.CheckUIHostReadAccess logic.StartFlowCleanupLoop = proLogic.StartFlowCleanupLoop logic.StopFlowCleanupLoop = proLogic.StopFlowCleanupLoop // Expose JIT functions diff --git a/pro/integration/mdm/active.go b/pro/integration/mdm/active.go new file mode 100644 index 000000000..f317b1ff6 --- /dev/null +++ b/pro/integration/mdm/active.go @@ -0,0 +1,51 @@ +package mdm + +import ( + "context" + "encoding/json" + "errors" + + "github.com/gravitl/netmaker/schema" +) + +const integrationType = "mdm" + +// GetActive returns the configured MDM integration row, or nil if none exists. +func GetActive(ctx context.Context) (*schema.Integration, error) { + intg := &schema.Integration{Type: integrationType} + list, err := intg.ListByType(ctx) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + if len(list) > 1 { + return nil, errors.New("multiple mdm integrations configured") + } + return &list[0], nil +} + +// BuildActive builds the provider for the active MDM integration. +func BuildActive(ctx context.Context) (Provider, error) { + intg, err := GetActive(ctx) + if err != nil { + return nil, err + } + if intg == nil { + return nil, nil + } + return Build(intg.ID, json.RawMessage(intg.Config)) +} + +// ActiveProviderID returns the provider id of the active MDM integration, or "" if none. +func ActiveProviderID(ctx context.Context) (string, error) { + intg, err := GetActive(ctx) + if err != nil { + return "", err + } + if intg == nil { + return "", nil + } + return intg.ID, nil +} diff --git a/pro/integration/mdm/config.go b/pro/integration/mdm/config.go new file mode 100644 index 000000000..a9cf26c7f --- /dev/null +++ b/pro/integration/mdm/config.go @@ -0,0 +1,225 @@ +package mdm + +import ( + "encoding/json" + "fmt" + "strings" + + "github.com/gravitl/netmaker/logic" +) + +const ( + ProviderIntune = "intune" + ProviderJamf = "jamf" + ProviderJumpCloud = "jumpcloud" + ProviderIru = "iru" +) + +// SyncSettings are shared across MDM provider configs. +type SyncSettings struct { + SyncEnabled bool `json:"sync_enabled"` + SyncIntervalMinutes int `json:"sync_interval_minutes"` +} + +// IntuneConfig is stored in integrations_v1.config for the intune provider. +type IntuneConfig struct { + SyncSettings + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + TenantID string `json:"tenant_id"` +} + +// JamfConfig is stored in integrations_v1.config for the jamf provider. +// +// ComplianceVendors optionally limits device-trust evaluation to named +// complianceVendor values from Jamf Conditional Access (e.g. "Jamf", "Intune"). +// When empty, every applicable compliance record for the device must be COMPLIANT. +type JamfConfig struct { + SyncSettings + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + BaseURL string `json:"base_url"` + ComplianceVendors []string `json:"compliance_vendors,omitempty"` +} + +// JumpCloudConfig is stored in integrations_v1.config for the jumpcloud provider. +// Auth uses a JumpCloud service account (client_id + client_secret) against +// admin-oauth.id.jumpcloud.com with Basic auth. BaseURL defaults to https://console.jumpcloud.com. +// +// CompliancePolicyIDs optionally limits device-trust evaluation to specific JumpCloud +// policy object IDs. When empty, all policy statuses returned for each system must pass. +type JumpCloudConfig struct { + SyncSettings + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + BaseURL string `json:"base_url"` + CompliancePolicyIDs []string `json:"compliance_policy_ids,omitempty"` +} + +// IruConfig is stored in integrations_v1.config for the iru provider (Iru Endpoint +// Management, formerly Kandji). APIURL is the tenant API hostname from Settings +// (e.g. https://acme.api.iru.com or https://acme.api.kandji.io). +// +// ComplianceLibraryItemIDs optionally limits compliance evaluation to specific +// library item IDs from GET /devices/{id}/status. When empty, all parameters +// and library items must have status PASS. +type IruConfig struct { + SyncSettings + APIURL string `json:"api_url"` + APIToken string `json:"api_token"` + ComplianceLibraryItemIDs []string `json:"compliance_library_item_ids,omitempty"` +} + +// ValidateConfig validates provider config JSON for the given provider id. +func ValidateConfig(providerID string, configJSON json.RawMessage) error { + switch providerID { + case ProviderIntune: + var cfg IntuneConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return fmt.Errorf("invalid intune config: %w", err) + } + if cfg.TenantID == "" { + return fmt.Errorf("tenant_id is required") + } + if cfg.ClientID == "" { + return fmt.Errorf("client_id is required") + } + if cfg.ClientSecret == "" { + return fmt.Errorf("client_secret is required") + } + if cfg.SyncIntervalMinutes < 0 { + return fmt.Errorf("sync_interval_minutes must be >= 0") + } + return nil + case ProviderJamf: + var cfg JamfConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return fmt.Errorf("invalid jamf config: %w", err) + } + if strings.TrimSpace(cfg.BaseURL) == "" { + return fmt.Errorf("base_url is required") + } + if cfg.ClientID == "" { + return fmt.Errorf("client_id is required") + } + if cfg.ClientSecret == "" { + return fmt.Errorf("client_secret is required") + } + if cfg.SyncIntervalMinutes < 0 { + return fmt.Errorf("sync_interval_minutes must be >= 0") + } + return nil + case ProviderJumpCloud: + var cfg JumpCloudConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return fmt.Errorf("invalid jumpcloud config: %w", err) + } + if cfg.ClientID == "" { + return fmt.Errorf("client_id is required") + } + if cfg.ClientSecret == "" { + return fmt.Errorf("client_secret is required") + } + if cfg.SyncIntervalMinutes < 0 { + return fmt.Errorf("sync_interval_minutes must be >= 0") + } + return nil + case ProviderIru: + var cfg IruConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return fmt.Errorf("invalid iru config: %w", err) + } + apiURL := strings.TrimSpace(cfg.APIURL) + if apiURL == "" { + return fmt.Errorf("api_url is required") + } + if !strings.HasPrefix(strings.ToLower(apiURL), "https://") { + return fmt.Errorf("api_url must use https") + } + if cfg.APIToken == "" { + return fmt.Errorf("api_token is required") + } + if cfg.SyncIntervalMinutes < 0 { + return fmt.Errorf("sync_interval_minutes must be >= 0") + } + return nil + default: + return fmt.Errorf("unknown mdm provider %q", providerID) + } +} + +// ParseSyncSettings extracts sync settings from stored integration config. +func ParseSyncSettings(providerID string, configJSON json.RawMessage) (SyncSettings, error) { + switch providerID { + case ProviderIntune: + var cfg IntuneConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return SyncSettings{}, err + } + return cfg.SyncSettings, nil + case ProviderJamf: + var cfg JamfConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return SyncSettings{}, err + } + return cfg.SyncSettings, nil + case ProviderJumpCloud: + var cfg JumpCloudConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return SyncSettings{}, err + } + return cfg.SyncSettings, nil + case ProviderIru: + var cfg IruConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return SyncSettings{}, err + } + return cfg.SyncSettings, nil + default: + return SyncSettings{}, fmt.Errorf("unknown mdm provider %q", providerID) + } +} + +// RedactConfig returns config JSON with secrets masked for API responses. +func RedactConfig(providerID string, configJSON json.RawMessage) (json.RawMessage, error) { + switch providerID { + case ProviderIntune: + var cfg IntuneConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, err + } + if cfg.ClientSecret != "" { + cfg.ClientSecret = logic.Mask() + } + return json.Marshal(cfg) + case ProviderJamf: + var cfg JamfConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, err + } + if cfg.ClientSecret != "" { + cfg.ClientSecret = logic.Mask() + } + return json.Marshal(cfg) + case ProviderJumpCloud: + var cfg JumpCloudConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, err + } + if cfg.ClientSecret != "" { + cfg.ClientSecret = logic.Mask() + } + return json.Marshal(cfg) + case ProviderIru: + var cfg IruConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, err + } + if cfg.APIToken != "" { + cfg.APIToken = logic.Mask() + } + return json.Marshal(cfg) + default: + return configJSON, nil + } +} diff --git a/pro/integration/mdm/errors.go b/pro/integration/mdm/errors.go new file mode 100644 index 000000000..ee28fa901 --- /dev/null +++ b/pro/integration/mdm/errors.go @@ -0,0 +1,28 @@ +package mdm + +import "errors" + +// Posture-facing error codes returned by Entra-keyed MDM lookups. +var ( + ErrDeviceNotRegisteredInEntra = errors.New("device_not_registered_in_entra") + ErrDeviceNotEnrolledInIntune = errors.New("device_not_enrolled_in_intune") + ErrDeviceNotFoundInMDM = errors.New("device_not_found_in_mdm") +) + +// LookupErrorCode maps a lookup error to a stable posture violation code. +// Returns "" for non-lookup failures (e.g. network errors). +func LookupErrorCode(err error) string { + if err == nil { + return "" + } + if errors.Is(err, ErrDeviceNotRegisteredInEntra) { + return ErrDeviceNotRegisteredInEntra.Error() + } + if errors.Is(err, ErrDeviceNotEnrolledInIntune) { + return ErrDeviceNotEnrolledInIntune.Error() + } + if errors.Is(err, ErrDeviceNotFoundInMDM) { + return ErrDeviceNotFoundInMDM.Error() + } + return "" +} diff --git a/pro/integration/mdm/intune/intune.go b/pro/integration/mdm/intune/intune.go new file mode 100644 index 000000000..00a7aad21 --- /dev/null +++ b/pro/integration/mdm/intune/intune.go @@ -0,0 +1,242 @@ +// Package intune implements an MDM provider backed by Microsoft Intune via +// Microsoft Graph. Self-registers with pro/integration/mdm in init(). +package intune + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/url" + "strings" + "sync" + "time" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +const ( + providerName = mdmpkg.ProviderIntune + providerDisplay = "Microsoft Intune" + + tokenURLFmt = "https://login.microsoftonline.com/%s/oauth2/v2.0/token" + tokenScope = "https://graph.microsoft.com/.default" + entraDevicesURL = "https://graph.microsoft.com/v1.0/devices" + devicesURL = "https://graph.microsoft.com/v1.0/deviceManagement/managedDevices" + deviceSelect = "id,azureADDeviceId,serialNumber,hardwareInformation,deviceName,userPrincipalName,managementState,deviceRegistrationState,enrolledDateTime,complianceState,lastSyncDateTime" +) + +func init() { + mdmpkg.Register(providerName, providerDisplay, New) + mdmpkg.RegisterCapabilities(providerName, mdmpkg.Capabilities{ReportsCompliant: true}) +} + +// New builds an Intune provider from integration config JSON. +func New(configJSON json.RawMessage) (mdmpkg.Provider, error) { + var cfg mdmpkg.IntuneConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, fmt.Errorf("invalid intune config: %w", err) + } + if err := mdmpkg.ValidateConfig(providerName, configJSON); err != nil { + return nil, err + } + return &Client{ + tenantID: cfg.TenantID, + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + http: &http.Client{Timeout: 30 * time.Second}, + }, nil +} + +// Client implements mdmpkg.Provider against Microsoft Graph. +type Client struct { + tenantID string + clientID string + clientSecret string + http *http.Client + + tokenMu sync.Mutex + token string + tokenExp time.Time +} + +func (c *Client) Name() string { return providerName } + +func (c *Client) Capabilities() mdmpkg.Capabilities { + return mdmpkg.Capabilities{ReportsCompliant: true} +} + +func (c *Client) Verify(ctx context.Context) error { + tok, err := c.accessToken(ctx) + if err != nil { + return err + } + u := entraDevicesURL + "?$top=1&$select=" + url.QueryEscape("id") + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + var body struct { + Error errorBody `json:"error"` + } + _ = json.NewDecoder(resp.Body).Decode(&body) + if resp.StatusCode >= 400 || body.Error.Code != "" { + if body.Error.Message != "" { + return fmt.Errorf("intune verify failed: %s", body.Error.Message) + } + return fmt.Errorf("intune verify failed: http %d", resp.StatusCode) + } + return nil +} + +// ListManagedDevices is not used for Intune posture checks; hosts are resolved +// per entra_device_id via LookupByEntraDeviceID (/devices, then managedDevices +// only when /devices returns no match). +func (c *Client) ListManagedDevices(ctx context.Context) ([]mdmpkg.ManagedDevice, error) { + return nil, nil +} + +func (c *Client) accessToken(ctx context.Context) (string, error) { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + if c.token != "" && time.Until(c.tokenExp) > time.Minute { + return c.token, nil + } + form := url.Values{} + form.Set("grant_type", "client_credentials") + form.Set("client_id", c.clientID) + form.Set("client_secret", c.clientSecret) + form.Set("scope", tokenScope) + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, + fmt.Sprintf(tokenURLFmt, c.tenantID), + strings.NewReader(form.Encode()), + ) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := c.http.Do(req) + if err != nil { + return "", errors.New("intune token: " + err.Error()) + } + defer resp.Body.Close() + var body tokenResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return "", err + } + if body.AccessToken == "" { + if body.Error != "" { + return "", fmt.Errorf("intune token: %s: %s", body.Error, body.ErrorDescription) + } + return "", errors.New("intune token: empty response") + } + c.token = body.AccessToken + if body.ExpiresIn > 0 { + c.tokenExp = time.Now().Add(time.Duration(body.ExpiresIn) * time.Second) + } else { + c.tokenExp = time.Now().Add(50 * time.Minute) + } + return c.token, nil +} + +// intuneComplianceCompliant is true only when Graph reports complianceState +// as "compliant" (case-insensitive). +func intuneComplianceCompliant(state string) bool { + return strings.EqualFold(strings.TrimSpace(state), "compliant") +} + +// intuneDeviceEnrolled reports whether a managedDevices row represents an +// enrolled Intune device. +func intuneDeviceEnrolled(d managedDevice) bool { + if d.ManagementState != "" && !strings.EqualFold(d.ManagementState, "discovered") { + return true + } + if strings.EqualFold(d.DeviceRegistrationState, "registered") { + return true + } + if strings.TrimSpace(d.EnrolledDateTime) != "" { + return true + } + return false +} + +func normalize(d managedDevice, entraByName map[string]string) mdmpkg.ManagedDevice { + last, _ := time.Parse(time.RFC3339, d.LastSyncDateTime) + azureAD := d.AzureADDeviceID + if azureAD == "" && entraByName != nil { + if id, ok := entraByName[strings.ToLower(strings.TrimSpace(d.DeviceName))]; ok { + azureAD = id + } + } + return mdmpkg.ManagedDevice{ + ProviderDeviceID: d.ID, + AzureADDeviceID: azureAD, + SerialNumber: d.SerialNumber, + HardwareUUID: d.HardwareInformation.SerialNumber, + DeviceName: d.DeviceName, + UserPrincipalName: d.UserPrincipalName, + Enrolled: intuneDeviceEnrolled(d), + Compliant: intuneComplianceCompliant(d.ComplianceState), + LastSeenAt: last, + } +} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type managedDevicesPage struct { + Value []managedDevice `json:"value"` + NextLink string `json:"@odata.nextLink"` + Error errorBody `json:"error"` +} + +type managedDevice struct { + ID string `json:"id"` + AzureADDeviceID string `json:"azureADDeviceId"` + SerialNumber string `json:"serialNumber"` + DeviceName string `json:"deviceName"` + UserPrincipalName string `json:"userPrincipalName"` + ManagementState string `json:"managementState"` + DeviceRegistrationState string `json:"deviceRegistrationState"` + EnrolledDateTime string `json:"enrolledDateTime"` + ComplianceState string `json:"complianceState"` + LastSyncDateTime string `json:"lastSyncDateTime"` + HardwareInformation hardwareInformation `json:"hardwareInformation"` +} + +type entraDevice struct { + ID string `json:"id"` + DeviceID string `json:"deviceId"` + DisplayName string `json:"displayName"` + OperatingSystem string `json:"operatingSystem"` + TrustType string `json:"trustType"` + IsManaged bool `json:"isManaged"` + IsCompliant bool `json:"isCompliant"` +} + +type entraDevicesPage struct { + Value []entraDevice `json:"value"` + Error errorBody `json:"error"` +} + +type hardwareInformation struct { + SerialNumber string `json:"serialNumber"` +} + +type errorBody struct { + Code string `json:"code"` + Message string `json:"message"` +} diff --git a/pro/integration/mdm/intune/intune_test.go b/pro/integration/mdm/intune/intune_test.go new file mode 100644 index 000000000..18b0037a7 --- /dev/null +++ b/pro/integration/mdm/intune/intune_test.go @@ -0,0 +1,142 @@ +package intune + +import "testing" + +func TestManagedFromEntraDevice(t *testing.T) { + got := managedFromEntraDevice(entraDevice{ + ID: "ee155866-00b2-4476-9cba-c0dfa37f0224", + DisplayName: "WIN-PMV0N6INPC6", + IsManaged: true, + IsCompliant: true, + }, "32f5f9ec-cd23-41e0-94e8-6b372232ff40") + if !got.Enrolled || !got.Compliant { + t.Fatalf("expected managed compliant entra device, got enrolled=%v compliant=%v", + got.Enrolled, got.Compliant) + } + if got.AzureADDeviceID != "32f5f9ec-cd23-41e0-94e8-6b372232ff40" { + t.Fatalf("unexpected azure ad device id: %q", got.AzureADDeviceID) + } + if got.ProviderDeviceID != "ee155866-00b2-4476-9cba-c0dfa37f0224" { + t.Fatalf("unexpected provider device id: %q", got.ProviderDeviceID) + } +} + +func TestManagedFromManagedDeviceBackup(t *testing.T) { + got := managedFromManagedDeviceBackup(managedDevice{ + ID: "ee155866-00b2-4476-9cba-c0dfa37f0224", + DeviceName: "WIN-PMV0N6INPC6", + AzureADDeviceID: "32f5f9ec-cd23-41e0-94e8-6b372232ff40", + ComplianceState: "compliant", + ManagementState: "managed", + }, "32f5f9ec-cd23-41e0-94e8-6b372232ff40") + if !got.Enrolled || !got.Compliant { + t.Fatalf("expected enrolled compliant backup device, got enrolled=%v compliant=%v", + got.Enrolled, got.Compliant) + } + if got.DeviceName != "WIN-PMV0N6INPC6" { + t.Fatalf("unexpected device name: %q", got.DeviceName) + } + if got.ProviderDeviceID != "ee155866-00b2-4476-9cba-c0dfa37f0224" { + t.Fatalf("unexpected provider device id: %q", got.ProviderDeviceID) + } +} + +func TestManagedFromManagedDeviceBackup_NotEnrolled(t *testing.T) { + got := managedFromManagedDeviceBackup(managedDevice{ + DeviceName: "WIN-PMV0N6INPC6", + ComplianceState: "compliant", + ManagementState: "discovered", + }, "32f5f9ec-cd23-41e0-94e8-6b372232ff40") + if got.Enrolled { + t.Fatal("discovered device should not be enrolled") + } +} + +func TestIntuneComplianceCompliant(t *testing.T) { + tests := []struct { + state string + want bool + }{ + {state: "compliant", want: true}, + {state: "Compliant", want: true}, + {state: "inGracePeriod", want: false}, + {state: "configManager", want: false}, + {state: "noncompliant", want: false}, + {state: "unknown", want: false}, + {state: "conflict", want: false}, + } + for _, tc := range tests { + if got := intuneComplianceCompliant(tc.state); got != tc.want { + t.Fatalf("intuneComplianceCompliant(%q) = %v, want %v", tc.state, got, tc.want) + } + } +} + +func TestIntuneDeviceEnrolled(t *testing.T) { + tests := []struct { + name string + d managedDevice + want bool + }{ + { + name: "managed", + d: managedDevice{ManagementState: "managed"}, + want: true, + }, + { + name: "discovered", + d: managedDevice{ManagementState: "discovered"}, + want: false, + }, + { + name: "registered without management state", + d: managedDevice{DeviceRegistrationState: "registered"}, + want: true, + }, + { + name: "enrolled datetime fallback", + d: managedDevice{EnrolledDateTime: "2024-01-01T00:00:00Z"}, + want: true, + }, + { + name: "empty signals", + d: managedDevice{}, + want: false, + }, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + if got := intuneDeviceEnrolled(tc.d); got != tc.want { + t.Fatalf("intuneDeviceEnrolled() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestNormalize_EntraDeviceIDFallback(t *testing.T) { + entraByName := map[string]string{ + "win-pmv0n6inpc6": "32f5f9ec-cd23-41e0-94e8-6b372232ff40", + } + d := managedDevice{ + ID: "d56cf8ec-e8a3-4ccb-9c01-821cc8fb38bd", + DeviceName: "WIN-PMV0N6INPC6", + } + got := normalize(d, entraByName) + if got.AzureADDeviceID != "32f5f9ec-cd23-41e0-94e8-6b372232ff40" { + t.Fatalf("expected entra deviceId fallback, got %q", got.AzureADDeviceID) + } +} + +func TestNormalize_PrefersIntuneAzureADDeviceID(t *testing.T) { + entraByName := map[string]string{ + "laptop": "from-entra", + } + d := managedDevice{ + DeviceName: "laptop", + AzureADDeviceID: "from-intune", + } + got := normalize(d, entraByName) + if got.AzureADDeviceID != "from-intune" { + t.Fatalf("expected intune azureADDeviceId, got %q", got.AzureADDeviceID) + } +} diff --git a/pro/integration/mdm/intune/lookup.go b/pro/integration/mdm/intune/lookup.go new file mode 100644 index 000000000..8e847fc94 --- /dev/null +++ b/pro/integration/mdm/intune/lookup.go @@ -0,0 +1,143 @@ +package intune + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +const ( + entraDeviceSelect = "id,deviceId,displayName,operatingSystem,trustType,isManaged,isCompliant" + managedDeviceBackupSelect = "id,deviceName,azureADDeviceId,complianceState,managementState,deviceRegistrationState,enrolledDateTime" +) + +// LookupByEntraDeviceID resolves a host using entra_device_id as Graph devices.deviceId. +// 1. GET /v1.0/devices?$filter=deviceId eq '' — isManaged / isCompliant +// 2. If no match, GET /deviceManagement/managedDevices?$filter=azureADDeviceId eq '' +// — complianceState == "compliant" +func (c *Client) LookupByEntraDeviceID(ctx context.Context, entraDeviceID string) (mdmpkg.ManagedDevice, error) { + entraDeviceID = normalizeEntraGUID(entraDeviceID) + if entraDeviceID == "" { + return mdmpkg.ManagedDevice{}, mdmpkg.ErrDeviceNotRegisteredInEntra + } + + tok, err := c.accessToken(ctx) + if err != nil { + return mdmpkg.ManagedDevice{}, err + } + + entraDevices, err := c.queryEntraDevicesByDeviceID(ctx, tok, entraDeviceID) + if err != nil { + return mdmpkg.ManagedDevice{}, err + } + if len(entraDevices) > 0 { + return managedFromEntraDevice(entraDevices[0], entraDeviceID), nil + } + + // /devices returned no row — managedDevices is fallback only, never called above. + return c.lookupManagedDeviceFallback(ctx, tok, entraDeviceID) +} + +func (c *Client) lookupManagedDeviceFallback(ctx context.Context, tok, entraDeviceID string) (mdmpkg.ManagedDevice, error) { + managed, err := c.queryManagedDevicesBackup(ctx, tok, entraDeviceID) + if err != nil { + return mdmpkg.ManagedDevice{}, err + } + if len(managed) == 0 { + return mdmpkg.ManagedDevice{}, mdmpkg.ErrDeviceNotRegisteredInEntra + } + return managedFromManagedDeviceBackup(managed[0], entraDeviceID), nil +} + +func managedFromEntraDevice(e entraDevice, entraDeviceID string) mdmpkg.ManagedDevice { + deviceID := entraDeviceID + if e.DeviceID != "" { + deviceID = normalizeEntraGUID(e.DeviceID) + } + return mdmpkg.ManagedDevice{ + ProviderDeviceID: e.ID, + AzureADDeviceID: deviceID, + DeviceName: e.DisplayName, + Enrolled: e.IsManaged, + Compliant: e.IsCompliant, + } +} + +func managedFromManagedDeviceBackup(d managedDevice, entraDeviceID string) mdmpkg.ManagedDevice { + return mdmpkg.ManagedDevice{ + ProviderDeviceID: d.ID, + AzureADDeviceID: entraDeviceID, + DeviceName: d.DeviceName, + Enrolled: intuneDeviceEnrolled(d), + Compliant: intuneComplianceCompliant(d.ComplianceState), + } +} + +func (c *Client) queryEntraDevicesByDeviceID(ctx context.Context, tok, deviceID string) ([]entraDevice, error) { + filter := "deviceId eq '" + odataQuote(deviceID) + "'" + u := entraDevicesURL + "?$filter=" + url.QueryEscape(filter) + + "&$select=" + url.QueryEscape(entraDeviceSelect) + var page entraDevicesPage + if err := c.graphGet(ctx, tok, u, &page); err != nil { + return nil, err + } + if page.Error.Code != "" { + return nil, fmt.Errorf("entra list devices: %s", page.Error.Message) + } + return page.Value, nil +} + +func (c *Client) queryManagedDevicesBackup(ctx context.Context, tok, entraDeviceID string) ([]managedDevice, error) { + filter := "azureADDeviceId eq '" + odataQuote(entraDeviceID) + "'" + u := devicesURL + "?$filter=" + url.QueryEscape(filter) + + "&$select=" + url.QueryEscape(managedDeviceBackupSelect) + var page managedDevicesPage + if err := c.graphGet(ctx, tok, u, &page); err != nil { + return nil, err + } + if page.Error.Code != "" { + return nil, fmt.Errorf("intune list devices: %s", page.Error.Message) + } + return page.Value, nil +} + +func (c *Client) graphGet(ctx context.Context, tok, u string, dest any) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + var body struct { + Error errorBody `json:"error"` + } + _ = json.NewDecoder(resp.Body).Decode(&body) + if body.Error.Message != "" { + return fmt.Errorf("graph http %d: %s", resp.StatusCode, body.Error.Message) + } + return fmt.Errorf("graph http %d", resp.StatusCode) + } + return json.NewDecoder(resp.Body).Decode(dest) +} + +func normalizeEntraGUID(id string) string { + id = strings.TrimSpace(id) + id = strings.TrimPrefix(id, "{") + id = strings.TrimSuffix(id, "}") + return id +} + +func odataQuote(s string) string { + return strings.ReplaceAll(s, "'", "''") +} diff --git a/pro/integration/mdm/iru/compliance.go b/pro/integration/mdm/iru/compliance.go new file mode 100644 index 000000000..079b36abc --- /dev/null +++ b/pro/integration/mdm/iru/compliance.go @@ -0,0 +1,63 @@ +package iru + +import ( + "strings" +) + +// deviceCompliant reports whether an Iru device status satisfies the configured +// baseline. All parameters and library_items must have status PASS unless +// compliance_library_item_ids limits evaluation to specific library items. +func deviceCompliant(status iruDeviceStatus, filterLibraryItemIDs map[string]struct{}) bool { + if len(status.Parameters) == 0 && len(status.LibraryItems) == 0 { + return true + } + for _, p := range status.Parameters { + if statusItemFailed(p.Status) { + return false + } + } + checked := 0 + for _, item := range status.LibraryItems { + if len(filterLibraryItemIDs) > 0 { + if _, ok := filterLibraryItemIDs[item.ItemID]; !ok { + continue + } + } + checked++ + if statusItemFailed(item.Status) { + return false + } + } + if len(filterLibraryItemIDs) > 0 && checked == 0 && len(status.LibraryItems) > 0 { + return false + } + return true +} + +func statusItemFailed(status string) bool { + return !strings.EqualFold(strings.TrimSpace(status), "PASS") +} + +func libraryItemIDSet(ids []string) map[string]struct{} { + if len(ids) == 0 { + return nil + } + out := make(map[string]struct{}, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id != "" { + out[id] = struct{}{} + } + } + return out +} + +type iruDeviceStatus struct { + Parameters []iruStatusItem `json:"parameters"` + LibraryItems []iruStatusItem `json:"library_items"` +} + +type iruStatusItem struct { + ItemID string `json:"item_id"` + Status string `json:"status"` +} diff --git a/pro/integration/mdm/iru/compliance_test.go b/pro/integration/mdm/iru/compliance_test.go new file mode 100644 index 000000000..c5d488844 --- /dev/null +++ b/pro/integration/mdm/iru/compliance_test.go @@ -0,0 +1,58 @@ +package iru + +import "testing" + +func TestDeviceCompliant(t *testing.T) { + filter := map[string]struct{}{"lib-1": {}} + + tests := []struct { + name string + status iruDeviceStatus + filter map[string]struct{} + want bool + }{ + {name: "empty status", status: iruDeviceStatus{}, want: true}, + {name: "all pass", status: iruDeviceStatus{ + Parameters: []iruStatusItem{{Status: "PASS"}}, + LibraryItems: []iruStatusItem{{ItemID: "lib-1", Status: "PASS"}}, + }, want: true}, + {name: "parameter fail", status: iruDeviceStatus{ + Parameters: []iruStatusItem{{Status: "FAIL"}}, + }, want: false}, + {name: "library item fail", status: iruDeviceStatus{ + LibraryItems: []iruStatusItem{{ItemID: "lib-1", Status: "FAIL"}}, + }, want: false}, + {name: "filter match pass", status: iruDeviceStatus{ + LibraryItems: []iruStatusItem{ + {ItemID: "lib-1", Status: "PASS"}, + {ItemID: "lib-2", Status: "FAIL"}, + }, + }, filter: filter, want: true}, + {name: "filter no match", status: iruDeviceStatus{ + LibraryItems: []iruStatusItem{{ItemID: "lib-2", Status: "PASS"}}, + }, filter: filter, want: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := deviceCompliant(tc.status, tc.filter) + if got != tc.want { + t.Fatalf("deviceCompliant() = %v, want %v", got, tc.want) + } + }) + } +} + +func TestStatusItemFailed(t *testing.T) { + if statusItemFailed("PASS") { + t.Fatal("PASS should not fail") + } + if statusItemFailed("pass") { + t.Fatal("pass should not fail") + } + if !statusItemFailed("FAIL") { + t.Fatal("FAIL should fail") + } + if !statusItemFailed("") { + t.Fatal("empty status should fail") + } +} diff --git a/pro/integration/mdm/iru/iru.go b/pro/integration/mdm/iru/iru.go new file mode 100644 index 000000000..6374d1477 --- /dev/null +++ b/pro/integration/mdm/iru/iru.go @@ -0,0 +1,241 @@ +// Package iru implements an MDM provider backed by Iru Endpoint Management +// (formerly Kandji). Self-registers with pro/integration/mdm in init(). +package iru + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +const ( + providerName = mdmpkg.ProviderIru + providerDisplay = "Iru" + devicesListPath = "/api/v1/devices" + deviceStatusPathFmt = "/api/v1/devices/%s/status" + defaultPageSize = 300 + maxStatusFetches = 8 +) + +func init() { + mdmpkg.Register(providerName, providerDisplay, New) + mdmpkg.RegisterCapabilities(providerName, mdmpkg.Capabilities{ReportsCompliant: true}) +} + +// New builds an Iru provider from integration config JSON. +func New(configJSON json.RawMessage) (mdmpkg.Provider, error) { + var cfg mdmpkg.IruConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, fmt.Errorf("invalid iru config: %w", err) + } + if err := mdmpkg.ValidateConfig(providerName, configJSON); err != nil { + return nil, err + } + return &Client{ + baseURL: strings.TrimRight(strings.TrimSpace(cfg.APIURL), "/"), + apiToken: cfg.APIToken, + complianceLibraryIDs: libraryItemIDSet(cfg.ComplianceLibraryItemIDs), + http: &http.Client{Timeout: 60 * time.Second}, + }, nil +} + +// Client implements mdmpkg.Provider against the Iru Endpoint Management API. +type Client struct { + baseURL string + apiToken string + complianceLibraryIDs map[string]struct{} + http *http.Client +} + +func (c *Client) Name() string { return providerName } + +func (c *Client) Capabilities() mdmpkg.Capabilities { + return mdmpkg.Capabilities{ReportsCompliant: true} +} + +func (c *Client) Verify(ctx context.Context) error { + _, err := c.listDevices(ctx, 0, 1) + if err != nil { + return fmt.Errorf("iru verify failed: %w", err) + } + return nil +} + +func (c *Client) ListManagedDevices(ctx context.Context) ([]mdmpkg.ManagedDevice, error) { + var devices []iruDevice + for offset := 0; ; offset += defaultPageSize { + page, err := c.listDevices(ctx, offset, defaultPageSize) + if err != nil { + return nil, err + } + devices = append(devices, page...) + if len(page) < defaultPageSize { + break + } + } + complianceByID, err := c.fetchDeviceCompliance(ctx, devices) + if err != nil { + return nil, err + } + out := make([]mdmpkg.ManagedDevice, 0, len(devices)) + for _, d := range devices { + compliant, ok := complianceByID[d.DeviceID] + if !ok { + compliant = false + } + out = append(out, normalize(d, compliant)) + } + return out, nil +} + +func (c *Client) fetchDeviceCompliance(ctx context.Context, devices []iruDevice) (map[string]bool, error) { + out := make(map[string]bool, len(devices)) + if len(devices) == 0 { + return out, nil + } + sem := make(chan struct{}, maxStatusFetches) + type result struct { + id string + compliant bool + err error + } + ch := make(chan result, len(devices)) + for _, d := range devices { + d := d + go func() { + sem <- struct{}{} + defer func() { <-sem }() + status, err := c.getDeviceStatus(ctx, d.DeviceID) + if err != nil { + ch <- result{id: d.DeviceID, err: err} + return + } + ch <- result{ + id: d.DeviceID, + compliant: deviceCompliant(status, c.complianceLibraryIDs), + } + }() + } + var firstErr error + success := 0 + for range devices { + r := <-ch + if r.err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("device %s status: %w", r.id, r.err) + } + out[r.id] = false + continue + } + out[r.id] = r.compliant + success++ + } + if success == 0 && firstErr != nil { + return out, firstErr + } + return out, nil +} + +func (c *Client) listDevices(ctx context.Context, offset, limit int) ([]iruDevice, error) { + u, err := url.Parse(c.baseURL + devicesListPath) + if err != nil { + return nil, err + } + q := u.Query() + q.Set("offset", fmt.Sprintf("%d", offset)) + q.Set("limit", fmt.Sprintf("%d", limit)) + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+c.apiToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("iru list devices: http %d", resp.StatusCode) + } + + var page devicesListResponse + if err := json.Unmarshal(body, &page); err == nil && len(page.Devices) > 0 { + return page.Devices, nil + } + var devices []iruDevice + if err := json.Unmarshal(body, &devices); err != nil { + return nil, err + } + return devices, nil +} + +func (c *Client) getDeviceStatus(ctx context.Context, deviceID string) (iruDeviceStatus, error) { + u := c.baseURL + fmt.Sprintf(deviceStatusPathFmt, url.PathEscape(deviceID)) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return iruDeviceStatus{}, err + } + req.Header.Set("Authorization", "Bearer "+c.apiToken) + req.Header.Set("Accept", "application/json") + + resp, err := c.http.Do(req) + if err != nil { + return iruDeviceStatus{}, err + } + defer resp.Body.Close() + body, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return iruDeviceStatus{}, fmt.Errorf("http %d", resp.StatusCode) + } + var status iruDeviceStatus + if err := json.Unmarshal(body, &status); err != nil { + return iruDeviceStatus{}, err + } + return status, nil +} + +func normalize(d iruDevice, compliant bool) mdmpkg.ManagedDevice { + last, _ := time.Parse(time.RFC3339, d.LastCheckIn) + email := d.UserEmail + if email == "" && d.User != nil { + email = d.User.Email + } + return mdmpkg.ManagedDevice{ + ProviderDeviceID: d.DeviceID, + SerialNumber: d.SerialNumber, + DeviceName: d.DeviceName, + UserPrincipalName: email, + Enrolled: true, + Compliant: compliant, + LastSeenAt: last, + } +} + +type devicesListResponse struct { + Devices []iruDevice `json:"devices"` +} + +type iruDevice struct { + DeviceID string `json:"device_id"` + DeviceName string `json:"device_name"` + SerialNumber string `json:"serial_number"` + LastCheckIn string `json:"last_check_in"` + UserEmail string `json:"user_email"` + User *iruUser `json:"user"` +} + +type iruUser struct { + Email string `json:"email"` +} diff --git a/pro/integration/mdm/jamf/compliance.go b/pro/integration/mdm/jamf/compliance.go new file mode 100644 index 000000000..3f442bd09 --- /dev/null +++ b/pro/integration/mdm/jamf/compliance.go @@ -0,0 +1,55 @@ +package jamf + +import ( + "strings" +) + +// jamfDeviceTrustCompliant aggregates Jamf Conditional Access compliance records +// (GET /v1/conditional-access/device-compliance-information/{computer|mobile}/{id}). +// Every applicable record must be COMPLIANT; NON_COMPLIANT or UNKNOWN fails. +func jamfDeviceTrustCompliant(records []deviceComplianceInfo, filterVendors map[string]struct{}) bool { + applicable := 0 + for _, r := range records { + if !r.Applicable { + continue + } + if len(filterVendors) > 0 { + vendor := strings.ToLower(strings.TrimSpace(r.ComplianceVendor)) + if _, ok := filterVendors[vendor]; !ok { + continue + } + } + applicable++ + switch strings.ToUpper(strings.TrimSpace(r.ComplianceState)) { + case "COMPLIANT": + continue + default: + return false + } + } + if len(filterVendors) > 0 && applicable == 0 { + return false + } + return true +} + +func complianceVendorSet(vendors []string) map[string]struct{} { + if len(vendors) == 0 { + return nil + } + out := make(map[string]struct{}, len(vendors)) + for _, v := range vendors { + v = strings.ToLower(strings.TrimSpace(v)) + if v != "" { + out[v] = struct{}{} + } + } + return out +} + +type deviceComplianceInfo struct { + DeviceID string `json:"deviceId"` + Applicable bool `json:"applicable"` + ComplianceState string `json:"complianceState"` + ComplianceVendor string `json:"complianceVendor"` +} diff --git a/pro/integration/mdm/jamf/compliance_test.go b/pro/integration/mdm/jamf/compliance_test.go new file mode 100644 index 000000000..873844956 --- /dev/null +++ b/pro/integration/mdm/jamf/compliance_test.go @@ -0,0 +1,44 @@ +package jamf + +import "testing" + +func TestJamfDeviceTrustCompliant(t *testing.T) { + filter := map[string]struct{}{"jamf": {}} + + tests := []struct { + name string + records []deviceComplianceInfo + filter map[string]struct{} + want bool + }{ + {name: "no records", want: true}, + {name: "not applicable ignored", records: []deviceComplianceInfo{ + {Applicable: false, ComplianceState: "NON_COMPLIANT"}, + }, want: true}, + {name: "compliant", records: []deviceComplianceInfo{ + {Applicable: true, ComplianceState: "COMPLIANT", ComplianceVendor: "Jamf"}, + }, want: true}, + {name: "non compliant", records: []deviceComplianceInfo{ + {Applicable: true, ComplianceState: "COMPLIANT"}, + {Applicable: true, ComplianceState: "NON_COMPLIANT"}, + }, want: false}, + {name: "unknown fails", records: []deviceComplianceInfo{ + {Applicable: true, ComplianceState: "UNKNOWN"}, + }, want: false}, + {name: "vendor filter pass", records: []deviceComplianceInfo{ + {Applicable: true, ComplianceState: "COMPLIANT", ComplianceVendor: "Jamf"}, + {Applicable: true, ComplianceState: "NON_COMPLIANT", ComplianceVendor: "Intune"}, + }, filter: filter, want: true}, + {name: "vendor filter missing", records: []deviceComplianceInfo{ + {Applicable: true, ComplianceState: "COMPLIANT", ComplianceVendor: "Intune"}, + }, filter: filter, want: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := jamfDeviceTrustCompliant(tc.records, tc.filter) + if got != tc.want { + t.Fatalf("jamfDeviceTrustCompliant() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/pro/integration/mdm/jamf/jamf.go b/pro/integration/mdm/jamf/jamf.go new file mode 100644 index 000000000..de59a7eeb --- /dev/null +++ b/pro/integration/mdm/jamf/jamf.go @@ -0,0 +1,421 @@ +// Package jamf implements an MDM provider backed by Jamf Pro. Self-registers +// with pro/integration/mdm in init(). +package jamf + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +const ( + providerName = mdmpkg.ProviderJamf + providerDisplay = "Jamf Pro" + + tokenPath = "/api/oauth/token" + computersPath = "/api/v1/computers-inventory" + computerCompliancePath = "/api/v1/conditional-access/device-compliance-information/computer/" + mobileDevPath = "/api/v2/mobile-devices/detail" + mobileCompliancePath = "/api/v1/conditional-access/device-compliance-information/mobile/" + computerSects = "GENERAL,HARDWARE,USER_AND_LOCATION" + defaultPageSz = 200 + maxComplianceFetches = 8 +) + +func init() { + mdmpkg.Register(providerName, providerDisplay, New) + mdmpkg.RegisterCapabilities(providerName, mdmpkg.Capabilities{ReportsCompliant: true}) +} + +// New builds a Jamf Pro provider from integration config JSON. +func New(configJSON json.RawMessage) (mdmpkg.Provider, error) { + var cfg mdmpkg.JamfConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, fmt.Errorf("invalid jamf config: %w", err) + } + if err := mdmpkg.ValidateConfig(providerName, configJSON); err != nil { + return nil, err + } + return &Client{ + baseURL: strings.TrimRight(cfg.BaseURL, "/"), + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + complianceVendors: complianceVendorSet(cfg.ComplianceVendors), + http: &http.Client{Timeout: 60 * time.Second}, + }, nil +} + +// Client implements mdmpkg.Provider against Jamf Pro. +type Client struct { + baseURL string + clientID string + clientSecret string + complianceVendors map[string]struct{} + http *http.Client + + tokenMu sync.Mutex + token string + tokenExp time.Time +} + +func (c *Client) Name() string { return providerName } + +func (c *Client) Capabilities() mdmpkg.Capabilities { + return mdmpkg.Capabilities{ReportsCompliant: true} +} + +func (c *Client) Verify(ctx context.Context) error { + tok, err := c.accessToken(ctx) + if err != nil { + return err + } + u := fmt.Sprintf("%s%s?page-size=1&page=0§ion=GENERAL", c.baseURL, computersPath) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return fmt.Errorf("jamf verify failed: http %d", resp.StatusCode) + } + return nil +} + +func (c *Client) ListManagedDevices(ctx context.Context) ([]mdmpkg.ManagedDevice, error) { + tok, err := c.accessToken(ctx) + if err != nil { + return nil, err + } + out := []mdmpkg.ManagedDevice{} + computers, err := c.listComputers(ctx, tok) + if err != nil { + return nil, err + } + compliance, err := c.fetchComputerCompliance(ctx, tok, computers) + if err != nil { + return nil, err + } + for _, r := range computers { + compliant := compliance[r.ID] + out = append(out, normalizeComputer(r, compliant)) + } + mobiles, err := c.listMobileDevices(ctx, tok) + if err != nil { + return out, fmt.Errorf("jamf list mobile-devices: %w", err) + } + mobileCompliance, err := c.fetchMobileCompliance(ctx, tok, mobiles) + if err != nil { + return out, fmt.Errorf("jamf mobile compliance: %w", err) + } + for _, r := range mobiles { + compliant := mobileCompliance[r.ID] + out = append(out, normalizeMobile(r, compliant)) + } + return out, nil +} + +func (c *Client) listComputers(ctx context.Context, tok string) ([]computerInventory, error) { + var out []computerInventory + for pageNum := 0; ; pageNum++ { + u := fmt.Sprintf("%s%s?page=%d&page-size=%d§ion=%s", + c.baseURL, computersPath, pageNum, defaultPageSz, computerSects) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("jamf computers-inventory: http %d", resp.StatusCode) + } + if readErr != nil { + return nil, readErr + } + var pageBody computerInventoryPage + if err := json.Unmarshal(body, &pageBody); err != nil { + return nil, err + } + out = append(out, pageBody.Results...) + if len(pageBody.Results) < defaultPageSz { + break + } + } + return out, nil +} + +func (c *Client) listMobileDevices(ctx context.Context, tok string) ([]mobileDevice, error) { + var out []mobileDevice + for pageNum := 0; ; pageNum++ { + u := fmt.Sprintf("%s%s?page=%d&page-size=%d", + c.baseURL, mobileDevPath, pageNum, defaultPageSz) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + body, readErr := io.ReadAll(resp.Body) + _ = resp.Body.Close() + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("jamf mobile-devices: http %d", resp.StatusCode) + } + if readErr != nil { + return nil, readErr + } + var pageBody mobileDevicesPage + if err := json.Unmarshal(body, &pageBody); err != nil { + return nil, err + } + out = append(out, pageBody.Results...) + if len(pageBody.Results) < defaultPageSz { + break + } + } + return out, nil +} + +func (c *Client) fetchComputerCompliance(ctx context.Context, tok string, computers []computerInventory) (map[string]bool, error) { + return c.fetchCompliance(ctx, tok, len(computers), func(i int) (string, error) { + return computers[i].ID, nil + }, computerCompliancePath) +} + +func (c *Client) fetchMobileCompliance(ctx context.Context, tok string, mobiles []mobileDevice) (map[string]bool, error) { + return c.fetchCompliance(ctx, tok, len(mobiles), func(i int) (string, error) { + return mobiles[i].ID, nil + }, mobileCompliancePath) +} + +func (c *Client) fetchCompliance( + ctx context.Context, + tok string, + n int, + deviceID func(int) (string, error), + pathPrefix string, +) (map[string]bool, error) { + out := make(map[string]bool, n) + if n == 0 { + return out, nil + } + sem := make(chan struct{}, maxComplianceFetches) + type result struct { + id string + compliant bool + err error + } + ch := make(chan result, n) + for i := 0; i < n; i++ { + i := i + go func() { + sem <- struct{}{} + defer func() { <-sem }() + id, err := deviceID(i) + if err != nil { + ch <- result{err: err} + return + } + records, err := c.getDeviceCompliance(ctx, tok, pathPrefix, id) + if err != nil { + ch <- result{id: id, err: err} + return + } + ch <- result{ + id: id, + compliant: jamfDeviceTrustCompliant(records, c.complianceVendors), + } + }() + } + var firstErr error + success := 0 + for range n { + r := <-ch + if r.err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("device %s compliance: %w", r.id, r.err) + } + if r.id != "" { + out[r.id] = false + } + continue + } + out[r.id] = r.compliant + success++ + } + if success == 0 && firstErr != nil { + return out, firstErr + } + return out, nil +} + +func (c *Client) getDeviceCompliance(ctx context.Context, tok, pathPrefix, deviceID string) ([]deviceComplianceInfo, error) { + u := c.baseURL + pathPrefix + url.PathEscape(deviceID) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode == http.StatusNotFound { + // Conditional Access compliance not configured for this device. + return nil, nil + } + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("http %d: %s", resp.StatusCode, strings.TrimSpace(string(respBody))) + } + var records []deviceComplianceInfo + if err := json.Unmarshal(respBody, &records); err != nil { + return nil, err + } + return records, nil +} + +func (c *Client) accessToken(ctx context.Context) (string, error) { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + if c.token != "" && time.Until(c.tokenExp) > time.Minute { + return c.token, nil + } + form := url.Values{} + form.Set("grant_type", "client_credentials") + form.Set("client_id", c.clientID) + form.Set("client_secret", c.clientSecret) + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, + c.baseURL+tokenPath, + strings.NewReader(form.Encode()), + ) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp, err := c.http.Do(req) + if err != nil { + return "", errors.New("jamf token: " + err.Error()) + } + defer resp.Body.Close() + if resp.StatusCode >= 400 { + return "", fmt.Errorf("jamf token: http %d", resp.StatusCode) + } + var body tokenResponse + if err := json.NewDecoder(resp.Body).Decode(&body); err != nil { + return "", err + } + if body.AccessToken == "" { + return "", errors.New("jamf token: empty response") + } + c.token = body.AccessToken + if body.ExpiresIn > 0 { + c.tokenExp = time.Now().Add(time.Duration(body.ExpiresIn) * time.Second) + } else { + c.tokenExp = time.Now().Add(20 * time.Minute) + } + return c.token, nil +} + +func normalizeComputer(r computerInventory, compliant bool) mdmpkg.ManagedDevice { + last, _ := time.Parse(time.RFC3339, r.General.LastContactTime) + return mdmpkg.ManagedDevice{ + ProviderDeviceID: r.ID, + SerialNumber: r.Hardware.SerialNumber, + HardwareUUID: r.General.UDID, + DeviceName: r.General.Name, + UserPrincipalName: r.UserAndLocation.EmailAddress, + Enrolled: true, + Compliant: compliant, + LastSeenAt: last, + } +} + +func normalizeMobile(r mobileDevice, compliant bool) mdmpkg.ManagedDevice { + last, _ := time.Parse(time.RFC3339, r.LastInventoryUpdateDate) + return mdmpkg.ManagedDevice{ + ProviderDeviceID: r.ID, + SerialNumber: r.SerialNumber, + HardwareUUID: r.UDID, + DeviceName: r.Name, + UserPrincipalName: r.Location.EmailAddress, + Enrolled: true, + Compliant: compliant, + LastSeenAt: last, + } +} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` +} + +type computerInventoryPage struct { + TotalCount int `json:"totalCount"` + Results []computerInventory `json:"results"` +} + +type computerInventory struct { + ID string `json:"id"` + General computerGeneral `json:"general"` + Hardware computerHardware `json:"hardware"` + UserAndLocation computerUserAndLocation `json:"userAndLocation"` +} + +type computerGeneral struct { + Name string `json:"name"` + UDID string `json:"udid"` + LastContactTime string `json:"lastContactTime"` +} + +type computerHardware struct { + SerialNumber string `json:"serialNumber"` +} + +type computerUserAndLocation struct { + EmailAddress string `json:"email"` +} + +type mobileDevicesPage struct { + TotalCount int `json:"totalCount"` + Results []mobileDevice `json:"results"` +} + +type mobileDevice struct { + ID string `json:"id"` + Name string `json:"name"` + UDID string `json:"udid"` + SerialNumber string `json:"serialNumber"` + LastInventoryUpdateDate string `json:"lastInventoryUpdateDate"` + Location mobileLocation `json:"location"` +} + +type mobileLocation struct { + EmailAddress string `json:"emailAddress"` +} diff --git a/pro/integration/mdm/jumpcloud/compliance.go b/pro/integration/mdm/jumpcloud/compliance.go new file mode 100644 index 000000000..91f5f74c9 --- /dev/null +++ b/pro/integration/mdm/jumpcloud/compliance.go @@ -0,0 +1,67 @@ +package jumpcloud + +import ( + "strings" +) + +const systemPolicyStatusesPath = "/api/v2/systems/" + +// deviceTrustCompliant reports whether a system satisfies the configured device-trust +// baseline using JumpCloud policy statuses (GET /api/v2/systems/{id}/policystatuses). +// JumpCloud does not expose a single "device trust" field; policy results are the +// supported API signal for whether bound policies (including trust/security baselines) +// are successfully applied on the device. +func deviceTrustCompliant(results []policyResult, filterPolicyIDs map[string]struct{}) bool { + if len(results) == 0 { + // No bound policies with results — nothing failed on the device. + return true + } + checked := 0 + for _, r := range results { + if len(filterPolicyIDs) > 0 { + if _, ok := filterPolicyIDs[r.PolicyID]; !ok { + continue + } + } + checked++ + if policyResultFailed(r) { + return false + } + } + if len(filterPolicyIDs) > 0 && checked == 0 { + // Filtered policy IDs are not bound to this system. + return false + } + return true +} + +func policyResultFailed(r policyResult) bool { + if r.Success != nil && !*r.Success { + return true + } + switch strings.ToLower(strings.TrimSpace(r.State)) { + case "failed", "error": + return true + } + return false +} + +func policyIDSet(ids []string) map[string]struct{} { + if len(ids) == 0 { + return nil + } + out := make(map[string]struct{}, len(ids)) + for _, id := range ids { + id = strings.TrimSpace(id) + if id != "" { + out[id] = struct{}{} + } + } + return out +} + +type policyResult struct { + PolicyID string `json:"policy_id"` + State string `json:"state"` + Success *bool `json:"success"` +} diff --git a/pro/integration/mdm/jumpcloud/compliance_test.go b/pro/integration/mdm/jumpcloud/compliance_test.go new file mode 100644 index 000000000..29a465c0d --- /dev/null +++ b/pro/integration/mdm/jumpcloud/compliance_test.go @@ -0,0 +1,46 @@ +package jumpcloud + +import "testing" + +func TestDeviceTrustCompliant(t *testing.T) { + ok := true + no := false + filter := map[string]struct{}{"p1": {}} + + tests := []struct { + name string + results []policyResult + filter map[string]struct{} + want bool + }{ + {name: "no results", results: nil, want: true}, + {name: "all success", results: []policyResult{ + {PolicyID: "p1", State: "success", Success: &ok}, + }, want: true}, + {name: "explicit failure", results: []policyResult{ + {PolicyID: "p1", State: "success", Success: &ok}, + {PolicyID: "p2", State: "failed", Success: &no}, + }, want: false}, + {name: "failed state", results: []policyResult{ + {PolicyID: "p1", State: "failed"}, + }, want: false}, + {name: "pending ignored", results: []policyResult{ + {PolicyID: "p1", State: "pending"}, + }, want: true}, + {name: "filter match pass", results: []policyResult{ + {PolicyID: "p1", State: "success", Success: &ok}, + {PolicyID: "p2", State: "failed", Success: &no}, + }, filter: filter, want: true}, + {name: "filter no match", results: []policyResult{ + {PolicyID: "p2", State: "success", Success: &ok}, + }, filter: filter, want: false}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := deviceTrustCompliant(tc.results, tc.filter) + if got != tc.want { + t.Fatalf("deviceTrustCompliant() = %v, want %v", got, tc.want) + } + }) + } +} diff --git a/pro/integration/mdm/jumpcloud/jumpcloud.go b/pro/integration/mdm/jumpcloud/jumpcloud.go new file mode 100644 index 000000000..7490b6059 --- /dev/null +++ b/pro/integration/mdm/jumpcloud/jumpcloud.go @@ -0,0 +1,343 @@ +// Package jumpcloud implements an MDM provider backed by JumpCloud. Self-registers +// with pro/integration/mdm in init(). +package jumpcloud + +import ( + "context" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "sync" + "time" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +const ( + providerName = mdmpkg.ProviderJumpCloud + providerDisplay = "JumpCloud" + + tokenURL = "https://admin-oauth.id.jumpcloud.com/oauth2/token" + defaultBaseURL = "https://console.jumpcloud.com" + systemsListPath = "/api/systems" + defaultPageSize = 100 + tokenScope = "api" + maxPolicyFetches = 8 +) + +func init() { + mdmpkg.Register(providerName, providerDisplay, New) + mdmpkg.RegisterCapabilities(providerName, mdmpkg.Capabilities{ReportsCompliant: true}) +} + +// New builds a JumpCloud provider from integration config JSON. +func New(configJSON json.RawMessage) (mdmpkg.Provider, error) { + var cfg mdmpkg.JumpCloudConfig + if err := json.Unmarshal(configJSON, &cfg); err != nil { + return nil, fmt.Errorf("invalid jumpcloud config: %w", err) + } + if err := mdmpkg.ValidateConfig(providerName, configJSON); err != nil { + return nil, err + } + baseURL := strings.TrimRight(cfg.BaseURL, "/") + if baseURL == "" { + baseURL = defaultBaseURL + } + return &Client{ + baseURL: baseURL, + clientID: cfg.ClientID, + clientSecret: cfg.ClientSecret, + compliancePolicyIDs: policyIDSet(cfg.CompliancePolicyIDs), + http: &http.Client{Timeout: 60 * time.Second}, + }, nil +} + +// Client implements mdmpkg.Provider against JumpCloud. +type Client struct { + baseURL string + clientID string + clientSecret string + compliancePolicyIDs map[string]struct{} + http *http.Client + + tokenMu sync.Mutex + token string + tokenExp time.Time +} + +func (c *Client) Name() string { return providerName } + +func (c *Client) Capabilities() mdmpkg.Capabilities { + return mdmpkg.Capabilities{ReportsCompliant: true} +} + +func (c *Client) Verify(ctx context.Context) error { + tok, err := c.accessToken(ctx) + if err != nil { + return err + } + _, err = c.listSystems(ctx, tok, 0, 1) + if err != nil { + return fmt.Errorf("jumpcloud verify failed: %w", err) + } + return nil +} + +func (c *Client) ListManagedDevices(ctx context.Context) ([]mdmpkg.ManagedDevice, error) { + tok, err := c.accessToken(ctx) + if err != nil { + return nil, err + } + var systems []jumpcloudSystem + for skip := 0; ; skip += defaultPageSize { + page, err := c.listSystems(ctx, tok, skip, defaultPageSize) + if err != nil { + return nil, err + } + systems = append(systems, page...) + if len(page) < defaultPageSize { + break + } + } + complianceByID, err := c.fetchDeviceTrustCompliance(ctx, tok, systems) + if err != nil { + return nil, err + } + out := make([]mdmpkg.ManagedDevice, 0, len(systems)) + for _, s := range systems { + compliant, ok := complianceByID[s.ID] + if !ok { + compliant = false + } + out = append(out, normalize(s, compliant)) + } + return out, nil +} + +func (c *Client) fetchDeviceTrustCompliance(ctx context.Context, tok string, systems []jumpcloudSystem) (map[string]bool, error) { + out := make(map[string]bool, len(systems)) + if len(systems) == 0 { + return out, nil + } + sem := make(chan struct{}, maxPolicyFetches) + type result struct { + id string + compliant bool + err error + } + ch := make(chan result, len(systems)) + for _, s := range systems { + s := s + go func() { + sem <- struct{}{} + defer func() { <-sem }() + statuses, err := c.listSystemPolicyStatuses(ctx, tok, s.ID) + if err != nil { + ch <- result{id: s.ID, err: err} + return + } + ch <- result{ + id: s.ID, + compliant: deviceTrustCompliant(statuses, c.compliancePolicyIDs), + } + }() + } + var firstErr error + success := 0 + for range systems { + r := <-ch + if r.err != nil { + if firstErr == nil { + firstErr = fmt.Errorf("system %s policy statuses: %w", r.id, r.err) + } + out[r.id] = false + continue + } + out[r.id] = r.compliant + success++ + } + if success == 0 && firstErr != nil { + return out, firstErr + } + return out, nil +} + +func (c *Client) listSystemPolicyStatuses(ctx context.Context, tok, systemID string) ([]policyResult, error) { + u := c.baseURL + systemPolicyStatusesPath + url.PathEscape(systemID) + "/policystatuses" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("http %d", resp.StatusCode) + } + var results []policyResult + if err := json.Unmarshal(respBody, &results); err != nil { + return nil, err + } + return results, nil +} + +func (c *Client) listSystems(ctx context.Context, tok string, skip, limit int) ([]jumpcloudSystem, error) { + u, err := url.Parse(c.baseURL + systemsListPath) + if err != nil { + return nil, err + } + q := u.Query() + q.Set("skip", fmt.Sprintf("%d", skip)) + q.Set("limit", fmt.Sprintf("%d", limit)) + u.RawQuery = q.Encode() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, err + } + req.Header.Set("Authorization", "Bearer "+tok) + req.Header.Set("Accept", "application/json") + req.Header.Set("Content-Type", "application/json") + resp, err := c.http.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return nil, fmt.Errorf("jumpcloud systems list: http %d", resp.StatusCode) + } + + var wrapped systemsListResponse + if err := json.Unmarshal(respBody, &wrapped); err == nil { + return wrapped.Results, nil + } + // Some responses may return a bare array. + var systems []jumpcloudSystem + if err := json.Unmarshal(respBody, &systems); err != nil { + return nil, err + } + return systems, nil +} + +func (c *Client) accessToken(ctx context.Context) (string, error) { + c.tokenMu.Lock() + defer c.tokenMu.Unlock() + if c.token != "" && time.Until(c.tokenExp) > time.Minute { + return c.token, nil + } + form := url.Values{} + form.Set("grant_type", "client_credentials") + form.Set("scope", tokenScope) + req, err := http.NewRequestWithContext( + ctx, http.MethodPost, + tokenURL, + strings.NewReader(form.Encode()), + ) + if err != nil { + return "", err + } + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + req.Header.Set( + "Authorization", + "Basic "+base64.StdEncoding.EncodeToString([]byte(c.clientID+":"+c.clientSecret)), + ) + resp, err := c.http.Do(req) + if err != nil { + return "", errors.New("jumpcloud token: " + err.Error()) + } + defer resp.Body.Close() + respBody, _ := io.ReadAll(resp.Body) + if resp.StatusCode >= 400 { + return "", fmt.Errorf("jumpcloud token: http %d", resp.StatusCode) + } + var body tokenResponse + if err := json.Unmarshal(respBody, &body); err != nil { + return "", err + } + if body.AccessToken == "" { + if body.Error != "" { + return "", fmt.Errorf("jumpcloud token: %s: %s", body.Error, body.ErrorDescription) + } + return "", errors.New("jumpcloud token: empty response") + } + c.token = body.AccessToken + if body.ExpiresIn > 0 { + c.tokenExp = time.Now().Add(time.Duration(body.ExpiresIn) * time.Second) + } else { + c.tokenExp = time.Now().Add(50 * time.Minute) + } + return c.token, nil +} + +func normalize(s jumpcloudSystem, compliant bool) mdmpkg.ManagedDevice { + // Prefer hostname for host matching; display names are often user-friendly labels. + name := s.Hostname + if name == "" { + name = s.DisplayName + } + last := time.Time{} + for _, raw := range []string{s.LastContact, s.Modified, s.Created} { + if raw == "" { + continue + } + if t, err := time.Parse(time.RFC3339, raw); err == nil { + last = t + break + } + } + email := "" + if s.PrimarySystemUser != nil { + email = s.PrimarySystemUser.Email + } + return mdmpkg.ManagedDevice{ + ProviderDeviceID: s.ID, + SerialNumber: s.SerialNumber, + HardwareUUID: s.HardwareUUID, + DeviceName: name, + UserPrincipalName: email, + Enrolled: s.Active, + Compliant: compliant, + LastSeenAt: last, + } +} + +type tokenResponse struct { + AccessToken string `json:"access_token"` + ExpiresIn int `json:"expires_in"` + Error string `json:"error"` + ErrorDescription string `json:"error_description"` +} + +type systemsListResponse struct { + Results []jumpcloudSystem `json:"results"` + TotalCount int `json:"totalCount"` +} + +type jumpcloudSystem struct { + ID string `json:"id"` + DisplayName string `json:"displayName"` + Hostname string `json:"hostname"` + SerialNumber string `json:"serialNumber"` + HardwareUUID string `json:"hardwareUuid"` + Active bool `json:"active"` + LastContact string `json:"lastContact"` + Created string `json:"created"` + Modified string `json:"modified"` + PrimarySystemUser *primarySystemUser `json:"primarySystemUser"` +} + +type primarySystemUser struct { + Email string `json:"email"` +} diff --git a/pro/integration/mdm/lookup.go b/pro/integration/mdm/lookup.go new file mode 100644 index 000000000..df4a2d99b --- /dev/null +++ b/pro/integration/mdm/lookup.go @@ -0,0 +1,93 @@ +package mdm + +import ( + "context" + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/schema" +) + +// EntraDeviceLookup is implemented by MDM providers that resolve a host using +// host.entra_device_id as Graph devices.deviceId. Intune queries GET /v1.0/devices +// first, then GET /deviceManagement/managedDevices when /devices returns no match. +// +// Providers without EntraDeviceLookup (Iru, Jamf, JumpCloud) are synced via +// ListManagedDevices and serial_number matching in sync.go. +type EntraDeviceLookup interface { + LookupByEntraDeviceID(ctx context.Context, entraDeviceID string) (ManagedDevice, error) +} + +// SyncHostMDMState refreshes MDM posture state for one host. When the host has +// entra_device_id and the active provider supports Entra-keyed lookup, Graph is +// queried directly; otherwise this is a no-op. +func SyncHostMDMState(ctx context.Context, hostID string) error { + intg, err := GetActive(ctx) + if err != nil { + return err + } + if intg == nil { + return nil + } + p, err := Build(intg.ID, json.RawMessage(intg.Config)) + if err != nil { + return err + } + lookup, ok := p.(EntraDeviceLookup) + if !ok { + return nil + } + id, err := uuid.Parse(hostID) + if err != nil { + return err + } + h := &schema.Host{ID: id} + if err := h.Get(db.WithContext(ctx)); err != nil { + return err + } + if h.EntraDeviceID == "" { + return nil + } + return upsertHostMDMFromEntraLookup(ctx, intg.ID, lookup, *h) +} + +func upsertHostMDMFromEntraLookup( + ctx context.Context, + providerID string, + lookup EntraDeviceLookup, + h schema.Host, +) error { + device, err := lookup.LookupByEntraDeviceID(ctx, h.EntraDeviceID) + now := time.Now().UTC() + state := schema.DeviceMDMState{ + HostID: h.ID.String(), + Provider: providerID, + MatchedBy: schema.MDMMatchEntraDeviceID, + LastSyncedAt: now, + } + if code := LookupErrorCode(err); code != "" { + state.LastError = code + state.Enrolled = false + state.Compliant = false + if upsertErr := state.Upsert(db.WithContext(ctx)); upsertErr != nil { + return upsertErr + } + return nil + } + if err != nil { + return err + } + state.MDMDeviceID = device.ProviderDeviceID + state.Enrolled = device.Enrolled + state.Compliant = device.Compliant + state.LastSeenAt = device.LastSeenAt + state.LastError = "" + if upsertErr := state.Upsert(db.WithContext(ctx)); upsertErr != nil { + return upsertErr + } + logger.Log(2, "mdm sync: entra lookup matched host", h.ID.String(), "device", device.ProviderDeviceID) + return nil +} diff --git a/pro/integration/mdm/registry.go b/pro/integration/mdm/registry.go new file mode 100644 index 000000000..8e87cfbee --- /dev/null +++ b/pro/integration/mdm/registry.go @@ -0,0 +1,111 @@ +// Package mdm defines the pluggable MDM provider interface and registry used by +// the Netmaker MDM posture-check feature. Concrete providers (Intune, Jamf, +// future Iru/JumpCloud/etc.) live in sibling packages and self-register via +// init(). +package mdm + +import ( + "context" + "encoding/json" + "fmt" + "time" +) + +// ManagedDevice is the normalised, provider-agnostic view of a device that an +// MDM Provider returns. Fields that a given provider can't fill are left as +// their zero value. +type ManagedDevice struct { + // ProviderDeviceID is the primary key in the upstream MDM. + ProviderDeviceID string + // AzureADDeviceID is filled by Intune; non-Entra MDMs leave it blank. + AzureADDeviceID string + + SerialNumber string + HardwareUUID string + DeviceName string + UserPrincipalName string // user email + + Enrolled bool + Compliant bool + LastSeenAt time.Time +} + +// Capabilities advertises optional provider features so callers (UI / API) +// know what to surface. +type Capabilities struct { + // ReportsCompliant is true if the provider populates ManagedDevice.Compliant + // with a meaningful value derived from upstream compliance state. When + // false, callers should treat Compliant as "unknown" rather than "false". + ReportsCompliant bool +} + +// Provider is the minimal contract every MDM integration must satisfy. +type Provider interface { + // Name returns the stable identifier of this provider (matches integrations_v1.id). + Name() string + // Capabilities advertises optional provider features. + Capabilities() Capabilities + // Verify confirms credentials and connectivity against the upstream MDM. + Verify(ctx context.Context) error + // ListManagedDevices returns every device known to the upstream MDM. + ListManagedDevices(ctx context.Context) ([]ManagedDevice, error) +} + +// ProviderType describes a provider implementation available at compile time. +type ProviderType struct { + Name string `json:"name"` + Display string `json:"display"` + ReportsCompliant bool `json:"reports_compliant"` +} + +// Factory builds a Provider instance from integration config JSON. +type Factory func(config json.RawMessage) (Provider, error) + +type providerEntry struct { + display string + factory Factory +} + +var providers = map[string]providerEntry{} + +// Register binds a provider implementation to its stable name. +func Register(name, display string, f Factory) { + providers[name] = providerEntry{display: display, factory: f} +} + +// ListProviderTypes returns the registered providers with capability flags. +func ListProviderTypes() []ProviderType { + out := make([]ProviderType, 0, len(providers)) + for name, entry := range providers { + pt := ProviderType{Name: name, Display: entry.display} + if c, ok := capabilityHints[name]; ok { + pt.ReportsCompliant = c.ReportsCompliant + } + out = append(out, pt) + } + return out +} + +var capabilityHints = map[string]Capabilities{} + +// RegisterCapabilities records the static capability profile of a provider. +func RegisterCapabilities(name string, c Capabilities) { + capabilityHints[name] = c +} + +// CapabilitiesFor returns the registered capability profile for a provider id. +func CapabilitiesFor(name string) Capabilities { + if c, ok := capabilityHints[name]; ok { + return c + } + return Capabilities{} +} + +// Build constructs a provider by explicit name from config JSON. +func Build(name string, config json.RawMessage) (Provider, error) { + entry, ok := providers[name] + if !ok { + return nil, fmt.Errorf("unknown mdm provider %q", name) + } + return entry.factory(config) +} diff --git a/pro/integration/mdm/sync.go b/pro/integration/mdm/sync.go new file mode 100644 index 000000000..eb3a7e6f1 --- /dev/null +++ b/pro/integration/mdm/sync.go @@ -0,0 +1,196 @@ +package mdm + +import ( + "context" + "encoding/json" + "errors" + "strings" + "sync" + "time" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/schema" +) + +var ( + syncMu sync.Mutex + lastSync time.Time +) + +// RunMDMSync refreshes DeviceMDMState for hosts via the active provider. +// Intune uses Entra-keyed lookup; other providers list devices and match serial_number. +// Honours sync_interval_minutes from integration config as an optional per-tick +// rate-limit hint. Returns nil (no-op) if MDM is not configured. +func RunMDMSync(ctx context.Context) error { + intg, err := GetActive(ctx) + if err != nil { + return err + } + if intg == nil { + return nil + } + sync, err := ParseSyncSettings(intg.ID, json.RawMessage(intg.Config)) + if err != nil { + return err + } + if !sync.SyncEnabled { + return nil + } + return runSyncLocked(ctx, intg, false) +} + +// RunMDMSyncForce ignores the rate-limit hint and triggers a fresh sync. +func RunMDMSyncForce(ctx context.Context) error { + intg, err := GetActive(ctx) + if err != nil { + return err + } + if intg == nil { + return errors.New("no MDM integration configured") + } + return runSyncLocked(ctx, intg, true) +} + +func runSyncLocked(ctx context.Context, intg *schema.Integration, force bool) error { + syncMu.Lock() + defer syncMu.Unlock() + + sync, err := ParseSyncSettings(intg.ID, json.RawMessage(intg.Config)) + if err != nil { + return err + } + if !force && sync.SyncIntervalMinutes > 0 && + !lastSync.IsZero() && + time.Since(lastSync) < time.Duration(sync.SyncIntervalMinutes)*time.Minute { + return nil + } + + p, err := Build(intg.ID, json.RawMessage(intg.Config)) + if err != nil { + logger.Log(0, "mdm sync: build provider:", err.Error()) + return err + } + + hosts, err := (&schema.Host{}).ListAll(db.WithContext(ctx)) + if err != nil { + logger.Log(0, "mdm sync: list hosts:", err.Error()) + return err + } + + matched := 0 + if lookup, ok := p.(EntraDeviceLookup); ok { + for i := range hosts { + if hosts[i].EntraDeviceID == "" { + if err := clearHostMDMState(ctx, intg.ID, hosts[i].ID.String()); err != nil { + logger.Log(0, "mdm sync: clear stale state for host", hosts[i].ID.String(), ":", err.Error()) + } + continue + } + if err := upsertHostMDMFromEntraLookup(ctx, intg.ID, lookup, hosts[i]); err != nil { + logger.Log(0, "mdm sync: entra lookup for host", hosts[i].ID.String(), ":", err.Error()) + continue + } + matched++ + } + lastSync = time.Now().UTC() + logger.Log(2, "mdm sync: provider=", p.Name(), "matched=", itoa(matched)) + return nil + } + + devices, err := p.ListManagedDevices(ctx) + if err != nil { + logger.Log(0, "mdm sync: list devices:", err.Error()) + return err + } + for i := range hosts { + if strings.TrimSpace(hosts[i].SerialNumber) == "" { + if err := clearHostMDMState(ctx, intg.ID, hosts[i].ID.String()); err != nil { + logger.Log(0, "mdm sync: clear stale state for host", hosts[i].ID.String(), ":", err.Error()) + } + continue + } + found := false + for _, d := range devices { + if !MatchHostToMDMDeviceBySerial(hosts[i], d) { + continue + } + state := schema.DeviceMDMState{ + HostID: hosts[i].ID.String(), + Provider: intg.ID, + MDMDeviceID: d.ProviderDeviceID, + Enrolled: d.Enrolled, + Compliant: d.Compliant, + MatchedBy: schema.MDMMatchSerialNumber, + LastSyncedAt: time.Now().UTC(), + LastSeenAt: d.LastSeenAt, + } + if err := state.Upsert(db.WithContext(ctx)); err != nil { + logger.Log(0, "mdm sync: upsert state for host", hosts[i].ID.String(), ":", err.Error()) + continue + } + matched++ + found = true + break + } + if !found { + if err := upsertUnmatchedHostMDMState(ctx, intg.ID, hosts[i].ID.String(), schema.MDMMatchSerialNumber); err != nil { + logger.Log(0, "mdm sync: clear state for host", hosts[i].ID.String(), ":", err.Error()) + continue + } + } + } + + lastSync = time.Now().UTC() + logger.Log(2, "mdm sync: provider=", p.Name(), "devices=", itoa(len(devices)), "matched=", itoa(matched)) + return nil +} + +func upsertUnmatchedHostMDMState(ctx context.Context, providerID, hostID, matchedBy string) error { + state := schema.DeviceMDMState{ + HostID: hostID, + Provider: providerID, + Enrolled: false, + Compliant: false, + MatchedBy: matchedBy, + LastSyncedAt: time.Now().UTC(), + LastError: ErrDeviceNotFoundInMDM.Error(), + } + return state.Upsert(db.WithContext(ctx)) +} + +func clearHostMDMState(ctx context.Context, providerID, hostID string) error { + state := &schema.DeviceMDMState{HostID: hostID, Provider: providerID} + return state.Delete(db.WithContext(ctx)) +} + +// MatchHostToMDMDeviceBySerial matches a host to an MDM device by serial number only. +func MatchHostToMDMDeviceBySerial(h schema.Host, d ManagedDevice) bool { + hostSerial := strings.TrimSpace(h.SerialNumber) + deviceSerial := strings.TrimSpace(d.SerialNumber) + return hostSerial != "" && deviceSerial != "" && + strings.EqualFold(hostSerial, deviceSerial) +} + +func itoa(i int) string { + if i == 0 { + return "0" + } + neg := false + if i < 0 { + neg = true + i = -i + } + var buf [20]byte + pos := len(buf) + for i > 0 { + pos-- + buf[pos] = byte('0' + i%10) + i /= 10 + } + if neg { + pos-- + buf[pos] = '-' + } + return string(buf[pos:]) +} diff --git a/pro/integration/mdm/sync_test.go b/pro/integration/mdm/sync_test.go new file mode 100644 index 000000000..bf9449f2b --- /dev/null +++ b/pro/integration/mdm/sync_test.go @@ -0,0 +1,38 @@ +package mdm + +import ( + "testing" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/schema" +) + +func TestMatchHostToMDMDeviceBySerial(t *testing.T) { + host := schema.Host{ + ID: uuid.New(), + SerialNumber: " ABC123 ", + } + device := ManagedDevice{SerialNumber: "abc123"} + if !MatchHostToMDMDeviceBySerial(host, device) { + t.Fatal("expected serial match") + } +} + +func TestMatchHostToMDMDeviceBySerial_NoMatch(t *testing.T) { + host := schema.Host{ + ID: uuid.New(), + SerialNumber: "ABC123", + } + device := ManagedDevice{SerialNumber: "XYZ999"} + if MatchHostToMDMDeviceBySerial(host, device) { + t.Fatal("expected no match") + } +} + +func TestMatchHostToMDMDeviceBySerial_EmptyHostSerial(t *testing.T) { + host := schema.Host{ID: uuid.New()} + device := ManagedDevice{SerialNumber: "ABC123"} + if MatchHostToMDMDeviceBySerial(host, device) { + t.Fatal("expected no match when host serial is empty") + } +} diff --git a/pro/integration/mdm_provider.go b/pro/integration/mdm_provider.go new file mode 100644 index 000000000..cb162cb68 --- /dev/null +++ b/pro/integration/mdm_provider.go @@ -0,0 +1,31 @@ +package integration + +import ( + "context" + "encoding/json" + + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" +) + +type mdmProvider struct { + id ProviderID +} + +func (m *mdmProvider) Validate(configJSON json.RawMessage) error { + return mdmpkg.ValidateConfig(string(m.id), configJSON) +} + +func (m *mdmProvider) Test(configJSON json.RawMessage) error { + if err := m.Validate(configJSON); err != nil { + return err + } + p, err := mdmpkg.Build(string(m.id), configJSON) + if err != nil { + return err + } + return p.Verify(context.Background()) +} + +func newMDMProvider(id ProviderID) Provider { + return &mdmProvider{id: id} +} diff --git a/pro/integration/providers.go b/pro/integration/providers.go index bc690d4d3..8b218bd25 100644 --- a/pro/integration/providers.go +++ b/pro/integration/providers.go @@ -3,6 +3,8 @@ package integration import ( "encoding/json" "fmt" + + "github.com/gravitl/netmaker/pro/integration/siem" ) type Type string @@ -11,6 +13,7 @@ type ProviderID string const ( TypeSIEM Type = "siem" + TypeMDM Type = "mdm" ) const ( @@ -18,6 +21,10 @@ const ( ProviderElastic ProviderID = "elastic" ProviderSentinel ProviderID = "sentinel" ProviderSplunk ProviderID = "splunk" + ProviderIntune ProviderID = "intune" + ProviderJamf ProviderID = "jamf" + ProviderJumpCloud ProviderID = "jumpcloud" + ProviderIru ProviderID = "iru" ) type Provider interface { @@ -27,10 +34,16 @@ type Provider interface { var registry = map[Type]map[ProviderID]Provider{ TypeSIEM: { - ProviderSplunk: &splunkProvider{}, - ProviderDatadog: &datadogProvider{}, - ProviderElastic: &elasticProvider{}, - ProviderSentinel: &sentinelProvider{}, + ProviderDatadog: siem.DatadogProvider(), + ProviderElastic: siem.ElasticProvider(), + ProviderSentinel: siem.SentinelProvider(), + ProviderSplunk: siem.SplunkProvider(), + }, + TypeMDM: { + ProviderIntune: newMDMProvider(ProviderIntune), + ProviderJamf: newMDMProvider(ProviderJamf), + ProviderJumpCloud: newMDMProvider(ProviderJumpCloud), + ProviderIru: newMDMProvider(ProviderIru), }, } @@ -45,3 +58,9 @@ func Lookup(intType Type, id ProviderID) (Provider, error) { } return p, nil } + +// TypeExists reports whether the integration type is registered. +func TypeExists(intType Type) bool { + _, ok := registry[intType] + return ok +} diff --git a/pro/integration/siem.go b/pro/integration/siem/client.go similarity index 59% rename from pro/integration/siem.go rename to pro/integration/siem/client.go index d2b197f4c..62fdde505 100644 --- a/pro/integration/siem.go +++ b/pro/integration/siem/client.go @@ -1,7 +1,7 @@ -package integration +package siem import "context" -type SIEMClient interface { +type Client interface { Export(ctx context.Context, events []any) error } diff --git a/pro/integration/siem_datadog.go b/pro/integration/siem/datadog.go similarity index 87% rename from pro/integration/siem_datadog.go rename to pro/integration/siem/datadog.go index 97f355f85..ed06a55b8 100644 --- a/pro/integration/siem_datadog.go +++ b/pro/integration/siem/datadog.go @@ -1,4 +1,4 @@ -package integration +package siem import ( "bytes" @@ -58,18 +58,18 @@ func (d *datadogProvider) Test(configJSON json.RawMessage) error { testEvent := map[string]any{ "message": "netmaker siem integration test", } - return NewDatadogSIEMClient(cfg).Export(context.Background(), []any{testEvent}) + return NewDatadogClient(cfg).Export(context.Background(), []any{testEvent}) } -type DatadogSIEMClient struct { +type DatadogClient struct { DatadogConfig } -func NewDatadogSIEMClient(config DatadogConfig) *DatadogSIEMClient { +func NewDatadogClient(config DatadogConfig) *DatadogClient { if config.Site == "" { config.Site = "datadoghq.com" } - return &DatadogSIEMClient{DatadogConfig: config} + return &DatadogClient{DatadogConfig: config} } type datadogLogItem struct { @@ -79,7 +79,7 @@ type datadogLogItem struct { DDTags string `json:"ddtags,omitempty"` } -func (d *DatadogSIEMClient) Export(ctx context.Context, events []any) error { +func (d *DatadogClient) Export(ctx context.Context, events []any) error { items := make([]datadogLogItem, 0, len(events)) for _, e := range events { msg, _ := json.Marshal(e) @@ -123,3 +123,5 @@ func (d *DatadogSIEMClient) Export(ctx context.Context, events []any) error { } return nil } + +func DatadogProvider() *datadogProvider { return &datadogProvider{} } diff --git a/pro/integration/siem_elastic.go b/pro/integration/siem/elastic.go similarity index 87% rename from pro/integration/siem_elastic.go rename to pro/integration/siem/elastic.go index b6f637c04..0f3582fa6 100644 --- a/pro/integration/siem_elastic.go +++ b/pro/integration/siem/elastic.go @@ -1,4 +1,4 @@ -package integration +package siem import ( "bytes" @@ -54,18 +54,18 @@ func (e *elasticProvider) Test(configJSON json.RawMessage) error { testEvent := map[string]any{ "message": "netmaker siem integration test", } - return NewElasticSIEMClient(cfg).Export(context.Background(), []any{testEvent}) + return NewElasticClient(cfg).Export(context.Background(), []any{testEvent}) } -type ElasticSIEMClient struct { +type ElasticClient struct { ElasticConfig } -func NewElasticSIEMClient(config ElasticConfig) *ElasticSIEMClient { - return &ElasticSIEMClient{ElasticConfig: config} +func NewElasticClient(config ElasticConfig) *ElasticClient { + return &ElasticClient{ElasticConfig: config} } -func (e *ElasticSIEMClient) Export(ctx context.Context, events []any) error { +func (e *ElasticClient) Export(ctx context.Context, events []any) error { metaLine, _ := json.Marshal(map[string]any{"index": map[string]any{"_index": e.Index}}) var buf bytes.Buffer for _, ev := range events { @@ -110,3 +110,5 @@ func (e *ElasticSIEMClient) Export(ctx context.Context, events []any) error { } return nil } + +func ElasticProvider() *elasticProvider { return &elasticProvider{} } diff --git a/pro/integration/siem_sentinel.go b/pro/integration/siem/sentinel.go similarity index 88% rename from pro/integration/siem_sentinel.go rename to pro/integration/siem/sentinel.go index 785071f79..941482da3 100644 --- a/pro/integration/siem_sentinel.go +++ b/pro/integration/siem/sentinel.go @@ -1,4 +1,4 @@ -package integration +package siem import ( "bytes" @@ -55,21 +55,21 @@ func (s *sentinelProvider) Test(configJSON json.RawMessage) error { testEvent := map[string]any{ "message": "netmaker siem integration test", } - return NewSentinelSIEMClient(cfg).Export(context.Background(), []any{testEvent}) + return NewSentinelClient(cfg).Export(context.Background(), []any{testEvent}) } -type SentinelSIEMClient struct { +type SentinelClient struct { SentinelConfig } -func NewSentinelSIEMClient(config SentinelConfig) *SentinelSIEMClient { +func NewSentinelClient(config SentinelConfig) *SentinelClient { if config.LogType == "" { config.LogType = "NetmakerSIEM" } - return &SentinelSIEMClient{SentinelConfig: config} + return &SentinelClient{SentinelConfig: config} } -func (s *SentinelSIEMClient) Export(ctx context.Context, events []any) error { +func (s *SentinelClient) Export(ctx context.Context, events []any) error { enriched := make([]map[string]any, 0, len(events)) for _, ev := range events { var evMap map[string]any @@ -113,3 +113,5 @@ func (s *SentinelSIEMClient) Export(ctx context.Context, events []any) error { } return nil } + +func SentinelProvider() *sentinelProvider { return &sentinelProvider{} } diff --git a/pro/integration/siem_splunk.go b/pro/integration/siem/splunk.go similarity index 85% rename from pro/integration/siem_splunk.go rename to pro/integration/siem/splunk.go index 5560a2a52..077301563 100644 --- a/pro/integration/siem_splunk.go +++ b/pro/integration/siem/splunk.go @@ -1,4 +1,4 @@ -package integration +package siem import ( "bytes" @@ -46,21 +46,21 @@ func (s *splunkProvider) Test(configJSON json.RawMessage) error { testEvent := map[string]any{ "message": "netmaker siem integration test", } - return NewSplunkSIEMClient(cfg).Export(context.Background(), []any{testEvent}) + return NewSplunkClient(cfg).Export(context.Background(), []any{testEvent}) } -type SplunkSIEMClient struct { +type SplunkClient struct { SplunkConfig } -func NewSplunkSIEMClient(config SplunkConfig) *SplunkSIEMClient { +func NewSplunkClient(config SplunkConfig) *SplunkClient { if config.SourceType == "" { config.SourceType = "_json" } - return &SplunkSIEMClient{SplunkConfig: config} + return &SplunkClient{SplunkConfig: config} } -func (s *SplunkSIEMClient) Export(ctx context.Context, events []any) error { +func (s *SplunkClient) Export(ctx context.Context, events []any) error { var buf bytes.Buffer for _, e := range events { payload := map[string]any{ @@ -98,3 +98,5 @@ func (s *SplunkSIEMClient) Export(ctx context.Context, events []any) error { } return nil } + +func SplunkProvider() *splunkProvider { return &splunkProvider{} } diff --git a/pro/logic/posture_check.go b/pro/logic/posture_check.go index 1111f8356..b69af60a8 100644 --- a/pro/logic/posture_check.go +++ b/pro/logic/posture_check.go @@ -15,6 +15,7 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + mdmpkg "github.com/gravitl/netmaker/pro/integration/mdm" "github.com/gravitl/netmaker/schema" "gorm.io/datatypes" ) @@ -63,6 +64,10 @@ func RunPostureChecks() error { } postureCheckMutex.Lock() defer postureCheckMutex.Unlock() + // Refresh MDM device state before evaluating; a no-op when no provider + // is configured. Errors are already logged inside; we don't want a + // remote-API hiccup to block the rest of the posture cycle. + _ = mdmpkg.RunMDMSync(db.WithContext(context.TODO())) nets, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) if err != nil { return err @@ -87,12 +92,13 @@ func RunPostureChecks() error { if nodeI.IsStatic && !nodeI.IsUserNode { continue } + deviceInfo := logic.GetPostureCheckDeviceInfoByNode(&nodeI) var postureChecksViolations []models.Violation var postureCheckVolationSeverityLevel schema.Severity if noChecks { postureCheckVolationSeverityLevel = schema.SeverityUnknown } else { - postureChecksViolations, postureCheckVolationSeverityLevel = GetPostureCheckViolations(pcLi, logic.GetPostureCheckDeviceInfoByNode(&nodeI)) + postureChecksViolations, postureCheckVolationSeverityLevel = GetPostureCheckViolations(pcLi, deviceInfo) } if nodeI.IsUserNode { extclient, err := logic.GetExtClient(nodeI.StaticNode.ClientID, nodeI.StaticNode.Network) @@ -100,6 +106,7 @@ func RunPostureChecks() error { if noChecks && len(extclient.PostureChecksViolations) == 0 { continue } + emitNewMDMViolationEvents(extclient.PostureChecksViolations, postureChecksViolations, deviceInfo, schema.NetworkID(netI.Name)) extclient.PostureChecksViolations = postureChecksViolations extclient.PostureCheckVolationSeverityLevel = postureCheckVolationSeverityLevel extclient.LastEvaluatedAt = time.Now().UTC() @@ -109,6 +116,7 @@ func RunPostureChecks() error { if noChecks && len(nodeI.PostureChecksViolations) == 0 { continue } + emitNewMDMViolationEvents(nodeI.PostureChecksViolations, postureChecksViolations, deviceInfo, schema.NetworkID(netI.Name)) _node := &schema.Node{ ID: nodeI.ID.String(), @@ -173,6 +181,8 @@ func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureChe // Check if posture check has wildcard tag - applies to all devices if _, hasWildcard := c.Tags["*"]; hasWildcard { // Wildcard tag matches all devices, continue to evaluate the check + } else if c.Attribute == schema.MDMCompliance && len(c.Tags) == 0 { + // Legacy MDM checks saved before wildcard default; apply to all hosts. } else if len(c.Tags) > 0 { // Check has specific tags - device must have at least one matching tag if len(d.Tags) == 0 { @@ -339,6 +349,15 @@ func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDevic KernelVersion: h.KernelVersion, AutoUpdate: h.AutoUpdate, Tags: node.Tags, + HostID: h.ID.String(), + } + // Attach the MDM snapshot for the active integration provider, if any. + ctx := db.WithContext(context.TODO()) + if providerID, err := mdmpkg.ActiveProviderID(ctx); err == nil && providerID != "" { + state := &schema.DeviceMDMState{HostID: h.ID.String(), Provider: providerID} + if err := state.Get(ctx); err == nil { + deviceInfo.MDMState = state + } } } else if node.IsUserNode { deviceInfo = models.PostureCheckDeviceInfo{ @@ -478,6 +497,32 @@ func evaluatePostureCheck(check *schema.PostureCheck, d models.PostureCheckDevic if !required && d.AutoUpdate { return true, "auto update must be disabled" } + + // ------------------------ + // 6. MDM compliance check + // Config: {require_enrolled, require_compliant, max_state_age_hours} + // ------------------------ + case schema.MDMCompliance: + cfg := ParseMDMComplianceConfig(check.Config) + if d.MDMState == nil { + return true, "no_mdm_state_for_host" + } + if d.MDMState.LastError != "" { + return true, d.MDMState.LastError + } + if cfg.RequireEnrolled && !d.MDMState.Enrolled { + return true, "device_not_mdm_enrolled" + } + if cfg.RequireCompliant { + providerID, _ := mdmpkg.ActiveProviderID(db.WithContext(context.TODO())) + if mdmpkg.CapabilitiesFor(providerID).ReportsCompliant && !d.MDMState.Compliant { + return true, "device_not_mdm_compliant" + } + } + if cfg.MaxStateAgeHours > 0 && + time.Since(d.MDMState.LastSyncedAt) > time.Duration(cfg.MaxStateAgeHours)*time.Hour { + return true, "mdm_state_stale" + } } return false, "" @@ -538,6 +583,28 @@ func PopulatePostureCheckGroupNames(pcs []schema.PostureCheck) { } } +// MergePostureCheckUpdate fills in fields omitted from an update payload using +// the existing stored posture check. Clients that toggle status often omit +// attribute-specific Config; without this merge validation would see empty +// MDM flags and reject the request. +func MergePostureCheckUpdate(existing, update *schema.PostureCheck) { + if update.Attribute != schema.MDMCompliance || existing.Config == nil { + return + } + if update.Config == nil { + update.Config = existing.Config + return + } + merged := datatypes.JSONMap{} + for k, v := range existing.Config { + merged[k] = v + } + for k, v := range update.Config { + merged[k] = v + } + update.Config = merged +} + func ValidatePostureCheck(pc *schema.PostureCheck) error { if pc.Name == "" { return errors.New("name cannot be empty") @@ -546,10 +613,29 @@ func ValidatePostureCheck(pc *schema.PostureCheck) error { if err != nil { return errors.New("invalid network") } - allowedAttrvaluesMap, ok := schema.PostureCheckAttrValuesMap[pc.Attribute] + _, ok := schema.PostureCheckAttrValuesMap[pc.Attribute] if !ok { return errors.New("unkown attribute") } + // MDMCompliance uses Config, not Values. Validate the Config payload and + // short-circuit the Values flow (Values is set to a placeholder so the + // rest of the system stays happy). + if pc.Attribute == schema.MDMCompliance { + if err := validateMDMComplianceConfig(pc); err != nil { + return err + } + pc.Values = datatypes.JSONSlice[string]{"mdm"} + if len(pc.Tags) == 0 { + // MDM checks apply to all hosts in the network unless scoped + // to specific tags; empty tags would otherwise be skipped. + pc.Tags = datatypes.JSONMap{"*": "*"} + } + if len(pc.UserGroups) == 0 { + pc.UserGroups = make(datatypes.JSONMap) + } + return nil + } + allowedAttrvaluesMap := schema.PostureCheckAttrValuesMap[pc.Attribute] if len(pc.Values) == 0 { return errors.New("attribute value cannot be empty") } @@ -619,3 +705,157 @@ func CountryNameFromISO(code string) string { } return c.Info().Name } + +// MDMComplianceConfig is the typed view of PostureCheck.Config when +// Attribute == MDMCompliance. +type MDMComplianceConfig struct { + RequireEnrolled bool + RequireCompliant bool + MaxStateAgeHours int +} + +// ParseMDMComplianceConfig decodes the JSONMap stored on PostureCheck.Config +// into a typed MDMComplianceConfig. Unknown keys are ignored. +func ParseMDMComplianceConfig(cfg datatypes.JSONMap) MDMComplianceConfig { + out := MDMComplianceConfig{} + if cfg == nil { + return out + } + if v, ok := cfg["require_enrolled"]; ok { + out.RequireEnrolled = asBool(v) + } + if v, ok := cfg["require_compliant"]; ok { + out.RequireCompliant = asBool(v) + } + if v, ok := cfg["max_state_age_hours"]; ok { + out.MaxStateAgeHours = asInt(v) + } + return out +} + +func asBool(v interface{}) bool { + switch x := v.(type) { + case bool: + return x + case string: + return strings.EqualFold(x, "true") + case float64: + return x != 0 + case int: + return x != 0 + } + return false +} + +func asInt(v interface{}) int { + switch x := v.(type) { + case int: + return x + case int64: + return int(x) + case float64: + return int(x) + case string: + if i, err := strconv.Atoi(x); err == nil { + return i + } + } + return 0 +} + +// emitNewMDMViolationEvents emits a posture_check_failed audit event for every +// MDM compliance violation that is newly present (not in oldVi) in newVi. +// Old violations don't re-fire; cleared violations are also ignored here. +func emitNewMDMViolationEvents(oldVi, newVi []models.Violation, d models.PostureCheckDeviceInfo, network schema.NetworkID) { + if len(newVi) == 0 { + return + } + prev := make(map[string]struct{}, len(oldVi)) + for _, v := range oldVi { + prev[v.CheckID+"|"+v.Message] = struct{}{} + } + providerID, _ := mdmpkg.ActiveProviderID(db.WithContext(context.TODO())) + for _, v := range newVi { + if v.Attribute != string(schema.MDMCompliance) { + continue + } + if _, ok := prev[v.CheckID+"|"+v.Message]; ok { + continue + } + diff := models.Diff{ + Old: nil, + New: map[string]interface{}{ + "event": "posture_check_failed", + "type": string(schema.MDMCompliance), + "host_id": d.HostID, + "check_id": v.CheckID, + "check": v.Name, + "reason": v.Message, + "severity": v.Severity, + "provider": providerID, + "enrolled": mdmStateEnrolled(d.MDMState), + "compliant": mdmStateCompliant(d.MDMState), + }, + } + logic.LogEvent(&models.Event{ + Action: schema.PostureCheckFailed, + Source: models.Subject{ + ID: d.HostID, + Name: d.HostID, + Type: schema.DeviceSub, + }, + TriggeredBy: "system", + Target: models.Subject{ + ID: v.CheckID, + Name: v.Name, + Type: schema.PostureCheckSub, + }, + NetworkID: network, + Origin: schema.Api, + Diff: diff, + }) + } +} + +func mdmStateEnrolled(s *schema.DeviceMDMState) bool { + if s == nil { + return false + } + return s.Enrolled +} + +func mdmStateCompliant(s *schema.DeviceMDMState) bool { + if s == nil { + return false + } + return s.Compliant +} + +// validateMDMComplianceConfig enforces the MDMCompliance posture-check +// invariants: an MDM integration must be configured, at least one of +// require_enrolled/require_compliant must be true when the check is enabled, +// and max_state_age_hours must be non-negative. +func validateMDMComplianceConfig(pc *schema.PostureCheck) error { + active, err := mdmpkg.GetActive(db.WithContext(context.TODO())) + if err != nil { + return err + } + if active == nil { + return errors.New("no MDM integration configured; configure via Integrations > MDM") + } + cfg := ParseMDMComplianceConfig(pc.Config) + if pc.Status && !cfg.RequireEnrolled && !cfg.RequireCompliant { + return errors.New("at least one of require_enrolled or require_compliant must be true") + } + if cfg.MaxStateAgeHours < 0 { + return errors.New("max_state_age_hours must be >= 0") + } + // Normalise the Config map so it's always present after validation. + if pc.Config == nil { + pc.Config = make(datatypes.JSONMap) + } + pc.Config["require_enrolled"] = cfg.RequireEnrolled + pc.Config["require_compliant"] = cfg.RequireCompliant + pc.Config["max_state_age_hours"] = cfg.MaxStateAgeHours + return nil +} diff --git a/pro/logic/posture_check_mdm_test.go b/pro/logic/posture_check_mdm_test.go new file mode 100644 index 000000000..f931a9663 --- /dev/null +++ b/pro/logic/posture_check_mdm_test.go @@ -0,0 +1,64 @@ +package logic + +import ( + "testing" + + "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" +) + +func TestMergePostureCheckUpdatePreservesMDMConfig(t *testing.T) { + existing := &schema.PostureCheck{ + Attribute: schema.MDMCompliance, + Config: datatypes.JSONMap{ + "require_enrolled": true, + "require_compliant": false, + "max_state_age_hours": 24, + }, + } + update := &schema.PostureCheck{ + Attribute: schema.MDMCompliance, + Status: false, + } + + MergePostureCheckUpdate(existing, update) + + if update.Config == nil { + t.Fatal("expected config to be merged from existing check") + } + if !asBool(update.Config["require_enrolled"]) { + t.Fatal("expected require_enrolled to be preserved") + } + if asBool(update.Config["require_compliant"]) { + t.Fatal("expected require_compliant to remain false") + } + if asInt(update.Config["max_state_age_hours"]) != 24 { + t.Fatalf("expected max_state_age_hours 24, got %d", asInt(update.Config["max_state_age_hours"])) + } +} + +func TestMergePostureCheckUpdateOverlayConfig(t *testing.T) { + existing := &schema.PostureCheck{ + Attribute: schema.MDMCompliance, + Config: datatypes.JSONMap{ + "require_enrolled": true, + "require_compliant": true, + "max_state_age_hours": 24, + }, + } + update := &schema.PostureCheck{ + Attribute: schema.MDMCompliance, + Config: datatypes.JSONMap{ + "require_compliant": false, + }, + } + + MergePostureCheckUpdate(existing, update) + + if !asBool(update.Config["require_enrolled"]) { + t.Fatal("expected require_enrolled from existing config") + } + if asBool(update.Config["require_compliant"]) { + t.Fatal("expected require_compliant to be overridden by update") + } +} diff --git a/pro/logic/security.go b/pro/logic/security.go index b7a79f654..a55f4326e 100644 --- a/pro/logic/security.go +++ b/pro/logic/security.go @@ -291,3 +291,48 @@ func GetSaaSNMUIHost() string { func GetSaaSNMUIHostWithVersion() string { return fmt.Sprintf("%s/%s", GetSaaSNMUIHost(), servercfg.GetVersion()) } + +// CheckUIHostReadAccess ensures a dashboard user may read posture data for a host +// by verifying network-scoped host read permission on at least one host network. +func CheckUIHostReadAccess(r *http.Request, host *schema.Host) error { + username := r.Header.Get("user") + if username == logic.MasterUser { + return nil + } + user := &schema.User{Username: username} + if err := user.Get(r.Context()); err != nil { + return err + } + userRole := &schema.UserRole{ID: user.PlatformRoleID} + if err := userRole.Get(r.Context()); err != nil { + return errors.New("access denied") + } + if userRole.FullAccess && !PlatformRoleRequiresGroupEnforcement(user.PlatformRoleID) { + return nil + } + + networks := map[string]struct{}{} + for _, nodeID := range host.Nodes { + node, err := logic.GetNodeByID(nodeID) + if err != nil { + continue + } + networks[node.Network] = struct{}{} + } + if len(networks) == 0 { + return errors.New("access denied") + } + + for netID := range networks { + req := r.Clone(r.Context()) + req.Header.Set("IS_GLOBAL_ACCESS", "no") + req.Header.Set("NET_ID", netID) + req.Header.Set("TARGET_RSRC", schema.HostRsrc.String()) + req.Header.Set("TARGET_RSRC_ID", host.ID.String()) + req.Method = http.MethodGet + if err := NetworkPermissionsCheck(username, req); err == nil { + return nil + } + } + return errors.New("access denied") +} diff --git a/schema/event.go b/schema/event.go index 80e210d82..f1bc88fe6 100644 --- a/schema/event.go +++ b/schema/event.go @@ -54,6 +54,9 @@ const ( DisableUser Action = "DISABLE_USER" EnableAclPolicy Action = "ENABLE_ACL_POLICY" DisableAclPolicy Action = "DISABLE_ACL_POLICY" + MDMVerify Action = "MDM_VERIFY" + MDMSync Action = "MDM_SYNC" + PostureCheckFailed Action = "POSTURE_CHECK_FAILED" ) type SubjectType string @@ -79,6 +82,7 @@ const ( NameserverSub SubjectType = "NAMESERVER" PostureCheckSub SubjectType = "POSTURE_CHECK" JITSub SubjectType = "JIT" + MDMSub SubjectType = "MDM" ) func (sub SubjectType) String() string { diff --git a/schema/hosts.go b/schema/hosts.go index 3024c58cd..b1042f5a5 100644 --- a/schema/hosts.go +++ b/schema/hosts.go @@ -157,8 +157,16 @@ type Host struct { Location string `json:"location" yaml:"location"` // Format: "lat,lon" CountryCode string `json:"country_code" yaml:"country_code"` EnableFlowLogs bool `json:"enable_flow_logs" yaml:"enable_flow_logs"` - CreatedAt time.Time `json:"created_at" yaml:"created_at"` - UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"` + + // MDM device-matching identifiers. Reported by netclient on host check-in + // and consumed by the MDM sync worker to match a Netmaker host to its + // upstream MDM-managed device record. + EntraDeviceID string `json:"entra_device_id" yaml:"entra_device_id"` + SerialNumber string `json:"serial_number" yaml:"serial_number"` + HardwareUUID string `json:"hardware_uuid" yaml:"hardware_uuid"` + + CreatedAt time.Time `json:"created_at" yaml:"created_at"` + UpdatedAt time.Time `json:"updated_at" yaml:"updated_at"` } func (h *Host) TableName() string { diff --git a/schema/mdm_device_state.go b/schema/mdm_device_state.go new file mode 100644 index 000000000..13a10cd73 --- /dev/null +++ b/schema/mdm_device_state.go @@ -0,0 +1,89 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" +) + +const deviceMDMStateTable = "device_mdm_state_v1" + +// MatchedBy* identify how a host was matched to an MDM-managed device record. +const ( + MDMMatchEntraDeviceID = "entra_device_id" + MDMMatchSerialNumber = "serial_number" + MDMMatchHardwareUUID = "hardware_uuid" + MDMMatchHostname = "hostname" +) + +// DeviceMDMState is the per-host snapshot of an MDM provider's view of a device. +// Composite PK (HostID, Provider) allows a host to carry historical state for +// multiple providers (e.g. after switching MDMs). +type DeviceMDMState struct { + HostID string `gorm:"primaryKey;column:host_id" json:"host_id"` + Provider string `gorm:"primaryKey;column:provider" json:"provider"` + MDMDeviceID string `gorm:"column:mdm_device_id" json:"mdm_device_id"` + Enrolled bool `gorm:"column:enrolled" json:"enrolled"` + Compliant bool `gorm:"column:compliant" json:"compliant"` + MatchedBy string `gorm:"column:matched_by" json:"matched_by"` + LastSyncedAt time.Time `gorm:"column:last_synced_at" json:"last_synced_at"` + LastSeenAt time.Time `gorm:"column:last_seen_at" json:"last_seen_at"` + // LastError holds a stable lookup failure code (e.g. device_not_registered_in_entra). + LastError string `gorm:"column:last_error" json:"last_error,omitempty"` +} + +func (s *DeviceMDMState) TableName() string { + return deviceMDMStateTable +} + +// Get loads the row identified by (HostID, Provider). +func (s *DeviceMDMState) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&DeviceMDMState{}). + Where("host_id = ? AND provider = ?", s.HostID, s.Provider). + First(s).Error +} + +// Upsert inserts the row or updates the existing one keyed by (HostID, Provider). +func (s *DeviceMDMState) Upsert(ctx context.Context) error { + return db.FromContext(ctx).Save(s).Error +} + +// Delete removes the row for (HostID, Provider). +func (s *DeviceMDMState) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&DeviceMDMState{}). + Where("host_id = ? AND provider = ?", s.HostID, s.Provider). + Delete(&DeviceMDMState{}).Error +} + +// DeleteByHostID removes all MDM state rows for a given host (used when a host is deleted). +func (s *DeviceMDMState) DeleteByHostID(ctx context.Context) error { + return db.FromContext(ctx).Model(&DeviceMDMState{}). + Where("host_id = ?", s.HostID). + Delete(&DeviceMDMState{}).Error +} + +// ListByHost returns all provider states for the host in s.HostID. +func (s *DeviceMDMState) ListByHost(ctx context.Context) ([]DeviceMDMState, error) { + var out []DeviceMDMState + err := db.FromContext(ctx).Model(&DeviceMDMState{}). + Where("host_id = ?", s.HostID). + Find(&out).Error + return out, err +} + +// ListByProvider returns all host states for the provider in s.Provider. +func (s *DeviceMDMState) ListByProvider(ctx context.Context) ([]DeviceMDMState, error) { + var out []DeviceMDMState + err := db.FromContext(ctx).Model(&DeviceMDMState{}). + Where("provider = ?", s.Provider). + Find(&out).Error + return out, err +} + +// ListAll returns every MDM state row. +func (s *DeviceMDMState) ListAll(ctx context.Context) ([]DeviceMDMState, error) { + var out []DeviceMDMState + err := db.FromContext(ctx).Model(&DeviceMDMState{}).Find(&out).Error + return out, err +} diff --git a/schema/models.go b/schema/models.go index e50870f8f..c11297b11 100644 --- a/schema/models.go +++ b/schema/models.go @@ -21,6 +21,7 @@ func ListModels() []interface{} { &UserInvite{}, &Node{}, &PostureCheckViolation{}, + &DeviceMDMState{}, &Integration{}, } } diff --git a/schema/posture_check.go b/schema/posture_check.go index dd4479276..a5fd8a8b8 100644 --- a/schema/posture_check.go +++ b/schema/posture_check.go @@ -21,6 +21,10 @@ const ( AutoUpdate Attribute = "auto_update" ClientVersion Attribute = "client_version" ClientLocation Attribute = "client_location" + // MDMCompliance evaluates the host's posture against the MDM provider + // configured in ServerSettings. Config payload (JSONMap): + // {"require_enrolled": bool, "require_compliant": bool, "max_state_age_hours": int} + MDMCompliance Attribute = "mdm_compliance" ) const ( @@ -39,6 +43,15 @@ var PostureCheckAttrs = []Attribute{ OSFamily, KernelVersion, AutoUpdate, + MDMCompliance, +} + +// MDMComplianceConfigKeys lists the supported keys in PostureCheck.Config when +// Attribute == MDMCompliance. +var MDMComplianceConfigKeys = []string{ + "require_enrolled", + "require_compliant", + "max_state_age_hours", } var PostureCheckAttrValuesMap = map[Attribute]map[string]struct{}{ @@ -77,6 +90,10 @@ var PostureCheckAttrValuesMap = map[Attribute]map[string]struct{}{ "true": {}, "false": {}, }, + // MDMCompliance is configured via PostureCheck.Config, not Values. + MDMCompliance: { + "mdm": {}, + }, } var PostureCheckAttrValues = map[Attribute][]string{ @@ -87,6 +104,7 @@ var PostureCheckAttrValues = map[Attribute][]string{ OSFamily: {"linux-debian", "linux-redhat", "linux-suse", "linux-arch", "linux-gentoo", "linux-other", "darwin", "windows", "ios", "android"}, KernelVersion: {"any_valid_semantic_version"}, AutoUpdate: {"true", "false"}, + MDMCompliance: {"mdm"}, } type PostureCheck struct { @@ -96,13 +114,17 @@ type PostureCheck struct { Description string `gorm:"description" json:"description"` Attribute Attribute `gorm:"attribute" json:"attribute"` Values datatypes.JSONSlice[string] `gorm:"values" json:"values"` - Severity Severity `gorm:"severity" json:"severity"` - Tags datatypes.JSONMap `gorm:"tags" json:"tags"` - UserGroups datatypes.JSONMap `gorm:"user_groups" json:"user_groups"` - Status bool `gorm:"status" json:"status"` - CreatedBy string `gorm:"created_by" json:"created_by"` - CreatedAt time.Time `gorm:"created_at" json:"created_at"` - UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` + // Config holds attribute-specific structured options. Used by MDMCompliance + // for {require_enrolled, require_compliant, max_state_age_hours}; null for + // legacy attributes that rely on Values. + Config datatypes.JSONMap `gorm:"config" json:"config"` + Severity Severity `gorm:"severity" json:"severity"` + Tags datatypes.JSONMap `gorm:"tags" json:"tags"` + UserGroups datatypes.JSONMap `gorm:"user_groups" json:"user_groups"` + Status bool `gorm:"status" json:"status"` + CreatedBy string `gorm:"created_by" json:"created_by"` + CreatedAt time.Time `gorm:"created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` } func (p *PostureCheck) Get(ctx context.Context) error {