diff --git a/auth/host_session.go b/auth/host_session.go index d2b9efa06..c8bd9f38b 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -215,7 +215,7 @@ func SessionHandler(conn *websocket.Conn) { if err = conn.WriteMessage(messageType, responseData); err != nil { logger.Log(0, "error during message writing:", err.Error()) } - go CheckNetRegAndHostUpdate(models.EnrollmentKey{Networks: netsToAdd}, &host, result.User) + go CheckNetRegAndHostUpdate(schema.EnrollmentKey{Networks: netsToAdd}, &host, result.User) case <-timeout: // the read from req.answerCh has timed out logger.Log(0, "timeout signal recv,exiting oauth socket conn") break @@ -229,14 +229,12 @@ func SessionHandler(conn *websocket.Conn) { } // CheckNetRegAndHostUpdate - run through networks and send a host update -func CheckNetRegAndHostUpdate(key models.EnrollmentKey, host *schema.Host, username string) { +func CheckNetRegAndHostUpdate(key schema.EnrollmentKey, host *schema.Host, username string) { // publish host update through MQ featureFlags := logic.GetFeatureFlags() keyTags := make(map[models.TagID]struct{}) - if len(key.Groups) > 0 { - for _, tagI := range key.Groups { - keyTags[tagI] = struct{}{} - } + for _, tagI := range key.Tags { + keyTags[models.TagID(tagI)] = struct{}{} } for _, netID := range key.Networks { network := &schema.Network{Name: netID} @@ -329,7 +327,7 @@ func CheckNetRegAndHostUpdate(key models.EnrollmentKey, host *schema.Host, usern Action: schema.JoinHostToNet, Source: models.Subject{ ID: key.Value, - Name: key.Tags[0], + Name: key.Name, Type: schema.EnrollmentKeySub, }, TriggeredBy: username, diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index edce55563..7cd3dd26b 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -5,20 +5,22 @@ import ( "errors" "fmt" "net/http" + "strconv" "time" "github.com/go-playground/validator/v10" "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/schema" + "golang.org/x/exp/slog" "github.com/gravitl/netmaker/auth" + dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/servercfg" - "golang.org/x/exp/slog" ) func enrollmentKeyHandlers(r *mux.Router) { @@ -26,6 +28,8 @@ func enrollmentKeyHandlers(r *mux.Router) { Methods(http.MethodPost) r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))). Methods(http.MethodGet) + r.HandleFunc("/api/v2/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(listEnrollmentKeys))). + Methods(http.MethodGet) r.HandleFunc("/api/v1/enrollment-keys/network/{network}/default", logic.SecurityCheck(true, http.HandlerFunc(getDefaultEnrollmentKeyForNetwork))). Methods(http.MethodGet) r.HandleFunc("/api/v1/enrollment-keys/{keyID}/regenerate-token", logic.SecurityCheck(true, http.HandlerFunc(regenerateEnrollmentKeyToken))). @@ -38,35 +42,96 @@ func enrollmentKeyHandlers(r *mux.Router) { Methods(http.MethodPut) } -// @Summary Lists all EnrollmentKeys for admins +// @Summary Lists all EnrollmentKeys // @Router /api/v1/enrollment-keys [get] // @Tags EnrollmentKeys // @Security oauth // @Produce json -// @Success 200 {array} models.EnrollmentKey +// @Success 200 {array} schema.EnrollmentKey // @Failure 500 {object} models.ErrorResponse func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { - keys, err := logic.GetAllEnrollmentKeys() + keys, err := (&schema.EnrollmentKey{}).ListAll(r.Context()) if err != nil { - logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys: ", err.Error()) + logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - ret := []*models.EnrollmentKey{} - for _, key := range keys { - key := key - if err = logic.Tokenize(&key, servercfg.GetAPIHost()); err != nil { - logger.Log(0, r.Header.Get("user"), "failed to get token values for keys:", err.Error()) + for i := range keys { + if err = logic.Tokenize(r.Context(), &keys[i], servercfg.GetAPIHost()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to tokenize enrollment key:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - ret = append(ret, &key) } - // return JSON/API formatted keys + logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") - w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(ret) + logic.ReturnSuccessResponseWithJson(w, r, keys, "fetched enrollment keys") +} + +// @Summary Lists EnrollmentKeys (paginated) +// @Router /api/v2/enrollment-keys [get] +// @Tags EnrollmentKeys +// @Security oauth +// @Produce json +// @Param page query int false "Page number (default 1)" +// @Param per_page query int false "Items per page (default 10, max 100)" +// @Param q query string false "Search across name, networks and tags" +// @Success 200 {object} models.PaginatedResponse +// @Failure 500 {object} models.ErrorResponse +func listEnrollmentKeys(w http.ResponseWriter, r *http.Request) { + page, _ := strconv.Atoi(r.URL.Query().Get("page")) + if page < 1 { + page = 1 + } + pageSize, _ := strconv.Atoi(r.URL.Query().Get("per_page")) + if pageSize < 1 || pageSize > 100 { + pageSize = 10 + } + q := r.URL.Query().Get("q") + + var filters, queryOptions []dbtypes.Option + if q != "" { + filters = append(filters, dbtypes.WithSearchQuery(q, "name", "networks", "tags")) + } + queryOptions = append(queryOptions, filters...) + queryOptions = append(queryOptions, dbtypes.InAscOrder("created_at")) + queryOptions = append(queryOptions, dbtypes.WithPagination(page, pageSize)) + + keys, err := (&schema.EnrollmentKey{}).ListAll(r.Context(), queryOptions...) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + total, err := (&schema.EnrollmentKey{}).Count(r.Context(), filters...) + if err != nil { + logger.Log(0, r.Header.Get("user"), "failed to fetch enrollment keys:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + + for i := range keys { + if err = logic.Tokenize(r.Context(), &keys[i], servercfg.GetAPIHost()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to tokenize enrollment key:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) + return + } + } + + totalPages := (total + pageSize - 1) / pageSize + if totalPages == 0 { + totalPages = 1 + } + + logger.Log(2, r.Header.Get("user"), "fetched enrollment keys") + logic.ReturnSuccessResponseWithJson(w, r, models.PaginatedResponse{ + Data: keys, + Page: page, + PerPage: pageSize, + Total: total, + TotalPages: totalPages, + }, "fetched enrollment keys") } // @Summary Get the default enrollment key for a network @@ -75,7 +140,7 @@ func getEnrollmentKeys(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Param network path string true "Network name" // @Produce json -// @Success 200 {object} models.EnrollmentKey +// @Success 200 {object} schema.EnrollmentKey // @Failure 404 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func getDefaultEnrollmentKeyForNetwork(w http.ResponseWriter, r *http.Request) { @@ -91,7 +156,7 @@ func getDefaultEnrollmentKeyForNetwork(w http.ResponseWriter, r *http.Request) { return } - key, err := logic.GetDefaultEnrollmentKeyForNetwork(network) + key, err := logic.GetDefaultEnrollmentKeyForNetwork(r.Context(), network) if err != nil { if errors.Is(err, logic.EnrollmentErrors.NoKeyFound) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) @@ -102,7 +167,7 @@ func getDefaultEnrollmentKeyForNetwork(w http.ResponseWriter, r *http.Request) { return } - if err = logic.Tokenize(&key, servercfg.GetAPIHost()); err != nil { + if err = logic.Tokenize(r.Context(), key, servercfg.GetAPIHost()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to tokenize default enrollment key:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return @@ -110,7 +175,7 @@ func getDefaultEnrollmentKeyForNetwork(w http.ResponseWriter, r *http.Request) { logger.Log(2, r.Header.Get("user"), "fetched default enrollment key for network", network) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(&key) + json.NewEncoder(w).Encode(key) } // @Summary Regenerate an enrollment key token @@ -118,9 +183,9 @@ func getDefaultEnrollmentKeyForNetwork(w http.ResponseWriter, r *http.Request) { // @Router /api/v1/enrollment-keys/{keyID}/regenerate-token [post] // @Tags EnrollmentKeys // @Security oauth -// @Param keyID path string true "Enrollment Key ID" +// @Param keyID path string true "Enrollment Key value" // @Produce json -// @Success 200 {object} models.EnrollmentKey +// @Success 200 {object} schema.EnrollmentKey // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func regenerateEnrollmentKeyToken(w http.ResponseWriter, r *http.Request) { @@ -130,29 +195,25 @@ func regenerateEnrollmentKeyToken(w http.ResponseWriter, r *http.Request) { return } - currKey, err := logic.GetEnrollmentKey(keyID) + currKey, err := logic.GetEnrollmentKey(r.Context(), keyID) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - newKey, err := logic.RegenerateEnrollmentKeyToken(keyID) + newKey, err := logic.RegenerateEnrollmentKeyToken(r.Context(), keyID) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to regenerate enrollment key token:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - if err = logic.Tokenize(newKey, servercfg.GetAPIHost()); err != nil { + if err = logic.Tokenize(r.Context(), newKey, servercfg.GetAPIHost()); err != nil { logger.Log(0, r.Header.Get("user"), "failed to tokenize regenerated enrollment key:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - eventName := "" - if len(newKey.Tags) > 0 { - eventName = newKey.Tags[0] - } logic.LogEvent(&models.Event{ Action: schema.Update, Source: models.Subject{ @@ -163,7 +224,7 @@ func regenerateEnrollmentKeyToken(w http.ResponseWriter, r *http.Request) { TriggeredBy: r.Header.Get("user"), Target: models.Subject{ ID: newKey.Value, - Name: eventName, + Name: enrollmentKeyName(newKey), Type: schema.EnrollmentKeySub, }, Diff: models.Diff{ @@ -182,19 +243,17 @@ func regenerateEnrollmentKeyToken(w http.ResponseWriter, r *http.Request) { // @Router /api/v1/enrollment-keys/{keyID} [delete] // @Tags EnrollmentKeys // @Security oauth -// @Param keyID path string true "Enrollment Key ID" +// @Param keyID path string true "Enrollment Key value" // @Success 200 {string} string // @Failure 500 {object} models.ErrorResponse func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { - params := mux.Vars(r) - keyID := params["keyID"] - key, err := logic.GetEnrollmentKey(keyID) + keyID := mux.Vars(r)["keyID"] + key, err := logic.GetEnrollmentKey(r.Context(), keyID) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - err = logic.DeleteEnrollmentKey(keyID, false) - if err != nil { + if err = logic.DeleteEnrollmentKey(r.Context(), keyID, false); err != nil { logger.Log(0, r.Header.Get("user"), "failed to remove enrollment key: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return @@ -209,7 +268,7 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { TriggeredBy: r.Header.Get("user"), Target: models.Subject{ ID: keyID, - Name: key.Tags[0], + Name: enrollmentKeyName(key), Type: schema.EnrollmentKeySub, }, Origin: schema.Dashboard, @@ -229,87 +288,78 @@ func deleteEnrollmentKey(w http.ResponseWriter, r *http.Request) { // @Accept json // @Produce json // @Param body body models.APIEnrollmentKey true "Enrollment Key parameters" -// @Success 200 {object} models.EnrollmentKey +// @Success 200 {object} schema.EnrollmentKey // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { - var enrollmentKeyBody models.APIEnrollmentKey - - err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody) - if err != nil { - logger.Log(0, r.Header.Get("user"), "error decoding request body: ", - err.Error()) + var req models.APIEnrollmentKey + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + logger.Log(0, r.Header.Get("user"), "error decoding request body:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + var newTime time.Time - if enrollmentKeyBody.Expiration > 0 { - newTime = time.Unix(enrollmentKeyBody.Expiration, 0) + if req.Expiration > 0 { + newTime = time.Unix(req.Expiration, 0) } + v := validator.New() - err = v.Struct(enrollmentKeyBody) - if err != nil { - logger.Log(0, r.Header.Get("user"), "error validating request body: ", - err.Error()) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError( - fmt.Errorf("validation error: name length must be between 3 and 32: %w", err), - "badrequest", - ), - ) + if err := v.Struct(req); err != nil { + logger.Log(0, r.Header.Get("user"), "error validating request body:", err.Error()) + logic.ReturnErrorResponse(w, r, logic.FormatError( + fmt.Errorf("validation error: name length must be between 3 and 32: %w", err), "badrequest")) return } - if existingKeys, err := logic.GetAllEnrollmentKeys(); err != nil { - logger.Log(0, r.Header.Get("user"), "error validating request body: ", - err.Error()) + existingKeys, err := logic.GetAllEnrollmentKeys(r.Context()) + if err != nil { + logger.Log(0, r.Header.Get("user"), "error fetching enrollment keys:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return - } else { - // check if any tags are duplicate - existingTags := make(map[string]struct{}) - for _, existingKey := range existingKeys { - for _, t := range existingKey.Tags { - existingTags[t] = struct{}{} - } + } + // check if any network names are duplicate across existing keys + existingNetworks := make(map[string]struct{}) + for _, existingKey := range existingKeys { + for _, n := range existingKey.Networks { + existingNetworks[n] = struct{}{} } - for _, t := range enrollmentKeyBody.Tags { - if _, ok := existingTags[t]; ok { - logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("key names must be unique"), "badrequest")) - return - } + } + for _, t := range req.Tags { + if _, ok := existingNetworks[t]; ok { + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("key names must be unique"), "badrequest")) + return } } - if enrollmentKeyBody.Default && len(enrollmentKeyBody.Networks) == 0 { + if req.Default && len(req.Networks) == 0 { logic.ReturnErrorResponse(w, r, logic.FormatError( errors.New("default enrollment keys require at least one network or tag"), "badrequest")) return } relayId := uuid.Nil - if enrollmentKeyBody.Relay != "" { - relayId, err = uuid.Parse(enrollmentKeyBody.Relay) + if req.Relay != "" { + relayId, err = uuid.Parse(req.Relay) if err != nil { - logger.Log(0, r.Header.Get("user"), "error parsing relay id: ", err.Error()) + logger.Log(0, r.Header.Get("user"), "error parsing relay id:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } } - newEnrollmentKey, err := logic.CreateEnrollmentKey( - enrollmentKeyBody.UsesRemaining, + newKey, err := logic.CreateEnrollmentKey( + r.Context(), + req.UsesRemaining, newTime, - enrollmentKeyBody.Networks, - enrollmentKeyBody.Tags, - enrollmentKeyBody.Groups, - enrollmentKeyBody.Unlimited, + req.Networks, + req.Tags, + req.Groups, + req.Unlimited, relayId, - enrollmentKeyBody.Default, - enrollmentKeyBody.AutoEgress, - enrollmentKeyBody.AutoAssignGateway, + req.Default, + req.AutoEgress, + req.AutoAssignGateway, ) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) @@ -317,11 +367,12 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { return } - if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil { - logger.Log(0, r.Header.Get("user"), "failed to create enrollment key:", err.Error()) + if err = logic.Tokenize(r.Context(), newKey, servercfg.GetAPIHost()); err != nil { + logger.Log(0, r.Header.Get("user"), "failed to tokenize enrollment key:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + logic.LogEvent(&models.Event{ Action: schema.Create, Source: models.Subject{ @@ -331,15 +382,15 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { }, TriggeredBy: r.Header.Get("user"), Target: models.Subject{ - ID: newEnrollmentKey.Value, - Name: newEnrollmentKey.Tags[0], + ID: newKey.Value, + Name: enrollmentKeyName(newKey), Type: schema.EnrollmentKeySub, }, Origin: schema.Dashboard, }) logger.Log(2, r.Header.Get("user"), "created enrollment key") w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(newEnrollmentKey) + json.NewEncoder(w).Encode(newKey) } // @Summary Updates an EnrollmentKey @@ -348,45 +399,44 @@ func createEnrollmentKey(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Accept json // @Produce json -// @Param keyID path string true "Enrollment Key ID" +// @Param keyID path string true "Enrollment Key value" // @Param body body models.APIEnrollmentKey true "Enrollment Key parameters" -// @Success 200 {object} models.EnrollmentKey +// @Success 200 {object} schema.EnrollmentKey // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) { - var enrollmentKeyBody models.APIEnrollmentKey - params := mux.Vars(r) - keyId := params["keyID"] + var req models.APIEnrollmentKey + keyId := mux.Vars(r)["keyID"] - err := json.NewDecoder(r.Body).Decode(&enrollmentKeyBody) - if err != nil { + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { slog.Error("error decoding request body", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - if enrollmentKeyBody.Relay != "" { - _, err = uuid.Parse(enrollmentKeyBody.Relay) - if err != nil { + if req.Relay != "" { + if _, err := uuid.Parse(req.Relay); err != nil { slog.Error("error parsing relay id", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } } - currKey, _ := logic.GetEnrollmentKey(keyId) - newEnrollmentKey, err := logic.UpdateEnrollmentKey(keyId, &enrollmentKeyBody) + currKey, _ := logic.GetEnrollmentKey(r.Context(), keyId) + + newKey, err := logic.UpdateEnrollmentKey(r.Context(), keyId, &req) if err != nil { slog.Error("failed to update enrollment key", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - if err = logic.Tokenize(newEnrollmentKey, servercfg.GetAPIHost()); err != nil { - slog.Error("failed to update enrollment key", "error", err) + if err = logic.Tokenize(r.Context(), newKey, servercfg.GetAPIHost()); err != nil { + slog.Error("failed to tokenize enrollment key", "error", err) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } + logic.LogEvent(&models.Event{ Action: schema.Update, Source: models.Subject{ @@ -396,19 +446,19 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) { }, TriggeredBy: r.Header.Get("user"), Target: models.Subject{ - ID: newEnrollmentKey.Value, - Name: newEnrollmentKey.Tags[0], + ID: newKey.Value, + Name: enrollmentKeyName(newKey), Type: schema.EnrollmentKeySub, }, Diff: models.Diff{ Old: currKey, - New: newEnrollmentKey, + New: newKey, }, Origin: schema.Dashboard, }) slog.Info("updated enrollment key", "id", keyId) w.WriteHeader(http.StatusOK) - json.NewEncoder(w).Encode(newEnrollmentKey) + json.NewEncoder(w).Encode(newKey) } // @Summary Handles a Netclient registration with server and add nodes accordingly @@ -422,74 +472,58 @@ func updateEnrollmentKey(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func handleHostRegister(w http.ResponseWriter, r *http.Request) { - params := mux.Vars(r) - token := params["token"] + token := mux.Vars(r)["token"] logger.Log(0, "received registration attempt with token", token) - // check if token exists - enrollmentKey, err := logic.DeTokenize(token) + + enrollmentKey, err := logic.DeTokenize(r.Context(), token) if err != nil { logger.Log(0, "invalid enrollment key used", token, err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - // get the host + var newHost schema.Host if err = json.NewDecoder(r.Body).Decode(&newHost); err != nil { - logger.Log(0, r.Header.Get("user"), "error decoding request body: ", - err.Error()) + logger.Log(0, r.Header.Get("user"), "error decoding request body:", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } - // check if host already exists + hostExists := false if hostExists = logic.HostExists(&newHost); hostExists && len(enrollmentKey.Networks) == 0 { - logger.Log( - 0, - "host", - newHost.ID.String(), - newHost.Name, - "attempted to re-register with no networks", - ) - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(fmt.Errorf("host already exists"), "badrequest"), - ) + logger.Log(0, "host", newHost.ID.String(), newHost.Name, "attempted to re-register with no networks") + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("host already exists"), "badrequest")) return } - // version check + if !logic.IsVersionCompatible(newHost.Version) { - err := fmt.Errorf("bad client version on register: %s", newHost.Version) - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError( + fmt.Errorf("bad client version on register: %s", newHost.Version), "badrequest")) return } if newHost.TrafficKeyPublic == nil && newHost.OS != models.OS_Types.IoT { - err := fmt.Errorf("missing traffic key") - logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("missing traffic key"), "badrequest")) return } + trafficKey, keyErr := logic.RetrievePublicTrafficKey() if keyErr != nil { logger.Log(0, "error retrieving key:", keyErr.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(keyErr, "internal")) return } - // use the token - if ok := logic.TryToUseEnrollmentKey(enrollmentKey); !ok { + + if ok := logic.TryToUseEnrollmentKey(r.Context(), enrollmentKey); !ok { logger.Log(0, "host", newHost.ID.String(), newHost.Name, "failed registration") - logic.ReturnErrorResponse( - w, - r, - logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest"), - ) + logic.ReturnErrorResponse(w, r, logic.FormatError(fmt.Errorf("invalid enrollment key"), "badrequest")) return } + keyTags := make(map[models.TagID]struct{}) - if len(enrollmentKey.Groups) > 0 { - for _, tagI := range enrollmentKey.Groups { - keyTags[tagI] = struct{}{} - } + for _, tagI := range enrollmentKey.Tags { + keyTags[models.TagID(tagI)] = struct{}{} } + var joinNetworks []string for _, netI := range enrollmentKey.Networks { violations, _ := logic.CheckPostureViolations(models.PostureCheckDeviceInfo{ @@ -508,37 +542,32 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { } if len(joinNetworks) != len(enrollmentKey.Networks) && len(joinNetworks) == 0 { logic.ReturnErrorResponse(w, r, - logic.FormatError(errors.New("access blocked: this device doesn’t meet security requirements"), logic.Forbidden)) + logic.FormatError(errors.New("access blocked: this device doesn't meet security requirements"), logic.Forbidden)) return } - // copying the enrollment key so that edits don't end up in the enrollment key cache. + + // copy key so network edits don't mutate the stored key key := *enrollmentKey - // need to remove the networks that were skipped from the enrollment key key.Networks = joinNetworks + var host *schema.Host if !hostExists { newHost.PersistentKeepalive = models.DefaultPersistentKeepAlive - // register host _ = logic.CheckHostPorts(&newHost) - // create EMQX credentials and ACLs for host if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { if err := mq.GetEmqxHandler().CreateEmqxUser(newHost.ID.String(), newHost.HostPass); err != nil { logger.Log(0, "failed to create host credentials for EMQX: ", err.Error()) return } } - if err = logic.CreateHost(&newHost); err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } host = &newHost } else { - currHost := &schema.Host{ - ID: newHost.ID, - } - err := currHost.Get(r.Context()) - if err != nil { + currHost := &schema.Host{ID: newHost.ID} + if err = currHost.Get(r.Context()); err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return } @@ -552,7 +581,7 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { } host = currHost } - // ready the response + server := logic.GetServerInfo() server.TrafficKey = trafficKey response := models.RegisterResponse{ @@ -563,6 +592,16 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { logger.Log(0, host.Name, host.ID.String(), "registered with Netmaker") w.WriteHeader(http.StatusOK) json.NewEncoder(w).Encode(&response) - // notify host of changes, peer and node updates go auth.CheckNetRegAndHostUpdate(key, host, r.Header.Get("user")) } + +// enrollmentKeyName returns a human-readable label for audit events. +func enrollmentKeyName(key *schema.EnrollmentKey) string { + if key != nil && key.Name != "" { + return key.Name + } + if key != nil { + return key.Value + } + return "" +} diff --git a/controllers/hosts.go b/controllers/hosts.go index 6cace29cc..52cc0bffd 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -1716,7 +1716,7 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) { }) return } - key := models.EnrollmentKey{} + key := schema.EnrollmentKey{} json.Unmarshal(p.EnrollmentKey, &key) network := &schema.Network{ @@ -1738,10 +1738,8 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) { } keyTags := make(map[models.TagID]struct{}) - if len(key.Groups) > 0 { - for _, tagI := range key.Groups { - keyTags[tagI] = struct{}{} - } + for _, tagI := range key.Tags { + keyTags[models.TagID(tagI)] = struct{}{} } violations, _ := logic.CheckPostureViolations( diff --git a/controllers/network.go b/controllers/network.go index 090b5ed53..e07c60588 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -382,6 +382,7 @@ func createNetwork(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } + logic.CreateDefaultNetworkEnrollmentKey(network.Name) logic.CreateDefaultNetworkRolesAndGroups(schema.NetworkID(network.Name)) logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name)) logic.CreateDefaultTags(schema.NetworkID(network.Name)) diff --git a/controllers/node_test.go b/controllers/node_test.go index b251cee19..7a502df02 100644 --- a/controllers/node_test.go +++ b/controllers/node_test.go @@ -1,11 +1,12 @@ package controller import ( + "context" "net" "testing" "github.com/google/uuid" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" @@ -49,7 +50,10 @@ func TestValidateEgressGateway(t *testing.T) { } func deleteAllNodes() { - database.DeleteAllRecords(database.NODES_TABLE_NAME) + nodes, _ := (&schema.Node{}).ListAll(db.WithContext(context.TODO())) + for _, node := range nodes { + _ = node.Delete(db.WithContext(context.TODO())) + } } func createTestNode() *models.Node { diff --git a/database/database.go b/database/database.go index a448c7368..89a5ac16b 100644 --- a/database/database.go +++ b/database/database.go @@ -10,48 +10,20 @@ import ( const ( // == Table Names == - // NETWORKS_TABLE_NAME - networks table - NETWORKS_TABLE_NAME = "networks" - // NODES_TABLE_NAME - nodes table - NODES_TABLE_NAME = "nodes" - // USERS_TABLE_NAME - users table - USERS_TABLE_NAME = "users" - // USER_PERMISSIONS_TABLE_NAME - user permissions table - USER_PERMISSIONS_TABLE_NAME = "user_permissions" // DNS_TABLE_NAME - dns table DNS_TABLE_NAME = "dns" // EXT_CLIENT_TABLE_NAME - ext client table EXT_CLIENT_TABLE_NAME = "extclients" - // SERVERCONF_TABLE_NAME - stores server conf - SERVERCONF_TABLE_NAME = "serverconf" - // SERVER_UUID_TABLE_NAME - stores unique netmaker server data - SERVER_UUID_TABLE_NAME = "serveruuid" - // SERVER_UUID_RECORD_KEY - telemetry thing - SERVER_UUID_RECORD_KEY = "serveruuid" - // DATABASE_FILENAME - database file name - DATABASE_FILENAME = "netmaker.db" - // GENERATED_TABLE_NAME - stores server generated k/v - GENERATED_TABLE_NAME = "generated" // ACLS_TABLE_NAME - table for acls v2 ACLS_TABLE_NAME = "acls" // SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins SSO_STATE_CACHE = "ssostatecache" // METRICS_TABLE_NAME - stores network metrics METRICS_TABLE_NAME = "metrics" - // USER_GROUPS_TABLE_NAME - table for storing usergroups - USER_GROUPS_TABLE_NAME = "usergroups" // CACHE_TABLE_NAME - caching table CACHE_TABLE_NAME = "cache" - // HOSTS_TABLE_NAME - the table name for hosts - HOSTS_TABLE_NAME = "hosts" - // ENROLLMENT_KEYS_TABLE_NAME - table name for enrollmentkeys - ENROLLMENT_KEYS_TABLE_NAME = "enrollmentkeys" // HOST_ACTIONS_TABLE_NAME - table name for enrollmentkeys HOST_ACTIONS_TABLE_NAME = "hostactions" - // PENDING_USERS_TABLE_NAME - table name for pending users - PENDING_USERS_TABLE_NAME = "pending_users" - // USER_INVITES - table for user invites - USER_INVITES_TABLE_NAME = "user_invites" // TAG_TABLE_NAME - table for tags TAG_TABLE_NAME = "tags" // SERVER_SETTINGS - table for server settings @@ -86,13 +58,9 @@ const ( var Tables = []string{ DNS_TABLE_NAME, EXT_CLIENT_TABLE_NAME, - SERVERCONF_TABLE_NAME, - SERVER_UUID_TABLE_NAME, - GENERATED_TABLE_NAME, SSO_STATE_CACHE, METRICS_TABLE_NAME, CACHE_TABLE_NAME, - ENROLLMENT_KEYS_TABLE_NAME, HOST_ACTIONS_TABLE_NAME, TAG_TABLE_NAME, ACLS_TABLE_NAME, diff --git a/logic/auth.go b/logic/auth.go index 4c872e396..8a094a37e 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -23,10 +23,6 @@ import ( "github.com/gravitl/netmaker/models" ) -const ( - auth_key = "netmaker_auth" -) - const ( DashboardApp = "dashboard" NetclientApp = "netclient" @@ -58,7 +54,7 @@ func GetUsers() ([]models.ReturnUser, error) { // IsOauthUser - returns func IsOauthUser(user *schema.User) error { - var currentValue, err = FetchPassValue("") + var currentValue, err = FetchOAuthSecret() if err != nil { return err } @@ -66,29 +62,6 @@ func IsOauthUser(user *schema.User) error { return bCryptErr } -func FetchPassValue(newValue string) (string, error) { - - type valueHolder struct { - Value string `json:"value" bson:"value"` - } - newValueHolder := valueHolder{} - var currentValue, err = FetchAuthSecret() - if err != nil { - return "", err - } - var unmarshErr = json.Unmarshal([]byte(currentValue), &newValueHolder) - if unmarshErr != nil { - return "", unmarshErr - } - - var b64CurrentValue, b64Err = base64.StdEncoding.DecodeString(newValueHolder.Value) - if b64Err != nil { - logger.Log(0, "could not decode pass") - return "", nil - } - return string(b64CurrentValue), nil -} - // CreateUser - creates a user func CreateUser(_user *schema.User) error { // check if user exists @@ -483,33 +456,39 @@ func DeleteUser(user string) error { return (&schema.UserAccessToken{UserName: user}).DeleteAllUserTokens(db.WithContext(context.TODO())) } -func SetAuthSecret(secret string) error { - type valueHolder struct { - Value string `json:"value" bson:"value"` +func SetOAuthSecret(secret string) error { + oauthSecret := &schema.Internal{ + Key: schema.InternalKey_OAuthSecret, } - record, err := FetchAuthSecret() - if err == nil { - v := valueHolder{} - json.Unmarshal([]byte(record), &v) - if v.Value != "" { - return nil - } + err := oauthSecret.Get(db.WithContext(context.TODO())) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err } - var b64NewValue = base64.StdEncoding.EncodeToString([]byte(secret)) - newValueHolder := valueHolder{ - Value: b64NewValue, + + if oauthSecret.Value != "" { + return nil } - d, _ := json.Marshal(newValueHolder) - return database.Insert(auth_key, string(d), database.GENERATED_TABLE_NAME) + + oauthSecret.Value = base64.StdEncoding.EncodeToString([]byte(secret)) + return oauthSecret.Set(db.WithContext(context.TODO())) } -// FetchAuthSecret - manages secrets for oauth -func FetchAuthSecret() (string, error) { - var record, err = database.FetchRecord(database.GENERATED_TABLE_NAME, auth_key) +// FetchOAuthSecret fetches secrets for oauth +func FetchOAuthSecret() (string, error) { + oauthSecret := &schema.Internal{ + Key: schema.InternalKey_OAuthSecret, + } + err := oauthSecret.Get(db.WithContext(context.TODO())) + if err != nil { + return "", err + } + + oauthSecretValue, err := base64.StdEncoding.DecodeString(oauthSecret.Value) if err != nil { return "", err } - return record, nil + + return string(oauthSecretValue), nil } // GetState - gets an SsoState from DB, if expired returns error diff --git a/logic/egress.go b/logic/egress.go index 8248fe016..8b0d7c1bf 100644 --- a/logic/egress.go +++ b/logic/egress.go @@ -9,7 +9,6 @@ import ( "slices" "strings" - "github.com/google/uuid" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" @@ -633,13 +632,10 @@ func RemoveNodeFromEgress(node models.Node) { } func RemoveNodeFromEnrollmentKeys(node *models.Node) { - keys, _ := GetAllEnrollmentKeys() - for _, key := range keys { - if key.Relay == node.ID { - key.Relay = uuid.Nil - _ = upsertEnrollmentKey(&key) - } + _node := &schema.Node{ + ID: node.ID.String(), } + _ = _node.ClearGatewayIDFromEnrollmentKeys(db.WithContext(context.TODO())) } func GetEgressRanges(netID schema.NetworkID) (map[string][]string, map[string]struct{}, error) { diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index fcdcbc7ac..b4ce9fe7e 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -7,14 +7,17 @@ import ( "fmt" "sort" "strings" - "sync" "time" "github.com/google/uuid" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/servercfg" + "github.com/gravitl/netmaker/schema" "golang.org/x/exp/slices" + "gorm.io/datatypes" + "gorm.io/gorm" + + "context" ) // EnrollmentErrors - struct for holding EnrollmentKey error messages @@ -33,117 +36,168 @@ var EnrollmentErrors = struct { FailedToTokenize: fmt.Errorf("failed to tokenize"), FailedToDeTokenize: fmt.Errorf("failed to detokenize"), } -var ( - enrollmentkeyCacheMutex = &sync.RWMutex{} - enrollmentkeyCacheMap = make(map[string]models.EnrollmentKey) -) // CreateEnrollmentKey - creates a new enrollment key in db -func CreateEnrollmentKey(uses int, expiration time.Time, networks, +func CreateEnrollmentKey(ctx context.Context, uses int, expiration time.Time, networks, tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, - defaultKey, autoEgress, autoAssignGw bool) (*models.EnrollmentKey, error) { - newKeyID, err := getUniqueEnrollmentID() + defaultKey, autoEgress, autoAssignGw bool) (*schema.EnrollmentKey, error) { + + newKeyID, err := getUniqueEnrollmentID(ctx) if err != nil { return nil, err } - k := &models.EnrollmentKey{ - Value: newKeyID, - Expiration: time.Time{}, - UsesRemaining: 0, - Unlimited: unlimited, - Networks: []string{}, - Tags: []string{}, - Type: models.Undefined, - Relay: relay, - Groups: groups, - Default: defaultKey, - AutoEgress: autoEgress, - AutoAssignGateway: autoAssignGw, - } + + var keyType schema.EnrollmentKeyType + var exp time.Time + var usesRemaining int + if uses > 0 { - k.UsesRemaining = uses - k.Type = models.Uses + usesRemaining = uses + keyType = schema.EnrollmentKeyType_LimitedUses } else if !expiration.IsZero() { - k.Expiration = expiration - k.Type = models.TimeExpiration - } else if k.Unlimited { - k.Type = models.Unlimited + exp = expiration + keyType = schema.EnrollmentKeyType_TimedExpiry + } else if unlimited { + keyType = schema.EnrollmentKeyType_UnlimitedUses + } + + // merge networks and tags (tags are also network names) + networksSet := make(map[string]struct{}, len(networks)+len(tags)) + for _, n := range networks { + networksSet[n] = struct{}{} + } + for _, t := range tags { + networksSet[t] = struct{}{} } - if len(networks) > 0 { - k.Networks = networks + mergedNetworks := make(datatypes.JSONSlice[string], 0, len(networksSet)) + for n := range networksSet { + mergedNetworks = append(mergedNetworks, n) + } + + keyTags := make(datatypes.JSONSlice[string], 0, len(groups)) + for _, g := range groups { + keyTags = append(keyTags, g.String()) } + + var relayPtr *string + if relay != uuid.Nil { + s := relay.String() + relayPtr = &s + } + + // tags[0] is the enrollment key display name + name := "" if len(tags) > 0 { - k.Tags = tags + name = tags[0] } - if err := k.Validate(); err != nil { - return nil, err + + k := &schema.EnrollmentKey{ + ID: uuid.NewString(), + Name: name, + Value: newKeyID, + Expiration: exp, + UsesRemaining: usesRemaining, + Unlimited: unlimited, + Networks: mergedNetworks, + Tags: keyTags, + Type: keyType, + GatewayID: relayPtr, + Default: defaultKey, + AutoEgress: autoEgress, + AutoAssignGateway: autoAssignGw, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if !enrollmentKeyIsValid(k) { + return nil, fmt.Errorf("%w: uses remaining: %d, expiration: %s, unlimited: %t", + models.ErrInvalidEnrollmentKey, k.UsesRemaining, k.Expiration, k.Unlimited) } + if relay != uuid.Nil { relayNode, err := GetNodeByID(relay.String()) if err != nil { return nil, err } - if !slices.Contains(k.Networks, relayNode.Network) { + if !slices.Contains([]string(k.Networks), relayNode.Network) { return nil, errors.New("relay node not in key's networks") } if !relayNode.IsRelay { return nil, errors.New("relay node is not a relay") } } + if defaultKey { - if err := clearDefaultEnrollmentKeysForNetworks(networksForDefaultEnrollmentKey(k.Networks, k.Tags), ""); err != nil { + if err := clearDefaultEnrollmentKeysForNetworks(ctx, networksForDefaultEnrollmentKey(k.Networks), ""); err != nil { return nil, err } } - if err = upsertEnrollmentKey(k); err != nil { + + if err = k.Create(ctx); err != nil { return nil, err } return k, nil } -// RegenerateEnrollmentKeyToken replaces the enrollment key value, invalidating any -// previously issued registration tokens while preserving key configuration. -func RegenerateEnrollmentKeyToken(keyID string) (*models.EnrollmentKey, error) { - key, err := GetEnrollmentKey(keyID) +// CreateDefaultNetworkEnrollmentKey creates an unlimited default enrollment key for a network. +func CreateDefaultNetworkEnrollmentKey(networkName string) (*schema.EnrollmentKey, error) { + ctx := db.WithContext(context.TODO()) + value, err := getUniqueEnrollmentID(ctx) if err != nil { return nil, err } - oldValue := key.Value - newValue, err := getUniqueEnrollmentID() + key := &schema.EnrollmentKey{ + ID: uuid.NewString(), + Name: networkName, + Value: value, + Token: "", + Default: true, + Unlimited: true, + Networks: []string{networkName}, + Type: schema.EnrollmentKeyType_UnlimitedUses, + } + err = key.Create(ctx) if err != nil { return nil, err } - key.Value = newValue - key.Token = "" - if err := key.Validate(); err != nil { + return key, nil +} + +// RegenerateEnrollmentKeyToken replaces the enrollment key value, invalidating any +// previously issued registration tokens while preserving key configuration. +func RegenerateEnrollmentKeyToken(ctx context.Context, keyValue string) (*schema.EnrollmentKey, error) { + key := &schema.EnrollmentKey{Value: keyValue} + if err := key.GetByValue(ctx); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, EnrollmentErrors.NoKeyFound + } return nil, err } - if err := upsertEnrollmentKey(&key); err != nil { + newValue, err := getUniqueEnrollmentID(ctx) + if err != nil { return nil, err } - if err := database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, oldValue); err != nil { - if !database.IsEmptyRecord(err) { - // best-effort rollback: remove the newly written key to avoid duplicates - _ = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, key.Value) - if servercfg.CacheEnabled() { - deleteEnrollmentkeyFromCache(key.Value) - } - return nil, err - } - } - if servercfg.CacheEnabled() { - deleteEnrollmentkeyFromCache(oldValue) + + key.Value = newValue + key.Token = "" + key.UpdatedAt = time.Now() + + if err := key.Upsert(ctx); err != nil { + return nil, err } - return &key, nil + return key, nil } -// UpdateEnrollmentKey - updates an existing enrollment key's associated relay -func UpdateEnrollmentKey(keyId string, updates *models.APIEnrollmentKey) (*models.EnrollmentKey, error) { - key, err := GetEnrollmentKey(keyId) - if err != nil { +// UpdateEnrollmentKey - updates an existing enrollment key's relay and groups +func UpdateEnrollmentKey(ctx context.Context, keyValue string, updates *models.APIEnrollmentKey) (*schema.EnrollmentKey, error) { + key := &schema.EnrollmentKey{Value: keyValue} + if err := key.GetByValue(ctx); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, EnrollmentErrors.NoKeyFound + } return nil, err } @@ -163,174 +217,99 @@ func UpdateEnrollmentKey(keyId string, updates *models.APIEnrollmentKey) (*model if !relayNode.IsRelay { return nil, errors.New("relay node is not a relay") } + updates.AutoAssignGateway = false } if relayID != uuid.Nil { - updates.AutoAssignGateway = false + s := relayID.String() + key.GatewayID = &s + } else { + key.GatewayID = nil } - key.Relay = relayID - key.Groups = updates.Groups + keyTags := make(datatypes.JSONSlice[string], 0, len(updates.Groups)) + for _, g := range updates.Groups { + keyTags = append(keyTags, g.String()) + } + key.Tags = keyTags key.AutoAssignGateway = updates.AutoAssignGateway + if !key.Default && updates.Default { - if len(key.Groups) == 0 && len(key.Networks) == 0 { + if len(key.Tags) == 0 && len(key.Networks) == 0 { return nil, errors.New("default enrollment keys require at least one network or tag") } key.Default = true - err = clearDefaultEnrollmentKeysForNetworks(networksForDefaultEnrollmentKey(key.Networks, key.Tags), "") - if err != nil { + if err := clearDefaultEnrollmentKeysForNetworks(ctx, networksForDefaultEnrollmentKey(key.Networks), ""); err != nil { return nil, err } } else if key.Default && !updates.Default { key.Default = false } - if err = upsertEnrollmentKey(&key); err != nil { - return nil, err - } - - return &key, nil -} -// GetAllEnrollmentKeys - fetches all enrollment keys from DB -func GetAllEnrollmentKeys() ([]models.EnrollmentKey, error) { - currentKeys, err := getEnrollmentKeysMap() - if err != nil { + key.UpdatedAt = time.Now() + if err := key.Upsert(ctx); err != nil { return nil, err } - var currentKeysList = []models.EnrollmentKey{} - for k := range currentKeys { - currentKeysList = append(currentKeysList, currentKeys[k]) - } - return currentKeysList, nil + return key, nil } -func enrollmentKeyAppliesToNetwork(key models.EnrollmentKey, network string) bool { - if slices.Contains(key.Networks, network) { - return true - } - return len(key.Tags) > 0 && key.Tags[0] == network +// GetAllEnrollmentKeys - fetches all enrollment keys from DB +func GetAllEnrollmentKeys(ctx context.Context) ([]schema.EnrollmentKey, error) { + return (&schema.EnrollmentKey{}).ListAll(ctx) } -func networksForDefaultEnrollmentKey(networks, tags []string) []string { - seen := make(map[string]struct{}, len(networks)+1) - out := make([]string, 0, len(networks)+1) - add := func(n string) { - if n == "" { - return - } - if _, ok := seen[n]; ok { - return +// GetEnrollmentKey - fetches a single enrollment key by value +func GetEnrollmentKey(ctx context.Context, value string) (*schema.EnrollmentKey, error) { + key := &schema.EnrollmentKey{Value: value} + if err := key.GetByValue(ctx); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, EnrollmentErrors.NoKeyFound } - seen[n] = struct{}{} - out = append(out, n) - } - for _, n := range networks { - add(n) - } - if len(tags) > 0 { - add(tags[0]) + return nil, err } - return out + return key, nil } -// clearDefaultEnrollmentKeysForNetworks unsets Default on any existing default keys -// for the given networks, except the key identified by exceptValue (if non-empty). -func clearDefaultEnrollmentKeysForNetworks(networks []string, exceptValue string) error { - if len(networks) == 0 { - return nil - } - keys, err := GetAllEnrollmentKeys() - if err != nil { +// DeleteEnrollmentKey - deletes a given enrollment key by value +func DeleteEnrollmentKey(ctx context.Context, value string, force bool) error { + key := &schema.EnrollmentKey{Value: value} + if err := key.GetByValue(ctx); err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return EnrollmentErrors.NoKeyFound + } return err } - networkSet := make(map[string]struct{}, len(networks)) - for _, n := range networks { - networkSet[n] = struct{}{} - } - for i := range keys { - if !keys[i].Default || keys[i].Value == exceptValue { - continue - } - applies := false - for network := range networkSet { - if enrollmentKeyAppliesToNetwork(keys[i], network) { - applies = true - break - } - } - if !applies { - continue - } - keys[i].Default = false - if err := upsertEnrollmentKey(&keys[i]); err != nil { - return err - } + if key.Default && !force { + return errors.New("cannot delete default network key") } - return nil + return key.DeleteByValue(ctx) } // GetDefaultEnrollmentKeyForNetwork returns the default enrollment key for a network. -func GetDefaultEnrollmentKeyForNetwork(network string) (models.EnrollmentKey, error) { - keys, err := GetAllEnrollmentKeys() +func GetDefaultEnrollmentKeyForNetwork(ctx context.Context, network string) (*schema.EnrollmentKey, error) { + keys, err := GetAllEnrollmentKeys(ctx) if err != nil { - return models.EnrollmentKey{}, err + return nil, err } sort.Slice(keys, func(i, j int) bool { return keys[i].Value < keys[j].Value }) - for _, key := range keys { - if !key.Default { + for i := range keys { + if !keys[i].Default { continue } - if enrollmentKeyAppliesToNetwork(key, network) { - return key, nil + if enrollmentKeyAppliesToNetwork(keys[i], network) { + return &keys[i], nil } } - return models.EnrollmentKey{}, EnrollmentErrors.NoKeyFound -} - -// GetEnrollmentKey - fetches a single enrollment key -// returns nil and error if not found -func GetEnrollmentKey(value string) (key models.EnrollmentKey, err error) { - currentKeys, err := getEnrollmentKeysMap() - if err != nil { - return key, err - } - if key, ok := currentKeys[value]; ok { - return key, nil - } - return key, EnrollmentErrors.NoKeyFound -} - -func deleteEnrollmentkeyFromCache(key string) { - enrollmentkeyCacheMutex.Lock() - delete(enrollmentkeyCacheMap, key) - enrollmentkeyCacheMutex.Unlock() -} - -// DeleteEnrollmentKey - delete's a given enrollment key by value -func DeleteEnrollmentKey(value string, force bool) error { - key, err := GetEnrollmentKey(value) - if err != nil { - return err - } - if key.Default && !force { - return errors.New("cannot delete default network key") - } - err = database.DeleteRecord(database.ENROLLMENT_KEYS_TABLE_NAME, value) - if err == nil { - if servercfg.CacheEnabled() { - deleteEnrollmentkeyFromCache(value) - } - } - return err + return nil, EnrollmentErrors.NoKeyFound } // TryToUseEnrollmentKey - checks first if key can be decremented // returns true if it is decremented or isvalid -func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { - key, err := decrementEnrollmentKey(k.Value) +func TryToUseEnrollmentKey(ctx context.Context, k *schema.EnrollmentKey) bool { + key, err := decrementEnrollmentKey(ctx, k.Value) if err != nil { if errors.Is(err, EnrollmentErrors.NoUsesRemaining) { - return k.IsValid() + return enrollmentKeyIsValid(k) } } else { k.UsesRemaining = key.UsesRemaining @@ -341,7 +320,7 @@ func TryToUseEnrollmentKey(k *models.EnrollmentKey) bool { // Tokenize - tokenizes an enrollment key to be used via registration // and attaches it to the Token field on the struct -func Tokenize(k *models.EnrollmentKey, serverAddr string) error { +func Tokenize(ctx context.Context, k *schema.EnrollmentKey, serverAddr string) error { if len(serverAddr) == 0 || k == nil { return EnrollmentErrors.FailedToTokenize } @@ -359,7 +338,7 @@ func Tokenize(k *models.EnrollmentKey, serverAddr string) error { // DeTokenize - detokenizes a base64 encoded string // and finds the associated enrollment key -func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { +func DeTokenize(ctx context.Context, b64Token string) (*schema.EnrollmentKey, error) { if len(b64Token) == 0 { return nil, EnrollmentErrors.FailedToDeTokenize } @@ -369,140 +348,44 @@ func DeTokenize(b64Token string) (*models.EnrollmentKey, error) { } var newToken models.EnrollmentToken - err = json.Unmarshal(tokenData, &newToken) - if err != nil { - return nil, err - } - k, err := GetEnrollmentKey(newToken.Value) - if err != nil { + if err = json.Unmarshal(tokenData, &newToken); err != nil { return nil, err } - return &k, nil -} - -// == private == - -// decrementEnrollmentKey - decrements the uses on a key if above 0 remaining -func decrementEnrollmentKey(value string) (*models.EnrollmentKey, error) { - k, err := GetEnrollmentKey(value) - if err != nil { - return nil, err - } - if k.UsesRemaining == 0 { - return nil, EnrollmentErrors.NoUsesRemaining - } - k.UsesRemaining = k.UsesRemaining - 1 - if err = upsertEnrollmentKey(&k); err != nil { - return nil, err - } - - return &k, nil -} - -func upsertEnrollmentKey(k *models.EnrollmentKey) error { - if k == nil { - return EnrollmentErrors.InvalidKey - } - data, err := json.Marshal(k) - if err != nil { - return err - } - err = database.Insert(k.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME) - if err == nil { - if servercfg.CacheEnabled() { - storeEnrollmentkeyInCache(k.Value, *k) - } - } - return nil -} - -func getUniqueEnrollmentID() (string, error) { - currentKeys, err := getEnrollmentKeysMap() - if err != nil { - return "", err - } - newID := RandomString(models.EnrollmentKeyLength) - for _, ok := currentKeys[newID]; ok; { - newID = RandomString(models.EnrollmentKeyLength) - } - return newID, nil -} - -func getEnrollmentkeysFromCache() map[string]models.EnrollmentKey { - return enrollmentkeyCacheMap -} - -func storeEnrollmentkeyInCache(key string, enrollmentkey models.EnrollmentKey) { - enrollmentkeyCacheMutex.Lock() - enrollmentkeyCacheMap[key] = enrollmentkey - enrollmentkeyCacheMutex.Unlock() -} - -func getEnrollmentKeysMap() (map[string]models.EnrollmentKey, error) { - if servercfg.CacheEnabled() { - keys := getEnrollmentkeysFromCache() - if len(keys) != 0 { - return keys, nil - } - } - records, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) - if err != nil { - if !database.IsEmptyRecord(err) { - return nil, err - } - } - if records == nil { - records = make(map[string]string) - } - currentKeys := make(map[string]models.EnrollmentKey, 0) - if len(records) > 0 { - for k := range records { - var currentKey models.EnrollmentKey - if err = json.Unmarshal([]byte(records[k]), ¤tKey); err != nil { - continue - } - currentKeys[k] = currentKey - if servercfg.CacheEnabled() { - storeEnrollmentkeyInCache(currentKey.Value, currentKey) - } - } - } - return currentKeys, nil + return GetEnrollmentKey(ctx, newToken.Value) } func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) { - keys, _ := GetAllEnrollmentKeys() + keys, _ := GetAllEnrollmentKeys(db.WithContext(context.TODO())) for _, key := range keys { - newTags := []models.TagID{} + newTags := datatypes.JSONSlice[string]{} update := false - for _, tagID := range key.Groups { - if tagID == deletedTagID { + for _, tagID := range key.Tags { + if tagID == deletedTagID.String() { update = true continue } newTags = append(newTags, tagID) } if update { - key.Groups = newTags - upsertEnrollmentKey(&key) + key.Tags = newTags + key.UpdatedAt = time.Now() + _ = key.Upsert(db.WithContext(context.TODO())) } - } } func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error { - keys, err := GetAllEnrollmentKeys() + keys, err := GetAllEnrollmentKeys(db.WithContext(context.TODO())) if err != nil { return fmt.Errorf("failed to retrieve keys: %w", err) } var errs []error for _, key := range keys { - newNetworks := []string{} - newTags := []models.TagID{} + newNetworks := datatypes.JSONSlice[string]{} + newTags := datatypes.JSONSlice[string]{} update := false - // Check and update networks for _, net := range key.Networks { if net == network { update = true @@ -511,14 +394,12 @@ func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error { newNetworks = append(newNetworks, net) } - // Check and update tags - for _, tag := range key.Groups { - tagParts := strings.Split(tag.String(), ".") + for _, tag := range key.Tags { + tagParts := strings.Split(tag, ".") if len(tagParts) == 0 { continue } - tagNetwork := tagParts[0] - if tagNetwork == network { + if tagParts[0] == network { update = true continue } @@ -526,15 +407,16 @@ func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error { } if update && len(newNetworks) == 0 && delete { - if err := DeleteEnrollmentKey(key.Value, true); err != nil { + if err := DeleteEnrollmentKey(db.WithContext(context.TODO()), key.Value, true); err != nil { errs = append(errs, fmt.Errorf("failed to delete key %s: %w", key.Value, err)) } continue } if update { key.Networks = newNetworks - key.Groups = newTags - if err := upsertEnrollmentKey(&key); err != nil { + key.Tags = newTags + key.UpdatedAt = time.Now() + if err := key.Upsert(db.WithContext(context.TODO())); err != nil { errs = append(errs, fmt.Errorf("failed to update key %s: %w", key.Value, err)) } } @@ -545,3 +427,104 @@ func UnlinkNetworkAndTagsFromEnrollmentKeys(network string, delete bool) error { } return nil } + +// == private == + +func enrollmentKeyIsValid(k *schema.EnrollmentKey) bool { + if k == nil { + return false + } + if k.UsesRemaining > 0 { + return true + } + if !k.Expiration.IsZero() && time.Now().Before(k.Expiration) { + return true + } + return k.Unlimited +} + +func enrollmentKeyAppliesToNetwork(key schema.EnrollmentKey, network string) bool { + return slices.Contains(key.Networks, network) +} + +func networksForDefaultEnrollmentKey(networks datatypes.JSONSlice[string]) []string { + seen := make(map[string]struct{}, len(networks)) + out := make([]string, 0, len(networks)) + for _, n := range networks { + if n == "" { + continue + } + if _, ok := seen[n]; ok { + continue + } + seen[n] = struct{}{} + out = append(out, n) + } + return out +} + +func clearDefaultEnrollmentKeysForNetworks(ctx context.Context, networks []string, exceptValue string) error { + if len(networks) == 0 { + return nil + } + keys, err := GetAllEnrollmentKeys(ctx) + if err != nil { + return err + } + networkSet := make(map[string]struct{}, len(networks)) + for _, n := range networks { + networkSet[n] = struct{}{} + } + for i := range keys { + if !keys[i].Default || keys[i].Value == exceptValue { + continue + } + applies := false + for network := range networkSet { + if enrollmentKeyAppliesToNetwork(keys[i], network) { + applies = true + break + } + } + if !applies { + continue + } + keys[i].Default = false + keys[i].UpdatedAt = time.Now() + if err := keys[i].Upsert(ctx); err != nil { + return err + } + } + return nil +} + +func decrementEnrollmentKey(ctx context.Context, value string) (*schema.EnrollmentKey, error) { + k, err := GetEnrollmentKey(ctx, value) + if err != nil { + return nil, err + } + if k.UsesRemaining == 0 { + return nil, EnrollmentErrors.NoUsesRemaining + } + k.UsesRemaining-- + k.UpdatedAt = time.Now() + if err = k.Upsert(ctx); err != nil { + return nil, err + } + return k, nil +} + +func getUniqueEnrollmentID(ctx context.Context) (string, error) { + newID := RandomString(models.EnrollmentKeyLength) + for { + key := &schema.EnrollmentKey{Value: newID} + err := key.GetByValue(ctx) + if errors.Is(err, gorm.ErrRecordNotFound) { + return newID, nil + } + if err != nil { + return "", err + } + newID = RandomString(models.EnrollmentKeyLength) + } +} diff --git a/logic/enrollmentkey_test.go b/logic/enrollmentkey_test.go deleted file mode 100644 index 7ace5e472..000000000 --- a/logic/enrollmentkey_test.go +++ /dev/null @@ -1,295 +0,0 @@ -package logic - -import ( - "testing" - "time" - - "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/schema" - - "github.com/google/uuid" - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/models" - "github.com/stretchr/testify/assert" -) - -func TestCreateEnrollmentKey(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - t.Run("Can_Not_Create_Key", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false) - assert.Nil(t, newKey) - assert.NotNil(t, err) - assert.ErrorIs(t, err, models.ErrInvalidEnrollmentKey) - }) - t.Run("Can_Create_Key_Uses", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false) - assert.Nil(t, err) - assert.Equal(t, 1, newKey.UsesRemaining) - assert.True(t, newKey.IsValid()) - }) - t.Run("Can_Create_Key_Time", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Now().Add(time.Minute), nil, nil, nil, false, uuid.Nil, false, false, false) - assert.Nil(t, err) - assert.True(t, newKey.IsValid()) - }) - t.Run("Can_Create_Key_Unlimited", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false, false) - assert.Nil(t, err) - assert.True(t, newKey.IsValid()) - }) - t.Run("Can_Create_Key_WithNetworks", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false) - assert.Nil(t, err) - assert.True(t, newKey.IsValid()) - assert.True(t, len(newKey.Networks) == 2) - }) - t.Run("Can_Create_Key_WithTags", func(t *testing.T) { - newKey, err := CreateEnrollmentKey(0, time.Time{}, nil, []string{"tag1", "tag2"}, nil, true, uuid.Nil, false, false, false) - assert.Nil(t, err) - assert.True(t, newKey.IsValid()) - assert.True(t, len(newKey.Tags) == 2) - }) - - t.Run("Can_Get_List_of_Keys", func(t *testing.T) { - keys, err := GetAllEnrollmentKeys() - assert.Nil(t, err) - assert.True(t, len(keys) > 0) - for i := range keys { - assert.Equal(t, len(keys[i].Value), models.EnrollmentKeyLength) - } - }) - t.Run("Can_Get_Default_Key_For_Network", func(t *testing.T) { - defaultKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet"}, []string{"mynet"}, nil, true, uuid.Nil, true, false, false) - assert.Nil(t, err) - assert.True(t, defaultKey.Default) - - found, err := GetDefaultEnrollmentKeyForNetwork("mynet") - assert.Nil(t, err) - assert.Equal(t, defaultKey.Value, found.Value) - assert.True(t, found.Default) - - _, err = GetDefaultEnrollmentKeyForNetwork("unknown-net") - assert.ErrorIs(t, err, EnrollmentErrors.NoKeyFound) - }) - removeAllEnrollments() -} - -func TestDefaultEnrollmentKeyUniquenessPerNetwork(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - - first, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet"}, []string{"mynet"}, nil, true, uuid.Nil, true, false, false) - assert.Nil(t, err) - assert.True(t, first.Default) - - second, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet"}, []string{"mynet"}, nil, true, uuid.Nil, true, false, false) - assert.Nil(t, err) - assert.True(t, second.Default) - - firstAgain, err := GetEnrollmentKey(first.Value) - assert.Nil(t, err) - assert.False(t, firstAgain.Default) - - found, err := GetDefaultEnrollmentKeyForNetwork("mynet") - assert.Nil(t, err) - assert.Equal(t, second.Value, found.Value) - - removeAllEnrollments() -} - -func TestRegenerate_EnrollmentKeyToken(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - newKey, err := CreateEnrollmentKey(0, time.Time{}, []string{"mynet"}, []string{"mynet"}, nil, true, uuid.Nil, false, false, false) - assert.Nil(t, err) - - regenerated, err := RegenerateEnrollmentKeyToken(newKey.Value) - assert.Nil(t, err) - assert.NotEqual(t, newKey.Value, regenerated.Value) - assert.True(t, regenerated.IsValid()) - assert.Equal(t, newKey.Networks, regenerated.Networks) - assert.Equal(t, newKey.Tags, regenerated.Tags) - - _, err = GetEnrollmentKey(newKey.Value) - assert.Equal(t, err, EnrollmentErrors.NoKeyFound) - - found, err := GetEnrollmentKey(regenerated.Value) - assert.Nil(t, err) - assert.Equal(t, regenerated.Value, found.Value) - - _, err = RegenerateEnrollmentKeyToken("notakey") - assert.Equal(t, err, EnrollmentErrors.NoKeyFound) - removeAllEnrollments() -} - -func TestDelete_EnrollmentKey(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false) - t.Run("Can_Delete_Key", func(t *testing.T) { - assert.True(t, newKey.IsValid()) - err := DeleteEnrollmentKey(newKey.Value, false) - assert.Nil(t, err) - oldKey, err := GetEnrollmentKey(newKey.Value) - assert.Equal(t, oldKey, models.EnrollmentKey{}) - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.NoKeyFound) - }) - t.Run("Can_Not_Delete_Invalid_Key", func(t *testing.T) { - err := DeleteEnrollmentKey("notakey", false) - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.NoKeyFound) - }) - removeAllEnrollments() -} - -func TestDecrement_EnrollmentKey(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false) - t.Run("Check_initial_uses", func(t *testing.T) { - assert.True(t, newKey.IsValid()) - assert.Equal(t, newKey.UsesRemaining, 1) - }) - t.Run("Check can decrement", func(t *testing.T) { - assert.Equal(t, newKey.UsesRemaining, 1) - k, err := decrementEnrollmentKey(newKey.Value) - assert.Nil(t, err) - newKey = k - }) - t.Run("Check can not decrement", func(t *testing.T) { - assert.Equal(t, newKey.UsesRemaining, 0) - _, err := decrementEnrollmentKey(newKey.Value) - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.NoUsesRemaining) - }) - - removeAllEnrollments() -} - -func TestUsability_EnrollmentKey(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - key1, _ := CreateEnrollmentKey(1, time.Time{}, nil, nil, nil, false, uuid.Nil, false, false, false) - key2, _ := CreateEnrollmentKey(0, time.Now().Add(time.Minute<<4), nil, nil, nil, false, uuid.Nil, false, false, false) - key3, _ := CreateEnrollmentKey(0, time.Time{}, nil, nil, nil, true, uuid.Nil, false, false, false) - t.Run("Check if valid use key can be used", func(t *testing.T) { - assert.Equal(t, key1.UsesRemaining, 1) - ok := TryToUseEnrollmentKey(key1) - assert.True(t, ok) - assert.Equal(t, 0, key1.UsesRemaining) - }) - - t.Run("Check if valid time key can be used", func(t *testing.T) { - assert.True(t, !key2.Expiration.IsZero()) - ok := TryToUseEnrollmentKey(key2) - assert.True(t, ok) - }) - - t.Run("Check if valid unlimited key can be used", func(t *testing.T) { - assert.True(t, key3.Unlimited) - ok := TryToUseEnrollmentKey(key3) - assert.True(t, ok) - }) - - t.Run("check invalid key can not be used", func(t *testing.T) { - ok := TryToUseEnrollmentKey(key1) - assert.False(t, ok) - }) -} - -func removeAllEnrollments() { - database.DeleteAllRecords(database.ENROLLMENT_KEYS_TABLE_NAME) -} - -//Test that cheks if it can tokenize -//Test that cheks if it can't tokenize - -func TestTokenize_EnrollmentKeys(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false) - const defaultValue = "MwE5MwE5MwE5MwE5MwE5MwE5MwE5MwE5" - const b64value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" - const serverAddr = "api.myserver.com" - t.Run("Can_Not_Tokenize_Nil_Key", func(t *testing.T) { - err := Tokenize(nil, "ServerAddress") - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) - }) - t.Run("Can_Not_Tokenize_Empty_Server_Address", func(t *testing.T) { - err := Tokenize(newKey, "") - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.FailedToTokenize) - }) - - t.Run("Can_Tokenize", func(t *testing.T) { - err := Tokenize(newKey, serverAddr) - assert.Nil(t, err) - assert.True(t, len(newKey.Token) > 0) - }) - - t.Run("Is_Correct_B64_Token", func(t *testing.T) { - newKey.Value = defaultValue - err := Tokenize(newKey, serverAddr) - assert.Nil(t, err) - assert.Equal(t, newKey.Token, b64value) - }) - removeAllEnrollments() -} - -func TestDeTokenize_EnrollmentKeys(t *testing.T) { - db.InitializeDB(schema.ListModels()...) - defer db.CloseDB() - - database.InitializeDatabase() - defer database.CloseDB() - newKey, _ := CreateEnrollmentKey(0, time.Time{}, []string{"mynet", "skynet"}, nil, nil, true, uuid.Nil, false, false, false) - const b64Value = "eyJzZXJ2ZXIiOiJhcGkubXlzZXJ2ZXIuY29tIiwidmFsdWUiOiJNd0U1TXdFNU13RTVNd0U1TXdFNU13RTVNd0U1TXdFNSJ9" - const serverAddr = "api.myserver.com" - - t.Run("Can_Not_DeTokenize", func(t *testing.T) { - value, err := DeTokenize("") - assert.Nil(t, value) - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.FailedToDeTokenize) - }) - t.Run("Can_Not_Find_Key", func(t *testing.T) { - value, err := DeTokenize(b64Value) - assert.Nil(t, value) - assert.NotNil(t, err) - assert.Equal(t, err, EnrollmentErrors.NoKeyFound) - }) - t.Run("Can_DeTokenize", func(t *testing.T) { - err := Tokenize(newKey, serverAddr) - assert.Nil(t, err) - output, err := DeTokenize(newKey.Token) - assert.Nil(t, err) - assert.NotNil(t, output) - assert.Equal(t, newKey.Value, output.Value) - }) - - removeAllEnrollments() -} diff --git a/logic/jwts.go b/logic/jwts.go index db8732070..b0ad2dc17 100644 --- a/logic/jwts.go +++ b/logic/jwts.go @@ -23,7 +23,7 @@ var jwtSecretKey []byte // SetJWTSecret - sets the jwt secret on server startup func SetJWTSecret() { - currentSecret, jwtErr := FetchJWTSecret() + currentSecret, jwtErr := GetJwtSecretValue() if jwtErr != nil { newValue := RandomString(64) jwtSecretKey = []byte(newValue) // 512 bit random password diff --git a/logic/networks.go b/logic/networks.go index 20aec4d8c..c5086294b 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -12,7 +12,6 @@ import ( "strings" "time" - "github.com/google/uuid" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" @@ -24,10 +23,11 @@ import ( func DeleteNetwork(network string, force bool, done chan struct{}) error { defer func() { // Delete default network enrollment key - keys, _ := GetAllEnrollmentKeys() + ctx := db.WithContext(context.TODO()) + keys, _ := GetAllEnrollmentKeys(ctx) for _, key := range keys { - if key.Default && len(key.Tags) > 0 && key.Tags[0] == network { - _ = DeleteEnrollmentKey(key.Value, true) + if key.Default && enrollmentKeyAppliesToNetwork(key, network) { + _ = DeleteEnrollmentKey(ctx, key.Value, true) break } } @@ -311,25 +311,7 @@ func CreateNetwork(_network *schema.Network) error { return err } - err = _network.Create(db.WithContext(context.TODO())) - if err != nil { - return err - } - - _, _ = CreateEnrollmentKey( - 0, - time.Time{}, - []string{_network.Name}, - []string{_network.Name}, - []models.TagID{}, - true, - uuid.Nil, - true, - false, - false, - ) - - return nil + return _network.Create(db.WithContext(context.TODO())) } func GetNetworkNetworkCIDR4(network *schema.Network) *net.IPNet { diff --git a/logic/serverconf.go b/logic/serverconf.go index 23c4aa6fc..4dbbe119e 100644 --- a/logic/serverconf.go +++ b/logic/serverconf.go @@ -1,51 +1,30 @@ package logic import ( - "encoding/json" - "time" + "context" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/schema" ) -var ( - FreeTier = false - // DefaultTrialEndDate - is a placeholder date for not applicable trial end dates - DefaultTrialEndDate, _ = time.Parse("2006-Jan-02", "2021-Apr-01") - - GetTrialEndDate = func() (time.Time, error) { - return DefaultTrialEndDate, nil - } -) - -type serverData struct { - PrivateKey string `json:"privatekey,omitempty" bson:"privatekey,omitempty"` -} - -// FetchJWTSecret - fetches jwt secret from db -func FetchJWTSecret() (string, error) { - var dbData string - var err error - var fetchedData = serverData{} - dbData, err = database.FetchRecord(database.SERVERCONF_TABLE_NAME, "nm-jwt-secret") - if err != nil { - return "", err +// GetJwtSecretValue fetches jwt secret from db +func GetJwtSecretValue() (string, error) { + jwtSecret := &schema.Internal{ + Key: schema.InternalKey_JwtSecret, } - err = json.Unmarshal([]byte(dbData), &fetchedData) + err := jwtSecret.Get(db.WithContext(context.TODO())) if err != nil { return "", err } - return fetchedData.PrivateKey, nil + + return jwtSecret.Value, nil } -// StoreJWTSecret - stores server jwt secret if needed +// StoreJWTSecret stores server jwt secret if needed func StoreJWTSecret(privateKey string) error { - var newData = serverData{} - var err error - var data []byte - newData.PrivateKey = privateKey - data, err = json.Marshal(&newData) - if err != nil { - return err + jwtSecret := &schema.Internal{ + Key: schema.InternalKey_JwtSecret, + Value: privateKey, } - return database.Insert("nm-jwt-secret", string(data), database.SERVERCONF_TABLE_NAME) + return jwtSecret.Set(db.WithContext(context.TODO())) } diff --git a/logic/telemetry.go b/logic/telemetry.go index c26d63dfe..47f77778a 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -2,15 +2,15 @@ package logic import ( "context" - "encoding/json" + "errors" "os" "time" "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" - "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/schema" + "gorm.io/gorm" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/servercfg" @@ -18,12 +18,6 @@ import ( "golang.org/x/exp/slog" ) -var ( - // flags to keep for telemetry - isFreeTier bool - telServerRecord = models.Telemetry{} -) - var LogEvent = func(a *models.Event) {} // posthog_pub_key - Key for sending data to PostHog @@ -32,22 +26,20 @@ const posthog_pub_key = "phc_1vEXhPOA1P7HP5jP2dVU9xDTUqXHAelmtravyZ1vvES" // posthog_endpoint - Endpoint of PostHog server const posthog_endpoint = "https://app.posthog.com" -// setFreeTierForTelemetry - store free tier flag without having an import cycle when used for telemetry -// (as the pro package needs the logic package as currently written). -func SetFreeTierForTelemetry(freeTierFlag bool) { - isFreeTier = freeTierFlag -} - // sendTelemetry - gathers telemetry data and sends to posthog func sendTelemetry() error { if Telemetry() == "off" { return nil } - var telRecord, err = FetchTelemetryRecord() + serverID := &schema.Internal{ + Key: schema.InternalKey_ServerID, + } + err := serverID.Get(db.WithContext(context.TODO())) if err != nil { return err } + // get telemetry data d := FetchTelemetryData() // get tenant admin email @@ -62,7 +54,7 @@ func sendTelemetry() error { // send to posthog return client.Enqueue(posthog.Capture{ - DistinctId: telRecord.UUID, + DistinctId: serverID.Value, Event: "daily checkin", Properties: posthog.NewProperties(). Set("nodes", d.Nodes). @@ -80,11 +72,11 @@ func sendTelemetry() error { Set("k8s", d.Count.K8S). Set("version", d.Version). Set("is_ee", d.IsPro). // TODO change is_ee to is_pro for consistency, but probably needs changes in posthog - Set("is_free_tier", isFreeTier). + Set("is_free_tier", false). Set("is_pro_trial", d.IsProTrial). Set("pro_trial_end_date", d.ProTrialEndDate.In(time.UTC).Format("2006-01-02")). Set("admin_email", adminEmail). - Set("email", adminEmail). // needed for posthog intgration with hubspot. "admin_email" can only be removed if not used in posthog + Set("email", adminEmail). // needed for posthog integration with hubspot. "admin_email" can only be removed if not used in posthog Set("is_saas_tenant", d.IsSaasTenant). Set("domain", d.Domain), }) @@ -104,11 +96,7 @@ func FetchTelemetryData() telemetryData { nodes, _ := (&schema.Node{}).ListAll(db.WithContext(context.TODO()), dbtypes.WithPreloads("Host")) data.Nodes = len(nodes) data.Count = getClientCount(nodes) - endDate, _ := GetTrialEndDate() - data.ProTrialEndDate = endDate - if endDate.After(time.Now()) { - data.IsProTrial = true - } + data.ProTrialEndDate, _ = time.Parse("2006-Jan-02", "2021-Apr-01") data.IsSaasTenant = servercfg.DeployedByOperator() data.Domain = servercfg.GetNmBaseDomain() return data @@ -116,31 +104,37 @@ func FetchTelemetryData() telemetryData { // getServerCount returns number of servers from database func getServerCount() int { - data, err := database.FetchRecords(database.SERVER_UUID_TABLE_NAME) - if err != nil { - logger.Log(0, "error retrieving server data", err.Error()) - } - return len(data) + return 1 } -// setTelemetryTimestamp - Give the entry in the DB a new timestamp -func setTelemetryTimestamp(telRecord *models.Telemetry) error { - lastsend := time.Now().Unix() - var serverTelData = models.Telemetry{ - UUID: telRecord.UUID, - LastSend: lastsend, - TrafficKeyPriv: telRecord.TrafficKeyPriv, - TrafficKeyPub: telRecord.TrafficKeyPub, +func getTelemetryLastReportedAt() (time.Time, error) { + telemetryLastReportedAt := &schema.Internal{ + Key: schema.InternalKey_TelemetryLastReportedAt, } - jsonObj, err := json.Marshal(&serverTelData) + err := telemetryLastReportedAt.Get(db.WithContext(context.TODO())) if err != nil { - return err + if errors.Is(err, gorm.ErrRecordNotFound) { + return time.Time{}, nil + } + + return time.Time{}, err } - err = database.Insert(database.SERVER_UUID_RECORD_KEY, string(jsonObj), database.SERVER_UUID_TABLE_NAME) - if err == nil { - telServerRecord = serverTelData + + telemetryLastReportedAtValue, err := time.Parse(telemetryLastReportedAt.Value, time.RFC3339) + if err != nil { + return time.Time{}, err + } + + return telemetryLastReportedAtValue, nil +} + +// setTelemetryLastReportedAt sets the time for the last hook run. +func setTelemetryLastReportedAt() error { + lastHookRunAt := &schema.Internal{ + Key: schema.InternalKey_TelemetryLastReportedAt, + Value: time.Now().UTC().Format(time.RFC3339), } - return err + return lastHookRunAt.Set(db.WithContext(context.TODO())) } // getClientCount - returns counts of nodes with various OS types and conditions @@ -164,25 +158,6 @@ func getClientCount(nodes []schema.Node) clientCount { return count } -// FetchTelemetryRecord - get the existing UUID and Timestamp from the DB -func FetchTelemetryRecord() (models.Telemetry, error) { - if telServerRecord.TrafficKeyPub != nil { - return telServerRecord, nil - } - var rawData string - var telObj models.Telemetry - var err error - rawData, err = database.FetchRecord(database.SERVER_UUID_TABLE_NAME, database.SERVER_UUID_RECORD_KEY) - if err != nil { - return telObj, err - } - err = json.Unmarshal([]byte(rawData), &telObj) - if err == nil { - telServerRecord = telObj - } - return telObj, err -} - // getDBLength - get length of DB to get count of objects func getDBLength(dbname string) int { data, err := database.FetchRecords(dbname) diff --git a/logic/timer.go b/logic/timer.go index ed0a7a3a8..2efdfb366 100644 --- a/logic/timer.go +++ b/logic/timer.go @@ -41,12 +41,11 @@ var hooksMutex sync.RWMutex // TimerCheckpoint - Checks if 24 hours has passed since telemetry was last sent. If so, sends telemetry data to posthog func TimerCheckpoint() error { - // get the telemetry record in the DB, which contains a timestamp - telRecord, err := FetchTelemetryRecord() + lastReportedAt, err := getTelemetryLastReportedAt() if err != nil { return err } - sendtime := time.Unix(telRecord.LastSend, 0).Add(time.Hour * time.Duration(timer_hours_between_runs)) + sendtime := lastReportedAt.Add(time.Hour * time.Duration(timer_hours_between_runs)) // can set to 2 minutes for testing // sendtime := time.Unix(telRecord.LastSend, 0).Add(time.Minute * 2) enoughTimeElapsed := time.Now().After(sendtime) @@ -54,7 +53,7 @@ func TimerCheckpoint() error { if enoughTimeElapsed { // run any time hooks runHooks() - return setTelemetryTimestamp(&telRecord) + return setTelemetryLastReportedAt() } return nil diff --git a/logic/traffic.go b/logic/traffic.go index 3c065c29e..5a8dd7d24 100644 --- a/logic/traffic.go +++ b/logic/traffic.go @@ -1,21 +1,35 @@ package logic -// RetrievePrivateTrafficKey - retrieves private key of server +import ( + "context" + "encoding/base64" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/schema" +) + +// RetrievePrivateTrafficKey retrieves private key of server func RetrievePrivateTrafficKey() ([]byte, error) { - var telRecord, err = FetchTelemetryRecord() + mqPrivateKey := &schema.Internal{ + Key: schema.InternalKey_MqPrivateKey, + } + err := mqPrivateKey.Get(db.WithContext(context.TODO())) if err != nil { return nil, err } - return telRecord.TrafficKeyPriv, nil + return base64.StdEncoding.DecodeString(mqPrivateKey.Value) } -// RetrievePublicTrafficKey - retrieves public key of server +// RetrievePublicTrafficKey retrieves public key of server func RetrievePublicTrafficKey() ([]byte, error) { - var telRecord, err = FetchTelemetryRecord() + mqPublicKey := &schema.Internal{ + Key: schema.InternalKey_MqPublicKey, + } + err := mqPublicKey.Get(db.WithContext(context.TODO())) if err != nil { return nil, err } - return telRecord.TrafficKeyPub, nil + return base64.StdEncoding.DecodeString(mqPublicKey.Value) } diff --git a/main.go b/main.go index 4fb85dde3..486628372 100644 --- a/main.go +++ b/main.go @@ -4,7 +4,8 @@ package main import ( "context" "crypto/rand" - "encoding/json" + "encoding/base64" + "errors" "flag" "fmt" "os" @@ -19,6 +20,7 @@ import ( "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/orchestrator/extensions" "github.com/gravitl/netmaker/schema" + "gorm.io/gorm" "github.com/google/uuid" "github.com/gravitl/netmaker/config" @@ -137,17 +139,26 @@ func initialize() { // Client Mode Prereq Check panic(err) } migrate.Run() - } - initializeUUID() + err = setServerID() + if err != nil { + logger.Log(0, "error setting server id: ", err.Error()) + } + + logic.SetJWTSecret() + + err = setMqKeys() + if err != nil { + logger.Log(0, "error setting mq keys: ", err.Error()) + } + + } //initialize cache _, _ = logic.GetAllExtClients() _ = logic.ListAcls() - _, _ = logic.GetAllEnrollmentKeys() _ = logic.CleanExpiredSSOStates() - logic.SetJWTSecret() } func startControllers(wg *sync.WaitGroup, ctx context.Context) { @@ -274,40 +285,66 @@ func setGarbageCollection() { } } -// initializeUUID - create a UUID record for server if none exists -func initializeUUID() error { - records, err := database.FetchRecords(database.SERVER_UUID_TABLE_NAME) - if err != nil { - if !database.IsEmptyRecord(err) { - return err - } - } else if len(records) > 0 { +func setServerID() error { + serverID := &schema.Internal{ + Key: schema.InternalKey_ServerID, + } + err := serverID.Get(db.WithContext(context.TODO())) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if serverID.Value != "" { return nil } - // setup encryption keys - var trafficPubKey, trafficPrivKey, errT = box.GenerateKey(rand.Reader) // generate traffic keys - if errT != nil { - return errT + + serverID.Value = uuid.NewString() + return serverID.Set(db.WithContext(context.TODO())) +} + +func setMqKeys() error { + mqPrivateKey := &schema.Internal{ + Key: schema.InternalKey_MqPrivateKey, + } + err := mqPrivateKey.Get(db.WithContext(context.TODO())) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + mqPublicKey := &schema.Internal{ + Key: schema.InternalKey_MqPublicKey, + } + err = mqPublicKey.Get(db.WithContext(context.TODO())) + if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + if mqPrivateKey.Value != "" && mqPublicKey.Value != "" { + return nil } - tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey) + + publicKey, privateKey, err := box.GenerateKey(rand.Reader) if err != nil { return err } - tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey) + privateKeyBytes, err := ncutils.ConvertKeyToBytes(privateKey) if err != nil { return err } - telemetry := models.Telemetry{ - UUID: uuid.NewString(), - TrafficKeyPriv: tPriv, - TrafficKeyPub: tPub, + publicKeyBytes, err := ncutils.ConvertKeyToBytes(publicKey) + if err != nil { + return err } - telJSON, err := json.Marshal(&telemetry) + + mqPrivateKey.Value = base64.StdEncoding.EncodeToString(privateKeyBytes) + mqPublicKey.Value = base64.StdEncoding.EncodeToString(publicKeyBytes) + + err = mqPrivateKey.Set(db.WithContext(context.TODO())) if err != nil { return err } - return database.Insert(database.SERVER_UUID_RECORD_KEY, string(telJSON), database.SERVER_UUID_TABLE_NAME) + return mqPublicKey.Set(db.WithContext(context.TODO())) } diff --git a/migrate/migrate.go b/migrate/migrate.go index fec52c70c..e79b23566 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -206,68 +206,26 @@ func assignSuperAdmin() { } func updateEnrollmentKeys() { - rows, err := database.FetchRecords(database.ENROLLMENT_KEYS_TABLE_NAME) + ctx := db.WithContext(context.TODO()) + existingKeys, err := logic.GetAllEnrollmentKeys(ctx) if err != nil { return } - for _, row := range rows { - var key models.EnrollmentKey - if err = json.Unmarshal([]byte(row), &key); err != nil { - continue - } - if key.Type != models.Undefined { - logger.Log(2, "migration: enrollment key type already set") - continue - } else { - logger.Log(2, "migration: updating enrollment key type") - if key.Unlimited { - key.Type = models.Unlimited - } else if key.UsesRemaining > 0 { - key.Type = models.Uses - } else if !key.Expiration.IsZero() { - key.Type = models.TimeExpiration - } - } - data, err := json.Marshal(key) - if err != nil { - logger.Log(0, "migration: marshalling enrollment key: "+err.Error()) - continue - } - if err = database.Insert(key.Value, string(data), database.ENROLLMENT_KEYS_TABLE_NAME); err != nil { - logger.Log(0, "migration: inserting enrollment key: "+err.Error()) - continue - } - - } - - existingKeys, err := logic.GetAllEnrollmentKeys() - if err != nil { - return - } - // check if any tags are duplicate - existingTags := make(map[string]struct{}) + // check if any networks already have a default enrollment key + existingNetworks := make(map[string]struct{}) for _, existingKey := range existingKeys { - for _, t := range existingKey.Tags { - existingTags[t] = struct{}{} + if existingKey.Default { + for _, n := range existingKey.Networks { + existingNetworks[n] = struct{}{} + } } } - networks, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) + networks, _ := (&schema.Network{}).ListAll(ctx) for _, network := range networks { - if _, ok := existingTags[network.Name]; ok { + if _, ok := existingNetworks[network.Name]; ok { continue } - _, _ = logic.CreateEnrollmentKey( - 0, - time.Time{}, - []string{network.Name}, - []string{network.Name}, - []models.TagID{}, - true, - uuid.Nil, - true, - false, - false, - ) + _, _ = logic.CreateDefaultNetworkEnrollmentKey(network.Name) } } diff --git a/migrate/migrate_schema.go b/migrate/migrate_schema.go index 484fb5d2b..ff8adfa23 100644 --- a/migrate/migrate_schema.go +++ b/migrate/migrate_schema.go @@ -34,6 +34,12 @@ func ToSQLSchema() error { return err } + // v1.7.0 migration includes migrating the server conf, generated and server uuid table. + err = ensureMigrationCompleted(context.TODO(), "migration-v1.7.0", migrateV1_7_0) + if err != nil { + return err + } + return nil } diff --git a/migrate/migrate_v1_5_1.go b/migrate/migrate_v1_5_1.go index dbacf2fa6..e3c432b03 100644 --- a/migrate/migrate_v1_5_1.go +++ b/migrate/migrate_v1_5_1.go @@ -17,6 +17,14 @@ import ( "gorm.io/datatypes" ) +const ( + TableName_Users = "users" + TableName_Networks = "networks" + TableName_UserPermissions = "user_permissions" + TableName_UserGroups = "usergroups" + TableName_Hosts = "hosts" +) + func migrateV1_5_1(ctx context.Context) error { err := migrateUsers(ctx) if err != nil { @@ -42,11 +50,11 @@ func migrateV1_5_1(ctx context.Context) error { } func migrateUsers(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.USERS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_Users) { return nil } - records, err := kvList(ctx, database.USERS_TABLE_NAME) + records, err := kvList(ctx, TableName_Users) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -104,11 +112,11 @@ func migrateUsers(ctx context.Context) error { } func migrateNetworks(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.NETWORKS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_Networks) { return nil } - records, err := kvList(ctx, database.NETWORKS_TABLE_NAME) + records, err := kvList(ctx, TableName_Networks) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -252,11 +260,11 @@ func migrateNetworks_Nameserver(ctx context.Context, network *models.Network) er } func migrateUserRoles(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.USER_PERMISSIONS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_UserPermissions) { return nil } - records, err := kvList(ctx, database.USER_PERMISSIONS_TABLE_NAME) + records, err := kvList(ctx, TableName_UserPermissions) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -281,11 +289,11 @@ func migrateUserRoles(ctx context.Context) error { } func migrateUserGroups(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.USER_GROUPS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_UserGroups) { return nil } - records, err := kvList(ctx, database.USER_GROUPS_TABLE_NAME) + records, err := kvList(ctx, TableName_UserGroups) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -310,11 +318,11 @@ func migrateUserGroups(ctx context.Context) error { } func migrateHosts(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.HOSTS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_Hosts) { return nil } - records, err := kvList(ctx, database.HOSTS_TABLE_NAME) + records, err := kvList(ctx, TableName_Hosts) if err != nil && !database.IsEmptyRecord(err) { return err } diff --git a/migrate/migrate_v1_6_0.go b/migrate/migrate_v1_6_0.go index 2f0d08582..19c502fba 100644 --- a/migrate/migrate_v1_6_0.go +++ b/migrate/migrate_v1_6_0.go @@ -20,6 +20,12 @@ import ( "gorm.io/gorm" ) +const ( + TableName_PendingUsers = "pending_users" + TableName_UserInvites = "user_invites" + TableName_Nodes = "nodes" +) + func migrateV1_6_0(ctx context.Context) error { err := migratePendingUsers(ctx) if err != nil { @@ -35,11 +41,11 @@ func migrateV1_6_0(ctx context.Context) error { } func migratePendingUsers(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.PENDING_USERS_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_PendingUsers) { return nil } - records, err := kvList(ctx, database.PENDING_USERS_TABLE_NAME) + records, err := kvList(ctx, TableName_PendingUsers) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -69,11 +75,11 @@ func migratePendingUsers(ctx context.Context) error { } func migrateUserInvites(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.USER_INVITES_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_UserInvites) { return nil } - records, err := kvList(ctx, database.USER_INVITES_TABLE_NAME) + records, err := kvList(ctx, TableName_UserInvites) if err != nil && !database.IsEmptyRecord(err) { return err } @@ -106,11 +112,11 @@ func migrateUserInvites(ctx context.Context) error { } func migrateNodes(ctx context.Context) error { - if !db.FromContext(ctx).Migrator().HasTable(database.NODES_TABLE_NAME) { + if !db.FromContext(ctx).Migrator().HasTable(TableName_Nodes) { return nil } - records, err := kvList(ctx, database.NODES_TABLE_NAME) + records, err := kvList(ctx, TableName_Nodes) if err != nil && !database.IsEmptyRecord(err) { return err } diff --git a/migrate/migrate_v1_7_0.go b/migrate/migrate_v1_7_0.go new file mode 100644 index 000000000..8478477c4 --- /dev/null +++ b/migrate/migrate_v1_7_0.go @@ -0,0 +1,299 @@ +package migrate + +import ( + "context" + "encoding/base64" + "encoding/json" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" +) + +const ( + TableName_ServerConf = "serverconf" + TableName_Generated = "generated" + TableName_ServerUUID = "serveruuid" + TableName_EnrollmentKey = "enrollmentkeys" +) + +func migrateV1_7_0(ctx context.Context) error { + err := migrateServerConf(ctx) + if err != nil { + return err + } + + err = migrateGenerated(ctx) + if err != nil { + return err + } + + err = migrateServerUUID(ctx) + if err != nil { + return err + } + + return migrateEnrollmentKeys(ctx) +} + +func migrateServerConf(ctx context.Context) error { + if !db.FromContext(ctx).Migrator().HasTable(TableName_ServerConf) { + return nil + } + + records, err := kvList(ctx, TableName_ServerConf) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + record, ok := records["nm-jwt-secret"] + if ok { + recordData := make(map[string]string) + err = json.Unmarshal([]byte(record), &recordData) + if err != nil { + return err + } + + jwtSecretValue, ok := recordData["privatekey"] + if ok { + jwtSecret := &schema.Internal{ + Key: schema.InternalKey_JwtSecret, + Value: jwtSecretValue, + } + err = jwtSecret.Set(ctx) + if err != nil { + return err + } + } + } + + record, ok = records["netmaker-id-key-pair"] + if ok { + recordData := make(map[string][]byte) + err = json.Unmarshal([]byte(record), &recordData) + if err != nil { + return err + } + + privateKeyValue, ok := recordData["private_key"] + if ok { + privateKey := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationPrivateKey, + Value: base64.StdEncoding.EncodeToString(privateKeyValue), + } + err = privateKey.Set(ctx) + if err != nil { + return err + } + } + + publicKeyValue, ok := recordData["public_key"] + if ok { + publicKey := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationPublicKey, + Value: base64.StdEncoding.EncodeToString(publicKeyValue), + } + err = publicKey.Set(ctx) + if err != nil { + return err + } + } + } + + return nil +} + +func migrateGenerated(ctx context.Context) error { + if !db.FromContext(ctx).Migrator().HasTable(TableName_Generated) { + return nil + } + + records, err := kvList(ctx, TableName_Generated) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + record, ok := records["netmaker_auth"] + if ok { + recordData := make(map[string]string) + err = json.Unmarshal([]byte(record), &recordData) + if err != nil { + return err + } + + oauthSecretValue, ok := recordData["value"] + if ok { + oauthSecret := &schema.Internal{ + Key: schema.InternalKey_OAuthSecret, + Value: oauthSecretValue, + } + err = oauthSecret.Set(ctx) + if err != nil { + return err + } + } + } + + return nil +} + +func migrateServerUUID(ctx context.Context) error { + if !db.FromContext(ctx).Migrator().HasTable(TableName_ServerUUID) { + return nil + } + + records, err := kvList(ctx, TableName_ServerUUID) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + record, ok := records["serveruuid"] + if ok { + type recordType struct { + UUID string `json:"uuid"` + LastSend int64 `json:"lastsend"` + TrafficKeyPriv []byte `json:"traffickeypriv"` + TrafficKeyPub []byte `json:"traffickeypub"` + } + + var recordData recordType + err = json.Unmarshal([]byte(record), &recordData) + if err != nil { + return err + } + + if recordData.UUID != "" { + serverID := &schema.Internal{ + Key: schema.InternalKey_ServerID, + Value: recordData.UUID, + } + err = serverID.Set(ctx) + if err != nil { + return err + } + } + + if recordData.LastSend != 0 { + telemetryLastReportedAt := &schema.Internal{ + Key: schema.InternalKey_TelemetryLastReportedAt, + Value: time.Unix(recordData.LastSend, 0).UTC().Format(time.RFC3339), + } + err = telemetryLastReportedAt.Set(ctx) + if err != nil { + return err + } + } + + if recordData.TrafficKeyPriv != nil && recordData.TrafficKeyPub != nil { + mqPrivateKey := &schema.Internal{ + Key: schema.InternalKey_MqPrivateKey, + Value: base64.StdEncoding.EncodeToString(recordData.TrafficKeyPriv), + } + err = mqPrivateKey.Set(ctx) + if err != nil { + return err + } + + mqPublicKey := &schema.Internal{ + Key: schema.InternalKey_MqPublicKey, + Value: base64.StdEncoding.EncodeToString(recordData.TrafficKeyPub), + } + err = mqPublicKey.Set(ctx) + if err != nil { + return err + } + } + } + + return nil +} + +func migrateEnrollmentKeys(ctx context.Context) error { + if !db.FromContext(ctx).Migrator().HasTable(TableName_EnrollmentKey) { + return nil + } + + records, err := kvList(ctx, TableName_EnrollmentKey) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + for _, record := range records { + var key models.EnrollmentKey + if err = json.Unmarshal([]byte(record), &key); err != nil { + return err + } + + // merge models.Networks and models.Tags (both hold network names) + networksSet := make(map[string]struct{}, len(key.Networks)+len(key.Tags)) + for _, n := range key.Networks { + networksSet[n] = struct{}{} + } + for _, t := range key.Tags { + networksSet[t] = struct{}{} + } + networks := make(datatypes.JSONSlice[string], 0, len(networksSet)) + for n := range networksSet { + networks = append(networks, n) + } + + // models.Groups (device tags) → schema.Tags + tags := make(datatypes.JSONSlice[string], 0, len(key.Groups)) + for _, g := range key.Groups { + tags = append(tags, g.String()) + } + + var gatewayID *string + if key.Relay != uuid.Nil { + s := key.Relay.String() + gatewayID = &s + } + + var keyType schema.EnrollmentKeyType + switch key.Type { + case models.Unlimited: + keyType = schema.EnrollmentKeyType_UnlimitedUses + case models.Uses: + keyType = schema.EnrollmentKeyType_LimitedUses + case models.TimeExpiration: + keyType = schema.EnrollmentKeyType_TimedExpiry + default: + keyType = schema.EnrollmentKeyType_UnlimitedUses + } + + // models.Tags[0] was used as the enrollment key display name + name := "" + if len(key.Tags) > 0 { + name = key.Tags[0] + } + + _key := &schema.EnrollmentKey{ + ID: uuid.NewString(), + Name: name, + Value: key.Value, + Token: key.Token, + Default: key.Default, + Unlimited: key.Unlimited, + UsesRemaining: key.UsesRemaining, + Expiration: key.Expiration, + Networks: networks, + Tags: tags, + GatewayID: gatewayID, + AutoEgress: key.AutoEgress, + AutoAssignGateway: key.AutoAssignGateway, + Type: keyType, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } + + if err = _key.Create(ctx); err != nil { + return err + } + } + + return nil +} diff --git a/models/structs.go b/models/structs.go index 541b6e58c..456403719 100644 --- a/models/structs.go +++ b/models/structs.go @@ -246,15 +246,6 @@ type InetNodeReq struct { InetNodeClientIDs []string `json:"inet_node_client_ids"` } -// Telemetry - contains UUID of the server and timestamp of last send to posthog -// also contains assymetrical encryption pub/priv keys for any server traffic -type Telemetry struct { - UUID string `json:"uuid" bson:"uuid"` - LastSend int64 `json:"lastsend" bson:"lastsend" swaggertype:"primitive,integer" format:"int64"` - TrafficKeyPriv []byte `json:"traffickeypriv" bson:"traffickeypriv"` - TrafficKeyPub []byte `json:"traffickeypub" bson:"traffickeypub"` -} - // ServerAddr - to pass to clients to tell server addresses and if it's the leader or not type ServerAddr struct { IsLeader bool `json:"isleader" bson:"isleader" yaml:"isleader"` diff --git a/orchestrator/extensions/node.go b/orchestrator/extensions/node.go index 96fc0f874..951a0bf08 100644 --- a/orchestrator/extensions/node.go +++ b/orchestrator/extensions/node.go @@ -7,7 +7,7 @@ import ( type NodeExtensions interface { ConfigureAutoRelay(node *schema.Node) - ConfigureAutoAssignGateway(node *schema.Node, key *models.EnrollmentKey) + ConfigureAutoAssignGateway(node *schema.Node, key *schema.EnrollmentKey) ConfigureTag(node *schema.Node, tagID models.TagID) } @@ -17,7 +17,7 @@ func (c *CENodeExtensions) ConfigureAutoRelay(node *schema.Node) { node.IsAutoRelay = "no" } -func (c *CENodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, _ *models.EnrollmentKey) { +func (c *CENodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, _ *schema.EnrollmentKey) { node.AutoAssignGateway = false } diff --git a/orchestrator/node.go b/orchestrator/node.go index acc8d2eac..bdb814100 100644 --- a/orchestrator/node.go +++ b/orchestrator/node.go @@ -47,8 +47,8 @@ func (n *NodeOrchestrator) CreateNode(ctx context.Context, host *schema.Host, ne if ops.useKey { n.nodeExt.ConfigureAutoAssignGateway(node, ops.key) - for _, tag := range ops.key.Groups { - n.nodeExt.ConfigureTag(node, tag) + for _, tag := range ops.key.Tags { + n.nodeExt.ConfigureTag(node, models.TagID(tag)) } } @@ -127,14 +127,14 @@ func (n *NodeOrchestrator) CreateNode(ctx context.Context, host *schema.Host, ne if err != nil { return nil, err } - } else if ops.useKey && ops.key.Relay != uuid.Nil { + } else if ops.useKey && ops.key.GatewayID != nil { gateway := &schema.Node{ - ID: ops.key.Relay.String(), + ID: *ops.key.GatewayID, } err = gateway.Get(ctx) if err == nil { // TODO: merge operation - relayID := ops.key.Relay.String() + relayID := *ops.key.GatewayID node.RelayedByNodeID = &relayID err = node.UpdateRelayingNode(ctx) if err != nil { diff --git a/orchestrator/node_test.go b/orchestrator/node_test.go index 114fff1cf..c53f890da 100644 --- a/orchestrator/node_test.go +++ b/orchestrator/node_test.go @@ -4,10 +4,9 @@ import ( "context" "net" - "github.com/google/uuid" "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/orchestrator/extensions" + "github.com/gravitl/netmaker/schema" testutils "github.com/gravitl/netmaker/test/utils" "github.com/stretchr/testify/suite" "gorm.io/datatypes" @@ -178,7 +177,7 @@ func (c *CENodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { network := testutils.CreateIPv10Network(c.T(), "network-0") c.Run("With AutoAssignGateway", func() { - key := &models.EnrollmentKey{ + key := &schema.EnrollmentKey{ AutoAssignGateway: true, } @@ -190,7 +189,7 @@ func (c *CENodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("Without AutoAssignGateway", func() { - key := &models.EnrollmentKey{ + key := &schema.EnrollmentKey{ AutoAssignGateway: false, } @@ -202,20 +201,20 @@ func (c *CENodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("With Tags", func() { - key := &models.EnrollmentKey{ - Groups: []models.TagID{"tag-0"}, + key := &schema.EnrollmentKey{ + Tags: []string{"tag-0"}, } node, err := GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, UseKey(key)) c.Require().NoError(err) - c.Require().NotContains(node.Tags, string(key.Groups[0])) + c.Require().NotContains(node.Tags, key.Tags[0]) testutils.DeleteNode(c.T(), node) }) c.Run("Without Tags", func() { - key := &models.EnrollmentKey{ - Groups: []models.TagID{}, + key := &schema.EnrollmentKey{ + Tags: []string{}, } node, err := GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, UseKey(key)) @@ -234,8 +233,8 @@ func (c *CENodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { gateway, err := GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), gatewayHost, network) c.Require().NoError(err) - key := &models.EnrollmentKey{ - Relay: uuid.MustParse(gateway.ID), + key := &schema.EnrollmentKey{ + GatewayID: &gateway.ID, } node, err := GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, UseKey(key)) @@ -253,7 +252,7 @@ func (c *CENodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("Without Gateway", func() { - key := &models.EnrollmentKey{} + key := &schema.EnrollmentKey{} node, err := GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, UseKey(key)) c.Require().NoError(err) diff --git a/orchestrator/options.go b/orchestrator/options.go index 20484d958..c0a152576 100644 --- a/orchestrator/options.go +++ b/orchestrator/options.go @@ -1,10 +1,10 @@ package orchestrator -import "github.com/gravitl/netmaker/models" +import "github.com/gravitl/netmaker/schema" type Options struct { useKey bool - key *models.EnrollmentKey + key *schema.EnrollmentKey skipHostUpdate bool skipNodeUpdate bool skipPublishPeerUpdate bool @@ -23,7 +23,7 @@ func applyOptions(opts ...Option) *Options { return o } -func UseKey(key *models.EnrollmentKey) Option { +func UseKey(key *schema.EnrollmentKey) Option { return func(o *Options) *Options { o.useKey = true o.key = key diff --git a/pro/auth/auth.go b/pro/auth/auth.go index b556668aa..26d4f4999 100644 --- a/pro/auth/auth.go +++ b/pro/auth/auth.go @@ -117,9 +117,9 @@ func InitializeAuthProvider() string { return "" } logger.Log(0, "setting oauth secret") - var err = logic.SetAuthSecret(logic.RandomString(64)) + var err = logic.SetOAuthSecret(logic.RandomString(64)) if err != nil { - logger.FatalLog("failed to set auth_secret", err.Error()) + logger.FatalLog("failed to set oauth secret", err.Error()) } var authInfo = logic.GetAuthProviderInfo(logic.GetServerSettings()) var serverConn = servercfg.GetAPIHost() diff --git a/pro/auth/azure-ad.go b/pro/auth/azure-ad.go index f30335d99..c9dddd79f 100644 --- a/pro/auth/azure-ad.go +++ b/pro/auth/azure-ad.go @@ -163,7 +163,7 @@ func handleAzureCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowed(w) return } - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { return } diff --git a/pro/auth/github.go b/pro/auth/github.go index 12f0aae92..1dc0e1d97 100644 --- a/pro/auth/github.go +++ b/pro/auth/github.go @@ -174,7 +174,7 @@ func handleGithubCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowed(w) return } - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { return } diff --git a/pro/auth/google.go b/pro/auth/google.go index 329ad6446..8a7638e47 100644 --- a/pro/auth/google.go +++ b/pro/auth/google.go @@ -168,7 +168,7 @@ func handleGoogleCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowed(w) return } - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { return } diff --git a/pro/auth/headless_callback.go b/pro/auth/headless_callback.go index 9d2a4931e..3882af3b5 100644 --- a/pro/auth/headless_callback.go +++ b/pro/auth/headless_callback.go @@ -91,7 +91,7 @@ func HandleHeadlessSSOCallback(w http.ResponseWriter, r *http.Request) { handleUserAccountDisabled(w) return } - newPass, fetchErr := logic.FetchPassValue("") + newPass, fetchErr := logic.FetchOAuthSecret() if fetchErr != nil { return } diff --git a/pro/auth/oidc.go b/pro/auth/oidc.go index 1462d0eff..aa63930fd 100644 --- a/pro/auth/oidc.go +++ b/pro/auth/oidc.go @@ -175,7 +175,7 @@ func handleOIDCCallback(w http.ResponseWriter, r *http.Request) { handleOauthUserNotAllowed(w) return } - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { return } diff --git a/pro/auth/sync.go b/pro/auth/sync.go index afd8e42f7..0977f0e87 100644 --- a/pro/auth/sync.go +++ b/pro/auth/sync.go @@ -140,7 +140,7 @@ func syncUsers(idpUsers []idp.User, removeIntegration bool) error { return err } - password, err := logic.FetchPassValue("") + password, err := logic.FetchOAuthSecret() if err != nil { return err } diff --git a/pro/controllers/users.go b/pro/controllers/users.go index b0f34dcbf..bf4326b3b 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -2027,7 +2027,7 @@ func approvePendingUser(w http.ResponseWriter, r *http.Request) { return } - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(fetchErr, "internal")) return diff --git a/pro/initialize.go b/pro/initialize.go index 26bbea938..7bd7e9302 100644 --- a/pro/initialize.go +++ b/pro/initialize.go @@ -56,49 +56,16 @@ func InitPro() { ) controller.ListRoles = proControllers.ListRoles logic.EnterpriseCheckFuncs = append(logic.EnterpriseCheckFuncs, func(ctx context.Context, wg *sync.WaitGroup) { - // == License Handling == - enableLicenseHook := true - // licenseKeyValue := servercfg.GetLicenseKey() - // netmakerTenantID := servercfg.GetNetmakerTenantID() - // if licenseKeyValue != "" && netmakerTenantID != "" { - // enableLicenseHook = true - // } - if !enableLicenseHook { - err := initTrial() - if err != nil { - logger.Log(0, "failed to init trial", err.Error()) - enableLicenseHook = true - } - trialEndDate, err := getTrialEndDate() - if err != nil { - slog.Error("failed to get trial end date", "error", err) - enableLicenseHook = true - } else { - // check if trial ended - if time.Now().After(trialEndDate) { - // trial ended already - enableLicenseHook = true - } - } - - } - - if enableLicenseHook { - logger.Log(0, "starting license checker") - license.ClearLicenseCache() - if err := license.ValidateLicense(); err != nil { - slog.Error(err.Error()) - return - } - logger.Log(0, "proceeding with Paid Tier license") - logic.SetFreeTierForTelemetry(false) - // == End License Handling == - // License validation runs on all pods to avoid audit issues - license.AddLicenseHooks() - } else { - logger.Log(0, "starting trial license hook") - addTrialLicenseHook() + logger.Log(0, "starting license checker") + _ = license.ClearLicenseCache() + if err := license.ValidateLicense(); err != nil { + slog.Error(err.Error()) + return } + logger.Log(0, "proceeding with Paid Tier license") + // == End License Handling == + // License validation runs on all pods to avoid audit issues + license.AddLicenseHooks() //AddUnauthorisedUserNodeHooks() @@ -156,7 +123,6 @@ func InitPro() { logic.DeleteNodeMetricsFromPeers = proLogic.DeleteNodeMetricsFromPeers logic.SetPeerMetricsDisconnected = proLogic.SetPeerMetricsDisconnected logic.TriggerCollectMetrics = proLogic.PublishCollectMetrics - logic.GetTrialEndDate = getTrialEndDate mq.UpdateMetrics = proLogic.MQUpdateMetrics mq.UpdateMetricsFallBack = proLogic.MQUpdateMetricsFallBack logic.GetFilteredNodesByUserAccess = proLogic.GetFilteredNodesByUserAccess diff --git a/pro/license/license.go b/pro/license/license.go index a2486337e..56db86edc 100644 --- a/pro/license/license.go +++ b/pro/license/license.go @@ -2,6 +2,7 @@ package license import ( "bytes" + "context" "crypto/rand" "encoding/json" "errors" @@ -10,29 +11,22 @@ import ( "net/http" "time" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/mq" proLogic "github.com/gravitl/netmaker/pro/logic" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/utils" + "gorm.io/gorm" "golang.org/x/crypto/nacl/box" "golang.org/x/exp/slog" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/servercfg" ) -const ( - db_license_key = "netmaker-id-key-pair" -) - -type apiServerConf struct { - PrivateKey []byte `json:"private_key" binding:"required"` - PublicKey []byte `json:"public_key" binding:"required"` -} - // AddLicenseHooks - adds the validation and cache clear hooks func AddLicenseHooks() { logic.HookManagerCh <- models.HookDetails{ @@ -148,41 +142,62 @@ func ValidateLicense() (err error) { // as well as secure communication with API // if none present, it generates a new pair func FetchApiServerKeys() (pub *[32]byte, priv *[32]byte, err error) { - returnData := apiServerConf{} - currentData, err := database.FetchRecord(database.SERVERCONF_TABLE_NAME, db_license_key) - if err != nil && !database.IsEmptyRecord(err) { - return nil, nil, err - } else if database.IsEmptyRecord(err) { // need to generate a new identifier pair + var create bool + privateKey := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationPrivateKey, + } + err = privateKey.Get(db.WithContext(context.TODO())) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + create = true + } else { + return nil, nil, err + } + } + + publicKey := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationPublicKey, + } + err = publicKey.Get(db.WithContext(context.TODO())) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + create = true + } else { + return nil, nil, err + } + } + + if create { pub, priv, err = box.GenerateKey(rand.Reader) if err != nil { return nil, nil, err } - pubBytes, err := ncutils.ConvertKeyToBytes(pub) + privateKeyBytes, err := ncutils.ConvertKeyToBytes(priv) if err != nil { return nil, nil, err } - privBytes, err := ncutils.ConvertKeyToBytes(priv) + publicKeyBytes, err := ncutils.ConvertKeyToBytes(pub) if err != nil { return nil, nil, err } - returnData.PrivateKey = privBytes - returnData.PublicKey = pubBytes - record, err := json.Marshal(&returnData) + + privateKey.Value = base64encode(privateKeyBytes) + err = privateKey.Set(db.WithContext(context.TODO())) if err != nil { return nil, nil, err } - if err = database.Insert(db_license_key, string(record), database.SERVERCONF_TABLE_NAME); err != nil { + + publicKey.Value = base64encode(publicKeyBytes) + err = publicKey.Set(db.WithContext(context.TODO())) + if err != nil { return nil, nil, err } } else { - if err = json.Unmarshal([]byte(currentData), &returnData); err != nil { - return nil, nil, err - } - priv, err = ncutils.ConvertBytesToKey(returnData.PrivateKey) + priv, err = ncutils.ConvertBytesToKey(base64decode(privateKey.Value)) if err != nil { return nil, nil, err } - pub, err = ncutils.ConvertBytesToKey(returnData.PublicKey) + pub, err = ncutils.ConvertBytesToKey(base64decode(publicKey.Value)) if err != nil { return nil, nil, err } @@ -291,31 +306,29 @@ func validateLicenseKey(encryptedData []byte, publicKey *[32]byte) ([]byte, bool } func cacheResponse(response []byte) error { - lrc := licenseResponseCache{ - Body: response, + cachedResponse := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationCachedResponse, + Value: base64encode(response), } - - record, err := json.Marshal(&lrc) - if err != nil { - return err - } - - return database.Insert(license_cache_key, string(record), database.CACHE_TABLE_NAME) + return cachedResponse.Set(db.WithContext(context.TODO())) } func getCachedResponse() ([]byte, error) { - var lrc licenseResponseCache - record, err := database.FetchRecord(database.CACHE_TABLE_NAME, license_cache_key) - if err != nil { - return nil, err + cachedResponse := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationCachedResponse, } - if err = json.Unmarshal([]byte(record), &lrc); err != nil { + err := cachedResponse.Get(db.WithContext(context.TODO())) + if err != nil { return nil, err } - return lrc.Body, nil + + return base64decode(cachedResponse.Value), nil } // ClearLicenseCache - clears the cached validate response func ClearLicenseCache() error { - return database.DeleteRecord(database.CACHE_TABLE_NAME, license_cache_key) + cachedResponse := &schema.Internal{ + Key: schema.InternalKey_LicenseValidationCachedResponse, + } + return cachedResponse.Reset(db.WithContext(context.TODO())) } diff --git a/pro/license/types.go b/pro/license/types.go index ce8ac70aa..8ec70b876 100644 --- a/pro/license/types.go +++ b/pro/license/types.go @@ -7,9 +7,7 @@ import ( ) const ( - license_cache_key = "license_response_cache" license_validation_err_msg = "invalid license" - server_id_key = "nm-server-id" ) var errValidation = errors.New(license_validation_err_msg) @@ -50,7 +48,3 @@ type ValidateLicenseRequest struct { EncryptedPart string `json:"secret" binding:"required"` NmBaseDomain string `json:"nm_base_domain"` } - -type licenseResponseCache struct { - Body []byte `json:"body" binding:"required"` -} diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index 1813fc6a8..bc3e97322 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -820,7 +820,7 @@ func IsNetworkRolesValid(networkRoles map[schema.NetworkID]map[schema.UserRoleID // PrepareOauthUserFromInvite - init oauth user before create func PrepareOauthUserFromInvite(in *schema.UserInvite) (schema.User, error) { - var newPass, fetchErr = logic.FetchPassValue("") + var newPass, fetchErr = logic.FetchOAuthSecret() if fetchErr != nil { return schema.User{}, fetchErr } diff --git a/pro/orchestrator/extensions/node.go b/pro/orchestrator/extensions/node.go index 124ec5ecd..a2d315c7d 100644 --- a/pro/orchestrator/extensions/node.go +++ b/pro/orchestrator/extensions/node.go @@ -12,7 +12,7 @@ func (p *ProNodeExtensions) ConfigureAutoRelay(node *schema.Node) { node.IsAutoRelay = "yes" } -func (p *ProNodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, key *models.EnrollmentKey) { +func (p *ProNodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, key *schema.EnrollmentKey) { node.AutoAssignGateway = key.AutoAssignGateway } diff --git a/pro/orchestrator/node_test.go b/pro/orchestrator/node_test.go index fdb49aff8..e0063708a 100644 --- a/pro/orchestrator/node_test.go +++ b/pro/orchestrator/node_test.go @@ -4,11 +4,10 @@ import ( "context" "net" - "github.com/google/uuid" "github.com/gravitl/netmaker/db" - "github.com/gravitl/netmaker/models" core "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/pro/orchestrator/extensions" + "github.com/gravitl/netmaker/schema" testutils "github.com/gravitl/netmaker/test/utils" "github.com/stretchr/testify/suite" @@ -181,7 +180,7 @@ func (c *ProNodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { tag := testutils.CreateTag(c.T(), "tag-0", network.Name) c.Run("With AutoAssignGateway", func() { - key := &models.EnrollmentKey{ + key := &schema.EnrollmentKey{ AutoAssignGateway: true, } @@ -193,7 +192,7 @@ func (c *ProNodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("Without AutoAssignGateway", func() { - key := &models.EnrollmentKey{ + key := &schema.EnrollmentKey{ AutoAssignGateway: false, } @@ -205,20 +204,20 @@ func (c *ProNodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("With Tags", func() { - key := &models.EnrollmentKey{ - Groups: []models.TagID{tag.ID}, + key := &schema.EnrollmentKey{ + Tags: []string{tag.ID.String()}, } node, err := core.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, core.UseKey(key)) c.Require().NoError(err) - c.Require().Contains(node.Tags, string(key.Groups[0])) + c.Require().Contains(node.Tags, key.Tags[0]) testutils.DeleteNode(c.T(), node) }) c.Run("Without Tags", func() { - key := &models.EnrollmentKey{ - Groups: []models.TagID{}, + key := &schema.EnrollmentKey{ + Tags: []string{}, } node, err := core.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, core.UseKey(key)) @@ -237,8 +236,8 @@ func (c *ProNodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { gateway, err := core.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), gatewayHost, network) c.Require().NoError(err) - key := &models.EnrollmentKey{ - Relay: uuid.MustParse(gateway.ID), + key := &schema.EnrollmentKey{ + GatewayID: &gateway.ID, } node, err := core.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, core.UseKey(key)) @@ -256,7 +255,7 @@ func (c *ProNodeOrchestratorTestSuite) TestCreateNodeWithEnrollmentKey() { }) c.Run("Without Gateway", func() { - key := &models.EnrollmentKey{} + key := &schema.EnrollmentKey{} node, err := core.GetRepository().NodeOrchestrator().CreateNode(db.WithContext(context.TODO()), host, network, core.UseKey(key)) c.Require().NoError(err) diff --git a/pro/trial.go b/pro/trial.go deleted file mode 100644 index 5b711b993..000000000 --- a/pro/trial.go +++ /dev/null @@ -1,161 +0,0 @@ -//go:build ee -// +build ee - -package pro - -import ( - "crypto/rand" - "encoding/json" - "errors" - "time" - - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" - "github.com/gravitl/netmaker/netclient/ncutils" - "github.com/gravitl/netmaker/servercfg" - "golang.org/x/crypto/nacl/box" -) - -type TrialInfo struct { - PrivKey []byte `json:"priv_key"` - PubKey []byte `json:"pub_key"` - Secret []byte `json:"secret"` -} - -func addTrialLicenseHook() { - logic.HookManagerCh <- models.HookDetails{ - Hook: logic.WrapHook(TrialLicenseHook), - Interval: time.Hour, - } -} - -type TrialDates struct { - TrialStartedAt time.Time `json:"trial_started_at"` - TrialEndsAt time.Time `json:"trial_ends_at"` -} - -const trial_table_name = "trial" - -const trial_data_key = "trialdata" - -// stores trial end date -func initTrial() error { - telData := logic.FetchTelemetryData() - if telData.Hosts > 0 || telData.Networks > 0 || telData.Users > 0 { - return nil // database is already populated, so skip creating trial - } - database.CreateTable(trial_table_name) - records, err := database.FetchRecords(trial_table_name) - if err != nil && !database.IsEmptyRecord(err) { - return err - } - if len(records) > 0 { - return nil - } - // setup encryption keys - trafficPubKey, trafficPrivKey, err := box.GenerateKey(rand.Reader) // generate traffic keys - if err != nil { - return err - } - tPriv, err := ncutils.ConvertKeyToBytes(trafficPrivKey) - if err != nil { - return err - } - - tPub, err := ncutils.ConvertKeyToBytes(trafficPubKey) - if err != nil { - return err - } - trialDates := TrialDates{ - TrialStartedAt: time.Now(), - TrialEndsAt: time.Now().Add(time.Hour * 24 * 14), - } - t := TrialInfo{ - PrivKey: tPriv, - PubKey: tPub, - } - tel, err := logic.FetchTelemetryRecord() - if err != nil { - return err - } - - trialDatesData, err := json.Marshal(trialDates) - if err != nil { - return err - } - telePubKey, err := ncutils.ConvertBytesToKey(tel.TrafficKeyPub) - if err != nil { - return err - } - trialDatesSecret, err := ncutils.BoxEncrypt(trialDatesData, telePubKey, trafficPrivKey) - if err != nil { - return err - } - t.Secret = trialDatesSecret - trialData, err := json.Marshal(t) - if err != nil { - return err - } - err = database.Insert(trial_data_key, string(trialData), trial_table_name) - if err != nil { - return err - } - return nil -} - -// TrialLicenseHook - hook func to check if pro trial has ended -func TrialLicenseHook() error { - endDate, err := getTrialEndDate() - if err != nil { - logger.FatalLog0("failed to trial end date", err.Error()) - } - if time.Now().After(endDate) { - logger.Log(0, "***IMPORTANT: Your Trial Has Ended, to continue using pro version, please visit https://app.netmaker.io/ and create on-prem tenant to obtain a license***\nIf you wish to downgrade to community version, please run this command `/root/nm-quick.sh -d`") - err = errors.New("your trial has ended") - servercfg.ErrLicenseValidation = err - return err - } - return nil -} - -// get trial date -func getTrialEndDate() (time.Time, error) { - record, err := database.FetchRecord(trial_table_name, trial_data_key) - if err != nil { - return logic.DefaultTrialEndDate, err - } - var trialInfo TrialInfo - err = json.Unmarshal([]byte(record), &trialInfo) - if err != nil { - return logic.DefaultTrialEndDate, err - } - tel, err := logic.FetchTelemetryRecord() - if err != nil { - return logic.DefaultTrialEndDate, err - } - telePrivKey, err := ncutils.ConvertBytesToKey(tel.TrafficKeyPriv) - if err != nil { - return logic.DefaultTrialEndDate, err - } - trialPubKey, err := ncutils.ConvertBytesToKey(trialInfo.PubKey) - if err != nil { - return logic.DefaultTrialEndDate, err - } - // decrypt secret - secretDecrypt, err := ncutils.BoxDecrypt(trialInfo.Secret, trialPubKey, telePrivKey) - if err != nil { - return logic.DefaultTrialEndDate, err - } - trialDates := TrialDates{} - err = json.Unmarshal(secretDecrypt, &trialDates) - if err != nil { - return logic.DefaultTrialEndDate, err - } - if trialDates.TrialEndsAt.IsZero() { - return logic.DefaultTrialEndDate, errors.New("invalid date") - } - return trialDates.TrialEndsAt, nil - -} diff --git a/schema/enrollment_keys.go b/schema/enrollment_keys.go new file mode 100644 index 000000000..8065dc317 --- /dev/null +++ b/schema/enrollment_keys.go @@ -0,0 +1,87 @@ +package schema + +import ( + "context" + "time" + + dbtypes "github.com/gravitl/netmaker/db/types" + "gorm.io/datatypes" + + "github.com/gravitl/netmaker/db" +) + +type EnrollmentKeyType string + +const ( + EnrollmentKeyType_UnlimitedUses EnrollmentKeyType = "unlimited_uses" + EnrollmentKeyType_LimitedUses EnrollmentKeyType = "limited_uses" + EnrollmentKeyType_TimedExpiry EnrollmentKeyType = "timed_expiry" +) + +type EnrollmentKey struct { + ID string `gorm:"primaryKey" json:"id"` + Name string `json:"name"` + Value string `json:"value"` + Token string `json:"token"` + Default bool `json:"default"` + Unlimited bool `json:"unlimited"` + UsesRemaining int `json:"uses_remaining"` + Expiration time.Time `json:"expiration"` + Networks datatypes.JSONSlice[string] `json:"networks"` + Tags datatypes.JSONSlice[string] `json:"tags"` + GatewayID *string `json:"gateway_id"` + AutoEgress bool `json:"auto_egress"` + AutoAssignGateway bool `json:"auto_assign_gateway"` + Type EnrollmentKeyType `json:"type"` + CreatedBy string `json:"created_by"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (e *EnrollmentKey) TableName() string { + return "enrollment_keys_v1" +} + +func (e *EnrollmentKey) Create(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}).Create(e).Error +} + +func (e *EnrollmentKey) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}).Where("id = ?", e.ID).First(e).Error +} + +func (e *EnrollmentKey) GetByValue(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}).Where("value = ?", e.Value).First(e).Error +} + +func (e *EnrollmentKey) Upsert(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *EnrollmentKey) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}).Where("id = ?", e.ID).Delete(e).Error +} + +func (e *EnrollmentKey) DeleteByValue(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}).Where("value = ?", e.Value).Delete(e).Error +} + +func (e *EnrollmentKey) ListAll(ctx context.Context, options ...dbtypes.Option) ([]EnrollmentKey, error) { + var keys []EnrollmentKey + query := db.FromContext(ctx).Model(&EnrollmentKey{}) + for _, opt := range options { + query = opt(query) + } + err := query.Find(&keys).Error + return keys, err +} + +func (e *EnrollmentKey) Count(ctx context.Context, options ...dbtypes.Option) (int, error) { + var count int64 + query := db.FromContext(ctx).Model(&EnrollmentKey{}) + for _, opt := range options { + query = opt(query) + } + err := query.Count(&count).Error + return int(count), err +} diff --git a/schema/internal.go b/schema/internal.go new file mode 100644 index 000000000..4582cb913 --- /dev/null +++ b/schema/internal.go @@ -0,0 +1,46 @@ +package schema + +import ( + "context" + + "github.com/gravitl/netmaker/db" +) + +const ( + InternalKey_ServerID = "server_id" + InternalKey_JwtSecret = "jwt_secret" + InternalKey_OAuthSecret = "oauth_secret" + InternalKey_MqPrivateKey = "mq_private_key" + InternalKey_MqPublicKey = "mq_public_key" + InternalKey_LicenseValidationPrivateKey = "license_validation_private_key" + InternalKey_LicenseValidationPublicKey = "license_validation_public_key" + InternalKey_LicenseValidationCachedResponse = "license_validation_cached_response" + InternalKey_TelemetryLastReportedAt = "telemetry_last_reported_at" +) + +type Internal struct { + Key string `gorm:"primaryKey"` + Value string `gorm:"not null"` +} + +func (i *Internal) TableName() string { + return "__internal__" +} + +func (i *Internal) Set(ctx context.Context) error { + return db.FromContext(ctx).Save(i).Error +} + +func (i *Internal) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&Internal{}). + Where("key = ?", i.Key). + First(i). + Error +} + +func (i *Internal) Reset(ctx context.Context) error { + return db.FromContext(ctx).Model(&Internal{}). + Where("key = ?", i.Key). + Delete(i). + Error +} diff --git a/schema/models.go b/schema/models.go index e50870f8f..d73d6ad0c 100644 --- a/schema/models.go +++ b/schema/models.go @@ -22,5 +22,6 @@ func ListModels() []interface{} { &Node{}, &PostureCheckViolation{}, &Integration{}, + &EnrollmentKey{}, } } diff --git a/schema/nodes.go b/schema/nodes.go index 4a80a3504..0a5a03022 100644 --- a/schema/nodes.go +++ b/schema/nodes.go @@ -504,3 +504,10 @@ func (n *Node) ResetGateway(ctx context.Context) error { UpdateColumn("auto_relayed_peers", expr.RemoveByValue("auto_relayed_peers", n.ID)). Error } + +func (n *Node) ClearGatewayIDFromEnrollmentKeys(ctx context.Context) error { + return db.FromContext(ctx).Model(&EnrollmentKey{}). + Where("gateway_id = ?", n.ID). + Update("gateway_id", nil). + Error +}