From 93c861cbdbc643a4a30ce0b1904cd449c71ab077 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 13:48:10 +0530 Subject: [PATCH 01/21] feat(go): add schema definition for organization and tenant tables; --- schema/organizations.go | 121 ++++++++++++++++++++++++++++++++++++++++ schema/tenants.go | 103 ++++++++++++++++++++++++++++++++++ 2 files changed, 224 insertions(+) create mode 100644 schema/organizations.go create mode 100644 schema/tenants.go diff --git a/schema/organizations.go b/schema/organizations.go new file mode 100644 index 000000000..b9e99231b --- /dev/null +++ b/schema/organizations.go @@ -0,0 +1,121 @@ +package schema + +import ( + "context" + "fmt" + "math/rand" + "regexp" + "strings" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/db" + "gorm.io/gorm" +) + +var slugNonAlphaNumericRegex = regexp.MustCompile(`[^a-z0-9]+`) + +// generateSlug produces a URL-friendly slug from name with a random 4-digit +// suffix to reduce collisions (e.g. "acme-corp-4821"). +func generateSlug(name string) string { + base := strings.Trim(slugNonAlphaNumericRegex.ReplaceAllString(strings.ToLower(name), "-"), "-") + if base == "" { + base = "org" + } + return fmt.Sprintf("%s-%04d", base, rand.Intn(9000)+1000) +} + +type Organization struct { + ID string `gorm:"primaryKey" json:"id"` + Name string `gorm:"not null" json:"name"` + Slug string `gorm:"uniqueIndex;not null" json:"slug"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (o *Organization) TableName() string { + return "organizations_v1" +} + +func (o *Organization) Create(ctx context.Context) error { + if o.ID == "" { + o.ID = uuid.NewString() + } + // Retry slug generation on unique constraint violation (max 5 attempts). + for i := 0; i < 5; i++ { + if o.Slug == "" { + o.Slug = generateSlug(o.Name) + } + err := db.FromContext(ctx).Model(&Organization{}).Create(o).Error + if err == nil { + return nil + } + if !isUniqueConstraintErr(err) { + return err + } + o.Slug = "" + } + return fmt.Errorf("failed to generate unique slug for organization %q after 5 attempts", o.Name) +} + +func (o *Organization) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&Organization{}). + Where("id = ? OR slug = ?", o.ID, o.Slug). + First(o).Error +} + +func (o *Organization) ListAll(ctx context.Context) ([]Organization, error) { + var orgs []Organization + err := db.FromContext(ctx).Model(&Organization{}).Find(&orgs).Error + return orgs, err +} + +func (o *Organization) Update(ctx context.Context) error { + return db.FromContext(ctx).Model(&Organization{}). + Where("id = ?", o.ID). + Updates(o).Error +} + +func (o *Organization) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&Organization{}). + Where("id = ?", o.ID). + Delete(o).Error +} + +// isUniqueConstraintErr returns true if err is a unique constraint violation +// from SQLite or PostgreSQL. +func isUniqueConstraintErr(err error) bool { + if err == nil { + return false + } + msg := err.Error() + // SQLite: "UNIQUE constraint failed" + // PostgreSQL: "duplicate key value violates unique constraint" + return strings.Contains(msg, "UNIQUE constraint failed") || + strings.Contains(msg, "duplicate key value violates unique constraint") || + strings.Contains(msg, "23505") // pg error code +} + +// EnsureDefaultOrganization creates the default organization if none exists, +// returning the org (existing or newly created). +func EnsureDefaultOrganization(ctx context.Context) (*Organization, error) { + var orgs []Organization + if err := db.FromContext(ctx).Model(&Organization{}).Limit(1).Find(&orgs).Error; err != nil { + return nil, err + } + if len(orgs) > 0 { + return &orgs[0], nil + } + org := &Organization{Name: "Default", Slug: "default"} + err := db.FromContext(ctx).Model(&Organization{}). + Where(gorm.Model{}). + FirstOrCreate(org, Organization{Slug: "default"}).Error + if err != nil { + // Slug "default" taken — use generated slug. + org.Slug = "" + if createErr := org.Create(ctx); createErr != nil { + return nil, createErr + } + } + return org, nil +} diff --git a/schema/tenants.go b/schema/tenants.go new file mode 100644 index 000000000..e71d0d6a8 --- /dev/null +++ b/schema/tenants.go @@ -0,0 +1,103 @@ +package schema + +import ( + "context" + "fmt" + "time" + + "github.com/google/uuid" + "github.com/gravitl/netmaker/db" + "gorm.io/gorm" +) + +type Tenant struct { + ID string `gorm:"primaryKey" json:"id"` + Name string `gorm:"not null" json:"name"` + Slug string `gorm:"uniqueIndex;not null" json:"slug"` + OrganizationID string `gorm:"not null;index" json:"organization_id"` + CreatedAt time.Time `json:"created_at"` + UpdatedAt time.Time `json:"updated_at"` +} + +func (t *Tenant) TableName() string { + return "tenants_v1" +} + +func (t *Tenant) Create(ctx context.Context) error { + if t.ID == "" { + t.ID = uuid.NewString() + } + // Retry slug generation on unique constraint violation (max 5 attempts). + for i := 0; i < 5; i++ { + if t.Slug == "" { + t.Slug = generateSlug(t.Name) + } + err := db.FromContext(ctx).Model(&Tenant{}).Create(t).Error + if err == nil { + return nil + } + if !isUniqueConstraintErr(err) { + return err + } + t.Slug = "" + } + return fmt.Errorf("failed to generate unique slug for tenant %q after 5 attempts", t.Name) +} + +func (t *Tenant) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&Tenant{}). + Where("id = ? OR slug = ?", t.ID, t.Slug). + First(t).Error +} + +func (t *Tenant) ListAll(ctx context.Context) ([]Tenant, error) { + var tenants []Tenant + err := db.FromContext(ctx).Model(&Tenant{}).Find(&tenants).Error + return tenants, err +} + +func (t *Tenant) ListByOrg(ctx context.Context, orgID string) ([]Tenant, error) { + var tenants []Tenant + err := db.FromContext(ctx).Model(&Tenant{}). + Where("organization_id = ?", orgID). + Find(&tenants).Error + return tenants, err +} + +func (t *Tenant) Update(ctx context.Context) error { + return db.FromContext(ctx).Model(&Tenant{}). + Where("id = ?", t.ID). + Updates(t).Error +} + +func (t *Tenant) Delete(ctx context.Context) error { + return db.FromContext(ctx).Model(&Tenant{}). + Where("id = ?", t.ID). + Delete(t).Error +} + +// EnsureDefaultTenant creates the default tenant for the given org if none +// exists, returning the tenant (existing or newly created). +func EnsureDefaultTenant(ctx context.Context, orgID string) (*Tenant, error) { + var tenants []Tenant + if err := db.FromContext(ctx).Model(&Tenant{}). + Where("organization_id = ?", orgID). + Limit(1).Find(&tenants).Error; err != nil { + return nil, err + } + if len(tenants) > 0 { + return &tenants[0], nil + } + tenant := &Tenant{OrganizationID: orgID, Name: "Default", Slug: "default"} + err := db.FromContext(ctx).Model(&Tenant{}). + Where(gorm.Model{}). + FirstOrCreate(tenant, Tenant{Slug: "default", OrganizationID: orgID}).Error + if err != nil { + // Slug "default" taken — use generated slug. + tenant.Slug = "" + if createErr := tenant.Create(ctx); createErr != nil { + return nil, createErr + } + } + return tenant, nil +} From be4988694aa1c943a22b18214e88bfaa64bb83f5 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 13:55:40 +0530 Subject: [PATCH 02/21] feat(go): add org/tenant scoping to db; add middleware for api handlers to declare scope; --- controllers/scope.go | 53 ++++++++++++++++++++++++++++++++++++++++++++ db/scope.go | 52 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 controllers/scope.go create mode 100644 db/scope.go diff --git a/controllers/scope.go b/controllers/scope.go new file mode 100644 index 000000000..89f8da79e --- /dev/null +++ b/controllers/scope.go @@ -0,0 +1,53 @@ +package controller + +import ( + "errors" + "net/http" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logic" +) + +var ( + errMissingTenantID = errors.New("X-Tenant-ID header is required") + errMissingOrgID = errors.New("X-Organization-ID header is required") +) + +const ( + HeaderTenantID = "X-Tenant-ID" + HeaderOrgID = "X-Organization-ID" +) + +// Scope wraps an http.Handler to enforce request-level tenancy scoping. +// +// For db.TenantScope: requires the X-Tenant-ID header and injects a +// WHERE tenant_id = ? scope into the GORM db stored in the request context. +// +// For db.OrgScope: requires the X-Organization-ID header and injects a +// WHERE organization_id = ? scope. +// +// For db.GlobalScope: passes through without modification. +func Scope(level db.ScopeLevel, next http.Handler) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + var id string + switch level { + case db.TenantScope: + id = r.Header.Get(HeaderTenantID) + if id == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingTenantID, logic.BadReq)) + return + } + case db.OrgScope: + id = r.Header.Get(HeaderOrgID) + if id == "" { + logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingOrgID, logic.BadReq)) + return + } + case db.GlobalScope: + // no header required + } + + ctx := db.Scope(r.Context(), level, id) + next.ServeHTTP(w, r.WithContext(ctx)) + } +} diff --git a/db/scope.go b/db/scope.go new file mode 100644 index 000000000..36127ef21 --- /dev/null +++ b/db/scope.go @@ -0,0 +1,52 @@ +package db + +import ( + "context" + + "gorm.io/gorm" +) + +// ScopeLevel represents the tenancy scope of a request. +type ScopeLevel int + +const ( + // GlobalScope applies no tenant filtering — raw, unscoped access. + GlobalScope ScopeLevel = iota + // OrgScope filters queries to a specific organization (WHERE organization_id = ?). + OrgScope + // TenantScope filters queries to a specific tenant (WHERE tenant_id = ?). + TenantScope +) + +// Scope returns a new context whose GORM db is scoped to the given level. +// +// For OrgScope and TenantScope, exactly one id must be provided. +// For GlobalScope, no id is needed; the db is returned unscoped. +// +// Panics on invalid usage (wrong number of ids). These call sites are always +// static, so invalid usage is caught during development and code review. +func Scope(ctx context.Context, level ScopeLevel, ids ...string) context.Context { + if len(ids) > 1 { + panic("db.Scope: at most one id is allowed") + } + if level != GlobalScope && len(ids) == 0 { + panic("db.Scope: id required for non-global scope") + } + if level == GlobalScope { + return ctx + } + gdb := FromContext(ctx) + switch level { + case TenantScope: + gdb = gdb.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Where("tenant_id = ?", ids[0]) + }) + case OrgScope: + gdb = gdb.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Where("organization_id = ?", ids[0]) + }) + default: + panic("db.Scope: unknown level") + } + return context.WithValue(ctx, dbCtxKey, gdb) +} From 0242485e96edc21a199978796ba05db0fc7535d2 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 14:05:47 +0530 Subject: [PATCH 03/21] feat(go): add tenant scope declaration for some controllers; --- controllers/acls.go | 14 +++---- controllers/dns.go | 26 ++++++------ controllers/egress.go | 10 ++--- controllers/enrollmentkeys.go | 15 +++---- controllers/ext_client.go | 18 ++++---- controllers/gateway.go | 12 +++--- controllers/hosts.go | 36 ++++++++-------- controllers/inet_gws.go | 7 ++-- controllers/network.go | 14 +++---- controllers/node.go | 22 +++++----- controllers/server.go | 8 ++-- controllers/user.go | 44 ++++++++++---------- pro/controllers/auto_relay.go | 6 +-- pro/controllers/events.go | 7 ++-- pro/controllers/flows.go | 4 +- pro/controllers/integrations.go | 10 +++-- pro/controllers/jit.go | 17 ++++---- pro/controllers/metrics.go | 10 +++-- pro/controllers/networks.go | 7 +++- pro/controllers/posture_check.go | 13 +++--- pro/controllers/rac.go | 8 ++-- pro/controllers/tags.go | 9 ++-- pro/controllers/users.go | 71 ++++++++++++++++---------------- schema/integrations.go | 3 +- 24 files changed, 205 insertions(+), 186 deletions(-) diff --git a/controllers/acls.go b/controllers/acls.go index 7276703d6..2f43aa699 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -18,19 +18,19 @@ import ( ) func aclHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/acls", logic.SecurityCheck(true, http.HandlerFunc(getAcls))). + r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAcls)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls/egress", logic.SecurityCheck(true, http.HandlerFunc(getEgressAcls))). + r.HandleFunc("/api/v1/acls/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressAcls)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls/policy_types", logic.SecurityCheck(true, http.HandlerFunc(aclPolicyTypes))). + r.HandleFunc("/api/v1/acls/policy_types", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclPolicyTypes)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls", logic.SecurityCheck(true, http.HandlerFunc(createAcl))). + r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createAcl)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/acls", logic.SecurityCheck(true, http.HandlerFunc(updateAcl))). + r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAcl)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/acls", logic.SecurityCheck(true, http.HandlerFunc(deleteAcl))). + r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAcl)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/acls/debug", logic.SecurityCheck(true, http.HandlerFunc(aclDebug))). + r.HandleFunc("/api/v1/acls/debug", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclDebug)))). Methods(http.MethodGet) } diff --git a/controllers/dns.go b/controllers/dns.go index 4db8b9a8f..4a3273e41 100644 --- a/controllers/dns.go +++ b/controllers/dns.go @@ -24,27 +24,27 @@ import ( func dnsHandlers(r *mux.Router) { - r.HandleFunc("/api/dns", logic.SecurityCheck(true, http.HandlerFunc(getAllDNS))). + r.HandleFunc("/api/dns", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/nodes", logic.SecurityCheck(true, http.HandlerFunc(getNodeDNS))). + r.HandleFunc("/api/dns/adm/{network}/nodes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/custom", logic.SecurityCheck(true, http.HandlerFunc(getCustomDNS))). + r.HandleFunc("/api/dns/adm/{network}/custom", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getCustomDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}", logic.SecurityCheck(true, http.HandlerFunc(getDNS))). + r.HandleFunc("/api/dns/adm/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/sync", logic.SecurityCheck(true, http.HandlerFunc(syncDNS))). + r.HandleFunc("/api/dns/adm/{network}/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/{network}", logic.SecurityCheck(true, http.HandlerFunc(createDNS))). + r.HandleFunc("/api/dns/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/adm/pushdns", logic.SecurityCheck(true, http.HandlerFunc(pushDNS))). + r.HandleFunc("/api/dns/adm/pushdns", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(pushDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(true, http.HandlerFunc(deleteDNS))). + r.HandleFunc("/api/dns/{network}/{domain}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteDNS)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(createNs))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(listNs))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(updateNs))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(deleteNs))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nameserver/global", logic.SecurityCheck(true, http.HandlerFunc(getGlobalNs))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNs)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNs)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNs)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNs)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/nameserver/global", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getGlobalNs)))).Methods(http.MethodGet) } // @Summary List Global Nameservers diff --git a/controllers/egress.go b/controllers/egress.go index abdd52254..fa4b873d0 100644 --- a/controllers/egress.go +++ b/controllers/egress.go @@ -20,11 +20,11 @@ import ( ) func egressHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/egress/presets", logic.SecurityCheck(true, http.HandlerFunc(getEgressPresets))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(createEgress))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(listEgress))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(updateEgress))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/egress", logic.SecurityCheck(true, http.HandlerFunc(deleteEgress))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/egress/presets", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressPresets)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgress)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEgress)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEgress)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgress)))).Methods(http.MethodDelete) } // @Summary List egress domain presets diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 7cd3dd26b..17c42da18 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -15,6 +15,7 @@ import ( "golang.org/x/exp/slog" "github.com/gravitl/netmaker/auth" + "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -24,21 +25,21 @@ import ( ) func enrollmentKeyHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey))). + r.HandleFunc("/api/v1/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys))). + r.HandleFunc("/api/v1/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys)))). Methods(http.MethodGet) - r.HandleFunc("/api/v2/enrollment-keys", logic.SecurityCheck(true, http.HandlerFunc(listEnrollmentKeys))). + r.HandleFunc("/api/v2/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEnrollmentKeys)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/enrollment-keys/network/{network}/default", logic.SecurityCheck(true, http.HandlerFunc(getDefaultEnrollmentKeyForNetwork))). + r.HandleFunc("/api/v1/enrollment-keys/network/{network}/default", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDefaultEnrollmentKeyForNetwork)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}/regenerate-token", logic.SecurityCheck(true, http.HandlerFunc(regenerateEnrollmentKeyToken))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}/regenerate-token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(regenerateEnrollmentKeyToken)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey)))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}", logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey)))). Methods(http.MethodPut) } diff --git a/controllers/ext_client.go b/controllers/ext_client.go index e963a644a..a72e089f6 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -32,23 +32,23 @@ var extUpdateMutex = &sync.Mutex{} func extClientHandlers(r *mux.Router) { - r.HandleFunc("/api/extclients", logic.SecurityCheck(true, http.HandlerFunc(getAllExtClients))). + r.HandleFunc("/api/extclients", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllExtClients)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtClients))). + r.HandleFunc("/api/extclients/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtClients)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(getExtClient))). + r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClient)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientConf))). + r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClientConf)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(updateExtClient))). + r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(updateExtClient)))). Methods(http.MethodPut) - r.HandleFunc("/api/extclients/{network}/{clientid}", logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient))). + r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/extclients/{network}/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteExtClients))). + r.HandleFunc("/api/v1/extclients/{network}/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteExtClients)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/extclients/{network}/bulk/status", logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateExtClientStatus))). + r.HandleFunc("/api/v1/extclients/{network}/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateExtClientStatus)))). Methods(http.MethodPut) - r.HandleFunc("/api/extclients/{network}/{nodeid}", logic.SecurityCheck(false, http.HandlerFunc(createExtClient))). + r.HandleFunc("/api/extclients/{network}/{nodeid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(createExtClient)))). Methods(http.MethodPost) // unused API //r.HandleFunc("/api/v1/client_conf/{network}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientHAConf))).Methods(http.MethodGet) diff --git a/controllers/gateway.go b/controllers/gateway.go index 7a1ec9e32..5b5c4fdd6 100644 --- a/controllers/gateway.go +++ b/controllers/gateway.go @@ -23,13 +23,13 @@ import ( ) func gwHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", logic.SecurityCheck(true, http.HandlerFunc(assignGw))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", logic.SecurityCheck(true, http.HandlerFunc(unassignGw))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(assignGw)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unassignGw)))).Methods(http.MethodPost) // old relay handlers - r.HandleFunc("/api/nodes/{network}/{nodeid}/createrelay", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/createrelay", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) } // @Summary Create a gateway diff --git a/controllers/hosts.go b/controllers/hosts.go index 52cc0bffd..5e707d847 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -29,37 +29,37 @@ import ( ) func hostHandlers(r *mux.Router) { - r.HandleFunc("/api/hosts", logic.SecurityCheck(true, http.HandlerFunc(getHosts))). + r.HandleFunc("/api/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/hosts", logic.SecurityCheck(true, http.HandlerFunc(listHosts))). + r.HandleFunc("/api/v1/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/hosts/keys", logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys))). + r.HandleFunc("/api/hosts/keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/sync", logic.SecurityCheck(true, http.HandlerFunc(syncHosts))). + r.HandleFunc("/api/hosts/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHosts)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/upgrade", logic.SecurityCheck(true, http.HandlerFunc(upgradeHosts))). + r.HandleFunc("/api/hosts/upgrade", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHosts)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}/keys", logic.SecurityCheck(true, http.HandlerFunc(updateKeys))). + r.HandleFunc("/api/hosts/{hostid}/keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateKeys)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}/sync", logic.SecurityCheck(true, http.HandlerFunc(syncHost))). + r.HandleFunc("/api/hosts/{hostid}/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(updateHost))). + r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateHost)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(getHost))). + r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHost)))). Methods(http.MethodGet) // used by netclient r.HandleFunc("/api/hosts/{hostid}", AuthorizeHost(http.HandlerFunc(deleteHost))). Methods(http.MethodDelete) // used by UI - r.HandleFunc("/api/v1/ui/hosts/{hostid}", logic.SecurityCheck(true, http.HandlerFunc(deleteHost))). + r.HandleFunc("/api/v1/ui/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHost)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/hosts/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteHosts))). + r.HandleFunc("/api/v1/hosts/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteHosts)))). Methods(http.MethodDelete) - r.HandleFunc("/api/hosts/{hostid}/upgrade", logic.SecurityCheck(true, http.HandlerFunc(upgradeHost))). + r.HandleFunc("/api/hosts/{hostid}/upgrade", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHost)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}/networks/{network}", logic.SecurityCheck(true, http.HandlerFunc(addHostToNetwork))). + r.HandleFunc("/api/hosts/{hostid}/networks/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addHostToNetwork)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}/networks/{network}", logic.SecurityCheck(true, http.HandlerFunc(deleteHostFromNetwork))). + r.HandleFunc("/api/hosts/{hostid}/networks/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHostFromNetwork)))). Methods(http.MethodDelete) r.HandleFunc("/api/hosts/adm/authenticate", authenticateHost).Methods(http.MethodPost) r.HandleFunc("/api/v1/host", AuthorizeHost(http.HandlerFunc(pull))). @@ -70,13 +70,13 @@ func hostHandlers(r *mux.Router) { Methods(http.MethodPut) r.HandleFunc("/api/v1/host/{hostid}/peer_info", AuthorizeHost(http.HandlerFunc(getHostPeerInfo))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/pending_hosts", logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts))). + r.HandleFunc("/api/v1/pending_hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/pending_hosts/approve/{id}", logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost))). + r.HandleFunc("/api/v1/pending_hosts/approve/{id}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/pending_hosts/reject/{id}", logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost))). + r.HandleFunc("/api/v1/pending_hosts/reject/{id}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/emqx/hosts", logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts))). + r.HandleFunc("/api/emqx/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts)))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/auth-register/host", socketHandler) } diff --git a/controllers/inet_gws.go b/controllers/inet_gws.go index dddd2002f..7ffc2c28a 100644 --- a/controllers/inet_gws.go +++ b/controllers/inet_gws.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/gorilla/mux" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -15,11 +16,11 @@ import ( ) func internetGatewayHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(createInternetGw))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createInternetGw)))). Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw)))). Methods(http.MethodPut) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw)))). Methods(http.MethodDelete) } diff --git a/controllers/network.go b/controllers/network.go index e07c60588..62b43a430 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -24,19 +24,19 @@ import ( ) func networkHandlers(r *mux.Router) { - r.HandleFunc("/api/networks", logic.SecurityCheck(true, http.HandlerFunc(getNetworks))). + r.HandleFunc("/api/networks", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworks)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/networks/stats", logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats))). + r.HandleFunc("/api/v1/networks/stats", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats)))). Methods(http.MethodGet) - r.HandleFunc("/api/networks", logic.SecurityCheck(true, http.HandlerFunc(createNetwork))). + r.HandleFunc("/api/networks", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNetwork)))). Methods(http.MethodPost) - r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(getNetwork))). + r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetwork)))). Methods(http.MethodGet) - r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork))). + r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork)))). Methods(http.MethodDelete) - r.HandleFunc("/api/networks/{networkname}", logic.SecurityCheck(true, http.HandlerFunc(updateNetwork))). + r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNetwork)))). Methods(http.MethodPut) - r.HandleFunc("/api/networks/{networkname}/egress_routes", logic.SecurityCheck(true, http.HandlerFunc(getNetworkEgressRoutes))) + r.HandleFunc("/api/networks/{networkname}/egress_routes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkEgressRoutes)))) } // @Summary Lists all networks diff --git a/controllers/node.go b/controllers/node.go index da54cf81e..8a60959bd 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -31,20 +31,20 @@ var hostIDHeader = "host-id" func nodeHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes", logic.SecurityCheck(true, http.HandlerFunc(getAllNodes))).Methods(http.MethodGet) - r.HandleFunc("/api/nodes/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodes))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/nodes/{network}", logic.SecurityCheck(true, http.HandlerFunc(listNetworkNodes))).Methods(http.MethodGet) + r.HandleFunc("/api/nodes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/nodes/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nodes/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkNodes)))).Methods(http.MethodGet) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(getNode))).Methods(http.MethodGet) - r.HandleFunc("/api/nodes/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(updateNode))).Methods(http.MethodPut) + r.HandleFunc("/api/nodes/{network}/{nodeid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNode)))).Methods(http.MethodPut) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(deleteNode))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", logic.SecurityCheck(true, http.HandlerFunc(createGateway))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", logic.SecurityCheck(true, http.HandlerFunc(deleteGateway))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods(http.MethodPost) - r.HandleFunc("/api/v1/nodes/{network}/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteNodes))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nodes/{network}/bulk/status", logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateNodeStatus))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/nodes/{network}/status", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nodes/{network}/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteNodes)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/nodes/{network}/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateNodeStatus)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/nodes/{network}/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus)))).Methods(http.MethodGet) } func authenticate(response http.ResponseWriter, request *http.Request) { diff --git a/controllers/server.go b/controllers/server.go index bbb432231..4aeb99d9a 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -61,19 +61,19 @@ func serverHandlers(r *mux.Router) { Methods(http.MethodGet) r.HandleFunc("/api/server/settings", allowUsers(http.HandlerFunc(getSettings))). Methods(http.MethodGet) - r.HandleFunc("/api/server/settings", logic.SecurityCheck(true, http.HandlerFunc(updateSettings))). + r.HandleFunc("/api/server/settings", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateSettings)))). Methods(http.MethodPut) - r.HandleFunc("/api/server/getserverinfo", logic.SecurityCheck(true, http.HandlerFunc(getServerInfo))). + r.HandleFunc("/api/server/getserverinfo", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getServerInfo)))). Methods(http.MethodGet) r.HandleFunc("/api/server/status", getStatus).Methods(http.MethodGet) - r.HandleFunc("/api/server/usage", logic.SecurityCheck(false, http.HandlerFunc(getUsage))). + r.HandleFunc("/api/server/usage", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUsage)))). Methods(http.MethodGet) r.HandleFunc("/api/server/cpu_profile", logic.SecurityCheck(false, http.HandlerFunc(cpuProfile))). Methods(http.MethodPost) r.HandleFunc("/api/server/mem_profile", logic.SecurityCheck(false, http.HandlerFunc(memProfile))). Methods(http.MethodPost) r.HandleFunc("/api/server/feature_flags", getFeatureFlags).Methods(http.MethodGet) - r.HandleFunc("/api/server/onboarding", logic.SecurityCheck(true, http.HandlerFunc(getOnboarding))).Methods(http.MethodGet) + r.HandleFunc("/api/server/onboarding", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getOnboarding)))).Methods(http.MethodGet) } func cpuProfile(w http.ResponseWriter, r *http.Request) { diff --git a/controllers/user.go b/controllers/user.go index 86d8e8b3d..c0ac0b9a9 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -40,31 +40,31 @@ var ListRoles = listRoles func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/adm/hassuperadmin", hasSuperAdmin).Methods(http.MethodGet) r.HandleFunc("/api/users/adm/createsuperadmin", createSuperAdmin).Methods(http.MethodPost) - r.HandleFunc("/api/users/adm/transfersuperadmin/{username}", logic.SecurityCheck(true, http.HandlerFunc(transferSuperAdmin))). + r.HandleFunc("/api/users/adm/transfersuperadmin/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(transferSuperAdmin)))). Methods(http.MethodPost) r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/validate-identity", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(validateUserIdentity)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/auth/init-totp", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(initiateTOTPSetup)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/auth/complete-totp", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(completeTOTPSetup)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/validate-identity", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(validateUserIdentity))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/auth/init-totp", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(initiateTOTPSetup))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/auth/complete-totp", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(completeTOTPSetup))))).Methods(http.MethodPost) r.HandleFunc("/api/users/{username}/auth/verify-totp", logic.PreAuthCheck(logic.ContinueIfUserMatch(http.HandlerFunc(verifyTOTP)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(updateUser))).Methods(http.MethodPut) - r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(createUser))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}", logic.SecurityCheck(true, http.HandlerFunc(deleteUser))).Methods(http.MethodDelete) - r.HandleFunc("/api/users/{username}", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser)))).Methods(http.MethodGet) - r.HandleFunc("/api/users/{username}/enable", logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/disable", logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings)))).Methods(http.MethodGet) - r.HandleFunc("/api/users/{username}/settings", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users", logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1)))).Methods(http.MethodGet) - r.HandleFunc("/api/users", logic.SecurityCheck(true, http.HandlerFunc(getUsers))).Methods(http.MethodGet) - r.HandleFunc("/api/v2/users", logic.SecurityCheck(true, http.HandlerFunc(listUsers))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/bulk", logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/bulk/status", logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateUserStatus))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/roles", logic.SecurityCheck(true, http.HandlerFunc(ListRoles))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/access_token", logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/logout", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(logout)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUser)))).Methods(http.MethodPut) + r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUser)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUser)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/{username}/enable", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/disable", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/settings", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/{username}/settings", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings))))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1))))).Methods(http.MethodGet) + r.HandleFunc("/api/users", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v2/users", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateUserStatus)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/roles", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ListRoles)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/logout", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(logout))))).Methods(http.MethodPost) } // @Summary Create a user API access token diff --git a/pro/controllers/auto_relay.go b/pro/controllers/auto_relay.go index 41407fe1a..4f0db8fc1 100644 --- a/pro/controllers/auto_relay.go +++ b/pro/controllers/auto_relay.go @@ -24,11 +24,11 @@ import ( func AutoRelayHandlers(r *mux.Router) { r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.AuthorizeHost(http.HandlerFunc(getAutoRelayGws))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", logic.SecurityCheck(true, http.HandlerFunc(unsetAutoRelay))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unsetAutoRelay)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw))). + r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw)))). Methods(http.MethodPost) r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.AuthorizeHost(http.HandlerFunc(autoRelayME))). Methods(http.MethodPost) diff --git a/pro/controllers/events.go b/pro/controllers/events.go index 954ab69b3..81ac3e1f7 100644 --- a/pro/controllers/events.go +++ b/pro/controllers/events.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -13,9 +14,9 @@ import ( ) func EventHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/network/activity", logic.SecurityCheck(true, http.HandlerFunc(listNetworkActivity))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/user/activity", logic.SecurityCheck(false, http.HandlerFunc(listUserActivity))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/activity", logic.SecurityCheck(true, http.HandlerFunc(listActivity))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/network/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkActivity)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/user/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(listUserActivity)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listActivity)))).Methods(http.MethodGet) } // @Summary List network activity diff --git a/pro/controllers/flows.go b/pro/controllers/flows.go index da75f993b..aa3592539 100644 --- a/pro/controllers/flows.go +++ b/pro/controllers/flows.go @@ -10,13 +10,15 @@ import ( "github.com/gorilla/mux" ch "github.com/gravitl/netmaker/clickhouse" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" proLogic "github.com/gravitl/netmaker/pro/logic" ) func FlowHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/flows", logic.SecurityCheck(true, http.HandlerFunc(handleListFlows))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/flows", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(handleListFlows)))).Methods(http.MethodGet) } const ( diff --git a/pro/controllers/integrations.go b/pro/controllers/integrations.go index 8ba3667a7..5f60b33f5 100644 --- a/pro/controllers/integrations.go +++ b/pro/controllers/integrations.go @@ -8,6 +8,8 @@ import ( "net/http" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/grpc/siem" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -21,10 +23,10 @@ import ( ) func IntegrationHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/integrations/{type}", logic.SecurityCheck(true, http.HandlerFunc(getIntegration))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/integrations/{type}/{id}", logic.SecurityCheck(true, http.HandlerFunc(upsertIntegration))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/integrations/{type}/{id}", logic.SecurityCheck(true, http.HandlerFunc(deleteIntegration))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/integrations/{type}/{id}/test", logic.SecurityCheck(true, http.HandlerFunc(testIntegration))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/integrations/{type}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIntegration)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/integrations/{type}/{id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upsertIntegration)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/integrations/{type}/{id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteIntegration)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/integrations/{type}/{id}/test", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIntegration)))).Methods(http.MethodPost) } // extractAndValidateIntegration pulls {type} and {id} from the URL diff --git a/pro/controllers/jit.go b/pro/controllers/jit.go index c4e753cc0..d660eaf96 100644 --- a/pro/controllers/jit.go +++ b/pro/controllers/jit.go @@ -9,6 +9,7 @@ import ( "time" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -20,17 +21,17 @@ import ( ) func JITHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/jit", logic.SecurityCheck(true, - http.HandlerFunc(handleJIT))).Methods(http.MethodPost, http.MethodGet) + r.HandleFunc("/api/v1/jit", controller.Scope(db.TenantScope, logic.SecurityCheck(true, + http.HandlerFunc(handleJIT)))).Methods(http.MethodPost, http.MethodGet) - r.HandleFunc("/api/v1/jit", logic.SecurityCheck(true, - http.HandlerFunc(deleteJITGrant))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/jit", controller.Scope(db.TenantScope, logic.SecurityCheck(true, + http.HandlerFunc(deleteJITGrant)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/jit_user/networks", logic.SecurityCheck(false, - http.HandlerFunc(getUserJITNetworks))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/jit_user/networks", controller.Scope(db.TenantScope, logic.SecurityCheck(false, + http.HandlerFunc(getUserJITNetworks)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/jit_user/request", logic.SecurityCheck(false, - http.HandlerFunc(requestJITAccess))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/jit_user/request", controller.Scope(db.TenantScope, logic.SecurityCheck(false, + http.HandlerFunc(requestJITAccess)))).Methods(http.MethodPost) } // @Summary List JIT requests for a network diff --git a/pro/controllers/metrics.go b/pro/controllers/metrics.go index 6323f6eb2..2fd4dc710 100644 --- a/pro/controllers/metrics.go +++ b/pro/controllers/metrics.go @@ -8,7 +8,9 @@ import ( "golang.org/x/exp/slog" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -16,10 +18,10 @@ import ( // MetricHandlers - How we handle Pro Metrics func MetricHandlers(r *mux.Router) { - r.HandleFunc("/api/metrics/{network}/{nodeid}", logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics", logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics-ext/{network}", logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtMetrics))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics/{network}/{nodeid}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics/{network}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics-ext/{network}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtMetrics)))).Methods(http.MethodGet) } // @Summary Get metrics for a specific node diff --git a/pro/controllers/networks.go b/pro/controllers/networks.go index 05c9a077d..14801fa1e 100644 --- a/pro/controllers/networks.go +++ b/pro/controllers/networks.go @@ -2,14 +2,17 @@ package controllers import ( "encoding/json" + "net/http" + "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "net/http" ) func NetworkHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/networks/{network}/graph", logic.SecurityCheck(true, http.HandlerFunc(getNetworkGraph))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/networks/{network}/graph", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkGraph)))).Methods(http.MethodGet) } // @Summary Get network topology graph diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index bba0c3330..73285caa9 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -10,6 +10,7 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -22,12 +23,12 @@ import ( ) func PostureCheckHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/posture_check", logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/posture_check/attrs", logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/posture_check/violations", logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/posture_check/attrs", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check/violations", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes)))).Methods(http.MethodGet) } // @Summary List Posture Checks Available Attributes diff --git a/pro/controllers/rac.go b/pro/controllers/rac.go index 0d1b127fd..8f253188c 100644 --- a/pro/controllers/rac.go +++ b/pro/controllers/rac.go @@ -4,11 +4,13 @@ import ( "net/http" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" ) func RacHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/rac/networks", logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworks))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/rac/network/{network}/access_points", logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworkGateways))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/rac/access_point/{access_point_id}/config", logic.SecurityCheck(false, http.HandlerFunc(getRemoteAccessGatewayConf))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/networks", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworks)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/network/{network}/access_points", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworkGateways)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/access_point/{access_point_id}/config", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getRemoteAccessGatewayConf)))).Methods(http.MethodGet) } diff --git a/pro/controllers/tags.go b/pro/controllers/tags.go index fe8371cfa..ab010e605 100644 --- a/pro/controllers/tags.go +++ b/pro/controllers/tags.go @@ -11,6 +11,7 @@ import ( "time" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -22,13 +23,13 @@ import ( ) func TagHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(getTags))). + r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getTags)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(createTag))). + r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createTag)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(updateTag))). + r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateTag)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/tags", logic.SecurityCheck(true, http.HandlerFunc(deleteTag))). + r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteTag)))). Methods(http.MethodDelete) } diff --git a/pro/controllers/users.go b/pro/controllers/users.go index bf4326b3b..42e056ebd 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -12,6 +12,7 @@ import ( "time" "github.com/gorilla/mux" + controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -43,47 +44,47 @@ func UserHandlers(r *mux.Router) { r.HandleFunc("/api/oauth/register/{regKey}", proAuth.RegisterHostSSO).Methods(http.MethodGet) // User Role Handlers - r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(getRole))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(createRole))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(updateRole))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/role", logic.SecurityCheck(true, http.HandlerFunc(deleteRole))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getRole)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createRole)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateRole)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteRole)))).Methods(http.MethodDelete) // User Group Handlers - r.HandleFunc("/api/v1/users/groups", logic.SecurityCheck(true, http.HandlerFunc(getUserGroups))).Methods(http.MethodGet) - r.HandleFunc("/api/v2/users/groups", logic.SecurityCheck(true, http.HandlerFunc(listUserGroups))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(getUserGroup))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(createUserGroup))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(updateUserGroup))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/group", logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/groups/network", logic.SecurityCheck(true, http.HandlerFunc(listNetworkUserGroups))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/network", logic.SecurityCheck(true, http.HandlerFunc(listNetworkUsers))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/add_network_user", logic.SecurityCheck(true, http.HandlerFunc(addUsertoNetwork))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/remove_network_user", logic.SecurityCheck(true, http.HandlerFunc(removeUserfromNetwork))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/unassigned_network_users", logic.SecurityCheck(true, http.HandlerFunc(listUnAssignedNetUsers))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/groups", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v2/users/groups", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroup)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserGroup)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUserGroup)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/groups/network", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/network", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/add_network_user", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addUsertoNetwork)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/remove_network_user", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserfromNetwork)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/unassigned_network_users", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUnAssignedNetUsers)))).Methods(http.MethodGet) // User Invite Handlers r.HandleFunc("/api/v1/users/invite", userInviteVerify).Methods(http.MethodGet) r.HandleFunc("/api/v1/users/invite-signup", userInviteSignUp).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/invite", logic.SecurityCheck(true, http.HandlerFunc(inviteUsers))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/invites", logic.SecurityCheck(true, http.HandlerFunc(listUserInvites))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/invite", logic.SecurityCheck(true, http.HandlerFunc(deleteUserInvite))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/invites", logic.SecurityCheck(true, http.HandlerFunc(deleteAllUserInvites))).Methods(http.MethodDelete) - - r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers))).Methods(http.MethodGet) - r.HandleFunc("/api/users_pending", logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers))).Methods(http.MethodDelete) - r.HandleFunc("/api/users_pending/user/{username}", logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser))).Methods(http.MethodDelete) - r.HandleFunc("/api/users_pending/user/{username}", logic.SecurityCheck(true, http.HandlerFunc(approvePendingUser))).Methods(http.MethodPost) - - r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(attachUserToRemoteAccessGw))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW))).Methods(http.MethodDelete) - r.HandleFunc("/api/users/{username}/remote_access_gw", logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1)))).Methods(http.MethodGet) - r.HandleFunc("/api/users/ingress/{ingress_id}", logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/network_ip", logic.SecurityCheck(true, http.HandlerFunc(userNetworkMapping))).Methods(http.MethodGet) - - r.HandleFunc("/api/idp/sync", logic.SecurityCheck(true, http.HandlerFunc(syncIDP))).Methods(http.MethodPost) - r.HandleFunc("/api/idp/sync/test", logic.SecurityCheck(true, http.HandlerFunc(testIDPSync))).Methods(http.MethodPost) - r.HandleFunc("/api/idp/sync/status", logic.SecurityCheck(true, http.HandlerFunc(getIDPSyncStatus))).Methods(http.MethodGet) - r.HandleFunc("/api/idp", logic.SecurityCheck(true, http.HandlerFunc(removeIDPIntegration))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/invite", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(inviteUsers)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/invites", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserInvites)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/invite", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserInvite)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/invites", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllUserInvites)))).Methods(http.MethodDelete) + + r.HandleFunc("/api/users_pending", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/users_pending", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users_pending/user/{username}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users_pending/user/{username}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingUser)))).Methods(http.MethodPost) + + r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(attachUserToRemoteAccessGw)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users/{username}/remote_access_gw", controller.Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/ingress/{ingress_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/network_ip", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(userNetworkMapping)))).Methods(http.MethodGet) + + r.HandleFunc("/api/idp/sync", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncIDP)))).Methods(http.MethodPost) + r.HandleFunc("/api/idp/sync/test", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIDPSync)))).Methods(http.MethodPost) + r.HandleFunc("/api/idp/sync/status", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIDPSyncStatus)))).Methods(http.MethodGet) + r.HandleFunc("/api/idp", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeIDPIntegration)))).Methods(http.MethodDelete) } // @Summary User signup via invite diff --git a/schema/integrations.go b/schema/integrations.go index 502257156..1d82516a6 100644 --- a/schema/integrations.go +++ b/schema/integrations.go @@ -9,7 +9,8 @@ import ( ) type Integration struct { - ID string `gorm:"primaryKey;column:id" json:"id"` + ID string `gorm:"primaryKey;column:id" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Type string `gorm:"not null;column:type" json:"type"` Config datatypes.JSON `gorm:"not null;column:config" json:"config"` CreatedAt time.Time `json:"created_at"` From 231f879904ce4d1b2fef3216ca2735f4704e3233 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 14:32:34 +0530 Subject: [PATCH 04/21] feat(go): add tenant id columns to all tables in schema pkg; --- schema/dns.go | 7 ++++--- schema/egress.go | 31 +++++++++++++++--------------- schema/enrollment_keys.go | 3 ++- schema/event.go | 5 +++-- schema/hosts.go | 3 ++- schema/jit_grant.go | 5 +++-- schema/jit_request.go | 5 +++-- schema/job.go | 1 + schema/models.go | 2 ++ schema/networks.go | 5 +++-- schema/nodes.go | 5 +++-- schema/pending_hosts.go | 5 +++-- schema/pending_users.go | 5 +++-- schema/posture_check.go | 7 ++++--- schema/posture_check_violations.go | 5 +++-- schema/user_access_token.go | 3 ++- schema/user_groups.go | 3 ++- schema/user_invites.go | 5 +++-- 18 files changed, 62 insertions(+), 43 deletions(-) diff --git a/schema/dns.go b/schema/dns.go index b70644e41..b637d8d6d 100644 --- a/schema/dns.go +++ b/schema/dns.go @@ -10,9 +10,10 @@ import ( ) type Nameserver struct { - ID string `gorm:"primaryKey" json:"id"` - Name string `gorm:"name" json:"name"` - NetworkID string `gorm:"network_id" json:"network_id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Name string `gorm:"name" json:"name"` + NetworkID string `gorm:"network_id" json:"network_id"` Description string `gorm:"description" json:"description"` Default bool `gorm:"column:default" json:"default"` Fallback bool `gorm:"fallback" json:"fallback"` diff --git a/schema/egress.go b/schema/egress.go index 13794ee41..e0b350f49 100644 --- a/schema/egress.go +++ b/schema/egress.go @@ -20,27 +20,28 @@ const ( ) type Egress struct { - ID string `gorm:"primaryKey" json:"id"` - Name string `gorm:"name" json:"name"` - Network string `gorm:"network" json:"network"` - Description string `gorm:"description" json:"description"` - Nodes datatypes.JSONMap `gorm:"nodes" json:"nodes"` - Tags datatypes.JSONMap `gorm:"tags" json:"tags"` - Range string `gorm:"range" json:"range"` - Mode EgressNATMode `gorm:"mode;default:direct_nat" json:"mode"` - VirtualRange string `gorm:"virtual_range" json:"virtual_range"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Name string `gorm:"name" json:"name"` + Network string `gorm:"network" json:"network"` + Description string `gorm:"description" json:"description"` + Nodes datatypes.JSONMap `gorm:"nodes" json:"nodes"` + Tags datatypes.JSONMap `gorm:"tags" json:"tags"` + Range string `gorm:"range" json:"range"` + Mode EgressNATMode `gorm:"mode;default:direct_nat" json:"mode"` + VirtualRange string `gorm:"virtual_range" json:"virtual_range"` // Domains is the user-configured hostname list (exact or *.suffix). Domains datatypes.JSONSlice[string] `gorm:"domains" json:"domains"` // DomainAnsByDomain maps each configured domain to its resolved CIDRs. DomainAnsByDomain datatypes.JSONMap `gorm:"domain_ans_by_domain" json:"domain_ans_by_domain"` - Nat bool `gorm:"nat" json:"nat"` + Nat bool `gorm:"nat" json:"nat"` //IsInetGw bool `gorm:"is_inet_gw" json:"is_internet_gateway"` // PresetID is the catalog id when this egress was created from a preset (empty if custom). - PresetID string `gorm:"preset_id" json:"preset_id"` - Status bool `gorm:"status" json:"status"` - CreatedBy string `gorm:"created_by" json:"created_by"` - CreatedAt time.Time `gorm:"created_at" json:"created_at"` - UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` + PresetID string `gorm:"preset_id" json:"preset_id"` + Status bool `gorm:"status" json:"status"` + CreatedBy string `gorm:"created_by" json:"created_by"` + CreatedAt time.Time `gorm:"created_at" json:"created_at"` + UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"` } func (e *Egress) Table() string { diff --git a/schema/enrollment_keys.go b/schema/enrollment_keys.go index 8065dc317..8a7a6f683 100644 --- a/schema/enrollment_keys.go +++ b/schema/enrollment_keys.go @@ -19,7 +19,8 @@ const ( ) type EnrollmentKey struct { - ID string `gorm:"primaryKey" json:"id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Name string `json:"name"` Value string `json:"value"` Token string `json:"token"` diff --git a/schema/event.go b/schema/event.go index 80e210d82..7af60d37a 100644 --- a/schema/event.go +++ b/schema/event.go @@ -95,8 +95,9 @@ const ( ) type Event struct { - ID string `gorm:"primaryKey" json:"id"` - Action Action `gorm:"action" json:"action"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Action Action `gorm:"action" json:"action"` Source datatypes.JSON `gorm:"source" json:"source"` Origin Origin `gorm:"origin" json:"origin"` Target datatypes.JSON `gorm:"target" json:"target"` diff --git a/schema/hosts.go b/schema/hosts.go index 3024c58cd..1ac194d4d 100644 --- a/schema/hosts.go +++ b/schema/hosts.go @@ -119,7 +119,8 @@ func (a *AddrPort) UnmarshalJSON(data []byte) error { } type Host struct { - ID uuid.UUID `gorm:"primaryKey" json:"id" yaml:"id"` + ID uuid.UUID `gorm:"primaryKey" json:"id" yaml:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Verbosity int `json:"verbosity" yaml:"verbosity"` FirewallInUse string `json:"firewallinuse" yaml:"firewallinuse"` Version string `json:"version" yaml:"version"` diff --git a/schema/jit_grant.go b/schema/jit_grant.go index 08f2e1678..09dd6411c 100644 --- a/schema/jit_grant.go +++ b/schema/jit_grant.go @@ -10,8 +10,9 @@ import ( const jitGrantTable = "jit_grants" type JITGrant struct { - ID string `gorm:"primaryKey" json:"id"` - NetworkID string `gorm:"network_id" json:"network_id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + NetworkID string `gorm:"network_id" json:"network_id"` UserID string `gorm:"user_id" json:"user_id"` RequestID string `gorm:"request_id" json:"request_id"` GrantedAt time.Time `gorm:"granted_at" json:"granted_at"` diff --git a/schema/jit_request.go b/schema/jit_request.go index d3e88e828..3467a483b 100644 --- a/schema/jit_request.go +++ b/schema/jit_request.go @@ -10,8 +10,9 @@ import ( const jitRequestTable = "jit_requests" type JITRequest struct { - ID string `gorm:"primaryKey" json:"id"` - NetworkID string `gorm:"network_id" json:"network_id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + NetworkID string `gorm:"network_id" json:"network_id"` UserID string `gorm:"user_id" json:"user_id"` UserName string `gorm:"user_name" json:"user_name"` Reason string `gorm:"reason" json:"reason"` diff --git a/schema/job.go b/schema/job.go index 1ba3badae..e1bceb093 100644 --- a/schema/job.go +++ b/schema/job.go @@ -18,6 +18,7 @@ import ( // being executed again. type Job struct { ID string `gorm:"primaryKey"` + TenantID string `gorm:"default:'';index"` CreatedAt time.Time } diff --git a/schema/models.go b/schema/models.go index d73d6ad0c..85d525915 100644 --- a/schema/models.go +++ b/schema/models.go @@ -3,6 +3,8 @@ package schema // ListModels lists all the models in this schema. func ListModels() []interface{} { return []interface{}{ + &Organization{}, + &Tenant{}, &Job{}, &Egress{}, &UserAccessToken{}, diff --git a/schema/networks.go b/schema/networks.go index 5cc3bb579..9e01c8824 100644 --- a/schema/networks.go +++ b/schema/networks.go @@ -26,8 +26,9 @@ var ( // // NOTE: json tags are different from field names to ensure compatibility with the older model. type Network struct { - ID string `gorm:"primaryKey" json:"id"` - Name string `gorm:"unique" json:"netid"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Name string `gorm:"unique" json:"netid"` AddressRange string `json:"addressrange"` AddressRange6 string `json:"addressrange6"` // in seconds. diff --git a/schema/nodes.go b/schema/nodes.go index 0a5a03022..f408ba665 100644 --- a/schema/nodes.go +++ b/schema/nodes.go @@ -36,8 +36,9 @@ const ( ) type Node struct { - ID string `gorm:"primaryKey" json:"id"` - HostID string `gorm:"not null;index" json:"host_id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + HostID string `gorm:"not null;index" json:"host_id"` Host *Host `gorm:"foreignKey:HostID;constraint:OnDelete:CASCADE" json:"host,omitempty"` NetworkID string `gorm:"not null;index" json:"network_id"` Network *Network `gorm:"foreignKey:NetworkID;constraint:OnDelete:CASCADE" json:"network,omitempty"` diff --git a/schema/pending_hosts.go b/schema/pending_hosts.go index 5c11bba87..c58fa0d0a 100644 --- a/schema/pending_hosts.go +++ b/schema/pending_hosts.go @@ -9,8 +9,9 @@ import ( ) type PendingHost struct { - ID string `gorm:"id" json:"id"` - HostID string `gorm:"host_id" json:"host_id"` + ID string `gorm:"id" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + HostID string `gorm:"host_id" json:"host_id"` Hostname string `gorm:"host_name" json:"host_name"` Network string `gorm:"network" json:"network"` PublicKey string `gorm:"public_key" json:"public_key"` diff --git a/schema/pending_users.go b/schema/pending_users.go index 6cf0c8438..6ced53f15 100644 --- a/schema/pending_users.go +++ b/schema/pending_users.go @@ -15,8 +15,9 @@ var ( ) type PendingUser struct { - ID string `gorm:"primaryKey" json:"id"` - Username string `gorm:"unique" json:"username"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Username string `gorm:"unique" json:"username"` ExternalIdentityProviderID string `json:"external_identity_provider_id"` CreatedAt time.Time `json:"created_at"` } diff --git a/schema/posture_check.go b/schema/posture_check.go index dd4479276..838c82260 100644 --- a/schema/posture_check.go +++ b/schema/posture_check.go @@ -90,9 +90,10 @@ var PostureCheckAttrValues = map[Attribute][]string{ } type PostureCheck struct { - ID string `gorm:"primaryKey" json:"id"` - Name string `gorm:"name" json:"name"` - NetworkID NetworkID `gorm:"network_id" json:"network_id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + Name string `gorm:"name" json:"name"` + NetworkID NetworkID `gorm:"network_id" json:"network_id"` Description string `gorm:"description" json:"description"` Attribute Attribute `gorm:"attribute" json:"attribute"` Values datatypes.JSONSlice[string] `gorm:"values" json:"values"` diff --git a/schema/posture_check_violations.go b/schema/posture_check_violations.go index 533ebaa34..84cd50065 100644 --- a/schema/posture_check_violations.go +++ b/schema/posture_check_violations.go @@ -8,8 +8,9 @@ const postureCheckViolationsTable = "posture_check_violations_v1" type PostureCheckViolation struct { EvaluationCycleID string `gorm:"primaryKey;column:evaluation_cycle_id" json:"evaluation_cycle_id"` - CheckID string `gorm:"primaryKey;column:check_id" json:"check_id"` - NodeID string `gorm:"primaryKey;column:node_id" json:"node_id"` + CheckID string `gorm:"primaryKey;column:check_id" json:"check_id"` + NodeID string `gorm:"primaryKey;column:node_id" json:"node_id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Name string `json:"name"` Attribute string `json:"attribute"` Message string `json:"message"` diff --git a/schema/user_access_token.go b/schema/user_access_token.go index 1760c4d9f..14b1056a4 100644 --- a/schema/user_access_token.go +++ b/schema/user_access_token.go @@ -9,7 +9,8 @@ import ( // UserAccessToken - token used to access netmaker type UserAccessToken struct { - ID string `gorm:"primaryKey" json:"id"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Name string `json:"name"` UserName string `json:"user_name"` ExpiresAt time.Time `json:"expires_at"` diff --git a/schema/user_groups.go b/schema/user_groups.go index fb3ba0241..1eb07f9b3 100644 --- a/schema/user_groups.go +++ b/schema/user_groups.go @@ -18,7 +18,8 @@ func (g UserGroupID) String() string { } type UserGroup struct { - ID UserGroupID `gorm:"primaryKey" json:"id"` + ID UserGroupID `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` Name string `json:"name"` Default bool `json:"default"` ExternalIdentityProviderID string `json:"external_identity_provider_id"` diff --git a/schema/user_invites.go b/schema/user_invites.go index cac6c86eb..c65d10f40 100644 --- a/schema/user_invites.go +++ b/schema/user_invites.go @@ -15,8 +15,9 @@ var ( ) type UserInvite struct { - ID string `gorm:"primaryKey" json:"id"` - InviteCode string `gorm:"unique" json:"invite_code"` + ID string `gorm:"primaryKey" json:"id"` + TenantID string `gorm:"default:'';index" json:"tenant_id"` + InviteCode string `gorm:"unique" json:"invite_code"` InviteURL string `json:"invite_url"` Email string `json:"email"` PlatformRoleID string `json:"platform_role_id"` From 7cdb593ab7200dbab9155489ceb0e6be7d83b9d8 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 14:33:20 +0530 Subject: [PATCH 05/21] feat(go): add multi-tenancy migration; --- migrate/migrate_multitenancy.go | 280 ++++++++++++++++++++++++++++++++ migrate/migrate_schema.go | 6 + 2 files changed, 286 insertions(+) create mode 100644 migrate/migrate_multitenancy.go diff --git a/migrate/migrate_multitenancy.go b/migrate/migrate_multitenancy.go new file mode 100644 index 000000000..b3ccd9731 --- /dev/null +++ b/migrate/migrate_multitenancy.go @@ -0,0 +1,280 @@ +package migrate + +import ( + "context" + "encoding/json" + "fmt" + "strings" + + "github.com/gravitl/netmaker/db" + "github.com/gravitl/netmaker/logger" + "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" +) + +const ( + TableName_DNS = "dns" + TableName_ExtClients = "extclients" + TableName_Acls = "acls" + TableName_Metrics = "metrics" + TableName_Tags = "tags" +) + +// migrateMultitenancy bootstraps the multi-tenancy foundation: +// 1. Adds tenant_id (and network_id where applicable) columns to old KV tables. +// 2. Creates the default Organization and Tenant if none exist. +// 3. Backfills tenant_id and network_id on all existing KV table records. +// 4. Backfills tenant_id on all existing GORM resource records. +func migrateMultitenancy(ctx context.Context) error { + if err := addTenantColumnsToKVTables(ctx); err != nil { + return err + } + + org, err := schema.EnsureDefaultOrganization(ctx) + if err != nil { + return fmt.Errorf("multitenancy migration: ensure default organization: %w", err) + } + + tenant, err := schema.EnsureDefaultTenant(ctx, org.ID) + if err != nil { + return fmt.Errorf("multitenancy migration: ensure default tenant: %w", err) + } + + if err = backfillKVTableTenantID(ctx, tenant.ID); err != nil { + return err + } + if err = backfillKVTableNetworkID(ctx); err != nil { + return err + } + return backfillTenantID(ctx, tenant.ID) +} + +// addTenantColumnsToKVTables runs ALTER TABLE on the legacy key-value tables to +// add tenant_id and network_id columns. Errors caused by the column already +// existing are ignored so the migration is safe to re-run. +func addTenantColumnsToKVTables(ctx context.Context) error { + withNetworkID := []string{TableName_DNS, TableName_ExtClients, TableName_Acls, TableName_Metrics, TableName_Tags} + for _, table := range withNetworkID { + if err := addColumnIfNotExists(ctx, table, "tenant_id", "TEXT"); err != nil { + return fmt.Errorf("multitenancy migration: alter table %s (tenant_id): %w", table, err) + } + if err := addColumnIfNotExists(ctx, table, "network_id", "TEXT"); err != nil { + return fmt.Errorf("multitenancy migration: alter table %s (network_id): %w", table, err) + } + } + + tenantOnly := []string{"server_settings"} + for _, table := range tenantOnly { + if err := addColumnIfNotExists(ctx, table, "tenant_id", "TEXT"); err != nil { + return fmt.Errorf("multitenancy migration: alter table %s (tenant_id): %w", table, err) + } + } + + return nil +} + +// addColumnIfNotExists attempts ALTER TABLE … ADD COLUMN and ignores the error +// if the column already exists (SQLite: "duplicate column name", PostgreSQL: "already exists"). +func addColumnIfNotExists(ctx context.Context, table, column, colDef string) error { + sql := fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s", table, column, colDef) + err := db.FromContext(ctx).Exec(sql).Error + if err == nil { + return nil + } + msg := err.Error() + if strings.Contains(msg, "duplicate column name") || strings.Contains(msg, "already exists") { + return nil + } + return err +} + +// backfillKVTableTenantID sets tenant_id = tenantID on every row in the legacy +// KV tables where tenant_id is NULL or empty. +func backfillKVTableTenantID(ctx context.Context, tenantID string) error { + tables := []string{TableName_DNS, TableName_ExtClients, TableName_Acls, TableName_Metrics, TableName_Tags, "server_settings"} + for _, table := range tables { + sql := fmt.Sprintf("UPDATE %s SET tenant_id = ? WHERE tenant_id IS NULL OR tenant_id = ''", table) + if err := db.FromContext(ctx).Exec(sql, tenantID).Error; err != nil { + return fmt.Errorf("multitenancy migration: backfill tenant_id in %s: %w", table, err) + } + } + return nil +} + +// backfillKVTableNetworkID unmarshals each legacy KV table row into its typed +// model and writes the network identifier into the new network_id column. +func backfillKVTableNetworkID(ctx context.Context) error { + if err := backfillDNSNetworkID(ctx); err != nil { + return err + } + if err := backfillExtClientNetworkID(ctx); err != nil { + return err + } + if err := backfillAclNetworkID(ctx); err != nil { + return err + } + if err := backfillMetricsNetworkID(ctx); err != nil { + return err + } + return backfillTagNetworkID(ctx) +} + +func backfillDNSNetworkID(ctx context.Context) error { + records, err := kvList(ctx, TableName_DNS) + if err != nil { + return fmt.Errorf("multitenancy migration: list dns records: %w", err) + } + for key, value := range records { + var entry models.DNSEntry + if err := json.Unmarshal([]byte(value), &entry); err != nil { + return fmt.Errorf("multitenancy migration: parse dns record %s: %w", key, err) + } + if entry.Network == "" { + continue + } + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling network_id for dns record %s", key)) + if err := db.FromContext(ctx).Exec( + "UPDATE dns SET network_id = ? WHERE key = ? AND (network_id IS NULL OR network_id = '')", + entry.Network, key, + ).Error; err != nil { + return fmt.Errorf("multitenancy migration: set network_id in dns record %s: %w", key, err) + } + } + return nil +} + +func backfillExtClientNetworkID(ctx context.Context) error { + records, err := kvList(ctx, TableName_ExtClients) + if err != nil { + return fmt.Errorf("multitenancy migration: list extclients records: %w", err) + } + for key, value := range records { + var client models.ExtClient + if err := json.Unmarshal([]byte(value), &client); err != nil { + return fmt.Errorf("multitenancy migration: parse extclient record %s: %w", key, err) + } + if client.Network == "" { + continue + } + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling network_id for extclient record %s", key)) + if err := db.FromContext(ctx).Exec( + "UPDATE extclients SET network_id = ? WHERE key = ? AND (network_id IS NULL OR network_id = '')", + client.Network, key, + ).Error; err != nil { + return fmt.Errorf("multitenancy migration: set network_id in extclient record %s: %w", key, err) + } + } + return nil +} + +func backfillAclNetworkID(ctx context.Context) error { + records, err := kvList(ctx, TableName_Acls) + if err != nil { + return fmt.Errorf("multitenancy migration: list acls records: %w", err) + } + for key, value := range records { + var acl models.Acl + if err := json.Unmarshal([]byte(value), &acl); err != nil { + return fmt.Errorf("multitenancy migration: parse acl record %s: %w", key, err) + } + if acl.NetworkID == "" { + continue + } + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling network_id for acl record %s", key)) + if err := db.FromContext(ctx).Exec( + "UPDATE acls SET network_id = ? WHERE key = ? AND (network_id IS NULL OR network_id = '')", + string(acl.NetworkID), key, + ).Error; err != nil { + return fmt.Errorf("multitenancy migration: set network_id in acl record %s: %w", key, err) + } + } + return nil +} + +func backfillMetricsNetworkID(ctx context.Context) error { + records, err := kvList(ctx, TableName_Metrics) + if err != nil { + return fmt.Errorf("multitenancy migration: list metrics records: %w", err) + } + for key, value := range records { + var m models.Metrics + if err := json.Unmarshal([]byte(value), &m); err != nil { + return fmt.Errorf("multitenancy migration: parse metrics record %s: %w", key, err) + } + if m.Network == "" { + continue + } + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling network_id for metrics record %s", key)) + if err := db.FromContext(ctx).Exec( + "UPDATE metrics SET network_id = ? WHERE key = ? AND (network_id IS NULL OR network_id = '')", + m.Network, key, + ).Error; err != nil { + return fmt.Errorf("multitenancy migration: set network_id in metrics record %s: %w", key, err) + } + } + return nil +} + +func backfillTagNetworkID(ctx context.Context) error { + records, err := kvList(ctx, TableName_Tags) + if err != nil { + return fmt.Errorf("multitenancy migration: list tags records: %w", err) + } + for key, value := range records { + var tag models.Tag + if err := json.Unmarshal([]byte(value), &tag); err != nil { + return fmt.Errorf("multitenancy migration: parse tag record %s: %w", key, err) + } + if tag.Network == "" { + continue + } + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling network_id for tag record %s", key)) + if err := db.FromContext(ctx).Exec( + "UPDATE tags SET network_id = ? WHERE key = ? AND (network_id IS NULL OR network_id = '')", + string(tag.Network), key, + ).Error; err != nil { + return fmt.Errorf("multitenancy migration: set network_id in tag record %s: %w", key, err) + } + } + return nil +} + +// backfillTenantID updates all existing GORM resource records that have an +// empty tenant_id to the provided default tenant ID. +func backfillTenantID(ctx context.Context, tenantID string) error { + gormDB := db.FromContext(ctx) + + type namedModel struct { + name string + model any + } + + gormModels := []namedModel{ + {"Network", &schema.Network{}}, + {"Host", &schema.Host{}}, + {"Node", &schema.Node{}}, + {"EnrollmentKey", &schema.EnrollmentKey{}}, + {"Egress", &schema.Egress{}}, + {"PendingHost", &schema.PendingHost{}}, + {"PendingUser", &schema.PendingUser{}}, + {"UserInvite", &schema.UserInvite{}}, + {"UserGroup", &schema.UserGroup{}}, + {"JITRequest", &schema.JITRequest{}}, + {"JITGrant", &schema.JITGrant{}}, + {"PostureCheck", &schema.PostureCheck{}}, + {"PostureCheckViolation", &schema.PostureCheckViolation{}}, + {"Integration", &schema.Integration{}}, + {"Event", &schema.Event{}}, + {"Job", &schema.Job{}}, + {"UserAccessToken", &schema.UserAccessToken{}}, + {"Nameserver", &schema.Nameserver{}}, + } + + for _, m := range gormModels { + logger.Log(4, fmt.Sprintf("multitenancy migration: backfilling tenant_id for %s", m.name)) + if err := gormDB.Model(m.model).Where("tenant_id = ''").Update("tenant_id", tenantID).Error; err != nil { + return fmt.Errorf("multitenancy migration: backfill tenant_id for %s: %w", m.name, err) + } + } + return nil +} diff --git a/migrate/migrate_schema.go b/migrate/migrate_schema.go index ff8adfa23..cf03b2cdf 100644 --- a/migrate/migrate_schema.go +++ b/migrate/migrate_schema.go @@ -40,6 +40,12 @@ func ToSQLSchema() error { return err } + // v1.7.0 multi-tenancy: adds org/tenant tables, tenant_id columns, and bootstraps defaults. + err = ensureMigrationCompleted(context.TODO(), "migration-v1.7.0-multitenancy", migrateMultitenancy) + if err != nil { + return err + } + return nil } From 8b8d5ce907eab861e760e73ea4504f0fdf24be9f Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 18:55:06 +0530 Subject: [PATCH 06/21] feat(go): migration ensures default org and tenant is created only once; move slug generation to utils; --- migrate/migrate_multitenancy.go | 25 ++++++++++++------ schema/organizations.go | 46 ++++++--------------------------- schema/tenants.go | 36 +++++++------------------- schema/utils.go | 20 ++++++++++++++ 4 files changed, 54 insertions(+), 73 deletions(-) create mode 100644 schema/utils.go diff --git a/migrate/migrate_multitenancy.go b/migrate/migrate_multitenancy.go index b3ccd9731..470d29b45 100644 --- a/migrate/migrate_multitenancy.go +++ b/migrate/migrate_multitenancy.go @@ -26,27 +26,36 @@ const ( // 3. Backfills tenant_id and network_id on all existing KV table records. // 4. Backfills tenant_id on all existing GORM resource records. func migrateMultitenancy(ctx context.Context) error { - if err := addTenantColumnsToKVTables(ctx); err != nil { + err := addTenantColumnsToKVTables(ctx) + if err != nil { return err } - org, err := schema.EnsureDefaultOrganization(ctx) + defaultOrg := &schema.Organization{} + err = defaultOrg.CreateDefault(ctx) if err != nil { - return fmt.Errorf("multitenancy migration: ensure default organization: %w", err) + return fmt.Errorf("multitenancy migration: failed to create default organization: %w", err) } - tenant, err := schema.EnsureDefaultTenant(ctx, org.ID) + defaultTenant := &schema.Tenant{ + OrganizationID: defaultOrg.ID, + } + err = defaultTenant.CreateDefault(ctx) if err != nil { - return fmt.Errorf("multitenancy migration: ensure default tenant: %w", err) + return fmt.Errorf("multitenancy migration: failed to create default tenant: %w", err) } - if err = backfillKVTableTenantID(ctx, tenant.ID); err != nil { + err = backfillKVTableTenantID(ctx, defaultTenant.ID) + if err != nil { return err } - if err = backfillKVTableNetworkID(ctx); err != nil { + + err = backfillTenantID(ctx, defaultTenant.ID) + if err != nil { return err } - return backfillTenantID(ctx, tenant.ID) + + return backfillKVTableNetworkID(ctx) } // addTenantColumnsToKVTables runs ALTER TABLE on the legacy key-value tables to diff --git a/schema/organizations.go b/schema/organizations.go index b9e99231b..da765b410 100644 --- a/schema/organizations.go +++ b/schema/organizations.go @@ -3,27 +3,14 @@ package schema import ( "context" "fmt" - "math/rand" - "regexp" "strings" "time" "github.com/google/uuid" "github.com/gravitl/netmaker/db" - "gorm.io/gorm" ) -var slugNonAlphaNumericRegex = regexp.MustCompile(`[^a-z0-9]+`) - -// generateSlug produces a URL-friendly slug from name with a random 4-digit -// suffix to reduce collisions (e.g. "acme-corp-4821"). -func generateSlug(name string) string { - base := strings.Trim(slugNonAlphaNumericRegex.ReplaceAllString(strings.ToLower(name), "-"), "-") - if base == "" { - base = "org" - } - return fmt.Sprintf("%s-%04d", base, rand.Intn(9000)+1000) -} +const defaultOrgSlug = "default" type Organization struct { ID string `gorm:"primaryKey" json:"id"` @@ -37,6 +24,13 @@ func (o *Organization) TableName() string { return "organizations_v1" } +func (o *Organization) CreateDefault(ctx context.Context) error { + o.ID = uuid.NewString() + o.Name = defaultOrgSlug + o.Slug = defaultOrgSlug + return db.FromContext(ctx).Model(&Organization{}).Create(o).Error +} + func (o *Organization) Create(ctx context.Context) error { if o.ID == "" { o.ID = uuid.NewString() @@ -95,27 +89,3 @@ func isUniqueConstraintErr(err error) bool { strings.Contains(msg, "duplicate key value violates unique constraint") || strings.Contains(msg, "23505") // pg error code } - -// EnsureDefaultOrganization creates the default organization if none exists, -// returning the org (existing or newly created). -func EnsureDefaultOrganization(ctx context.Context) (*Organization, error) { - var orgs []Organization - if err := db.FromContext(ctx).Model(&Organization{}).Limit(1).Find(&orgs).Error; err != nil { - return nil, err - } - if len(orgs) > 0 { - return &orgs[0], nil - } - org := &Organization{Name: "Default", Slug: "default"} - err := db.FromContext(ctx).Model(&Organization{}). - Where(gorm.Model{}). - FirstOrCreate(org, Organization{Slug: "default"}).Error - if err != nil { - // Slug "default" taken — use generated slug. - org.Slug = "" - if createErr := org.Create(ctx); createErr != nil { - return nil, createErr - } - } - return org, nil -} diff --git a/schema/tenants.go b/schema/tenants.go index e71d0d6a8..706b6f738 100644 --- a/schema/tenants.go +++ b/schema/tenants.go @@ -7,9 +7,10 @@ import ( "github.com/google/uuid" "github.com/gravitl/netmaker/db" - "gorm.io/gorm" ) +const defaultTenantSlug = "default" + type Tenant struct { ID string `gorm:"primaryKey" json:"id"` Name string `gorm:"not null" json:"name"` @@ -23,6 +24,13 @@ func (t *Tenant) TableName() string { return "tenants_v1" } +func (t *Tenant) CreateDefault(ctx context.Context) error { + t.ID = uuid.NewString() + t.Name = defaultTenantSlug + t.Slug = defaultTenantSlug + return db.FromContext(ctx).Model(&Tenant{}).Create(&t).Error +} + func (t *Tenant) Create(ctx context.Context) error { if t.ID == "" { t.ID = uuid.NewString() @@ -75,29 +83,3 @@ func (t *Tenant) Delete(ctx context.Context) error { Where("id = ?", t.ID). Delete(t).Error } - -// EnsureDefaultTenant creates the default tenant for the given org if none -// exists, returning the tenant (existing or newly created). -func EnsureDefaultTenant(ctx context.Context, orgID string) (*Tenant, error) { - var tenants []Tenant - if err := db.FromContext(ctx).Model(&Tenant{}). - Where("organization_id = ?", orgID). - Limit(1).Find(&tenants).Error; err != nil { - return nil, err - } - if len(tenants) > 0 { - return &tenants[0], nil - } - tenant := &Tenant{OrganizationID: orgID, Name: "Default", Slug: "default"} - err := db.FromContext(ctx).Model(&Tenant{}). - Where(gorm.Model{}). - FirstOrCreate(tenant, Tenant{Slug: "default", OrganizationID: orgID}).Error - if err != nil { - // Slug "default" taken — use generated slug. - tenant.Slug = "" - if createErr := tenant.Create(ctx); createErr != nil { - return nil, createErr - } - } - return tenant, nil -} diff --git a/schema/utils.go b/schema/utils.go new file mode 100644 index 000000000..560699882 --- /dev/null +++ b/schema/utils.go @@ -0,0 +1,20 @@ +package schema + +import ( + "fmt" + "math/rand" + "regexp" + "strings" +) + +var slugNonAlphaNumericRegex = regexp.MustCompile(`[^a-z0-9]+`) + +// generateSlug produces a URL-friendly slug from name with a random 4-digit +// suffix to reduce collisions (e.g. "acme-corp-4821"). +func generateSlug(name string) string { + base := strings.Trim(slugNonAlphaNumericRegex.ReplaceAllString(strings.ToLower(name), "-"), "-") + if base == "" { + base = "org" + } + return fmt.Sprintf("%s-%04d", base, rand.Intn(9000)+1000) +} From 21f4316a89d8a3861b38f1b6d2a8dd8c85221671 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 18:56:00 +0530 Subject: [PATCH 07/21] feat(go): remove shutdown api; --- controllers/server.go | 22 +--------------------- 1 file changed, 1 insertion(+), 21 deletions(-) diff --git a/controllers/server.go b/controllers/server.go index 4aeb99d9a..9d78def49 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -8,7 +8,6 @@ import ( "net/http" "os" "strings" - "syscall" "time" "github.com/google/go-cmp/cmp" @@ -37,26 +36,7 @@ func serverHandlers(r *mux.Router) { resp.Write([]byte("Server is up and running!!")) }, ).Methods(http.MethodGet) - r.HandleFunc( - "/api/server/shutdown", logic.SecurityCheck(true, - http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if r.Header.Get("ismaster") != "yes" { - caller := &schema.User{ - Username: r.Header.Get("user"), - } - err := caller.Get(r.Context()) - if err != nil || caller.PlatformRoleID != schema.SuperAdminRole { - logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("only a super-admin can shut down the server"), "forbidden")) - return - } - } - msg := "received api call to shutdown server, sending interruption..." - slog.Warn(msg) - _, _ = w.Write([]byte(msg)) - w.WriteHeader(http.StatusOK) - _ = syscall.Kill(syscall.Getpid(), syscall.SIGINT) - })), - ).Methods(http.MethodPost) + // TODO: scope to tenant r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))). Methods(http.MethodGet) r.HandleFunc("/api/server/settings", allowUsers(http.HandlerFunc(getSettings))). From 9a2b6a753332278a99bdf62d9c7442f39f4192aa Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 18:58:59 +0530 Subject: [PATCH 08/21] feat(go): assign tenant scope to server config and settings api; --- controllers/server.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/controllers/server.go b/controllers/server.go index 9d78def49..8e9804757 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -36,10 +36,9 @@ func serverHandlers(r *mux.Router) { resp.Write([]byte("Server is up and running!!")) }, ).Methods(http.MethodGet) - // TODO: scope to tenant - r.HandleFunc("/api/server/getconfig", allowUsers(http.HandlerFunc(getConfig))). + r.HandleFunc("/api/server/getconfig", Scope(db.TenantScope, allowUsers(http.HandlerFunc(getConfig)))). Methods(http.MethodGet) - r.HandleFunc("/api/server/settings", allowUsers(http.HandlerFunc(getSettings))). + r.HandleFunc("/api/server/settings", Scope(db.TenantScope, allowUsers(http.HandlerFunc(getSettings)))). Methods(http.MethodGet) r.HandleFunc("/api/server/settings", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateSettings)))). Methods(http.MethodPut) From 91853656d1652a2caabd6e16ed6f281c057e870c Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 19:06:57 +0530 Subject: [PATCH 09/21] feat(go): assign tenant scope to apis host calls; --- controllers/hosts.go | 12 ++++++------ pro/controllers/auto_relay.go | 8 ++++---- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/controllers/hosts.go b/controllers/hosts.go index 5e707d847..6ffbd2c30 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -48,7 +48,7 @@ func hostHandlers(r *mux.Router) { r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHost)))). Methods(http.MethodGet) // used by netclient - r.HandleFunc("/api/hosts/{hostid}", AuthorizeHost(http.HandlerFunc(deleteHost))). + r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(deleteHost)))). Methods(http.MethodDelete) // used by UI r.HandleFunc("/api/v1/ui/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHost)))). @@ -61,14 +61,14 @@ func hostHandlers(r *mux.Router) { Methods(http.MethodPost) r.HandleFunc("/api/hosts/{hostid}/networks/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHostFromNetwork)))). Methods(http.MethodDelete) - r.HandleFunc("/api/hosts/adm/authenticate", authenticateHost).Methods(http.MethodPost) - r.HandleFunc("/api/v1/host", AuthorizeHost(http.HandlerFunc(pull))). + r.HandleFunc("/api/hosts/adm/authenticate", Scope(db.TenantScope, http.HandlerFunc(authenticateHost))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/host", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(pull)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/host/{hostid}/signalpeer", AuthorizeHost(http.HandlerFunc(signalPeer))). + r.HandleFunc("/api/v1/host/{hostid}/signalpeer", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(signalPeer)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/fallback/host/{hostid}", AuthorizeHost(http.HandlerFunc(hostUpdateFallback))). + r.HandleFunc("/api/v1/fallback/host/{hostid}", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(hostUpdateFallback)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/host/{hostid}/peer_info", AuthorizeHost(http.HandlerFunc(getHostPeerInfo))). + r.HandleFunc("/api/v1/host/{hostid}/peer_info", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(getHostPeerInfo)))). Methods(http.MethodGet) r.HandleFunc("/api/v1/pending_hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts)))). Methods(http.MethodGet) diff --git a/pro/controllers/auto_relay.go b/pro/controllers/auto_relay.go index 4f0db8fc1..bbb0e4854 100644 --- a/pro/controllers/auto_relay.go +++ b/pro/controllers/auto_relay.go @@ -22,7 +22,7 @@ import ( // AutoRelayHandlers - handlers for AutoRelay func AutoRelayHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.AuthorizeHost(http.HandlerFunc(getAutoRelayGws))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(getAutoRelayGws)))). Methods(http.MethodGet) r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay)))). Methods(http.MethodPost) @@ -30,11 +30,11 @@ func AutoRelayHandlers(r *mux.Router) { Methods(http.MethodDelete) r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.AuthorizeHost(http.HandlerFunc(autoRelayME))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayME)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.AuthorizeHost(http.HandlerFunc(autoRelayMEUpdate))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayMEUpdate)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_check", controller.AuthorizeHost(http.HandlerFunc(checkautoRelayCtx))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_check", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(checkautoRelayCtx)))). Methods(http.MethodGet) } From 10c9b1a14840c1631e2dac97ac5bc08b0ed41832 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Thu, 18 Jun 2026 21:14:56 +0530 Subject: [PATCH 10/21] feat(go): prevent error on missing tenant/organization id for existing tenants; For existing tenants, there's no concept of multiple tenants and organization. But the code still creates default tenant and org for them. For these tenants, we can relax the requirement of the tenant id and organization id headers and just use the default tenant and org scope. --- controllers/scope.go | 57 ++++++++++++++++++++++++++++++++++++----- models/structs.go | 1 + schema/organizations.go | 16 +++++++++--- schema/tenants.go | 19 +++++++++++--- 4 files changed, 80 insertions(+), 13 deletions(-) diff --git a/controllers/scope.go b/controllers/scope.go index 89f8da79e..aed802adb 100644 --- a/controllers/scope.go +++ b/controllers/scope.go @@ -6,11 +6,16 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/schema" ) var ( - errMissingTenantID = errors.New("X-Tenant-ID header is required") - errMissingOrgID = errors.New("X-Organization-ID header is required") + errMissingTenantID = errors.New("X-Tenant-ID header is required") + errDefaultTenantNotFound = errors.New("default tenant not found") + errTenantNotFound = errors.New("tenant not found") + errMissingOrgID = errors.New("X-Organization-ID header is required") + errDefaultOrgNotFound = errors.New("default organization not found") + errOrgNotFound = errors.New("organization not found") ) const ( @@ -34,14 +39,54 @@ func Scope(level db.ScopeLevel, next http.Handler) http.HandlerFunc { case db.TenantScope: id = r.Header.Get(HeaderTenantID) if id == "" { - logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingTenantID, logic.BadReq)) - return + if logic.GetFeatureFlags().AllowMultipleTenants { + logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingTenantID, logic.BadReq)) + return + } + + defaultTenant := &schema.Tenant{} + err := defaultTenant.GetDefault(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errDefaultTenantNotFound, logic.Internal)) + return + } + + id = defaultTenant.ID + } else { + tenant := &schema.Tenant{ + ID: id, + } + err := tenant.Get(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errTenantNotFound, logic.BadReq)) + return + } } case db.OrgScope: id = r.Header.Get(HeaderOrgID) if id == "" { - logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingOrgID, logic.BadReq)) - return + if logic.GetFeatureFlags().AllowMultipleTenants { + logic.ReturnErrorResponse(w, r, logic.FormatError(errMissingOrgID, logic.BadReq)) + return + } + + defaultOrg := &schema.Organization{} + err := defaultOrg.Get(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errDefaultOrgNotFound, logic.Internal)) + return + } + + id = defaultOrg.ID + } else { + org := &schema.Organization{ + ID: id, + } + err := org.Get(r.Context()) + if err != nil { + logic.ReturnErrorResponse(w, r, logic.FormatError(errOrgNotFound, logic.BadReq)) + return + } } case db.GlobalScope: // no header required diff --git a/models/structs.go b/models/structs.go index 456403719..7026a2156 100644 --- a/models/structs.go +++ b/models/structs.go @@ -22,6 +22,7 @@ type FeatureFlags struct { EnableJIT bool `json:"enable_jit"` EnableOverlappingEgressRanges bool `json:"enable_overlapping_egress_ranges"` EnableSIEMIntegration bool `json:"enable_siem_integration"` + AllowMultipleTenants bool `json:"allow_multiple_tenants"` } // AuthParams - struct for auth params diff --git a/schema/organizations.go b/schema/organizations.go index da765b410..0796347e1 100644 --- a/schema/organizations.go +++ b/schema/organizations.go @@ -55,7 +55,15 @@ func (o *Organization) Create(ctx context.Context) error { func (o *Organization) Get(ctx context.Context) error { return db.FromContext(ctx).Model(&Organization{}). Where("id = ? OR slug = ?", o.ID, o.Slug). - First(o).Error + First(o). + Error +} + +func (o *Organization) GetDefault(ctx context.Context) error { + return db.FromContext(ctx).Model(&Organization{}). + Where("slug = ?", defaultOrgSlug). + Find(o). + Error } func (o *Organization) ListAll(ctx context.Context) ([]Organization, error) { @@ -67,13 +75,15 @@ func (o *Organization) ListAll(ctx context.Context) ([]Organization, error) { func (o *Organization) Update(ctx context.Context) error { return db.FromContext(ctx).Model(&Organization{}). Where("id = ?", o.ID). - Updates(o).Error + Updates(o). + Error } func (o *Organization) Delete(ctx context.Context) error { return db.FromContext(ctx).Model(&Organization{}). Where("id = ?", o.ID). - Delete(o).Error + Delete(o). + Error } // isUniqueConstraintErr returns true if err is a unique constraint violation diff --git a/schema/tenants.go b/schema/tenants.go index 706b6f738..be569c989 100644 --- a/schema/tenants.go +++ b/schema/tenants.go @@ -55,7 +55,15 @@ func (t *Tenant) Create(ctx context.Context) error { func (t *Tenant) Get(ctx context.Context) error { return db.FromContext(ctx).Model(&Tenant{}). Where("id = ? OR slug = ?", t.ID, t.Slug). - First(t).Error + First(t). + Error +} + +func (t *Tenant) GetDefault(ctx context.Context) error { + return db.FromContext(ctx).Model(&Tenant{}). + Where("slug = ?", defaultTenantSlug). + First(t). + Error } func (t *Tenant) ListAll(ctx context.Context) ([]Tenant, error) { @@ -68,18 +76,21 @@ func (t *Tenant) ListByOrg(ctx context.Context, orgID string) ([]Tenant, error) var tenants []Tenant err := db.FromContext(ctx).Model(&Tenant{}). Where("organization_id = ?", orgID). - Find(&tenants).Error + Find(&tenants). + Error return tenants, err } func (t *Tenant) Update(ctx context.Context) error { return db.FromContext(ctx).Model(&Tenant{}). Where("id = ?", t.ID). - Updates(t).Error + Updates(t). + Error } func (t *Tenant) Delete(ctx context.Context) error { return db.FromContext(ctx).Model(&Tenant{}). Where("id = ?", t.ID). - Delete(t).Error + Delete(t). + Error } From 0f34afa1ef5aa96d72325d6988704941f2dd1a59 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Fri, 19 Jun 2026 11:27:21 +0530 Subject: [PATCH 11/21] feat(go): add schema definition for organization settings; --- schema/org_settings.go | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 schema/org_settings.go diff --git a/schema/org_settings.go b/schema/org_settings.go new file mode 100644 index 000000000..83e92abd4 --- /dev/null +++ b/schema/org_settings.go @@ -0,0 +1,31 @@ +package schema + +import ( + "context" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +const orgSettingsTable = "org_settings_v1" + +type OrganizationSettings struct { + OrganizationID string `gorm:"primaryKey"` + Settings datatypes.JSONType[OrganizationSettingsData] +} + +type OrganizationSettingsData struct{} + +func (o *OrganizationSettings) TableName() string { + return orgSettingsTable +} + +func (o *OrganizationSettings) Upsert(ctx context.Context) error { + return db.FromContext(ctx).Save(&o).Error +} + +func (o *OrganizationSettings) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&OrganizationSettings{}). + First(&o). + Error +} From d2427d3fc9b0f8940cd86137181983987c37850d Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Fri, 19 Jun 2026 12:10:29 +0530 Subject: [PATCH 12/21] feat(go): move user settings to user table; --- controllers/user.go | 8 ++++---- logic/settings.go | 41 ++++++++++++++++++++++++----------------- models/settings.go | 14 ++++---------- schema/users.go | 33 ++++++++++++++++++++++++++++----- 4 files changed, 60 insertions(+), 36 deletions(-) diff --git a/controllers/user.go b/controllers/user.go index c0ac0b9a9..a9233870b 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -929,8 +929,8 @@ func updateUserAccountStatus(w http.ResponseWriter, r *http.Request, disableAcco // @Param username path string true "Username of the user" // @Success 200 {object} models.UserSettings func getUserSettings(w http.ResponseWriter, r *http.Request) { - userID := r.Header.Get("user") - userSettings := logic.GetUserSettings(userID) + username := r.Header.Get("user") + userSettings := logic.GetUserSettings(username) logic.ReturnSuccessResponseWithJson(w, r, userSettings, "fetched user settings") } @@ -946,7 +946,7 @@ func getUserSettings(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func updateUserSettings(w http.ResponseWriter, r *http.Request) { - userID := r.Header.Get("user") + username := r.Header.Get("user") var req models.UserSettings err := json.NewDecoder(r.Body).Decode(&req) if err != nil { @@ -956,7 +956,7 @@ func updateUserSettings(w http.ResponseWriter, r *http.Request) { return } - err = logic.UpsertUserSettings(userID, req) + err = logic.UpsertUserSettings(username, req) if err != nil { err = fmt.Errorf("failed to update user settings: %v", err.Error()) logger.Log(0, err.Error()) diff --git a/logic/settings.go b/logic/settings.go index 9b00ff5aa..2aee99260 100644 --- a/logic/settings.go +++ b/logic/settings.go @@ -1,6 +1,7 @@ package logic import ( + "context" "encoding/json" "errors" "fmt" @@ -14,7 +15,9 @@ import ( "github.com/gravitl/netmaker/config" "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" ) @@ -31,7 +34,7 @@ var serverSettingsCache atomic.Value var defaultUserSettings = models.UserSettings{ TextSize: "16", - Theme: models.Dark, + Theme: schema.Dark, ReducedMotion: false, } @@ -113,38 +116,42 @@ func UpsertServerSettings(s models.ServerSettings) error { return nil } -func GetUserSettings(userID string) models.UserSettings { - data, err := database.FetchRecord(database.SERVER_SETTINGS, userID) +func GetUserSettings(username string) models.UserSettings { + user := &schema.User{ + Username: username, + } + err := user.Get(db.WithContext(context.TODO())) if err != nil { return defaultUserSettings } - var userSettings models.UserSettings - err = json.Unmarshal([]byte(data), &userSettings) - if err != nil { + + if user.Theme == "" && user.TextSize == "" && user.ReducedMotion == false { return defaultUserSettings } - return userSettings + return models.UserSettings{ + Theme: user.Theme, + TextSize: user.TextSize, + ReducedMotion: user.ReducedMotion, + } } -func UpsertUserSettings(userID string, userSettings models.UserSettings) error { +func UpsertUserSettings(username string, userSettings models.UserSettings) error { if userSettings.TextSize == "" { userSettings.TextSize = "16" } if userSettings.Theme == "" { - userSettings.Theme = models.Dark + userSettings.Theme = schema.Dark } - data, err := json.Marshal(userSettings) - if err != nil { - return err + user := &schema.User{ + Username: username, + Theme: userSettings.Theme, + TextSize: userSettings.TextSize, + ReducedMotion: userSettings.ReducedMotion, } - return database.Insert(userID, string(data), database.SERVER_SETTINGS) -} - -func DeleteUserSettings(userID string) error { - return database.DeleteRecord(database.SERVER_SETTINGS, userID) + return user.UpdateUserSettings(db.WithContext(context.TODO())) } func ValidateNewSettings(req models.ServerSettings) error { diff --git a/models/settings.go b/models/settings.go index 8cb52d6d7..e60ccf463 100644 --- a/models/settings.go +++ b/models/settings.go @@ -1,12 +1,6 @@ package models -type Theme string - -const ( - Dark Theme = "dark" - Light Theme = "light" - System Theme = "system" -) +import "github.com/gravitl/netmaker/schema" type ServerSettings struct { NetclientAutoUpdate bool `json:"netclientautoupdate"` @@ -57,7 +51,7 @@ type ServerSettings struct { } type UserSettings struct { - Theme Theme `json:"theme"` - TextSize string `json:"text_size"` - ReducedMotion bool `json:"reduced_motion"` + Theme schema.Theme `json:"theme"` + TextSize string `json:"text_size"` + ReducedMotion bool `json:"reduced_motion"` } diff --git a/schema/users.go b/schema/users.go index 35a2ce984..ff4786276 100644 --- a/schema/users.go +++ b/schema/users.go @@ -13,11 +13,19 @@ import ( type AuthType string -var ( +const ( BasicAuth AuthType = "basic_auth" OAuth AuthType = "oauth" ) +type Theme string + +const ( + Dark Theme = "dark" + Light Theme = "light" + System Theme = "system" +) + var ( ErrUserIdentifiersNotProvided = errors.New("user identifiers not provided") ) @@ -33,6 +41,9 @@ type User struct { Password string `json:"password"` IsMFAEnabled bool `json:"is_mfa_enabled"` TOTPSecret string `json:"totp_secret"` + Theme Theme `json:"theme"` + TextSize string `json:"text_size"` + ReducedMotion bool `json:"reduced_motion"` // NOTE: json tag is different from field name to ensure compatibility with the older model. LastLoginAt time.Time `json:"last_login_time"` // NOTE: json tag is different from field name to ensure compatibility with the older model. @@ -136,8 +147,7 @@ func (u *User) UpdateAccountStatus(ctx context.Context) error { Where("id = ? OR username = ?", u.ID, u.Username). Updates(map[string]any{ "account_disabled": u.AccountDisabled, - }). - Error + }).Error } func (u *User) UpdateMFA(ctx context.Context) error { @@ -150,8 +160,21 @@ func (u *User) UpdateMFA(ctx context.Context) error { Updates(map[string]any{ "is_mfa_enabled": u.IsMFAEnabled, "totp_secret": u.TOTPSecret, - }). - Error + }).Error +} + +func (u *User) UpdateUserSettings(ctx context.Context) error { + if u.ID == "" && u.Username == "" { + return ErrUserIdentifiersNotProvided + } + + return db.FromContext(ctx).Model(&User{}). + Where("id = ? OR username = ?", u.ID, u.Username). + Updates(map[string]any{ + "theme": u.Theme, + "text_size": u.TextSize, + "reduced_motion": u.ReducedMotion, + }).Error } func (u *User) Delete(ctx context.Context) error { From 194fe9b3ba9c796fc1928b7899593182808e04c5 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Fri, 19 Jun 2026 12:22:34 +0530 Subject: [PATCH 13/21] feat(go): add migration of user settings from server settings table; --- migrate/migrate_schema.go | 3 +- migrate/migrate_v1_7_0.go | 72 ++++++++++++++++++++++++++++++++++++--- migrate/utils.go | 7 ++++ 3 files changed, 76 insertions(+), 6 deletions(-) diff --git a/migrate/migrate_schema.go b/migrate/migrate_schema.go index cf03b2cdf..4513bd326 100644 --- a/migrate/migrate_schema.go +++ b/migrate/migrate_schema.go @@ -34,7 +34,8 @@ func ToSQLSchema() error { return err } - // v1.7.0 migration includes migrating the server conf, generated and server uuid table. + // v1.7.0 migration includes migrating the server conf, generated, server uuid and server + // settings tables. err = ensureMigrationCompleted(context.TODO(), "migration-v1.7.0", migrateV1_7_0) if err != nil { return err diff --git a/migrate/migrate_v1_7_0.go b/migrate/migrate_v1_7_0.go index 8478477c4..27224c0b1 100644 --- a/migrate/migrate_v1_7_0.go +++ b/migrate/migrate_v1_7_0.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "encoding/json" + "errors" "time" "github.com/google/uuid" @@ -12,13 +13,19 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "gorm.io/datatypes" + "gorm.io/gorm" ) const ( - TableName_ServerConf = "serverconf" - TableName_Generated = "generated" - TableName_ServerUUID = "serveruuid" - TableName_EnrollmentKey = "enrollmentkeys" + TableName_ServerConf = "serverconf" + TableName_Generated = "generated" + TableName_ServerUUID = "serveruuid" + TableName_EnrollmentKey = "enrollmentkeys" + TableName_ServerSettings = "server_settings" +) + +const ( + Key_ServerSettings = "server_cfg" ) func migrateV1_7_0(ctx context.Context) error { @@ -37,7 +44,12 @@ func migrateV1_7_0(ctx context.Context) error { return err } - return migrateEnrollmentKeys(ctx) + err = migrateEnrollmentKeys(ctx) + if err != nil { + return err + } + + return migrateUserSettings(ctx) } func migrateServerConf(ctx context.Context) error { @@ -297,3 +309,53 @@ func migrateEnrollmentKeys(ctx context.Context) error { return nil } + +func migrateUserSettings(ctx context.Context) error { + if !db.FromContext(ctx).Migrator().HasTable(TableName_ServerSettings) { + return nil + } + + records, err := kvList(ctx, TableName_ServerSettings) + if err != nil && !database.IsEmptyRecord(err) { + return err + } + + for key, record := range records { + if key == Key_ServerSettings { + continue + } + + user := &schema.User{ + Username: key, + } + err = user.Get(ctx) + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + err = kvDelete(ctx, TableName_ServerSettings, key) + if err != nil { + return err + } + + continue + } + + return err + } + + var userSettings models.UserSettings + err = json.Unmarshal([]byte(record), &userSettings) + if err != nil { + return err + } + + user.Theme = userSettings.Theme + user.TextSize = userSettings.TextSize + user.ReducedMotion = userSettings.ReducedMotion + err = user.UpdateUserSettings(ctx) + if err != nil { + return err + } + } + + return nil +} diff --git a/migrate/utils.go b/migrate/utils.go index d2ea12c8b..07204674b 100644 --- a/migrate/utils.go +++ b/migrate/utils.go @@ -35,3 +35,10 @@ func kvList(ctx context.Context, tableName string) (map[string]string, error) { return list, nil } + +func kvDelete(ctx context.Context, tableName, key string) error { + return db.FromContext(ctx).Table(tableName). + Where("key = ?", key). + Delete(&KVRecord{}). + Error +} From 1f0354e417588bbe45ea1477242b201613724bd8 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Fri, 19 Jun 2026 13:32:58 +0530 Subject: [PATCH 14/21] feat(go): remove usage of host actions table; --- auth/host_session.go | 10 +----- database/database.go | 3 -- logic/hostactions/hostactions.go | 56 -------------------------------- mq/handlers.go | 25 +++----------- 4 files changed, 6 insertions(+), 88 deletions(-) delete mode 100644 logic/hostactions/hostactions.go diff --git a/auth/host_session.go b/auth/host_session.go index c8bd9f38b..9650d9bf3 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -12,7 +12,6 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/logic/pro/netcache" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" @@ -287,7 +286,7 @@ func CheckNetRegAndHostUpdate(key schema.EnrollmentKey, host *schema.Host, usern continue } - node, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode( + _, err := orchestrator.GetRepository().NodeOrchestrator().CreateNode( db.WithContext(context.TODO()), host, network, @@ -298,13 +297,6 @@ func CheckNetRegAndHostUpdate(key schema.EnrollmentKey, host *schema.Host, usern if err != nil { logger.Log(0, fmt.Sprintf("failed to add host (%s, %s) to network (%s): %v", host.ID.String(), host.Name, netID, err.Error())) } else { - newNode := logic.ConvertSchemaNodeToModelsNode(node) - hostactions.AddAction(models.HostUpdate{ - Action: models.JoinHostToNetwork, - Host: *host, - Node: *newNode, - }) - if len(username) > 0 { logic.LogEvent(&models.Event{ Action: schema.JoinHostToNet, diff --git a/database/database.go b/database/database.go index 89a5ac16b..8f6d7055d 100644 --- a/database/database.go +++ b/database/database.go @@ -22,8 +22,6 @@ const ( METRICS_TABLE_NAME = "metrics" // CACHE_TABLE_NAME - caching table CACHE_TABLE_NAME = "cache" - // HOST_ACTIONS_TABLE_NAME - table name for enrollmentkeys - HOST_ACTIONS_TABLE_NAME = "hostactions" // TAG_TABLE_NAME - table for tags TAG_TABLE_NAME = "tags" // SERVER_SETTINGS - table for server settings @@ -61,7 +59,6 @@ var Tables = []string{ SSO_STATE_CACHE, METRICS_TABLE_NAME, CACHE_TABLE_NAME, - HOST_ACTIONS_TABLE_NAME, TAG_TABLE_NAME, ACLS_TABLE_NAME, SERVER_SETTINGS, diff --git a/logic/hostactions/hostactions.go b/logic/hostactions/hostactions.go deleted file mode 100644 index dcafdcd17..000000000 --- a/logic/hostactions/hostactions.go +++ /dev/null @@ -1,56 +0,0 @@ -package hostactions - -import ( - "encoding/json" - - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/models" -) - -// AddAction - adds a host action to a host's list to be retrieved from broker update -func AddAction(hu models.HostUpdate) { - hostID := hu.Host.ID.String() - currentRecords, err := database.FetchRecord(database.HOST_ACTIONS_TABLE_NAME, hostID) - if err != nil { - if database.IsEmptyRecord(err) { // no list exists yet - newEntry, err := json.Marshal([]models.HostUpdate{hu}) - if err != nil { - return - } - _ = database.Insert(hostID, string(newEntry), database.HOST_ACTIONS_TABLE_NAME) - } - return - } - var currentList []models.HostUpdate - if err := json.Unmarshal([]byte(currentRecords), ¤tList); err != nil { - return - } - currentList = append(currentList, hu) - newData, err := json.Marshal(currentList) - if err != nil { - return - } - _ = database.Insert(hostID, string(newData), database.HOST_ACTIONS_TABLE_NAME) -} - -// GetAction - gets an action if exists -func GetAction(id string) *models.HostUpdate { - currentRecords, err := database.FetchRecord(database.HOST_ACTIONS_TABLE_NAME, id) - if err != nil { - return nil - } - var currentList []models.HostUpdate - if err = json.Unmarshal([]byte(currentRecords), ¤tList); err != nil { - return nil - } - if len(currentList) > 0 { - hu := currentList[0] - newData, err := json.Marshal(currentList[1:]) - if err != nil { - newData, _ = json.Marshal([]models.HostUpdate{}) - } - _ = database.Insert(id, string(newData), database.HOST_ACTIONS_TABLE_NAME) - return &hu - } - return nil -} diff --git a/mq/handlers.go b/mq/handlers.go index 0d18045d6..5e2b562e5 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -12,7 +12,6 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/logic/hostactions" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/netclient/ncutils" "github.com/gravitl/netmaker/schema" @@ -117,25 +116,11 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { if err != nil { return } - hu := hostactions.GetAction(currentHost.ID.String()) - if hu != nil { - if err = HostUpdate(hu); err != nil { - slog.Error("failed to send new node to host", "name", hostUpdate.Host.Name, "id", currentHost.ID, "error", err) - return - } else { - - if err = PublishSingleHostPeerUpdate(currentHost, nodes, nil, nil, nil, false, nil); err != nil { - slog.Error("failed peers publish after join acknowledged", "name", hostUpdate.Host.Name, "id", currentHost.ID, "error", err) - return - } - } - } else { - // send latest host update - HostUpdate(&models.HostUpdate{ - Action: models.UpdateHost, - Host: *currentHost}) - PublishSingleHostPeerUpdate(currentHost, nodes, nil, nil, nil, false, nil) - } + // send latest host update + HostUpdate(&models.HostUpdate{ + Action: models.UpdateHost, + Host: *currentHost}) + PublishSingleHostPeerUpdate(currentHost, nodes, nil, nil, nil, false, nil) case models.UpdateHost: if hostUpdate.Host.PublicKey != currentHost.PublicKey { //remove old peer entry From 8816f8d4dc081fac74539bf57894f1d2f3ca9a60 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Fri, 19 Jun 2026 13:50:58 +0530 Subject: [PATCH 15/21] feat(go): add schema definition for tenant settings; --- schema/server_settings.go | 80 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) create mode 100644 schema/server_settings.go diff --git a/schema/server_settings.go b/schema/server_settings.go new file mode 100644 index 000000000..7e9afc070 --- /dev/null +++ b/schema/server_settings.go @@ -0,0 +1,80 @@ +package schema + +import ( + "context" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +type ServerSettings struct { + Key string `gorm:"primaryKey"` + Value datatypes.JSONType[ServerSettingsData] +} + +type ServerSettingsData struct { + NetclientAutoUpdate bool `json:"netclientautoupdate"` + Verbosity int32 `json:"verbosity"` + AuthProvider string `json:"authprovider"` + OIDCIssuer string `json:"oidcissuer"` + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + SyncEnabled bool `json:"sync_enabled"` + GoogleAdminEmail string `json:"google_admin_email"` + GoogleSACredsJson string `json:"google_sa_creds_json"` + AzureTenant string `json:"azure_tenant"` + OktaOrgURL string `json:"okta_org_url"` + OktaAPIToken string `json:"okta_api_token"` + UserFilters []string `json:"user_filters"` + GroupFilters []string `json:"group_filters"` + IDPSyncInterval string `json:"idp_sync_interval"` + Telemetry string `json:"telemetry"` + BasicAuth bool `json:"basic_auth"` + // JwtValidityDuration is the validity duration of auth tokens for users + // on the dashboard (NMUI). + JwtValidityDuration int `json:"jwt_validity_duration"` + // JwtValidityDurationClients is the validity duration of auth tokens for + // users on the clients (NetDesk). + JwtValidityDurationClients int `json:"jwt_validity_duration_clients"` + MFAEnforced bool `json:"mfa_enforced"` + RacRestrictToSingleNetwork bool `json:"rac_restrict_to_single_network"` + EndpointDetection bool `json:"endpoint_detection"` + AllowedEmailDomains string `json:"allowed_email_domains"` + EmailSenderAddr string `json:"email_sender_addr"` + EmailSenderUser string `json:"email_sender_user"` + EmailSenderPassword string `json:"email_sender_password"` + SmtpHost string `json:"smtp_host"` + SmtpPort int `json:"smtp_port"` + MetricInterval string `json:"metric_interval"` + MetricsPort int `json:"metrics_port"` + // IPDetectionInterval is the interval (in seconds) at which devices check for changes in public ip. + IPDetectionInterval int `json:"ip_detection_interval"` + ManageDNS bool `json:"manage_dns"` + DefaultDomain string `json:"default_domain"` + Stun bool `json:"stun"` + StunServers string `json:"stun_servers"` + AuditLogsRetentionPeriodInDays int `json:"audit_logs_retention_period"` + PeerConnectionCheckInterval string `json:"peer_connection_check_interval"` + PostureCheckInterval string `json:"posture_check_interval"` // in minutes + CleanUpInterval int `json:"clean_up_interval_in_mins"` + EnableFlowLogs bool `json:"enable_flow_logs"` +} + +func (s *ServerSettings) TableName() string { + return "server_settings" +} + +func (s *ServerSettings) Upsert(ctx context.Context) error { + return db.FromContext(ctx).Save(&s).Error +} + +func (s *ServerSettings) Get(ctx context.Context) error { + return db.FromContext(ctx).Model(&ServerSettings{}). + Where("key = ?", s.Key). + First(&s). + Error +} + +func (s *ServerSettings) Delete(ctx context.Context) error { + return db.FromContext(ctx).Delete(&s).Error +} From a494009cd03d60aec4839466116251ae10c8b57f Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 11:34:38 +0530 Subject: [PATCH 16/21] feat(go): move scope to it's own package; --- controllers/acls.go | 15 ++--- controllers/dns.go | 60 ++++++++++---------- controllers/egress.go | 11 ++-- controllers/enrollmentkeys.go | 16 +++--- controllers/ext_client.go | 19 ++++--- controllers/gateway.go | 13 +++-- controllers/hosts.go | 49 ++++++++-------- controllers/inet_gws.go | 8 +-- controllers/network.go | 15 ++--- controllers/node.go | 23 ++++---- controllers/server.go | 19 ++++--- controllers/user.go | 45 +++++++-------- db/scope.go | 52 ----------------- pro/controllers/auto_relay.go | 15 ++--- pro/controllers/events.go | 8 +-- pro/controllers/flows.go | 5 +- pro/controllers/integrations.go | 11 ++-- pro/controllers/jit.go | 10 ++-- pro/controllers/metrics.go | 11 ++-- pro/controllers/networks.go | 5 +- pro/controllers/posture_check.go | 14 ++--- pro/controllers/rac.go | 9 ++- pro/controllers/tags.go | 10 ++-- pro/controllers/users.go | 72 ++++++++++++------------ {controllers => scope}/scope.go | 97 ++++++++++++++++++++++++++++---- 25 files changed, 318 insertions(+), 294 deletions(-) delete mode 100644 db/scope.go rename {controllers => scope}/scope.go (52%) diff --git a/controllers/acls.go b/controllers/acls.go index 2f43aa699..12a87537b 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -15,22 +15,23 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" ) func aclHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAcls)))). + r.HandleFunc("/api/v1/acls", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAcls)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressAcls)))). + r.HandleFunc("/api/v1/acls/egress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressAcls)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls/policy_types", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclPolicyTypes)))). + r.HandleFunc("/api/v1/acls/policy_types", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclPolicyTypes)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createAcl)))). + r.HandleFunc("/api/v1/acls", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createAcl)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAcl)))). + r.HandleFunc("/api/v1/acls", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAcl)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/acls", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAcl)))). + r.HandleFunc("/api/v1/acls", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAcl)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/acls/debug", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclDebug)))). + r.HandleFunc("/api/v1/acls/debug", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(aclDebug)))). Methods(http.MethodGet) } diff --git a/controllers/dns.go b/controllers/dns.go index 4a3273e41..0f7eef7ce 100644 --- a/controllers/dns.go +++ b/controllers/dns.go @@ -11,40 +11,40 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "gorm.io/datatypes" ) func dnsHandlers(r *mux.Router) { - r.HandleFunc("/api/dns", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllDNS)))). + r.HandleFunc("/api/dns", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/nodes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeDNS)))). + r.HandleFunc("/api/dns/adm/{network}/nodes", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/custom", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getCustomDNS)))). + r.HandleFunc("/api/dns/adm/{network}/custom", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getCustomDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDNS)))). + r.HandleFunc("/api/dns/adm/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDNS)))). Methods(http.MethodGet) - r.HandleFunc("/api/dns/adm/{network}/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncDNS)))). + r.HandleFunc("/api/dns/adm/{network}/sync", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createDNS)))). + r.HandleFunc("/api/dns/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/adm/pushdns", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(pushDNS)))). + r.HandleFunc("/api/dns/adm/pushdns", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(pushDNS)))). Methods(http.MethodPost) - r.HandleFunc("/api/dns/{network}/{domain}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteDNS)))). + r.HandleFunc("/api/dns/{network}/{domain}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteDNS)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNs)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNs)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNs)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/nameserver", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNs)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nameserver/global", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getGlobalNs)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nameserver", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNs)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/nameserver", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNs)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nameserver", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNs)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/nameserver", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNs)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/nameserver/global", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getGlobalNs)))).Methods(http.MethodGet) } // @Summary List Global Nameservers @@ -389,10 +389,10 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry + var dns []schema.DNSEntry var params = mux.Vars(r) network := params["network"] - dns, err := logic.GetNodeDNS(network) + dns, err := logic.GetNodeDNS(r.Context(), network) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get node DNS entries for network [%s]: %v", network, err)) @@ -412,7 +412,7 @@ func getNodeDNS(w http.ResponseWriter, r *http.Request) { // @Failure 500 {object} models.ErrorResponse func getAllDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - dns, err := logic.GetAllDNS() + dns, err := logic.GetAllDNS(r.Context()) if err != nil { logger.Log(0, r.Header.Get("user"), "failed to get all DNS entries: ", err.Error()) logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -435,10 +435,10 @@ func getCustomDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry + var dns []schema.DNSEntry var params = mux.Vars(r) network := params["network"] - dns, err := logic.GetCustomDNS(network) + dns, err := logic.GetCustomDNS(r.Context(), network) if err != nil { logger.Log( 0, @@ -468,10 +468,10 @@ func getDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - var dns []models.DNSEntry + var dns []schema.DNSEntry var params = mux.Vars(r) network := params["network"] - dns, err := logic.GetDNS(network) + dns, err := logic.GetDNS(r.Context(), network) if err != nil { logger.Log(0, r.Header.Get("user"), fmt.Sprintf("failed to get all DNS entries for network [%s]: %v", network, err.Error())) @@ -496,7 +496,7 @@ func getDNS(w http.ResponseWriter, r *http.Request) { func createDNS(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - var entry models.DNSEntry + var entry schema.DNSEntry var params = mux.Vars(r) netID := params["network"] @@ -566,18 +566,16 @@ func deleteDNS(w http.ResponseWriter, r *http.Request) { } // GetDNSEntry - gets a DNS entry -func GetDNSEntry(domain string, network string) (models.DNSEntry, error) { - var entry models.DNSEntry +func GetDNSEntry(domain string, network string) (schema.DNSEntry, error) { key, err := logic.GetRecordKey(domain, network) if err != nil { - return entry, err + return schema.DNSEntry{}, err } - record, err := database.FetchRecord(database.DNS_TABLE_NAME, key) - if err != nil { - return entry, err + d := &schema.DNS{Key: key} + if err = d.Get(db.WithContext(context.Background())); err != nil { + return schema.DNSEntry{}, err } - err = json.Unmarshal([]byte(record), &entry) - return entry, err + return d.Value.Data(), nil } // @Summary Push DNS entries to nameserver @@ -628,7 +626,7 @@ func syncDNS(w http.ResponseWriter, r *http.Request) { } var params = mux.Vars(r) netID := params["network"] - k, err := logic.GetDNS(netID) + k, err := logic.GetDNS(r.Context(), netID) if err == nil && len(k) > 0 { err = mq.PushSyncDNS(k) } diff --git a/controllers/egress.go b/controllers/egress.go index fa4b873d0..a66f117c9 100644 --- a/controllers/egress.go +++ b/controllers/egress.go @@ -15,16 +15,17 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "gorm.io/datatypes" ) func egressHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/egress/presets", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressPresets)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgress)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEgress)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEgress)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/egress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgress)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/egress/presets", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEgressPresets)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/egress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgress)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/egress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEgress)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/egress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEgress)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/egress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgress)))).Methods(http.MethodDelete) } // @Summary List egress domain presets diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 17c42da18..22e48118e 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -12,10 +12,10 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "golang.org/x/exp/slog" "github.com/gravitl/netmaker/auth" - "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -25,21 +25,21 @@ import ( ) func enrollmentKeyHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey)))). + r.HandleFunc("/api/v1/enrollment-keys", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEnrollmentKey)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys)))). + r.HandleFunc("/api/v1/enrollment-keys", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getEnrollmentKeys)))). Methods(http.MethodGet) - r.HandleFunc("/api/v2/enrollment-keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEnrollmentKeys)))). + r.HandleFunc("/api/v2/enrollment-keys", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listEnrollmentKeys)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/enrollment-keys/network/{network}/default", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDefaultEnrollmentKeyForNetwork)))). + r.HandleFunc("/api/v1/enrollment-keys/network/{network}/default", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getDefaultEnrollmentKeyForNetwork)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}/regenerate-token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(regenerateEnrollmentKeyToken)))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}/regenerate-token", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(regenerateEnrollmentKeyToken)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey)))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEnrollmentKey)))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/host/register/{token}", http.HandlerFunc(handleHostRegister)). Methods(http.MethodPost) - r.HandleFunc("/api/v1/enrollment-keys/{keyID}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey)))). + r.HandleFunc("/api/v1/enrollment-keys/{keyID}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateEnrollmentKey)))). Methods(http.MethodPut) } diff --git a/controllers/ext_client.go b/controllers/ext_client.go index a72e089f6..c06c929ae 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -20,6 +20,7 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/mq" "github.com/skip2/go-qrcode" @@ -32,23 +33,23 @@ var extUpdateMutex = &sync.Mutex{} func extClientHandlers(r *mux.Router) { - r.HandleFunc("/api/extclients", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllExtClients)))). + r.HandleFunc("/api/extclients", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllExtClients)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtClients)))). + r.HandleFunc("/api/extclients/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtClients)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClient)))). + r.HandleFunc("/api/extclients/{network}/{clientid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClient)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClientConf)))). + r.HandleFunc("/api/extclients/{network}/{clientid}/{type}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getExtClientConf)))). Methods(http.MethodGet) - r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(updateExtClient)))). + r.HandleFunc("/api/extclients/{network}/{clientid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(updateExtClient)))). Methods(http.MethodPut) - r.HandleFunc("/api/extclients/{network}/{clientid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient)))). + r.HandleFunc("/api/extclients/{network}/{clientid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(deleteExtClient)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/extclients/{network}/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteExtClients)))). + r.HandleFunc("/api/v1/extclients/{network}/bulk", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteExtClients)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/extclients/{network}/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateExtClientStatus)))). + r.HandleFunc("/api/v1/extclients/{network}/bulk/status", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateExtClientStatus)))). Methods(http.MethodPut) - r.HandleFunc("/api/extclients/{network}/{nodeid}", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(createExtClient)))). + r.HandleFunc("/api/extclients/{network}/{nodeid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(createExtClient)))). Methods(http.MethodPost) // unused API //r.HandleFunc("/api/v1/client_conf/{network}", logic.SecurityCheck(false, http.HandlerFunc(getExtClientHAConf))).Methods(http.MethodGet) diff --git a/controllers/gateway.go b/controllers/gateway.go index 5b5c4fdd6..e2c1e4bb4 100644 --- a/controllers/gateway.go +++ b/controllers/gateway.go @@ -16,6 +16,7 @@ import ( "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" "gorm.io/datatypes" @@ -23,13 +24,13 @@ import ( ) func gwHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(assignGw)))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unassignGw)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/assign", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(assignGw)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/gateway/unassign", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unassignGw)))).Methods(http.MethodPost) // old relay handlers - r.HandleFunc("/api/nodes/{network}/{nodeid}/createrelay", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/createrelay", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deleterelay", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) } // @Summary Create a gateway diff --git a/controllers/hosts.go b/controllers/hosts.go index 6ffbd2c30..2adf37445 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -22,6 +22,7 @@ import ( "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" @@ -29,54 +30,54 @@ import ( ) func hostHandlers(r *mux.Router) { - r.HandleFunc("/api/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHosts)))). + r.HandleFunc("/api/hosts", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listHosts)))). + r.HandleFunc("/api/v1/hosts", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/hosts/keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys)))). + r.HandleFunc("/api/hosts/keys", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateAllKeys)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHosts)))). + r.HandleFunc("/api/hosts/sync", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHosts)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/upgrade", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHosts)))). + r.HandleFunc("/api/hosts/upgrade", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHosts)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}/keys", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateKeys)))). + r.HandleFunc("/api/hosts/{hostid}/keys", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateKeys)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}/sync", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHost)))). + r.HandleFunc("/api/hosts/{hostid}/sync", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateHost)))). + r.HandleFunc("/api/hosts/{hostid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateHost)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHost)))). + r.HandleFunc("/api/hosts/{hostid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getHost)))). Methods(http.MethodGet) // used by netclient - r.HandleFunc("/api/hosts/{hostid}", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(deleteHost)))). + r.HandleFunc("/api/hosts/{hostid}", scope.Middleware(scope.TenantScope, AuthorizeHost(http.HandlerFunc(deleteHost)))). Methods(http.MethodDelete) // used by UI - r.HandleFunc("/api/v1/ui/hosts/{hostid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHost)))). + r.HandleFunc("/api/v1/ui/hosts/{hostid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHost)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/hosts/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteHosts)))). + r.HandleFunc("/api/v1/hosts/bulk", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteHosts)))). Methods(http.MethodDelete) - r.HandleFunc("/api/hosts/{hostid}/upgrade", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHost)))). + r.HandleFunc("/api/hosts/{hostid}/upgrade", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upgradeHost)))). Methods(http.MethodPut) - r.HandleFunc("/api/hosts/{hostid}/networks/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addHostToNetwork)))). + r.HandleFunc("/api/hosts/{hostid}/networks/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addHostToNetwork)))). Methods(http.MethodPost) - r.HandleFunc("/api/hosts/{hostid}/networks/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHostFromNetwork)))). + r.HandleFunc("/api/hosts/{hostid}/networks/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteHostFromNetwork)))). Methods(http.MethodDelete) - r.HandleFunc("/api/hosts/adm/authenticate", Scope(db.TenantScope, http.HandlerFunc(authenticateHost))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/host", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(pull)))). + r.HandleFunc("/api/hosts/adm/authenticate", scope.Middleware(scope.TenantScope, http.HandlerFunc(authenticateHost))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/host", scope.Middleware(scope.TenantScope, AuthorizeHost(http.HandlerFunc(pull)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/host/{hostid}/signalpeer", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(signalPeer)))). + r.HandleFunc("/api/v1/host/{hostid}/signalpeer", scope.Middleware(scope.TenantScope, AuthorizeHost(http.HandlerFunc(signalPeer)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/fallback/host/{hostid}", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(hostUpdateFallback)))). + r.HandleFunc("/api/v1/fallback/host/{hostid}", scope.Middleware(scope.TenantScope, AuthorizeHost(http.HandlerFunc(hostUpdateFallback)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/host/{hostid}/peer_info", Scope(db.TenantScope, AuthorizeHost(http.HandlerFunc(getHostPeerInfo)))). + r.HandleFunc("/api/v1/host/{hostid}/peer_info", scope.Middleware(scope.TenantScope, AuthorizeHost(http.HandlerFunc(getHostPeerInfo)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/pending_hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts)))). + r.HandleFunc("/api/v1/pending_hosts", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingHosts)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/pending_hosts/approve/{id}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost)))). + r.HandleFunc("/api/v1/pending_hosts/approve/{id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/pending_hosts/reject/{id}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost)))). + r.HandleFunc("/api/v1/pending_hosts/reject/{id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(rejectPendingHost)))). Methods(http.MethodPost) - r.HandleFunc("/api/emqx/hosts", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts)))). + r.HandleFunc("/api/emqx/hosts", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(delEmqxHosts)))). Methods(http.MethodDelete) r.HandleFunc("/api/v1/auth-register/host", socketHandler) } diff --git a/controllers/inet_gws.go b/controllers/inet_gws.go index 7ffc2c28a..eba2888fb 100644 --- a/controllers/inet_gws.go +++ b/controllers/inet_gws.go @@ -6,21 +6,21 @@ import ( "net/http" "github.com/gorilla/mux" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" ) func internetGatewayHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createInternetGw)))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createInternetGw)))). Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw)))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateInternetGw)))). Methods(http.MethodPut) - r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw)))). + r.HandleFunc("/api/nodes/{network}/{nodeid}/inet_gw", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteInternetGw)))). Methods(http.MethodDelete) } diff --git a/controllers/network.go b/controllers/network.go index 62b43a430..9d313a24f 100644 --- a/controllers/network.go +++ b/controllers/network.go @@ -14,6 +14,7 @@ import ( dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "golang.org/x/exp/slog" "github.com/gravitl/netmaker/database" @@ -24,19 +25,19 @@ import ( ) func networkHandlers(r *mux.Router) { - r.HandleFunc("/api/networks", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworks)))). + r.HandleFunc("/api/networks", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworks)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/networks/stats", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats)))). + r.HandleFunc("/api/v1/networks/stats", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworksStats)))). Methods(http.MethodGet) - r.HandleFunc("/api/networks", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNetwork)))). + r.HandleFunc("/api/networks", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createNetwork)))). Methods(http.MethodPost) - r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetwork)))). + r.HandleFunc("/api/networks/{networkname}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetwork)))). Methods(http.MethodGet) - r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork)))). + r.HandleFunc("/api/networks/{networkname}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteNetwork)))). Methods(http.MethodDelete) - r.HandleFunc("/api/networks/{networkname}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNetwork)))). + r.HandleFunc("/api/networks/{networkname}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNetwork)))). Methods(http.MethodPut) - r.HandleFunc("/api/networks/{networkname}/egress_routes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkEgressRoutes)))) + r.HandleFunc("/api/networks/{networkname}/egress_routes", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkEgressRoutes)))) } // @Summary Lists all networks diff --git a/controllers/node.go b/controllers/node.go index 8a60959bd..fdd68a41e 100644 --- a/controllers/node.go +++ b/controllers/node.go @@ -21,6 +21,7 @@ import ( "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/orchestrator" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" @@ -31,20 +32,20 @@ var hostIDHeader = "host-id" func nodeHandlers(r *mux.Router) { - r.HandleFunc("/api/nodes", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllNodes)))).Methods(http.MethodGet) - r.HandleFunc("/api/nodes/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodes)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/nodes/{network}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/nodes", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/nodes/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nodes/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkNodes)))).Methods(http.MethodGet) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(getNode))).Methods(http.MethodGet) - r.HandleFunc("/api/nodes/{network}/{nodeid}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNode)))).Methods(http.MethodPut) + r.HandleFunc("/api/nodes/{network}/{nodeid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateNode)))).Methods(http.MethodPut) r.HandleFunc("/api/nodes/{network}/{nodeid}", AuthorizeHost(http.HandlerFunc(deleteNode))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway)))).Methods(http.MethodDelete) - r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) - r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/creategateway", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createEgressGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deletegateway", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteEgressGateway)))).Methods(http.MethodDelete) + r.HandleFunc("/api/nodes/{network}/{nodeid}/createingress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createGateway)))).Methods(http.MethodPost) + r.HandleFunc("/api/nodes/{network}/{nodeid}/deleteingress", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteGateway)))).Methods(http.MethodDelete) r.HandleFunc("/api/nodes/adm/{network}/authenticate", authenticate).Methods(http.MethodPost) - r.HandleFunc("/api/v1/nodes/{network}/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteNodes)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/nodes/{network}/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateNodeStatus)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/nodes/{network}/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/nodes/{network}/bulk", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteNodes)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/nodes/{network}/bulk/status", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateNodeStatus)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/nodes/{network}/status", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodeStatus)))).Methods(http.MethodGet) } func authenticate(response http.ResponseWriter, request *http.Request) { diff --git a/controllers/server.go b/controllers/server.go index 8e9804757..8e84469ce 100644 --- a/controllers/server.go +++ b/controllers/server.go @@ -15,6 +15,7 @@ import ( ch "github.com/gravitl/netmaker/clickhouse" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "golang.org/x/exp/slog" "github.com/gravitl/netmaker/database" @@ -36,23 +37,23 @@ func serverHandlers(r *mux.Router) { resp.Write([]byte("Server is up and running!!")) }, ).Methods(http.MethodGet) - r.HandleFunc("/api/server/getconfig", Scope(db.TenantScope, allowUsers(http.HandlerFunc(getConfig)))). + r.HandleFunc("/api/server/getconfig", scope.Middleware(scope.TenantScope, allowUsers(http.HandlerFunc(getConfig)))). Methods(http.MethodGet) - r.HandleFunc("/api/server/settings", Scope(db.TenantScope, allowUsers(http.HandlerFunc(getSettings)))). + r.HandleFunc("/api/server/settings", scope.Middleware(scope.TenantScope, allowUsers(http.HandlerFunc(getSettings)))). Methods(http.MethodGet) - r.HandleFunc("/api/server/settings", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateSettings)))). + r.HandleFunc("/api/server/settings", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateSettings)))). Methods(http.MethodPut) - r.HandleFunc("/api/server/getserverinfo", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getServerInfo)))). + r.HandleFunc("/api/server/getserverinfo", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getServerInfo)))). Methods(http.MethodGet) r.HandleFunc("/api/server/status", getStatus).Methods(http.MethodGet) - r.HandleFunc("/api/server/usage", Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUsage)))). + r.HandleFunc("/api/server/usage", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUsage)))). Methods(http.MethodGet) r.HandleFunc("/api/server/cpu_profile", logic.SecurityCheck(false, http.HandlerFunc(cpuProfile))). Methods(http.MethodPost) r.HandleFunc("/api/server/mem_profile", logic.SecurityCheck(false, http.HandlerFunc(memProfile))). Methods(http.MethodPost) r.HandleFunc("/api/server/feature_flags", getFeatureFlags).Methods(http.MethodGet) - r.HandleFunc("/api/server/onboarding", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getOnboarding)))).Methods(http.MethodGet) + r.HandleFunc("/api/server/onboarding", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getOnboarding)))).Methods(http.MethodGet) } func cpuProfile(w http.ResponseWriter, r *http.Request) { @@ -226,7 +227,7 @@ func getSettings(w http.ResponseWriter, r *http.Request) { // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func updateSettings(w http.ResponseWriter, r *http.Request) { - var req models.ServerSettings + var req schema.ServerSettingsData force := r.URL.Query().Get("force") if err := json.NewDecoder(r.Body).Decode(&req); err != nil { logger.Log(0, r.Header.Get("user"), "error decoding request body: ", err.Error()) @@ -303,7 +304,7 @@ func updateSettings(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponseWithJson(w, r, req, "updated server settings successfully") } -func reInit(curr, new models.ServerSettings, force bool) { +func reInit(curr, new schema.ServerSettingsData, force bool) { logic.SettingsMutex.Lock() defer logic.SettingsMutex.Unlock() logic.ResetAuthProvider() @@ -349,7 +350,7 @@ func reInit(curr, new models.ServerSettings, force bool) { go mq.PublishPeerUpdate(false) } -func identifySettingsUpdateAction(old, new models.ServerSettings) schema.Action { +func identifySettingsUpdateAction(old, new schema.ServerSettingsData) schema.Action { // TODO: here we are relying on the dashboard to only // make singular updates, but it's possible that the // API can be called to make multiple changes to the diff --git a/controllers/user.go b/controllers/user.go index a9233870b..01bda2d54 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -24,6 +24,7 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/mq" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" @@ -40,31 +41,31 @@ var ListRoles = listRoles func userHandlers(r *mux.Router) { r.HandleFunc("/api/users/adm/hassuperadmin", hasSuperAdmin).Methods(http.MethodGet) r.HandleFunc("/api/users/adm/createsuperadmin", createSuperAdmin).Methods(http.MethodPost) - r.HandleFunc("/api/users/adm/transfersuperadmin/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(transferSuperAdmin)))). + r.HandleFunc("/api/users/adm/transfersuperadmin/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(transferSuperAdmin)))). Methods(http.MethodPost) r.HandleFunc("/api/users/adm/authenticate", authenticateUser).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/validate-identity", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(validateUserIdentity))))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/auth/init-totp", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(initiateTOTPSetup))))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/auth/complete-totp", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(completeTOTPSetup))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/validate-identity", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(validateUserIdentity))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/auth/init-totp", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(initiateTOTPSetup))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/auth/complete-totp", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(completeTOTPSetup))))).Methods(http.MethodPost) r.HandleFunc("/api/users/{username}/auth/verify-totp", logic.PreAuthCheck(logic.ContinueIfUserMatch(http.HandlerFunc(verifyTOTP)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUser)))).Methods(http.MethodPut) - r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUser)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUser)))).Methods(http.MethodDelete) - r.HandleFunc("/api/users/{username}", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser))))).Methods(http.MethodGet) - r.HandleFunc("/api/users/{username}/enable", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/disable", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/settings", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings))))).Methods(http.MethodGet) - r.HandleFunc("/api/users/{username}/settings", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings))))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1))))).Methods(http.MethodGet) - r.HandleFunc("/api/users", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUsers)))).Methods(http.MethodGet) - r.HandleFunc("/api/v2/users", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUsers)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/bulk", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/bulk/status", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateUserStatus)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/roles", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ListRoles)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/access_token", Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/logout", Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(logout))))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUser)))).Methods(http.MethodPut) + r.HandleFunc("/api/users/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUser)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUser)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUser))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/{username}/enable", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(enableUserAccount)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/disable", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(disableUserAccount)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/settings", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserSettings))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/{username}/settings", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(updateUserSettings))))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatchOrAdmin(http.HandlerFunc(getUserV1))))).Methods(http.MethodGet) + r.HandleFunc("/api/users", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v2/users", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/bulk", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkDeleteUsers)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/bulk/status", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(bulkUpdateUserStatus)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/roles", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ListRoles)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/access_token", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserAccessToken)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/access_token", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserAccessTokens)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/access_token", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserAccessTokens)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/logout", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(logout))))).Methods(http.MethodPost) } // @Summary Create a user API access token diff --git a/db/scope.go b/db/scope.go deleted file mode 100644 index 36127ef21..000000000 --- a/db/scope.go +++ /dev/null @@ -1,52 +0,0 @@ -package db - -import ( - "context" - - "gorm.io/gorm" -) - -// ScopeLevel represents the tenancy scope of a request. -type ScopeLevel int - -const ( - // GlobalScope applies no tenant filtering — raw, unscoped access. - GlobalScope ScopeLevel = iota - // OrgScope filters queries to a specific organization (WHERE organization_id = ?). - OrgScope - // TenantScope filters queries to a specific tenant (WHERE tenant_id = ?). - TenantScope -) - -// Scope returns a new context whose GORM db is scoped to the given level. -// -// For OrgScope and TenantScope, exactly one id must be provided. -// For GlobalScope, no id is needed; the db is returned unscoped. -// -// Panics on invalid usage (wrong number of ids). These call sites are always -// static, so invalid usage is caught during development and code review. -func Scope(ctx context.Context, level ScopeLevel, ids ...string) context.Context { - if len(ids) > 1 { - panic("db.Scope: at most one id is allowed") - } - if level != GlobalScope && len(ids) == 0 { - panic("db.Scope: id required for non-global scope") - } - if level == GlobalScope { - return ctx - } - gdb := FromContext(ctx) - switch level { - case TenantScope: - gdb = gdb.Scopes(func(db *gorm.DB) *gorm.DB { - return db.Where("tenant_id = ?", ids[0]) - }) - case OrgScope: - gdb = gdb.Scopes(func(db *gorm.DB) *gorm.DB { - return db.Where("organization_id = ?", ids[0]) - }) - default: - panic("db.Scope: unknown level") - } - return context.WithValue(ctx, dbCtxKey, gdb) -} diff --git a/pro/controllers/auto_relay.go b/pro/controllers/auto_relay.go index bbb0e4854..c38ba5fc5 100644 --- a/pro/controllers/auto_relay.go +++ b/pro/controllers/auto_relay.go @@ -16,25 +16,26 @@ import ( "github.com/gravitl/netmaker/mq" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" ) // AutoRelayHandlers - handlers for AutoRelay func AutoRelayHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(getAutoRelayGws)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", scope.Middleware(scope.TenantScope, controller.AuthorizeHost(http.HandlerFunc(getAutoRelayGws)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(setAutoRelay)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unsetAutoRelay)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(unsetAutoRelay)))). Methods(http.MethodDelete) - r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw)))). + r.HandleFunc("/api/v1/node/{network}/auto_relay/reset", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(resetAutoRelayGw)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayME)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", scope.Middleware(scope.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayME)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayMEUpdate)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_me", scope.Middleware(scope.TenantScope, controller.AuthorizeHost(http.HandlerFunc(autoRelayMEUpdate)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_check", controller.Scope(db.TenantScope, controller.AuthorizeHost(http.HandlerFunc(checkautoRelayCtx)))). + r.HandleFunc("/api/v1/node/{nodeid}/auto_relay_check", scope.Middleware(scope.TenantScope, controller.AuthorizeHost(http.HandlerFunc(checkautoRelayCtx)))). Methods(http.MethodGet) } diff --git a/pro/controllers/events.go b/pro/controllers/events.go index 81ac3e1f7..260d76790 100644 --- a/pro/controllers/events.go +++ b/pro/controllers/events.go @@ -6,17 +6,17 @@ import ( "time" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" ) func EventHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/network/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkActivity)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/user/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(listUserActivity)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/activity", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listActivity)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/network/activity", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkActivity)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/user/activity", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(listUserActivity)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/activity", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listActivity)))).Methods(http.MethodGet) } // @Summary List network activity diff --git a/pro/controllers/flows.go b/pro/controllers/flows.go index aa3592539..3b672f031 100644 --- a/pro/controllers/flows.go +++ b/pro/controllers/flows.go @@ -10,15 +10,14 @@ import ( "github.com/gorilla/mux" ch "github.com/gravitl/netmaker/clickhouse" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" proLogic "github.com/gravitl/netmaker/pro/logic" + "github.com/gravitl/netmaker/scope" ) func FlowHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/flows", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(handleListFlows)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/flows", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(handleListFlows)))).Methods(http.MethodGet) } const ( diff --git a/pro/controllers/integrations.go b/pro/controllers/integrations.go index 5f60b33f5..ad1e25dc4 100644 --- a/pro/controllers/integrations.go +++ b/pro/controllers/integrations.go @@ -8,8 +8,6 @@ import ( "net/http" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/grpc/siem" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -17,16 +15,17 @@ import ( "github.com/gravitl/netmaker/pro/integration" logic2 "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "google.golang.org/protobuf/types/known/structpb" "gorm.io/datatypes" "gorm.io/gorm" ) func IntegrationHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/integrations/{type}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIntegration)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/integrations/{type}/{id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upsertIntegration)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/integrations/{type}/{id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteIntegration)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/integrations/{type}/{id}/test", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIntegration)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/integrations/{type}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIntegration)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/integrations/{type}/{id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(upsertIntegration)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/integrations/{type}/{id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteIntegration)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/integrations/{type}/{id}/test", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIntegration)))).Methods(http.MethodPost) } // extractAndValidateIntegration pulls {type} and {id} from the URL diff --git a/pro/controllers/jit.go b/pro/controllers/jit.go index d660eaf96..c79e5e7c4 100644 --- a/pro/controllers/jit.go +++ b/pro/controllers/jit.go @@ -9,7 +9,6 @@ import ( "time" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -17,20 +16,21 @@ import ( "github.com/gravitl/netmaker/pro/email" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "golang.org/x/exp/slog" ) func JITHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/jit", controller.Scope(db.TenantScope, logic.SecurityCheck(true, + r.HandleFunc("/api/v1/jit", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(handleJIT)))).Methods(http.MethodPost, http.MethodGet) - r.HandleFunc("/api/v1/jit", controller.Scope(db.TenantScope, logic.SecurityCheck(true, + r.HandleFunc("/api/v1/jit", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteJITGrant)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/jit_user/networks", controller.Scope(db.TenantScope, logic.SecurityCheck(false, + r.HandleFunc("/api/v1/jit_user/networks", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserJITNetworks)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/jit_user/request", controller.Scope(db.TenantScope, logic.SecurityCheck(false, + r.HandleFunc("/api/v1/jit_user/request", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(requestJITAccess)))).Methods(http.MethodPost) } diff --git a/pro/controllers/metrics.go b/pro/controllers/metrics.go index 2fd4dc710..8fbff781e 100644 --- a/pro/controllers/metrics.go +++ b/pro/controllers/metrics.go @@ -5,12 +5,11 @@ import ( "net/http" proLogic "github.com/gravitl/netmaker/pro/logic" + "github.com/gravitl/netmaker/scope" "golang.org/x/exp/slog" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" @@ -18,10 +17,10 @@ import ( // MetricHandlers - How we handle Pro Metrics func MetricHandlers(r *mux.Router) { - r.HandleFunc("/api/metrics/{network}/{nodeid}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics)))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics/{network}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics)))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics)))).Methods(http.MethodGet) - r.HandleFunc("/api/metrics-ext/{network}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics/{network}/{nodeid}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNodeMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkNodesMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getAllMetrics)))).Methods(http.MethodGet) + r.HandleFunc("/api/metrics-ext/{network}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkExtMetrics)))).Methods(http.MethodGet) } // @Summary Get metrics for a specific node diff --git a/pro/controllers/networks.go b/pro/controllers/networks.go index 14801fa1e..06302c528 100644 --- a/pro/controllers/networks.go +++ b/pro/controllers/networks.go @@ -5,14 +5,13 @@ import ( "net/http" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/scope" ) func NetworkHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/networks/{network}/graph", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkGraph)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/networks/{network}/graph", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getNetworkGraph)))).Methods(http.MethodGet) } // @Summary Get network topology graph diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index 73285caa9..e40c57c82 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -10,7 +10,6 @@ import ( "github.com/google/uuid" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -19,16 +18,17 @@ import ( "github.com/gravitl/netmaker/mq" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "gorm.io/gorm" ) func PostureCheckHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/posture_check", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/posture_check/attrs", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/posture_check/violations", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createPostureCheck)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/posture_check", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecks)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updatePostureCheck)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/posture_check", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePostureCheck)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/posture_check/attrs", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureChecksAttrs)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/posture_check/violations", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listPostureCheckViolatedNodes)))).Methods(http.MethodGet) } // @Summary List Posture Checks Available Attributes diff --git a/pro/controllers/rac.go b/pro/controllers/rac.go index 8f253188c..0b374efb7 100644 --- a/pro/controllers/rac.go +++ b/pro/controllers/rac.go @@ -4,13 +4,12 @@ import ( "net/http" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" - "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" + "github.com/gravitl/netmaker/scope" ) func RacHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/rac/networks", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworks)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/rac/network/{network}/access_points", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworkGateways)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/rac/access_point/{access_point_id}/config", controller.Scope(db.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getRemoteAccessGatewayConf)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/networks", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworks)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/network/{network}/access_points", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getUserRemoteAccessNetworkGateways)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/rac/access_point/{access_point_id}/config", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, http.HandlerFunc(getRemoteAccessGatewayConf)))).Methods(http.MethodGet) } diff --git a/pro/controllers/tags.go b/pro/controllers/tags.go index ab010e605..e4f4e4fd4 100644 --- a/pro/controllers/tags.go +++ b/pro/controllers/tags.go @@ -11,7 +11,6 @@ import ( "time" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -20,16 +19,17 @@ import ( "github.com/gravitl/netmaker/mq" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" ) func TagHandlers(r *mux.Router) { - r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getTags)))). + r.HandleFunc("/api/v1/tags", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getTags)))). Methods(http.MethodGet) - r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createTag)))). + r.HandleFunc("/api/v1/tags", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createTag)))). Methods(http.MethodPost) - r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateTag)))). + r.HandleFunc("/api/v1/tags", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateTag)))). Methods(http.MethodPut) - r.HandleFunc("/api/v1/tags", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteTag)))). + r.HandleFunc("/api/v1/tags", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteTag)))). Methods(http.MethodDelete) } diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 42e056ebd..3215f2ee2 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -12,7 +12,6 @@ import ( "time" "github.com/gorilla/mux" - controller "github.com/gravitl/netmaker/controllers" "github.com/gravitl/netmaker/db" dbtypes "github.com/gravitl/netmaker/db/types" "github.com/gravitl/netmaker/logger" @@ -28,6 +27,7 @@ import ( "github.com/gravitl/netmaker/pro/idp/okta" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "github.com/gravitl/netmaker/utils" "golang.org/x/exp/slog" @@ -44,47 +44,47 @@ func UserHandlers(r *mux.Router) { r.HandleFunc("/api/oauth/register/{regKey}", proAuth.RegisterHostSSO).Methods(http.MethodGet) // User Role Handlers - r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getRole)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createRole)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateRole)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/role", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteRole)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/role", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getRole)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/role", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createRole)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/role", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateRole)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/role", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteRole)))).Methods(http.MethodDelete) // User Group Handlers - r.HandleFunc("/api/v1/users/groups", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroups)))).Methods(http.MethodGet) - r.HandleFunc("/api/v2/users/groups", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserGroups)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroup)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserGroup)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUserGroup)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/group", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/groups/network", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUserGroups)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/network", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUsers)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/add_network_user", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addUsertoNetwork)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/remove_network_user", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserfromNetwork)))).Methods(http.MethodPut) - r.HandleFunc("/api/v1/users/unassigned_network_users", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUnAssignedNetUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/groups", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v2/users/groups", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/group", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getUserGroup)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/group", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(createUserGroup)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/group", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(updateUserGroup)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/group", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserGroup)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/groups/network", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUserGroups)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/network", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listNetworkUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/add_network_user", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(addUsertoNetwork)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/remove_network_user", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserfromNetwork)))).Methods(http.MethodPut) + r.HandleFunc("/api/v1/users/unassigned_network_users", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUnAssignedNetUsers)))).Methods(http.MethodGet) // User Invite Handlers r.HandleFunc("/api/v1/users/invite", userInviteVerify).Methods(http.MethodGet) r.HandleFunc("/api/v1/users/invite-signup", userInviteSignUp).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/invite", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(inviteUsers)))).Methods(http.MethodPost) - r.HandleFunc("/api/v1/users/invites", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserInvites)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/invite", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserInvite)))).Methods(http.MethodDelete) - r.HandleFunc("/api/v1/users/invites", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllUserInvites)))).Methods(http.MethodDelete) - - r.HandleFunc("/api/users_pending", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers)))).Methods(http.MethodGet) - r.HandleFunc("/api/users_pending", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers)))).Methods(http.MethodDelete) - r.HandleFunc("/api/users_pending/user/{username}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser)))).Methods(http.MethodDelete) - r.HandleFunc("/api/users_pending/user/{username}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingUser)))).Methods(http.MethodPost) - - r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(attachUserToRemoteAccessGw)))).Methods(http.MethodPost) - r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW)))).Methods(http.MethodDelete) - r.HandleFunc("/api/users/{username}/remote_access_gw", controller.Scope(db.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1))))).Methods(http.MethodGet) - r.HandleFunc("/api/users/ingress/{ingress_id}", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers)))).Methods(http.MethodGet) - r.HandleFunc("/api/v1/users/network_ip", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(userNetworkMapping)))).Methods(http.MethodGet) - - r.HandleFunc("/api/idp/sync", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncIDP)))).Methods(http.MethodPost) - r.HandleFunc("/api/idp/sync/test", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIDPSync)))).Methods(http.MethodPost) - r.HandleFunc("/api/idp/sync/status", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIDPSyncStatus)))).Methods(http.MethodGet) - r.HandleFunc("/api/idp", controller.Scope(db.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeIDPIntegration)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/invite", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(inviteUsers)))).Methods(http.MethodPost) + r.HandleFunc("/api/v1/users/invites", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(listUserInvites)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/invite", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteUserInvite)))).Methods(http.MethodDelete) + r.HandleFunc("/api/v1/users/invites", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllUserInvites)))).Methods(http.MethodDelete) + + r.HandleFunc("/api/users_pending", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getPendingUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/users_pending", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deleteAllPendingUsers)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users_pending/user/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(deletePendingUser)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users_pending/user/{username}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(approvePendingUser)))).Methods(http.MethodPost) + + r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(attachUserToRemoteAccessGw)))).Methods(http.MethodPost) + r.HandleFunc("/api/users/{username}/remote_access_gw/{remote_access_gateway_id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeUserFromRemoteAccessGW)))).Methods(http.MethodDelete) + r.HandleFunc("/api/users/{username}/remote_access_gw", scope.Middleware(scope.TenantScope, logic.SecurityCheck(false, logic.ContinueIfUserMatch(http.HandlerFunc(getUserRemoteAccessGwsV1))))).Methods(http.MethodGet) + r.HandleFunc("/api/users/ingress/{ingress_id}", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(ingressGatewayUsers)))).Methods(http.MethodGet) + r.HandleFunc("/api/v1/users/network_ip", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(userNetworkMapping)))).Methods(http.MethodGet) + + r.HandleFunc("/api/idp/sync", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(syncIDP)))).Methods(http.MethodPost) + r.HandleFunc("/api/idp/sync/test", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(testIDPSync)))).Methods(http.MethodPost) + r.HandleFunc("/api/idp/sync/status", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(getIDPSyncStatus)))).Methods(http.MethodGet) + r.HandleFunc("/api/idp", scope.Middleware(scope.TenantScope, logic.SecurityCheck(true, http.HandlerFunc(removeIDPIntegration)))).Methods(http.MethodDelete) } // @Summary User signup via invite diff --git a/controllers/scope.go b/scope/scope.go similarity index 52% rename from controllers/scope.go rename to scope/scope.go index aed802adb..4f78b2c49 100644 --- a/controllers/scope.go +++ b/scope/scope.go @@ -1,12 +1,38 @@ -package controller +package scope import ( + "context" "errors" "net/http" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/schema" + "gorm.io/gorm" +) + +const ( + HeaderTenantID = "X-Tenant-ID" + HeaderOrgID = "X-Organization-ID" +) + +// Level represents the tenancy scope of a request. +type Level int + +const ( + // GlobalScope applies no tenant filtering — raw, unscoped access. + GlobalScope Level = iota + // OrgScope filters queries to a specific organization (WHERE organization_id = ?). + OrgScope + // TenantScope filters queries to a specific tenant (WHERE tenant_id = ?). + TenantScope +) + +type scopeCtxKeyType int + +const ( + scopeLevel scopeCtxKeyType = iota + scopeID ) var ( @@ -18,12 +44,7 @@ var ( errOrgNotFound = errors.New("organization not found") ) -const ( - HeaderTenantID = "X-Tenant-ID" - HeaderOrgID = "X-Organization-ID" -) - -// Scope wraps an http.Handler to enforce request-level tenancy scoping. +// Middleware wraps an http.Handler to enforce request-level tenancy scoping. // // For db.TenantScope: requires the X-Tenant-ID header and injects a // WHERE tenant_id = ? scope into the GORM db stored in the request context. @@ -32,11 +53,11 @@ const ( // WHERE organization_id = ? scope. // // For db.GlobalScope: passes through without modification. -func Scope(level db.ScopeLevel, next http.Handler) http.HandlerFunc { +func Middleware(level Level, next http.Handler) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var id string switch level { - case db.TenantScope: + case TenantScope: id = r.Header.Get(HeaderTenantID) if id == "" { if logic.GetFeatureFlags().AllowMultipleTenants { @@ -62,7 +83,7 @@ func Scope(level db.ScopeLevel, next http.Handler) http.HandlerFunc { return } } - case db.OrgScope: + case OrgScope: id = r.Header.Get(HeaderOrgID) if id == "" { if logic.GetFeatureFlags().AllowMultipleTenants { @@ -88,11 +109,63 @@ func Scope(level db.ScopeLevel, next http.Handler) http.HandlerFunc { return } } - case db.GlobalScope: + case GlobalScope: // no header required } - ctx := db.Scope(r.Context(), level, id) + ctx := WithContext(r.Context(), level, id) next.ServeHTTP(w, r.WithContext(ctx)) } } + +// WithContext returns a new context with GORM handle that is scoped to the given level. +// +// For OrgScope and TenantScope, exactly one id must be provided. +// For GlobalScope, no id is needed; the db is returned unscoped. +// +// Panics on invalid usage (wrong number of ids). These call sites are always +// static, so invalid usage is caught during development and code review. +func WithContext(ctx context.Context, level Level, ids ...string) context.Context { + if len(ids) > 1 { + panic("db.Scope: at most one id is allowed") + } + if level != GlobalScope && len(ids) == 0 { + panic("db.Scope: id required for non-global scope") + } + + ctx = context.WithValue(ctx, scopeLevel, level) + ctx = context.WithValue(ctx, scopeID, ids[0]) + + switch level { + case TenantScope: + return db.Modify(ctx, func(db *gorm.DB) *gorm.DB { + return db.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Where("tenant_id = ?", ids[0]) + }) + }) + case OrgScope: + return db.Modify(ctx, func(db *gorm.DB) *gorm.DB { + return db.Scopes(func(db *gorm.DB) *gorm.DB { + return db.Where("organization_id = ?", ids[0]) + }) + }) + case GlobalScope: + return db.Modify(ctx, func(db *gorm.DB) *gorm.DB { return db }) + default: + panic("db.Scope: unknown level") + } +} + +func Default(ctx context.Context) context.Context { + defaultTenant := &schema.Tenant{} + err := defaultTenant.GetDefault(db.WithContext(context.TODO())) + if err != nil { + return ctx + } + + return WithContext(ctx, TenantScope, defaultTenant.ID) +} + +func ID(ctx context.Context) string { + return ctx.Value(scopeID).(string) +} From 41fd5de4e0897de9d1e7ff26ba74b64edfcde419 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 12:07:58 +0530 Subject: [PATCH 17/21] feat(go): add schema definition for dns records; --- schema/dns_records.go | 63 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 schema/dns_records.go diff --git a/schema/dns_records.go b/schema/dns_records.go new file mode 100644 index 000000000..8c919cb28 --- /dev/null +++ b/schema/dns_records.go @@ -0,0 +1,63 @@ +package schema + +import ( + "context" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +type DNSEntryType string + +const ( + DNSEntryType_Node = "node" + DNSEntryType_Custom = "custom" +) + +// DNSEntry - a DNS entry represented as struct +type DNSEntry struct { + Type DNSEntryType `json:"type"` + Address string `json:"address" validate:"omitempty,ip"` + Address6 string `json:"address6" validate:"omitempty,ip"` + Name string `json:"name" validate:"required,name_unique,min=1,max=192,whitespace"` + Network string `json:"network" validate:"network_exists"` +} + +// DNS is the GORM model for the legacy "dns" key-value table, extended with +// tenant_id and network_id columns for multi-tenancy. +type DNS struct { + Key string `gorm:"primaryKey"` + TenantID string `gorm:"default:''"` + NetworkID string + Value datatypes.JSONType[DNSEntry] +} + +func (*DNS) TableName() string { return "dns" } + +func (d *DNS) Create(ctx context.Context) error { + return db.FromContext(ctx).Create(d).Error +} + +func (d *DNS) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", d.Key).First(d).Error +} + +func (d *DNS) ListAll(ctx context.Context) ([]DNS, error) { + var entries []DNS + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (d *DNS) ListByNetwork(ctx context.Context) ([]DNS, error) { + var entries []DNS + err := db.FromContext(ctx).Where("network_id = ?", d.NetworkID).Find(&entries).Error + return entries, err +} + +func (d *DNS) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", d.Key).Delete(&DNS{}).Error +} + +func (d *DNS) DeleteByNetwork(ctx context.Context) error { + return db.FromContext(ctx).Where("network_id = ?", d.NetworkID).Delete(&DNS{}).Error +} From 0ccf21a8cc4896cff5c0b75733e2e23f1c70ec66 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 12:08:54 +0530 Subject: [PATCH 18/21] wip(go): add method to modify db in ctx; --- db/db.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/db/db.go b/db/db.go index 491ee8d82..6fe593229 100644 --- a/db/db.go +++ b/db/db.go @@ -111,6 +111,18 @@ func BeginTx(ctx context.Context) context.Context { return context.WithValue(ctx, dbCtxKey, dbInCtx.Begin()) } +func Modify(ctx context.Context, mod func(db *gorm.DB) *gorm.DB) context.Context { + var moddb *gorm.DB + dbInCtx, ok := ctx.Value(dbCtxKey).(*gorm.DB) + if ok { + moddb = dbInCtx + } else { + moddb = db + } + + return context.WithValue(ctx, dbCtxKey, mod(moddb)) +} + // CloseDB close a connection to the database // (if one exists). It panics if any error // occurs. From a0c89eabbbbce716bce745be3a7f2ead81438a5d Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 12:19:08 +0530 Subject: [PATCH 19/21] wip(go): move dns records db management to schema pkg; --- cli/cmd/dns/create.go | 4 +- cli/cmd/dns/list.go | 4 +- cli/functions/dns.go | 22 ++-- controllers/dns_test.go | 89 +++++++++-------- database/database.go | 3 - logic/dns.go | 172 ++++++++------------------------ logic/networks.go | 2 +- migrate/migrate.go | 14 +-- migrate/migrate_multitenancy.go | 2 +- models/dnsEntry.go | 16 --- mq/publishers.go | 7 +- schema/models.go | 1 + scope/scope.go | 3 + 13 files changed, 116 insertions(+), 223 deletions(-) diff --git a/cli/cmd/dns/create.go b/cli/cmd/dns/create.go index e0eb5ebec..a37e737aa 100644 --- a/cli/cmd/dns/create.go +++ b/cli/cmd/dns/create.go @@ -4,7 +4,7 @@ import ( "log" "github.com/gravitl/netmaker/cli/functions" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/spf13/cobra" ) @@ -17,7 +17,7 @@ var dnsCreateCmd = &cobra.Command{ if address == "" && address6 == "" { log.Fatal("Either IPv4 or IPv6 address is required") } - dnsEntry := &models.DNSEntry{Name: dnsName, Address: address, Address6: address6, Network: networkName} + dnsEntry := &schema.DNSEntry{Name: dnsName, Address: address, Address6: address6, Network: networkName} functions.PrettyPrint(functions.CreateDNS(networkName, dnsEntry)) }, } diff --git a/cli/cmd/dns/list.go b/cli/cmd/dns/list.go index 3ce09689a..e5726a82e 100644 --- a/cli/cmd/dns/list.go +++ b/cli/cmd/dns/list.go @@ -6,7 +6,7 @@ import ( "github.com/gravitl/netmaker/cli/cmd/commons" "github.com/gravitl/netmaker/cli/functions" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/guumaster/tablewriter" "github.com/spf13/cobra" ) @@ -17,7 +17,7 @@ var dnsListCmd = &cobra.Command{ Short: "List DNS entries", Long: `List DNS entries`, Run: func(cmd *cobra.Command, args []string) { - var data []models.DNSEntry + var data []schema.DNSEntry if networkName != "" { switch dnsType { case "node": diff --git a/cli/functions/dns.go b/cli/functions/dns.go index 9d1487704..92974dfe6 100644 --- a/cli/functions/dns.go +++ b/cli/functions/dns.go @@ -4,32 +4,32 @@ import ( "fmt" "net/http" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) // GetDNS - fetch all DNS entries -func GetDNS() *[]models.DNSEntry { - return request[[]models.DNSEntry](http.MethodGet, "/api/dns", nil) +func GetDNS() *[]schema.DNSEntry { + return request[[]schema.DNSEntry](http.MethodGet, "/api/dns", nil) } // GetNodeDNS - fetch all Node DNS entires -func GetNodeDNS(networkName string) *[]models.DNSEntry { - return request[[]models.DNSEntry](http.MethodGet, fmt.Sprintf("/api/dns/adm/%s/nodes", networkName), nil) +func GetNodeDNS(networkName string) *[]schema.DNSEntry { + return request[[]schema.DNSEntry](http.MethodGet, fmt.Sprintf("/api/dns/adm/%s/nodes", networkName), nil) } // GetCustomDNS - fetch user defined DNS entriees -func GetCustomDNS(networkName string) *[]models.DNSEntry { - return request[[]models.DNSEntry](http.MethodGet, fmt.Sprintf("/api/dns/adm/%s/custom", networkName), nil) +func GetCustomDNS(networkName string) *[]schema.DNSEntry { + return request[[]schema.DNSEntry](http.MethodGet, fmt.Sprintf("/api/dns/adm/%s/custom", networkName), nil) } // GetNetworkDNS - fetch DNS entries associated with a network -func GetNetworkDNS(networkName string) *[]models.DNSEntry { - return request[[]models.DNSEntry](http.MethodGet, "/api/dns/adm/"+networkName, nil) +func GetNetworkDNS(networkName string) *[]schema.DNSEntry { + return request[[]schema.DNSEntry](http.MethodGet, "/api/dns/adm/"+networkName, nil) } // CreateDNS - create a DNS entry -func CreateDNS(networkName string, payload *models.DNSEntry) *models.DNSEntry { - return request[models.DNSEntry](http.MethodPost, "/api/dns/"+networkName, payload) +func CreateDNS(networkName string, payload *schema.DNSEntry) *schema.DNSEntry { + return request[schema.DNSEntry](http.MethodPost, "/api/dns/"+networkName, payload) } // PushDNS - push a DNS entry to CoreDNS diff --git a/controllers/dns_test.go b/controllers/dns_test.go index 9fbeb846b..af13ceaf0 100644 --- a/controllers/dns_test.go +++ b/controllers/dns_test.go @@ -1,16 +1,17 @@ package controller import ( + "context" "fmt" "testing" "github.com/google/uuid" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/stretchr/testify/assert" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" ) var dnsHost schema.Host @@ -21,25 +22,25 @@ func TestGetAllDNS(t *testing.T) { createNet() createHost() t.Run("NoEntries", func(t *testing.T) { - entries, err := logic.GetAllDNS() + entries, err := logic.GetAllDNS(scope.Default(context.TODO())) assert.Nil(t, err) - assert.Equal(t, []models.DNSEntry(nil), entries) + assert.Equal(t, []schema.DNSEntry(nil), entries) }) t.Run("OneEntry", func(t *testing.T) { - entry := models.DNSEntry{ + entry := schema.DNSEntry{ Address: "10.0.0.3", Name: "newhost", Network: "skynet", } _, err := logic.CreateDNS(entry) assert.Nil(t, err) - entries, err := logic.GetAllDNS() + entries, err := logic.GetAllDNS(scope.Default(context.TODO())) assert.Nil(t, err) assert.Equal(t, 1, len(entries)) }) t.Run("MultipleEntry", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.7", Name: "anotherhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.7", Name: "anotherhost", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) - entries, err := logic.GetAllDNS() + entries, err := logic.GetAllDNS(scope.Default(context.TODO())) assert.Nil(t, err) assert.Equal(t, 2, len(entries)) }) @@ -49,35 +50,35 @@ func TestGetCustomDNS(t *testing.T) { deleteAllDNS(t) deleteAllNetworks() t.Run("NoNetworks", func(t *testing.T) { - dns, err := logic.GetCustomDNS("skynet") + dns, err := logic.GetCustomDNS(scope.Default(context.TODO()), "skynet") assert.EqualError(t, err, "could not find any records") - assert.Equal(t, []models.DNSEntry(nil), dns) + assert.Equal(t, []schema.DNSEntry(nil), dns) }) t.Run("NoNodes", func(t *testing.T) { createNet() - dns, err := logic.GetCustomDNS("skynet") + dns, err := logic.GetCustomDNS(scope.Default(context.TODO()), "skynet") assert.EqualError(t, err, "could not find any records") - assert.Equal(t, []models.DNSEntry(nil), dns) + assert.Equal(t, []schema.DNSEntry(nil), dns) }) t.Run("NodeExists", func(t *testing.T) { createTestNode() - dns, err := logic.GetCustomDNS("skynet") + dns, err := logic.GetCustomDNS(scope.Default(context.TODO()), "skynet") assert.EqualError(t, err, "could not find any records") assert.Equal(t, 0, len(dns)) }) t.Run("EntryExist", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.3", Name: "custom1", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.3", Name: "custom1", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) - dns, err := logic.GetCustomDNS("skynet") + dns, err := logic.GetCustomDNS(scope.Default(context.TODO()), "skynet") assert.Nil(t, err) assert.Equal(t, 1, len(dns)) }) t.Run("MultipleEntries", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.4", Name: "host4", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.4", Name: "host4", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) - dns, err := logic.GetCustomDNS("skynet") + dns, err := logic.GetCustomDNS(scope.Default(context.TODO()), "skynet") assert.Nil(t, err) assert.Equal(t, 2, len(dns)) }) @@ -93,7 +94,7 @@ func TestGetDNSEntryNum(t *testing.T) { assert.Equal(t, 0, num) }) t.Run("NodeExists", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) num, err := logic.GetDNSEntryNum("newhost", "skynet") @@ -106,15 +107,15 @@ func TestGetDNS(t *testing.T) { deleteAllNetworks() createNet() t.Run("NoEntries", func(t *testing.T) { - dns, err := logic.GetDNS("skynet") + dns, err := logic.GetDNS(scope.Default(context.TODO()), "skynet") assert.Nil(t, err) assert.Nil(t, dns) }) t.Run("CustomDNSExists", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) - dns, err := logic.GetDNS("skynet") + dns, err := logic.GetDNS(scope.Default(context.TODO()), "skynet") t.Log(dns) assert.Nil(t, err) assert.NotNil(t, dns) @@ -124,17 +125,17 @@ func TestGetDNS(t *testing.T) { t.Run("NodeExists", func(t *testing.T) { deleteAllDNS(t) createTestNode() - dns, err := logic.GetDNS("skynet") + dns, err := logic.GetDNS(scope.Default(context.TODO()), "skynet") assert.Nil(t, err) assert.NotNil(t, dns) assert.Equal(t, "skynet", dns[0].Network) assert.Equal(t, 1, len(dns)) }) t.Run("NodeAndCustomDNS", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} _, err := logic.CreateDNS(entry) assert.Nil(t, err) - dns, err := logic.GetDNS("skynet") + dns, err := logic.GetDNS(scope.Default(context.TODO()), "skynet") t.Log(dns) assert.Nil(t, err) assert.NotNil(t, dns) @@ -148,7 +149,7 @@ func TestCreateDNS(t *testing.T) { deleteAllDNS(t) deleteAllNetworks() createNet() - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} dns, err := logic.CreateDNS(entry) assert.Nil(t, err) assert.Equal(t, "newhost", dns.Name) @@ -159,17 +160,17 @@ func TestGetDNSEntry(t *testing.T) { deleteAllNetworks() createNet() createTestNode() - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} _, _ = logic.CreateDNS(entry) t.Run("wrong net", func(t *testing.T) { entry, err := GetDNSEntry("newhost", "w286 Toronto Street South, Uxbridge, ONirecat") assert.EqualError(t, err, "no result found") - assert.Equal(t, models.DNSEntry{}, entry) + assert.Equal(t, schema.DNSEntry{}, entry) }) t.Run("wrong host", func(t *testing.T) { entry, err := GetDNSEntry("badhost", "skynet") assert.EqualError(t, err, "no result found") - assert.Equal(t, models.DNSEntry{}, entry) + assert.Equal(t, schema.DNSEntry{}, entry) }) t.Run("good host", func(t *testing.T) { entry, err := GetDNSEntry("newhost", "skynet") @@ -179,7 +180,7 @@ func TestGetDNSEntry(t *testing.T) { t.Run("node", func(t *testing.T) { entry, err := GetDNSEntry("testnode", "skynet") assert.EqualError(t, err, "no result found") - assert.Equal(t, models.DNSEntry{}, entry) + assert.Equal(t, schema.DNSEntry{}, entry) }) } @@ -187,7 +188,7 @@ func TestDeleteDNS(t *testing.T) { deleteAllDNS(t) deleteAllNetworks() createNet() - entry := models.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "newhost", Network: "skynet"} _, _ = logic.CreateDNS(entry) t.Run("EntryExists", func(t *testing.T) { err := logic.DeleteDNS("newhost", "skynet") @@ -208,16 +209,16 @@ func TestValidateDNSUpdate(t *testing.T) { deleteAllDNS(t) deleteAllNetworks() createNet() - entry := models.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "skynet"} t.Run("BadNetwork", func(t *testing.T) { - change := models.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "badnet"} + change := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "badnet"} err := logic.ValidateDNSUpdate(change, entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag") }) t.Run("EmptyNetwork", func(t *testing.T) { // this can't actually happen as change.Network is populated if is blank - change := models.DNSEntry{Address: "10.0.0.2", Name: "myhost"} + change := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost"} err := logic.ValidateDNSUpdate(change, entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag") @@ -230,14 +231,14 @@ func TestValidateDNSUpdate(t *testing.T) { // assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag") // }) t.Run("BadAddress", func(t *testing.T) { - change := models.DNSEntry{Address: "10.0.256.1", Name: "myhost", Network: "skynet"} + change := schema.DNSEntry{Address: "10.0.256.1", Name: "myhost", Network: "skynet"} err := logic.ValidateDNSUpdate(change, entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag") }) t.Run("EmptyName", func(t *testing.T) { // this can't actually happen as change.Name is populated if is blank - change := models.DNSEntry{Address: "10.0.0.2", Network: "skynet"} + change := schema.DNSEntry{Address: "10.0.0.2", Network: "skynet"} err := logic.ValidateDNSUpdate(change, entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'required' tag") @@ -247,13 +248,13 @@ func TestValidateDNSUpdate(t *testing.T) { for i := 1; i < 194; i++ { name = name + "a" } - change := models.DNSEntry{Address: "10.0.0.2", Name: name, Network: "skynet"} + change := schema.DNSEntry{Address: "10.0.0.2", Name: name, Network: "skynet"} err := logic.ValidateDNSUpdate(change, entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag") }) t.Run("NameUnique", func(t *testing.T) { - change := models.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "wirecat"} + change := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "wirecat"} _, _ = logic.CreateDNS(entry) _, _ = logic.CreateDNS(change) err := logic.ValidateDNSUpdate(change, entry) @@ -268,7 +269,7 @@ func TestValidateDNSUpdate(t *testing.T) { func TestValidateDNSCreate(t *testing.T) { _ = logic.DeleteDNS("mynode", "skynet") t.Run("NoNetwork", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "badnet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "badnet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Network' failed on the 'network_exists' tag") @@ -280,13 +281,13 @@ func TestValidateDNSCreate(t *testing.T) { // assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'required' tag") // }) t.Run("BadAddress", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.256.1", Name: "myhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.256.1", Name: "myhost", Network: "skynet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Address' failed on the 'ip' tag") }) t.Run("EmptyName", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Network: "skynet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "invalid input") @@ -296,26 +297,26 @@ func TestValidateDNSCreate(t *testing.T) { for i := 1; i < 194; i++ { name = name + "a" } - entry := models.DNSEntry{Address: "10.0.0.2", Name: name, Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: name, Network: "skynet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'max' tag") }) t.Run("NameUnique", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.0.0.2", Name: "myhost", Network: "skynet"} _, _ = logic.CreateDNS(entry) err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "Field validation for 'Name' failed on the 'name_unique' tag") }) t.Run("WhiteSpace", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.10.10.5", Name: "white space", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.10.10.5", Name: "white space", Network: "skynet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "invalid input") }) t.Run("AllSpaces", func(t *testing.T) { - entry := models.DNSEntry{Address: "10.10.10.5", Name: " ", Network: "skynet"} + entry := schema.DNSEntry{Address: "10.10.10.5", Name: " ", Network: "skynet"} err := logic.ValidateDNSCreate(entry) assert.NotNil(t, err) assert.Contains(t, err.Error(), "invalid input") @@ -339,7 +340,7 @@ func createHost() { } func deleteAllDNS(t *testing.T) { - dns, err := logic.GetAllDNS() + dns, err := logic.GetAllDNS(scope.Default(context.TODO())) assert.Nil(t, err) for _, record := range dns { err := logic.DeleteDNS(record.Name, record.Network) diff --git a/database/database.go b/database/database.go index 8f6d7055d..27b59d233 100644 --- a/database/database.go +++ b/database/database.go @@ -10,8 +10,6 @@ import ( const ( // == Table Names == - // DNS_TABLE_NAME - dns table - DNS_TABLE_NAME = "dns" // EXT_CLIENT_TABLE_NAME - ext client table EXT_CLIENT_TABLE_NAME = "extclients" // ACLS_TABLE_NAME - table for acls v2 @@ -54,7 +52,6 @@ const ( ) var Tables = []string{ - DNS_TABLE_NAME, EXT_CLIENT_TABLE_NAME, SSO_STATE_CACHE, METRICS_TABLE_NAME, diff --git a/logic/dns.go b/logic/dns.go index a9bd40763..faf95503d 100644 --- a/logic/dns.go +++ b/logic/dns.go @@ -2,11 +2,9 @@ package logic import ( "context" - "encoding/json" "errors" "fmt" "net" - "os" "regexp" "sort" "strings" @@ -19,7 +17,9 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" + "gorm.io/datatypes" ) const ( @@ -103,13 +103,13 @@ func CreateFallbackNameserver(networkID string) error { } // GetDNS - gets the DNS of a current network -func GetDNS(network string) ([]models.DNSEntry, error) { +func GetDNS(ctx context.Context, network string) ([]schema.DNSEntry, error) { - dns, err := GetNodeDNS(network) + dns, err := GetNodeDNS(ctx, network) if err != nil && !database.IsEmptyRecord(err) { return dns, err } - customdns, err := GetCustomDNS(network) + customdns, err := GetCustomDNS(ctx, network) if err != nil && !database.IsEmptyRecord(err) { return dns, err } @@ -118,7 +118,7 @@ func GetDNS(network string) ([]models.DNSEntry, error) { return dns, nil } -func EgressDNs(network string) (entries []models.DNSEntry) { +func EgressDNs(network string) (entries []schema.DNSEntry) { egs, _ := (&schema.Egress{ Network: network, }).ListByNetwork(db.WithContext(context.TODO())) @@ -128,7 +128,7 @@ func EgressDNs(network string) (entries []models.DNSEntry) { } if IsDomainBasedEgress(egI) && HasEgressDomainAns(egI) { for _, name := range ConfiguredDomainsForEgress(egI) { - entry := models.DNSEntry{ + entry := schema.DNSEntry{ Name: name, } for _, domainAns := range DomainAnsForDomain(egI, name) { @@ -148,32 +148,10 @@ func EgressDNs(network string) (entries []models.DNSEntry) { return } -// GetExtclientDNS - gets all extclients dns entries -func GetExtclientDNS() []models.DNSEntry { - extclients, err := GetAllExtClients() - if err != nil { - return []models.DNSEntry{} - } - var dns []models.DNSEntry - for _, extclient := range extclients { - var entry = models.DNSEntry{} - entry.Name = fmt.Sprintf("%s.%s", extclient.ClientID, extclient.Network) - entry.Network = extclient.Network - if extclient.Address != "" { - entry.Address = extclient.Address - } - if extclient.Address6 != "" { - entry.Address6 = extclient.Address6 - } - dns = append(dns, entry) - } - return dns -} - // GetNodeDNS - gets the DNS of a network node -func GetNodeDNS(network string) ([]models.DNSEntry, error) { +func GetNodeDNS(ctx context.Context, network string) ([]schema.DNSEntry, error) { - var dns []models.DNSEntry + var dns []schema.DNSEntry nodes, err := GetNetworkNodes(network) if err != nil { @@ -187,11 +165,11 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) { host := &schema.Host{ ID: node.HostID, } - err = host.Get(db.WithContext(context.TODO())) + err = host.Get(ctx) if err != nil { continue } - var entry = models.DNSEntry{} + var entry = schema.DNSEntry{} if defaultDomain == "" { entry.Name = fmt.Sprintf("%s.%s", host.Name, network) } else { @@ -204,7 +182,7 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) { if node.Address6.IP != nil { entry.Address6 = node.Address6.IP.String() } - entry.Type = models.DNSEntryType_Node + entry.Type = schema.DNSEntryType_Node dns = append(dns, entry) } @@ -244,98 +222,35 @@ func SetDNSOnWgConfig(gwNode *models.Node, extclient *models.ExtClient) { } // GetCustomDNS - gets the custom DNS of a network -func GetCustomDNS(network string) ([]models.DNSEntry, error) { - - var dns []models.DNSEntry - - collection, err := database.FetchRecords(database.DNS_TABLE_NAME) +func GetCustomDNS(ctx context.Context, network string) ([]schema.DNSEntry, error) { + records, err := (&schema.DNS{NetworkID: network}).ListByNetwork(ctx) if err != nil { - return dns, err + return nil, err } defaultDomain := GetDefaultDomain() - for _, value := range collection { // filter for entries based on network - var entry models.DNSEntry - if err := json.Unmarshal([]byte(value), &entry); err != nil { - continue - } - - if entry.Network == network { - if defaultDomain != "" { - entry.Name = fmt.Sprintf("%s.%s", entry.Name, defaultDomain) - } - entry.Type = models.DNSEntryType_Custom - dns = append(dns, entry) - } - } - - return dns, err -} - -func DeleteNetworkDNS(network string) error { - records, err := database.FetchRecords(database.DNS_TABLE_NAME) - if err != nil { - if database.IsEmptyRecord(err) { - return nil - } - - return err - } - - for key, record := range records { - var entry models.DNSEntry - err := json.Unmarshal([]byte(record), &entry) - if err != nil { - continue - } - - if entry.Network == network { - _ = database.DeleteRecord(database.DNS_TABLE_NAME, key) + dns := make([]schema.DNSEntry, 0, len(records)) + for _, r := range records { + entry := r.Value.Data() + if defaultDomain != "" { + entry.Name = fmt.Sprintf("%s.%s", entry.Name, defaultDomain) } + entry.Type = schema.DNSEntryType_Custom + dns = append(dns, entry) } - - return nil -} - -// SetCorefile - sets the core file of the system -func SetCorefile(domains string) error { - dir, err := os.Getwd() - if err != nil { - return err - } - - err = os.MkdirAll(dir+"/config/dnsconfig", 0744) - if err != nil { - logger.Log(0, "couldnt find or create /config/dnsconfig") - return err - } - - corefile := domains + ` { - reload 15s - hosts /root/dnsconfig/netmaker.hosts { - fallthrough - } - forward . 8.8.8.8 8.8.4.4 - log -} -` - err = os.WriteFile(dir+"/config/dnsconfig/Corefile", []byte(corefile), 0644) - if err != nil { - return err - } - return err + return dns, nil } // GetAllDNS - gets all dns entries -func GetAllDNS() ([]models.DNSEntry, error) { - var dns []models.DNSEntry - networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) +func GetAllDNS(ctx context.Context) ([]schema.DNSEntry, error) { + var dns []schema.DNSEntry + networks, err := (&schema.Network{}).ListAll(ctx) if err != nil { - return []models.DNSEntry{}, err + return []schema.DNSEntry{}, err } for _, net := range networks { - netdns, err := GetDNS(net.Name) + netdns, err := GetDNS(ctx, net.Name) if err != nil { - return []models.DNSEntry{}, nil + return []schema.DNSEntry{}, nil } dns = append(dns, netdns...) } @@ -347,7 +262,7 @@ func GetDNSEntryNum(domain string, network string) (int, error) { num := 0 - entries, err := GetDNS(network) + entries, err := GetDNS(scope.Default(context.TODO()), network) if err != nil { return 0, err } @@ -363,7 +278,7 @@ func GetDNSEntryNum(domain string, network string) (int, error) { } // SortDNSEntrys - Sorts slice of DNSEnteys by their Address alphabetically with numbers first -func SortDNSEntrys(unsortedDNSEntrys []models.DNSEntry) { +func SortDNSEntrys(unsortedDNSEntrys []schema.DNSEntry) { sort.Slice(unsortedDNSEntrys, func(i, j int) bool { return unsortedDNSEntrys[i].Address < unsortedDNSEntrys[j].Address }) @@ -376,7 +291,7 @@ func IsDNSEntryValid(d string) bool { } // ValidateDNSCreate - checks if an entry is valid -func ValidateDNSCreate(entry models.DNSEntry) error { +func ValidateDNSCreate(entry schema.DNSEntry) error { if !IsDNSEntryValid(entry.Name) { return errors.New("invalid input. Only uppercase letters (A-Z), lowercase letters (a-z), numbers (0-9), minus sign (-) and dots (.) are allowed") } @@ -407,7 +322,7 @@ func ValidateDNSCreate(entry models.DNSEntry) error { } // ValidateDNSUpdate - validates a DNS update -func ValidateDNSUpdate(change models.DNSEntry, entry models.DNSEntry) error { +func ValidateDNSUpdate(change schema.DNSEntry, entry schema.DNSEntry) error { v := validator.New() @@ -445,25 +360,22 @@ func DeleteDNS(domain string, network string) error { if err != nil { return err } - err = database.DeleteRecord(database.DNS_TABLE_NAME, key) - return err + return (&schema.DNS{Key: key}).Delete(db.WithContext(context.TODO())) } // CreateDNS - creates a DNS entry -func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) { - entry.Type = models.DNSEntryType_Custom +func CreateDNS(entry schema.DNSEntry) (schema.DNSEntry, error) { + entry.Type = schema.DNSEntryType_Custom k, err := GetRecordKey(entry.Name, entry.Network) if err != nil { - return models.DNSEntry{}, err + return schema.DNSEntry{}, err } - - data, err := json.Marshal(&entry) - if err != nil { - return models.DNSEntry{}, err + d := &schema.DNS{ + Key: k, + NetworkID: entry.Network, + Value: datatypes.NewJSONType(entry), } - - err = database.Insert(k, string(data), database.DNS_TABLE_NAME) - return entry, err + return entry, d.Create(db.WithContext(context.TODO())) } func validateNameserverReq(ns *schema.Nameserver) error { diff --git a/logic/networks.go b/logic/networks.go index c5086294b..f44aa6410 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -32,7 +32,7 @@ func DeleteNetwork(network string, force bool, done chan struct{}) error { } } - _ = DeleteNetworkDNS(network) + _ = (&schema.DNS{NetworkID: network}).DeleteByNetwork(db.WithContext(context.TODO())) }() nodeCount, err := GetNetworkNonServerNodeCount(network) diff --git a/migrate/migrate.go b/migrate/migrate.go index e79b23566..c38d353fa 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -679,17 +679,11 @@ func cleanUpDeleteNetworksRefs() { networksMap[network.Name] = true } - records, _ := database.FetchRecords(database.DNS_TABLE_NAME) - for key, record := range records { - var entry models.DNSEntry - err := json.Unmarshal([]byte(record), &entry) - if err != nil { - continue - } - - _, ok := networksMap[entry.Network] + entries, _ := (&schema.DNS{}).ListAll(db.WithContext(context.TODO())) + for _, entry := range entries { + _, ok := networksMap[entry.Value.Data().Network] if !ok { - _ = database.DeleteRecord(database.DNS_TABLE_NAME, key) + _ = entry.Delete(db.WithContext(context.TODO())) } } diff --git a/migrate/migrate_multitenancy.go b/migrate/migrate_multitenancy.go index 470d29b45..4e920222d 100644 --- a/migrate/migrate_multitenancy.go +++ b/migrate/migrate_multitenancy.go @@ -134,7 +134,7 @@ func backfillDNSNetworkID(ctx context.Context) error { return fmt.Errorf("multitenancy migration: list dns records: %w", err) } for key, value := range records { - var entry models.DNSEntry + var entry schema.DNSEntry if err := json.Unmarshal([]byte(value), &entry); err != nil { return fmt.Errorf("multitenancy migration: parse dns record %s: %w", key, err) } diff --git a/models/dnsEntry.go b/models/dnsEntry.go index 67bfba71b..b266401fb 100644 --- a/models/dnsEntry.go +++ b/models/dnsEntry.go @@ -40,22 +40,6 @@ type DNSUpdate struct { NewAddress string } -type DNSEntryType string - -const ( - DNSEntryType_Node = "node" - DNSEntryType_Custom = "custom" -) - -// DNSEntry - a DNS entry represented as struct -type DNSEntry struct { - Type DNSEntryType `json:"type"` - Address string `json:"address" validate:"omitempty,ip"` - Address6 string `json:"address6" validate:"omitempty,ip"` - Name string `json:"name" validate:"required,name_unique,min=1,max=192,whitespace"` - Network string `json:"network" validate:"network_exists"` -} - type NameserverReq struct { Name string `json:"name"` Network string `json:"network"` diff --git a/mq/publishers.go b/mq/publishers.go index a9183f829..4b31d3e3d 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -19,6 +19,7 @@ import ( "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" ) @@ -403,7 +404,7 @@ func sendPeers() { func SendDNSSyncByNetwork(network string) error { - k, err := logic.GetDNS(network) + k, err := logic.GetDNS(scope.Default(context.TODO()), network) k = append(k, logic.EgressDNs(network)...) if err == nil && len(k) > 0 { err = PushSyncDNS(k) @@ -419,7 +420,7 @@ func sendDNSSync() error { networks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) if err == nil && len(networks) > 0 { for _, v := range networks { - k, err := logic.GetDNS(v.Name) + k, err := logic.GetDNS(scope.Default(context.TODO()), v.Name) k = append(k, logic.EgressDNs(v.Name)...) if err == nil && len(k) > 0 { err = PushSyncDNS(k) @@ -433,7 +434,7 @@ func sendDNSSync() error { return err } -func PushSyncDNS(dnsEntries []models.DNSEntry) error { +func PushSyncDNS(dnsEntries []schema.DNSEntry) error { logger.Log(2, "----> Pushing Sync DNS") data, err := json.Marshal(dnsEntries) if err != nil { diff --git a/schema/models.go b/schema/models.go index 85d525915..d1de16e63 100644 --- a/schema/models.go +++ b/schema/models.go @@ -25,5 +25,6 @@ func ListModels() []interface{} { &PostureCheckViolation{}, &Integration{}, &EnrollmentKey{}, + &DNS{}, } } diff --git a/scope/scope.go b/scope/scope.go index 4f78b2c49..6be59838b 100644 --- a/scope/scope.go +++ b/scope/scope.go @@ -156,6 +156,9 @@ func WithContext(ctx context.Context, level Level, ids ...string) context.Contex } } +// Default returns a default tenant context. +// TODO: this is a temporary function. remove it and all it's usages. +// TODO: tenant context setting MUST be explicit. func Default(ctx context.Context) context.Context { defaultTenant := &schema.Tenant{} err := defaultTenant.GetDefault(db.WithContext(context.TODO())) From b2a98ef7777aa61cb22e36842bf5fb3602f28d98 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 16:43:25 +0530 Subject: [PATCH 20/21] wip(go): move acl, metrics, extclient and tag records db management to schema pkg; --- auth/host_session.go | 4 +- cli/cmd/ext_client/list.go | 4 +- cli/functions/ext_client.go | 17 +- cli/functions/metrics.go | 9 +- controllers/acls.go | 94 +++--- controllers/egress.go | 2 +- controllers/enrollmentkeys.go | 4 +- controllers/ext_client.go | 30 +- controllers/hosts.go | 8 +- controllers/user.go | 6 +- database/database.go | 9 - logic/acls.go | 446 ++++++++++++++-------------- logic/clients.go | 10 +- logic/dns.go | 5 +- logic/egress.go | 32 +- logic/enrollmentkey.go | 4 +- logic/extpeers.go | 109 +++---- logic/gateway.go | 8 +- logic/hosts.go | 4 +- logic/metrics.go | 8 +- logic/networks.go | 2 +- logic/nodes.go | 10 +- logic/peers.go | 14 +- logic/relay.go | 2 +- logic/settings.go | 22 +- logic/telemetry.go | 4 +- logic/usage.go | 2 +- logic/user_mgmt.go | 16 +- migrate/migrate.go | 44 +-- migrate/migrate_multitenancy.go | 9 +- migrate/migrate_v1_6_0.go | 38 +-- models/acl.go | 111 ++----- models/api_node.go | 36 +-- models/enrollment_key.go | 48 +-- models/extclient.go | 89 ++---- models/host.go | 2 +- models/metrics.go | 28 +- models/mqtt.go | 10 +- models/node.go | 56 ++-- models/settings.go | 48 --- models/structs.go | 12 +- models/tags.go | 33 +- mq/handlers.go | 2 +- mq/publishers.go | 19 +- orchestrator/extensions/node.go | 5 +- orchestrator/node.go | 6 +- pro/controllers/auto_relay.go | 2 +- pro/controllers/metrics.go | 9 +- pro/controllers/posture_check.go | 2 +- pro/controllers/tags.go | 14 +- pro/controllers/users.go | 12 +- pro/logic/acls.go | 210 ++++++------- pro/logic/dns.go | 4 +- pro/logic/egress.go | 5 +- pro/logic/metrics.go | 56 ++-- pro/logic/migrate.go | 11 +- pro/logic/nodes.go | 22 +- pro/logic/posture_check.go | 28 +- pro/logic/status.go | 4 +- pro/logic/tags.go | 68 ++--- pro/logic/user_mgmt.go | 96 +++--- pro/orchestrator/extensions/node.go | 3 +- pro/remote_access_client.go | 4 +- schema/acl.go | 136 +++++++++ schema/extclient.go | 116 ++++++++ schema/metrics.go | 83 ++++++ schema/models.go | 4 + schema/nodes.go | 7 + schema/org_settings.go | 4 + schema/posture_check.go | 9 + schema/tag.go | 78 +++++ test/utils/tag.go | 14 +- 72 files changed, 1302 insertions(+), 1180 deletions(-) create mode 100644 schema/acl.go create mode 100644 schema/extclient.go create mode 100644 schema/metrics.go create mode 100644 schema/tag.go diff --git a/auth/host_session.go b/auth/host_session.go index 9650d9bf3..8bd5d81bf 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -231,9 +231,9 @@ func SessionHandler(conn *websocket.Conn) { func CheckNetRegAndHostUpdate(key schema.EnrollmentKey, host *schema.Host, username string) { // publish host update through MQ featureFlags := logic.GetFeatureFlags() - keyTags := make(map[models.TagID]struct{}) + keyTags := make(map[schema.TagID]struct{}) for _, tagI := range key.Tags { - keyTags[models.TagID(tagI)] = struct{}{} + keyTags[schema.TagID(tagI)] = struct{}{} } for _, netID := range key.Networks { network := &schema.Network{Name: netID} diff --git a/cli/cmd/ext_client/list.go b/cli/cmd/ext_client/list.go index 7c36b66e9..dc6eb59f1 100644 --- a/cli/cmd/ext_client/list.go +++ b/cli/cmd/ext_client/list.go @@ -7,7 +7,7 @@ import ( "github.com/gravitl/netmaker/cli/cmd/commons" "github.com/gravitl/netmaker/cli/functions" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" "github.com/guumaster/tablewriter" "github.com/spf13/cobra" ) @@ -20,7 +20,7 @@ var extClientListCmd = &cobra.Command{ Short: "List External Clients", Long: `List External Clients`, Run: func(cmd *cobra.Command, args []string) { - var data []models.ExtClient + var data []schema.ExtClient if networkName != "" { data = *functions.GetNetworkExtClients(networkName) } else { diff --git a/cli/functions/ext_client.go b/cli/functions/ext_client.go index 085e6a4d6..b8a154294 100644 --- a/cli/functions/ext_client.go +++ b/cli/functions/ext_client.go @@ -5,21 +5,22 @@ import ( "net/http" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) // GetAllExtClients - fetch all external clients -func GetAllExtClients() *[]models.ExtClient { - return request[[]models.ExtClient](http.MethodGet, "/api/extclients", nil) +func GetAllExtClients() *[]schema.ExtClient { + return request[[]schema.ExtClient](http.MethodGet, "/api/extclients", nil) } // GetNetworkExtClients - fetch external clients associated with a network -func GetNetworkExtClients(networkName string) *[]models.ExtClient { - return request[[]models.ExtClient](http.MethodGet, "/api/extclients/"+networkName, nil) +func GetNetworkExtClients(networkName string) *[]schema.ExtClient { + return request[[]schema.ExtClient](http.MethodGet, "/api/extclients/"+networkName, nil) } // GetExtClient - fetch a single external client -func GetExtClient(networkName, clientID string) *models.ExtClient { - return request[models.ExtClient](http.MethodGet, fmt.Sprintf("/api/extclients/%s/%s", networkName, clientID), nil) +func GetExtClient(networkName, clientID string) *schema.ExtClient { + return request[schema.ExtClient](http.MethodGet, fmt.Sprintf("/api/extclients/%s/%s", networkName, clientID), nil) } // GetExtClientConfig - fetch a wireguard config of an external client @@ -43,6 +44,6 @@ func DeleteExtClient(networkName, clientID string) *models.SuccessResponse { } // UpdateExtClient - update an external client -func UpdateExtClient(networkName, clientID string, payload *models.CustomExtClient) *models.ExtClient { - return request[models.ExtClient](http.MethodPut, fmt.Sprintf("/api/extclients/%s/%s", networkName, clientID), payload) +func UpdateExtClient(networkName, clientID string, payload *models.CustomExtClient) *schema.ExtClient { + return request[schema.ExtClient](http.MethodPut, fmt.Sprintf("/api/extclients/%s/%s", networkName, clientID), payload) } diff --git a/cli/functions/metrics.go b/cli/functions/metrics.go index aea0ed25e..ce1ede64c 100644 --- a/cli/functions/metrics.go +++ b/cli/functions/metrics.go @@ -5,11 +5,12 @@ import ( "net/http" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) // GetNodeMetrics - fetch a single node's metrics -func GetNodeMetrics(networkName, nodeID string) *models.Metrics { - return request[models.Metrics](http.MethodGet, fmt.Sprintf("/api/metrics/%s/%s", networkName, nodeID), nil) +func GetNodeMetrics(networkName, nodeID string) *schema.Metrics { + return request[schema.Metrics](http.MethodGet, fmt.Sprintf("/api/metrics/%s/%s", networkName, nodeID), nil) } // GetNetworkNodeMetrics - fetch an entire network's metrics @@ -23,6 +24,6 @@ func GetAllMetrics() *models.NetworkMetrics { } // GetNetworkExtMetrics - fetch external client metrics belonging to a network -func GetNetworkExtMetrics(networkName string) *map[string]models.Metric { - return request[map[string]models.Metric](http.MethodGet, "/api/metrics-ext/"+networkName, nil) +func GetNetworkExtMetrics(networkName string) *map[string]schema.Metric { + return request[map[string]schema.Metric](http.MethodGet, "/api/metrics-ext/"+networkName, nil) } diff --git a/controllers/acls.go b/controllers/acls.go index 12a87537b..c74722064 100644 --- a/controllers/acls.go +++ b/controllers/acls.go @@ -44,100 +44,100 @@ func aclHandlers(r *mux.Router) { // @Failure 500 {object} models.ErrorResponse func aclPolicyTypes(w http.ResponseWriter, r *http.Request) { resp := models.AclPolicyTypes{ - RuleTypes: []models.AclPolicyType{ - models.DevicePolicy, - models.UserPolicy, + RuleTypes: []schema.AclPolicyType{ + schema.DevicePolicy, + schema.UserPolicy, }, - SrcGroupTypes: []models.AclGroupType{ - models.UserAclID, - models.UserGroupAclID, - models.NodeTagID, - models.NodeID, + SrcGroupTypes: []schema.AclGroupType{ + schema.UserAclID, + schema.UserGroupAclID, + schema.NodeTagID, + schema.NodeID, }, - DstGroupTypes: []models.AclGroupType{ - models.NodeTagID, - models.NodeID, - models.EgressID, - models.NetmakerIPAclID, - // models.NetmakerSubNetRangeAClID, + DstGroupTypes: []schema.AclGroupType{ + schema.NodeTagID, + schema.NodeID, + schema.EgressID, + schema.NetmakerIPAclID, + // schema.NetmakerSubNetRangeAClID, }, ProtocolTypes: []models.ProtocolType{ { Name: models.Any, - AllowedProtocols: []models.Protocol{ - models.ALL, + AllowedProtocols: []schema.Protocol{ + schema.ALL, }, PortRange: "All ports", AllowPortSetting: false, }, { Name: models.Http, - AllowedProtocols: []models.Protocol{ - models.TCP, + AllowedProtocols: []schema.Protocol{ + schema.TCP, }, PortRange: "80", }, { Name: models.Https, - AllowedProtocols: []models.Protocol{ - models.TCP, + AllowedProtocols: []schema.Protocol{ + schema.TCP, }, PortRange: "443", }, // { // Name: "MySQL", - // AllowedProtocols: []models.Protocol{ - // models.TCP, + // AllowedProtocols: []schema.Protocol{ + // schema.TCP, // }, // PortRange: "3306", // }, // { // Name: "DNS TCP", - // AllowedProtocols: []models.Protocol{ - // models.TCP, + // AllowedProtocols: []schema.Protocol{ + // schema.TCP, // }, // PortRange: "53", // }, // { // Name: "DNS UDP", - // AllowedProtocols: []models.Protocol{ - // models.UDP, + // AllowedProtocols: []schema.Protocol{ + // schema.UDP, // }, // PortRange: "53", // }, { Name: models.AllTCP, - AllowedProtocols: []models.Protocol{ - models.TCP, + AllowedProtocols: []schema.Protocol{ + schema.TCP, }, PortRange: "All ports", }, { Name: models.AllUDP, - AllowedProtocols: []models.Protocol{ - models.UDP, + AllowedProtocols: []schema.Protocol{ + schema.UDP, }, PortRange: "All ports", }, { Name: models.ICMPService, - AllowedProtocols: []models.Protocol{ - models.ICMP, + AllowedProtocols: []schema.Protocol{ + schema.ICMP, }, PortRange: "", }, { Name: models.SSH, - AllowedProtocols: []models.Protocol{ - models.TCP, + AllowedProtocols: []schema.Protocol{ + schema.TCP, }, PortRange: "22", }, { Name: models.Custom, - AllowedProtocols: []models.Protocol{ - models.UDP, - models.TCP, + AllowedProtocols: []schema.Protocol{ + schema.UDP, + schema.TCP, }, PortRange: "All ports", AllowPortSetting: true, @@ -163,7 +163,7 @@ func aclDebug(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - peer = extclient.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&extclient) } else { peer, err = logic.GetNodeByID(peerID) @@ -175,7 +175,7 @@ func aclDebug(w http.ResponseWriter, r *http.Request) { type resp struct { IsNodeAllowed bool IsPeerAllowed bool - Policies []models.Acl + Policies []schema.Acl IngressRules []models.FwRule NodeAllPolicy bool EgressNets map[string]models.Node @@ -203,7 +203,7 @@ func aclDebug(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Produce json // @Param network query string true "Network ID" -// @Success 200 {array} models.Acl +// @Success 200 {array} schema.Acl // @Failure 500 {object} models.ErrorResponse func getAcls(w http.ResponseWriter, r *http.Request) { netID := r.URL.Query().Get("network") @@ -234,7 +234,7 @@ func getAcls(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Produce json // @Param egress_id query string true "Egress ID" -// @Success 200 {array} models.Acl +// @Success 200 {array} schema.Acl // @Failure 500 {object} models.ErrorResponse func getEgressAcls(w http.ResponseWriter, r *http.Request) { eID := r.URL.Query().Get("egress_id") @@ -266,12 +266,12 @@ func getEgressAcls(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Accept json // @Produce json -// @Param body body models.Acl true "ACL policy details" -// @Success 200 {object} models.Acl +// @Param body body schema.Acl true "ACL policy details" +// @Success 200 {object} schema.Acl // @Failure 400 {object} models.ErrorResponse // @Failure 500 {object} models.ErrorResponse func createAcl(w http.ResponseWriter, r *http.Request) { - var req models.Acl + var req schema.Acl err := json.NewDecoder(r.Body).Decode(&req) if err != nil { logger.Log(0, "error decoding request body: ", @@ -302,7 +302,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) { acl.Default = false if acl.ServiceType == models.Any { acl.Port = []string{} - acl.Proto = models.ALL + acl.Proto = schema.ALL } // validate create acl policy if err := logic.IsAclPolicyValid(acl); err != nil { @@ -336,7 +336,7 @@ func createAcl(w http.ResponseWriter, r *http.Request) { Origin: schema.Dashboard, }) go mq.PublishPeerUpdate(true) - acls := []models.Acl{acl} + acls := []schema.Acl{acl} logic.PopulateAclPolicyTagNames(acls) logic.ReturnSuccessResponseWithJson(w, r, acls[0], "created acl successfully") } @@ -421,7 +421,7 @@ func updateAcl(w http.ResponseWriter, r *http.Request) { logic.ReturnSuccessResponse(w, r, "updated acl "+acl.Name) return } - acls := []models.Acl{updatedAcl} + acls := []schema.Acl{updatedAcl} logic.PopulateAclPolicyTagNames(acls) logic.ReturnSuccessResponseWithJson(w, r, acls[0], "updated acl "+acl.Name) } diff --git a/controllers/egress.go b/controllers/egress.go index a66f117c9..4c4f3cac6 100644 --- a/controllers/egress.go +++ b/controllers/egress.go @@ -525,7 +525,7 @@ func deleteEgress(w http.ResponseWriter, r *http.Request) { for _, acl := range acls { for i := len(acl.Dst) - 1; i >= 0; i-- { - if acl.Dst[i].ID == models.EgressID && acl.Dst[i].Value == id { + if acl.Dst[i].ID == schema.EgressID && acl.Dst[i].Value == id { acl.Dst = append(acl.Dst[:i], acl.Dst[i+1:]...) } } diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index 22e48118e..cfd8422cb 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -520,9 +520,9 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { return } - keyTags := make(map[models.TagID]struct{}) + keyTags := make(map[schema.TagID]struct{}) for _, tagI := range enrollmentKey.Tags { - keyTags[models.TagID(tagI)] = struct{}{} + keyTags[schema.TagID(tagI)] = struct{}{} } var joinNetworks []string diff --git a/controllers/ext_client.go b/controllers/ext_client.go index c06c929ae..027a03d77 100644 --- a/controllers/ext_client.go +++ b/controllers/ext_client.go @@ -69,13 +69,13 @@ func checkIngressExists(nodeID string) bool { // @Security oauth // @Produce json // @Param network path string true "Network ID" -// @Success 200 {array} models.ExtClient +// @Success 200 {array} schema.ExtClient // @Failure 500 {object} models.ErrorResponse func getNetworkExtClients(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") - var extclients []models.ExtClient + var extclients []schema.ExtClient var params = mux.Vars(r) network := params["network"] extclients, err := logic.GetNetworkExtClients(network) @@ -98,7 +98,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) { } err := userRole.Get(r.Context()) if err != nil || !userRole.FullAccess { - filtered := []models.ExtClient{} + filtered := []schema.ExtClient{} for _, ec := range extclients { if logic.IsUserAllowedAccessToExtClient(username, ec) { filtered = append(filtered, ec) @@ -121,7 +121,7 @@ func getNetworkExtClients(w http.ResponseWriter, r *http.Request) { // @Tags Config Files // @Security oauth // @Produce json -// @Success 200 {array} models.ExtClient +// @Success 200 {array} schema.ExtClient // @Failure 500 {object} models.ErrorResponse func getAllExtClients(w http.ResponseWriter, r *http.Request) { @@ -148,7 +148,7 @@ func getAllExtClients(w http.ResponseWriter, r *http.Request) { // @Produce json // @Param network path string true "Network ID" // @Param clientid path string true "Client ID" -// @Success 200 {object} models.ExtClient +// @Success 200 {object} schema.ExtClient // @Failure 500 {object} models.ErrorResponse // @Failure 403 {object} models.ErrorResponse func getExtClient(w http.ResponseWriter, r *http.Request) { @@ -191,7 +191,7 @@ func getExtClient(w http.ResponseWriter, r *http.Request) { // @Param clientid path string true "Client ID" // @Param type path string true "Config type (qr or file)" // @Param preferredip query string false "Preferred endpoint IP" -// @Success 200 {object} models.ExtClient +// @Success 200 {object} schema.ExtClient // @Failure 500 {object} models.ErrorResponse // @Failure 403 {object} models.ErrorResponse func getExtClientConf(w http.ResponseWriter, r *http.Request) { @@ -433,7 +433,7 @@ Endpoint = %s // @Param network path string true "Network ID" // @Param nodeid path string true "Node ID (Ingress Gateway)" // @Param body body models.CustomExtClient true "Custom ext client parameters" -// @Success 200 {object} models.ExtClient +// @Success 200 {object} schema.ExtClient // @Failure 500 {object} models.ErrorResponse // @Failure 400 {object} models.ErrorResponse // @Failure 403 {object} models.ErrorResponse @@ -548,14 +548,14 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { } } - extclient := logic.UpdateExtClient(&models.ExtClient{}, &customExtClient) + extclient := logic.UpdateExtClient(&schema.ExtClient{}, &customExtClient) extclient.OwnerID = userName extclient.RemoteAccessClientID = customExtClient.RemoteAccessClientID extclient.IngressGatewayID = nodeid extclient.Network = node.Network - extclient.Tags = make(map[models.TagID]struct{}) - // extclient.Tags[models.TagID(fmt.Sprintf("%s.%s", extclient.Network, + extclient.Tags = make(map[schema.TagID]struct{}) + // extclient.Tags[schema.TagID(fmt.Sprintf("%s.%s", extclient.Network, // models.RemoteAccessTagName))] = struct{}{} // set extclient dns to ingressdns if extclient dns is not explicitly gwDNS := logic.GetGwDNS(&node) @@ -603,7 +603,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { if extclient.DeviceID != "" { // check for violations connecting from desktop app - staticNode := extclient.ConvertToStaticNode() + staticNode := models.ConvertToStaticNode(&extclient) violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(extclient.Network)) if len(violations) > 0 { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden)) @@ -809,7 +809,7 @@ func createExtClient(w http.ResponseWriter, r *http.Request) { // @Param network path string true "Network ID" // @Param clientid path string true "Client ID" // @Param body body models.CustomExtClient true "Custom ext client update" -// @Success 200 {object} models.ExtClient +// @Success 200 {object} schema.ExtClient // @Failure 500 {object} models.ErrorResponse // @Failure 400 {object} models.ErrorResponse // @Failure 403 {object} models.ErrorResponse @@ -819,7 +819,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { var params = mux.Vars(r) var update models.CustomExtClient - //var oldExtClient models.ExtClient + //var oldExtClient schema.ExtClient var replacePeers bool err := json.NewDecoder(r.Body).Decode(&update) if err != nil { @@ -875,7 +875,7 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) { newclient := logic.UpdateExtClient(&oldExtClient, &update) if newclient.DeviceID != "" && newclient.Enabled { // check for violations connecting from desktop app - staticNode := newclient.ConvertToStaticNode() + staticNode := models.ConvertToStaticNode(&newclient) violations, _ := logic.CheckPostureViolations(logic.GetPostureCheckDeviceInfoByNode(&staticNode), schema.NetworkID(newclient.Network)) if len(violations) > 0 { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("posture check violations"), logic.Forbidden)) @@ -1062,7 +1062,7 @@ func bulkDeleteExtClients(w http.ResponseWriter, r *http.Request) { go func() { deleted := 0 - gwDeletedClients := make(map[string][]models.ExtClient) + gwDeletedClients := make(map[string][]schema.ExtClient) for _, clientID := range req.IDs { extclient, err := logic.GetExtClient(clientID, network) if err != nil { diff --git a/controllers/hosts.go b/controllers/hosts.go index 2adf37445..3a9a5fe31 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -583,7 +583,7 @@ func hostUpdateFallback(w http.ResponseWriter, r *http.Request) { nodes := make([]models.Node, 0, len(extclients)+1) nodes = append(nodes, node) for _, extclient := range extclients { - nodes = append(nodes, extclient.ConvertToStaticNode()) + nodes = append(nodes, models.ConvertToStaticNode(&extclient)) } nodesWithStatus := logic.AddStatusToNodes(nodes, true) @@ -1738,9 +1738,9 @@ func approvePendingHost(w http.ResponseWriter, r *http.Request) { return } - keyTags := make(map[models.TagID]struct{}) + keyTags := make(map[schema.TagID]struct{}) for _, tagI := range key.Tags { - keyTags[models.TagID(tagI)] = struct{}{} + keyTags[schema.TagID(tagI)] = struct{}{} } violations, _ := logic.CheckPostureViolations( @@ -1834,7 +1834,7 @@ func addDefaultHostToNetworks(host *schema.Host) { KernelVersion: host.KernelVersion, AutoUpdate: host.AutoUpdate, SkipAutoUpdate: true, - Tags: make(map[models.TagID]struct{}), + Tags: make(map[schema.TagID]struct{}), }, schema.NetworkID(network.Name), ) diff --git a/controllers/user.go b/controllers/user.go index 01bda2d54..d4454af17 100644 --- a/controllers/user.go +++ b/controllers/user.go @@ -1788,7 +1788,7 @@ func bulkDeleteUsers(w http.ResponseWriter, r *http.Request) { logic.ReturnAcceptedResponse(w, r, fmt.Sprintf("bulk delete of %d user(s) accepted", len(req.IDs))) go func() { - ownerExtClients := make(map[string][]models.ExtClient) + ownerExtClients := make(map[string][]schema.ExtClient) extclients, err := logic.GetAllExtClients() if err != nil { slog.Error("bulk user delete: failed to get extclients", "error", err) @@ -1911,13 +1911,13 @@ func bulkUpdateUserStatus(w http.ResponseWriter, r *http.Request) { logic.ReturnAcceptedResponse(w, r, fmt.Sprintf("bulk %s of %d user(s) accepted", action, len(req.IDs))) go func() { - var ownerExtClients map[string][]models.ExtClient + var ownerExtClients map[string][]schema.ExtClient if forceToggle { extclients, err := logic.GetAllExtClients() if err != nil { slog.Error("bulk user status: failed to get extclients", "error", err) } else { - ownerExtClients = make(map[string][]models.ExtClient, len(req.IDs)) + ownerExtClients = make(map[string][]schema.ExtClient, len(req.IDs)) for _, ec := range extclients { ownerExtClients[ec.OwnerID] = append(ownerExtClients[ec.OwnerID], ec) } diff --git a/database/database.go b/database/database.go index 27b59d233..28b0a3194 100644 --- a/database/database.go +++ b/database/database.go @@ -10,18 +10,12 @@ import ( const ( // == Table Names == - // EXT_CLIENT_TABLE_NAME - ext client table - EXT_CLIENT_TABLE_NAME = "extclients" // ACLS_TABLE_NAME - table for acls v2 ACLS_TABLE_NAME = "acls" // SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins SSO_STATE_CACHE = "ssostatecache" - // METRICS_TABLE_NAME - stores network metrics - METRICS_TABLE_NAME = "metrics" // CACHE_TABLE_NAME - caching table CACHE_TABLE_NAME = "cache" - // TAG_TABLE_NAME - table for tags - TAG_TABLE_NAME = "tags" // SERVER_SETTINGS - table for server settings SERVER_SETTINGS = "server_settings" // == ERROR CONSTS == @@ -52,11 +46,8 @@ const ( ) var Tables = []string{ - EXT_CLIENT_TABLE_NAME, SSO_STATE_CACHE, - METRICS_TABLE_NAME, CACHE_TABLE_NAME, - TAG_TABLE_NAME, ACLS_TABLE_NAME, SERVER_SETTINGS, } diff --git a/logic/acls.go b/logic/acls.go index 067ebf60a..c9e96d637 100644 --- a/logic/acls.go +++ b/logic/acls.go @@ -2,7 +2,6 @@ package logic import ( "context" - "encoding/json" "errors" "fmt" "maps" @@ -17,6 +16,7 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" + "gorm.io/datatypes" ) var GetFwRulesForNodeAndPeerOnGw = getFwRulesForNodeAndPeerOnGw @@ -31,12 +31,12 @@ var getEgressByNetwork = func(network string) ([]schema.Egress, error) { e := schema.Egress{Network: network} return e.ListByNetwork(db.WithContext(context.Background())) } -var getDevicePoliciesByNetwork = func(netID schema.NetworkID) []models.Acl { +var getDevicePoliciesByNetwork = func(netID schema.NetworkID) []schema.Acl { return ListDevicePolicies(netID) } // listNetworkExtClients fetches all extclients in a network; tests may override. -var listNetworkExtClients = func(network string) ([]models.ExtClient, error) { +var listNetworkExtClients = func(network string) ([]schema.ExtClient, error) { return GetNetworkExtClients(network) } @@ -59,9 +59,9 @@ var GetUserAclRulesForNode = func(targetnode *models.Node, var GetFwRulesForUserNodesOnGw = func(node models.Node, nodes []models.Node) (rules []models.FwRule) { return } -func getEgressToEgressPoliciesForNode(targetnode models.Node) []models.Acl { +func getEgressToEgressPoliciesForNode(targetnode models.Node) []schema.Acl { policies := getDevicePoliciesByNetwork(schema.NetworkID(targetnode.Network)) - filtered := make([]models.Acl, 0) + filtered := make([]schema.Acl, 0) for _, policy := range policies { if !policy.Enabled { continue @@ -74,7 +74,7 @@ func getEgressToEgressPoliciesForNode(targetnode models.Node) []models.Acl { return filtered } -func isEgressToEgressPolicyForTarget(policy models.Acl, targetnode models.Node) bool { +func isEgressToEgressPolicyForTarget(policy schema.Acl, targetnode models.Node) bool { srcEgresses := getEgressesFromPolicyTags(policy.Src, targetnode.Network) if len(srcEgresses) == 0 { return false @@ -93,7 +93,7 @@ func isEgressToEgressPolicyForTarget(policy models.Acl, targetnode models.Node) return targetRoutesSrcEgress || targetRoutesDstEgress } -func getEgressesFromPolicyTags(tags []models.AclPolicyTag, network string) []schema.Egress { +func getEgressesFromPolicyTags(tags []schema.AclPolicyTag, network string) []schema.Egress { egresses := make([]schema.Egress, 0) seen := make(map[string]struct{}) for _, tag := range tags { @@ -110,7 +110,7 @@ func getEgressesFromPolicyTags(tags []models.AclPolicyTag, network string) []sch seen[e.ID] = struct{}{} egresses = append(egresses, e) } - case tag.ID == models.EgressID || tag.ID == models.EgressRange: + case tag.ID == schema.EgressID || tag.ID == schema.EgressRange: e, err := getEgressByID(tag.Value) if err != nil { continue @@ -149,7 +149,7 @@ func targetNodeRoutesAnyEgress(targetnode models.Node, egresses []schema.Egress) // pair routes the matching egress, otherwise the dst-side router would never // add the src-side router as a peer (callers query symmetrically as // (X, Y) and (Y, X)) and the handshake would silently never occur. -func IsEgressRoutingPolicyAllowedForNodes(policy models.Acl, node, peer models.Node) bool { +func IsEgressRoutingPolicyAllowedForNodes(policy schema.Acl, node, peer models.Node) bool { srcEgresses := getEgressesFromPolicyTags(policy.Src, node.Network) if len(srcEgresses) == 0 { return false @@ -207,8 +207,8 @@ func egressSiteToSiteRuleKey(aclID string, reverse bool, idx int, _ int) string func appendEgressSiteToSiteRules( rules map[string]models.AclRule, - acl models.Acl, - direction models.AllowedTrafficDirection, + acl schema.Acl, + direction schema.AllowedTrafficDirection, v4pairs, v6pairs []struct{ Src, Dst net.IPNet }, reverse bool, ) { @@ -346,7 +346,7 @@ func getEgressAclRulesForTargetNode(targetnode models.Node) map[string]models.Ac appendEgressSiteToSiteRules(rules, acl, acl.AllowedDirection, v4pairs, v6pairs, false) - if acl.AllowedDirection == models.TrafficDirectionBi { + if acl.AllowedDirection == schema.TrafficDirectionBi { revSrcIP4 := append([]net.IPNet(nil), dstIP4...) revSrcIP6 := append([]net.IPNet(nil), dstIP6...) if srcRouted && !dstRouted && dstNat { @@ -367,7 +367,7 @@ func getEgressAclRulesForTargetNode(targetnode models.Node) map[string]models.Ac if len(revV4) == 0 && len(revV6) == 0 { continue } - appendEgressSiteToSiteRules(rules, acl, models.TrafficDirectionUni, revV4, revV6, true) + appendEgressSiteToSiteRules(rules, acl, schema.TrafficDirectionUni, revV4, revV6, true) } } return rules @@ -445,7 +445,7 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { return string(rules[i].DstIP.IP.To16()) < string(rules[j].DstIP.IP.To16()) }) }() - defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) nodes, _ := GetNetworkNodes(node.Network) nodes = append(nodes, GetStaticNodesByNetwork(schema.NetworkID(node.Network), true)...) rules = GetFwRulesForUserNodesOnGw(node, nodes) @@ -459,14 +459,14 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { if relayedNode.Address.IP != nil { rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: relayedNode.AddressIPNet4(), SrcIP: node.NetworkRange, }) rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: node.NetworkRange, @@ -476,14 +476,14 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { if relayedNode.Address6.IP != nil { rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: relayedNode.AddressIPNet6(), SrcIP: node.NetworkRange6, }) rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: node.NetworkRange6, @@ -567,16 +567,16 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { continue } if peer.IsStatic { - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } - var allowedPolicies1 []models.Acl + var allowedPolicies1 []schema.Acl var ok bool - if ok, allowedPolicies1 = IsNodeAllowedToCommunicate(nodeI.StaticNode.ConvertToStaticNode(), peer, true); ok { - rules = append(rules, GetFwRulesForNodeAndPeerOnGw(nodeI.StaticNode.ConvertToStaticNode(), peer, allowedPolicies1)...) + if ok, allowedPolicies1 = IsNodeAllowedToCommunicate(models.ConvertToStaticNode(&nodeI.StaticNode), peer, true); ok { + rules = append(rules, GetFwRulesForNodeAndPeerOnGw(models.ConvertToStaticNode(&nodeI.StaticNode), peer, allowedPolicies1)...) } - if ok, allowedPolicies2 := IsNodeAllowedToCommunicate(peer, nodeI.StaticNode.ConvertToStaticNode(), true); ok { + if ok, allowedPolicies2 := IsNodeAllowedToCommunicate(peer, models.ConvertToStaticNode(&nodeI.StaticNode), true); ok { rules = append(rules, - GetFwRulesForNodeAndPeerOnGw(peer, nodeI.StaticNode.ConvertToStaticNode(), + GetFwRulesForNodeAndPeerOnGw(peer, models.ConvertToStaticNode(&nodeI.StaticNode), getUniquePolicies(allowedPolicies1, allowedPolicies2))...) } } @@ -605,14 +605,14 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { if relayedNode.Address.IP != nil { rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: relayedNode.AddressIPNet4(), SrcIP: node.NetworkRange, }) rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: node.NetworkRange, @@ -622,14 +622,14 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { if relayedNode.Address6.IP != nil { rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: relayedNode.AddressIPNet6(), SrcIP: node.NetworkRange6, }) rules = append(rules, models.FwRule{ - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, DstIP: node.NetworkRange6, @@ -642,7 +642,7 @@ func GetFwRulesOnIngressGateway(node models.Node) (rules []models.FwRule) { return } -func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) { +func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []schema.Acl) (rules []models.FwRule) { for _, policy := range allowedPolicies { // if static peer dst rule not for ingress node -> skip @@ -673,7 +673,7 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode Allow: true, }) } - if policy.AllowedDirection == models.TrafficDirectionBi { + if policy.AllowedDirection == schema.TrafficDirectionBi { if node.Address.IP != nil { rules = append(rules, models.FwRule{ SrcIP: net.IPNet{ @@ -764,7 +764,7 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode // add egress range rules selectedIP4, selectedIP6 := getSelectedEgressIPNets(policy.Dst) for _, dstI := range policy.Dst { - if dstI.ID == models.EgressID { + if dstI.ID == schema.EgressID { e := schema.Egress{ID: dstI.Value} err := e.Get(db.WithContext(context.TODO())) @@ -915,7 +915,7 @@ func getFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode return } -func getUniquePolicies(policies1, policies2 []models.Acl) []models.Acl { +func getUniquePolicies(policies1, policies2 []schema.Acl) []schema.Acl { policies1Map := make(map[string]struct{}) for _, policy1I := range policies1 { policies1Map[policy1I.ID] = struct{}{} @@ -940,8 +940,8 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) { defer func() { sortIPs(ips) }() - defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy) - defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.UserPolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) extclients := GetStaticNodesByNetwork(schema.NetworkID(node.Network), false) for _, extclient := range extclients { @@ -966,7 +966,7 @@ func GetStaticNodeIps(node models.Node) (ips []net.IP) { var CleanupGwsMigration = func() {} -var CheckIfAnyPolicyisUniDirectional = func(targetNode models.Node, acls []models.Acl) bool { +var CheckIfAnyPolicyisUniDirectional = func(targetNode models.Node, acls []schema.Acl) bool { return false } @@ -981,8 +981,8 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu if IsNodeAllowedToCommunicateWithAllRsrcs(targetnode) { aclRule := models.AclRule{ ID: fmt.Sprintf("%s-all-allowed-node-rule", targetnode.ID.String()), - AllowedProtocol: models.ALL, - Direction: models.TrafficDirectionBi, + AllowedProtocol: schema.ALL, + Direction: schema.TrafficDirectionBi, Allowed: true, IPList: []net.IPNet{targetnode.NetworkRange}, IP6List: []net.IPNet{targetnode.NetworkRange6}, @@ -1019,14 +1019,14 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu rules[aclRule.ID] = aclRule return } - var taggedNodes map[models.TagID][]models.Node + var taggedNodes map[schema.TagID][]models.Node if targetnode.IsIngressGateway { taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), false) } else { taggedNodes = GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true) } acls := getDevicePoliciesByNetwork(schema.NetworkID(targetnode.Network)) - var targetNodeTags = make(map[models.TagID]struct{}) + var targetNodeTags = make(map[schema.TagID]struct{}) if targetnode.Mutex != nil { targetnode.Mutex.Lock() targetNodeTags = maps.Clone(targetnode.Tags) @@ -1035,9 +1035,9 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu targetNodeTags = maps.Clone(targetnode.Tags) } if targetNodeTags == nil { - targetNodeTags = make(map[models.TagID]struct{}) + targetNodeTags = make(map[schema.TagID]struct{}) } - targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{} + targetNodeTags[schema.TagID(targetnode.ID.String())] = struct{}{} targetNodeTags["*"] = struct{}{} for _, acl := range acls { if !acl.Enabled { @@ -1058,7 +1058,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu // srcTags lets the existing taggedNodes[NodeID] / GetNodeByID(NodeID) // resolution add their mesh AddressIPNet4/6 to the rule's IPList. for _, src := range acl.Src { - if src.ID != models.EgressID { + if src.ID != schema.EgressID { continue } e, err := getEgressByID(src.Value) @@ -1123,7 +1123,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu } break } - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e, err := getEgressByID(dst.Value) if err == nil && e.Status && len(e.Nodes) > 0 { nodeOwnsEgress := false @@ -1190,7 +1190,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu aclRule.Dst6 = append(aclRule.Dst6, egressRanges6...) } for nodeTag := range targetNodeTags { - if acl.AllowedDirection == models.TrafficDirectionBi { + if acl.AllowedDirection == schema.TrafficDirectionBi { var existsInSrcTag bool var existsInDstTag bool @@ -1214,7 +1214,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu continue } // Get peers in the tags and add allowed rules - nodes := taggedNodes[models.TagID(dst)] + nodes := taggedNodes[schema.TagID(dst)] if dst != targetnode.ID.String() { node, err := GetNodeByID(dst) if err == nil { @@ -1251,7 +1251,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu continue } // Get peers in the tags and add allowed rules - nodes := taggedNodes[models.TagID(src)] + nodes := taggedNodes[schema.TagID(src)] if src != targetnode.ID.String() { node, err := GetNodeByID(src) if err == nil { @@ -1289,7 +1289,7 @@ func GetAclRulesForNode(targetnodeI *models.Node) (rules map[string]models.AclRu continue } // Get peers in the tags and add allowed rules - nodes := taggedNodes[models.TagID(src)] + nodes := taggedNodes[schema.TagID(src)] for _, node := range nodes { if node.ID == targetnode.ID { continue @@ -1352,8 +1352,8 @@ func GetEgressDefaultAllowAllFwRule(node models.Node) (models.AclRule, bool) { } rule := models.AclRule{ ID: fmt.Sprintf("%s-egress-all-rsrc-mesh", node.ID.String()), - AllowedProtocol: models.ALL, - Direction: models.TrafficDirectionBi, + AllowedProtocol: schema.ALL, + Direction: schema.TrafficDirectionBi, Allowed: true, } if node.NetworkRange.IP != nil { @@ -1379,11 +1379,11 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR taggedNodes := GetTagMapWithNodesByNetwork(schema.NetworkID(targetnode.Network), true) acls := getDevicePoliciesByNetwork(schema.NetworkID(targetnode.Network)) - var targetNodeTags = make(map[models.TagID]struct{}) - targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{} + var targetNodeTags = make(map[schema.TagID]struct{}) + targetNodeTags[schema.TagID(targetnode.ID.String())] = struct{}{} targetNodeTags["*"] = struct{}{} if targetnode.IsGw && !servercfg.IsPro { - targetNodeTags[models.TagID(fmt.Sprintf("%s.%s", targetnode.Network, models.GwTagName))] = struct{}{} + targetNodeTags[schema.TagID(fmt.Sprintf("%s.%s", targetnode.Network, schema.GwTagName))] = struct{}{} } egs, _ := getEgressByNetwork(targetnode.Network) @@ -1460,7 +1460,7 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR // get all src tags for src := range srcTags { // Get peers in the tags and add allowed rules - nodes := taggedNodes[models.TagID(src)] + nodes := taggedNodes[schema.TagID(src)] for _, node := range nodes { if node.ID == targetnode.ID { continue @@ -1493,7 +1493,7 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR // (egress IPs/range -> src devices) emitted as a separate AclRule, because the // downstream firewall generator pairs IPList x Dst rather than expanding the // Bi direction into both legs. - if acl.AllowedDirection == models.TrafficDirectionBi && + if acl.AllowedDirection == schema.TrafficDirectionBi && (len(aclRule.Dst) > 0 || len(aclRule.Dst6) > 0) { revID := acl.ID + egressSiteACLReverseSuffix rules[revID] = models.AclRule{ @@ -1536,7 +1536,7 @@ func GetEgressRulesForNode(targetnode models.Node) (rules map[string]models.AclR // rules keyed by acl.ID, and a "-reverse" companion is added for Bi policies. func appendExtClientRemoteEgressFwdRules( targetnode models.Node, - acls []models.Acl, + acls []schema.Acl, remoteEgresses map[string]schema.Egress, ) map[string]models.AclRule { out := make(map[string]models.AclRule) @@ -1652,7 +1652,7 @@ func appendExtClientRemoteEgressFwdRules( } out[ruleID] = aclRule - if acl.AllowedDirection == models.TrafficDirectionBi && + if acl.AllowedDirection == schema.TrafficDirectionBi && (len(aclRule.Dst) > 0 || len(aclRule.Dst6) > 0) { revID := ruleID + egressSiteACLReverseSuffix out[revID] = models.AclRule{ @@ -1680,7 +1680,7 @@ func appendExtClientRemoteEgressFwdRules( // keys, and a "-reverse" companion is added for Bi policies. func appendDeviceRemoteEgressFwdRules( targetnode models.Node, - acls []models.Acl, + acls []schema.Acl, remoteEgresses map[string]schema.Egress, ) map[string]models.AclRule { out := make(map[string]models.AclRule) @@ -1740,7 +1740,7 @@ func appendDeviceRemoteEgressFwdRules( } out[ruleID] = aclRule - if acl.AllowedDirection == models.TrafficDirectionBi && + if acl.AllowedDirection == schema.TrafficDirectionBi && (len(aclRule.Dst) > 0 || len(aclRule.Dst6) > 0) { revID := ruleID + egressSiteACLReverseSuffix out[revID] = models.AclRule{ @@ -1771,7 +1771,7 @@ func getExtClientEgressFwRulesOnIngressGw(node models.Node) (rules []models.FwRu if err != nil { return } - var attached []models.ExtClient + var attached []schema.ExtClient for _, ec := range extclients { if !ec.Enabled { continue @@ -1905,7 +1905,7 @@ func getDeviceEgressFwRulesOnIngressGw(node models.Node) (rules []models.FwRule) // egress range, and domain answers take precedence over Range when configured. func computeEgressDstsForAcl( nodeID string, - acl models.Acl, + acl schema.Acl, egByID map[string]schema.Egress, ) (dst4, dst6 []net.IPNet) { dstTags := ConvAclTagToValueMap(acl.Dst) @@ -1962,7 +1962,7 @@ func computeEgressDstsForAcl( // against the dst CIDRs. For Bi-directional acls a reverse leg is also emitted // so return traffic from the egress range back to the source is allowed. // Zero-valued IPNets (IP == nil) are treated as "this address family not present". -func emitEgressFwRulesForSrc(acl models.Acl, src4, src6 net.IPNet, dst4, dst6 []net.IPNet) (rules []models.FwRule) { +func emitEgressFwRulesForSrc(acl schema.Acl, src4, src6 net.IPNet, dst4, dst6 []net.IPNet) (rules []models.FwRule) { if src4.IP != nil { for _, cidr := range dst4 { rules = append(rules, models.FwRule{ @@ -1985,7 +1985,7 @@ func emitEgressFwRulesForSrc(acl models.Acl, src4, src6 net.IPNet, dst4, dst6 [] }) } } - if acl.AllowedDirection != models.TrafficDirectionBi { + if acl.AllowedDirection != schema.TrafficDirectionBi { return } if src4.IP != nil { @@ -2016,7 +2016,7 @@ func emitEgressFwRulesForSrc(acl models.Acl, src4, src6 net.IPNet, dst4, dst6 [] // extclientMatchesAclSrc reports whether an extclient is permitted as a source by an acl, // matching on its ClientID or any of its tags (mirroring how AddTagMapWithStaticNodes // keys the tag map). -func extclientMatchesAclSrc(ec models.ExtClient, srcTags map[string]struct{}, srcAll bool) bool { +func extclientMatchesAclSrc(ec schema.ExtClient, srcTags map[string]struct{}, srcAll bool) bool { if srcAll { return true } @@ -2061,9 +2061,9 @@ func GetAclRuleForInetGw(targetnode models.Node) (rules map[string]models.AclRul if targetnode.IsInternetGateway { aclRule := models.AclRule{ ID: fmt.Sprintf("%s-inet-gw-internal-rule", targetnode.ID.String()), - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, - Direction: models.TrafficDirectionBi, + Direction: schema.TrafficDirectionBi, Allowed: true, } if targetnode.NetworkRange.IP != nil { @@ -2159,13 +2159,13 @@ func cidrContainsCIDR(parent, child *net.IPNet) bool { return parent.Contains(last) } -func NormalizeAndValidateAclEgressIPs(acl *models.Acl) error { +func NormalizeAndValidateAclEgressIPs(acl *schema.Acl) error { if acl == nil { return nil } egressCIDRs := []*net.IPNet{} for _, dst := range acl.Dst { - if dst.ID != models.EgressID || dst.Value == "*" { + if dst.ID != schema.EgressID || dst.Value == "*" { continue } e, err := getEgressByID(dst.Value) @@ -2182,7 +2182,7 @@ func NormalizeAndValidateAclEgressIPs(acl *models.Acl) error { egressCIDRs = append(egressCIDRs, cidr) } for i := range acl.Dst { - if acl.Dst[i].ID != models.NetmakerIPAclID { + if acl.Dst[i].ID != schema.NetmakerIPAclID { continue } if len(egressCIDRs) == 0 { @@ -2210,7 +2210,7 @@ func NormalizeAndValidateAclEgressIPs(acl *models.Acl) error { } srcEgressCIDRs := []*net.IPNet{} for _, src := range acl.Src { - if src.ID != models.EgressID || src.Value == "*" { + if src.ID != schema.EgressID || src.Value == "*" { continue } e, err := getEgressByID(src.Value) @@ -2227,7 +2227,7 @@ func NormalizeAndValidateAclEgressIPs(acl *models.Acl) error { srcEgressCIDRs = append(srcEgressCIDRs, cidr) } for i := range acl.Src { - if acl.Src[i].ID != models.NetmakerIPAclID { + if acl.Src[i].ID != schema.NetmakerIPAclID { continue } if len(srcEgressCIDRs) == 0 { @@ -2256,9 +2256,9 @@ func NormalizeAndValidateAclEgressIPs(acl *models.Acl) error { return nil } -func getSelectedEgressIPNets(dstTags []models.AclPolicyTag) (dst4, dst6 []net.IPNet) { +func getSelectedEgressIPNets(dstTags []schema.AclPolicyTag) (dst4, dst6 []net.IPNet) { for _, dst := range dstTags { - if dst.ID != models.NetmakerIPAclID { + if dst.ID != schema.NetmakerIPAclID { continue } normalized, err := NormalizeIPOrCIDR(dst.Value) @@ -2278,10 +2278,10 @@ func getSelectedEgressIPNets(dstTags []models.AclPolicyTag) (dst4, dst6 []net.IP return } -func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) { +func checkIfAclTagisValid(a schema.Acl, t schema.AclPolicyTag, isSrc bool) (err error) { switch t.ID { - case models.NodeID: - if a.RuleType == models.UserPolicy && isSrc { + case schema.NodeID: + if a.RuleType == schema.UserPolicy && isSrc { return errors.New("user policy source mismatch") } _, nodeErr := GetNodeByID(t.Value) @@ -2291,7 +2291,7 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err return errors.New("invalid node " + t.Value) } } - case models.EgressID, models.EgressRange: + case schema.EgressID, schema.EgressRange: e := schema.Egress{ ID: t.Value, } @@ -2299,7 +2299,7 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err if err != nil { return errors.New("invalid egress") } - case models.NetmakerIPAclID: + case schema.NetmakerIPAclID: _, err := NormalizeIPOrCIDR(t.Value) if err != nil { return err @@ -2310,20 +2310,20 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err return nil } -var IsAclPolicyValid = func(acl models.Acl) (err error) { +var IsAclPolicyValid = func(acl schema.Acl) (err error) { //check if src and dst are valid - if acl.AllowedDirection == models.TrafficDirectionUni { + if acl.AllowedDirection == schema.TrafficDirectionUni { return errors.New("uni traffic flow not allowed on CE") } switch acl.RuleType { - case models.DevicePolicy: + case schema.DevicePolicy: for _, srcI := range acl.Src { if srcI.Value == "*" { continue } - if srcI.ID == models.NodeTagID && srcI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName) { + if srcI.ID == schema.NodeTagID && srcI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName) { continue } if err = checkIfAclTagisValid(acl, srcI, true); err != nil { @@ -2335,7 +2335,7 @@ var IsAclPolicyValid = func(acl models.Acl) (err error) { if dstI.Value == "*" { continue } - if dstI.ID == models.NodeTagID && dstI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName) { + if dstI.ID == schema.NodeTagID && dstI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName) { continue } if err = checkIfAclTagisValid(acl, dstI, false); err != nil { @@ -2361,30 +2361,30 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool { // } if node.IsStatic { nodeId = node.StaticNode.ClientID - node = node.StaticNode.ConvertToStaticNode() + node = models.ConvertToStaticNode(&node.StaticNode) } else { nodeId = node.ID.String() } if peer.IsStatic { peerId = peer.StaticNode.ClientID - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } else { peerId = peer.ID.String() } - peerTags := make(map[models.TagID]struct{}) - nodeTags := make(map[models.TagID]struct{}) - nodeTags[models.TagID(nodeId)] = struct{}{} - peerTags[models.TagID(peerId)] = struct{}{} + peerTags := make(map[schema.TagID]struct{}) + nodeTags := make(map[schema.TagID]struct{}) + nodeTags[schema.TagID(nodeId)] = struct{}{} + peerTags[schema.TagID(peerId)] = struct{}{} if peer.IsGw { - peerTags[models.TagID(fmt.Sprintf("%s.%s", peer.Network, models.GwTagName))] = struct{}{} + peerTags[schema.TagID(fmt.Sprintf("%s.%s", peer.Network, schema.GwTagName))] = struct{}{} } if node.IsGw { - nodeTags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{} + nodeTags[schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.GwTagName))] = struct{}{} } if checkDefaultPolicy { // check default policy if all allowed return true - defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if err == nil { if defaultPolicy.Enabled { return true @@ -2411,7 +2411,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool { srcMap = ConvAclTagToValueMap(policy.Src) dstMap = ConvAclTagToValueMap(policy.Dst) for _, dst := range policy.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status { @@ -2430,7 +2430,7 @@ var IsPeerAllowed = func(node, peer models.Node, checkDefaultPolicy bool) bool { } func CheckTagGroupPolicy(srcMap, dstMap map[string]struct{}, node, peer models.Node, - nodeTags, peerTags map[models.TagID]struct{}) bool { + nodeTags, peerTags map[schema.TagID]struct{}) bool { // check for node ID if _, ok := srcMap[node.ID.String()]; ok { if _, ok = dstMap[peer.ID.String()]; ok { @@ -2497,8 +2497,8 @@ var ( DeleteAllNetworkTags = func(networkID schema.NetworkID) {} - IsUserAllowedToCommunicate = func(userName string, peer models.Node) (bool, []models.Acl) { - return false, []models.Acl{} + IsUserAllowedToCommunicate = func(userName string, peer models.Node) (bool, []schema.Acl) { + return false, []schema.Acl{} } RemoveUserFromAclPolicy = func(userName string) {} @@ -2514,7 +2514,7 @@ var ( var ( aclCacheMutex = &sync.RWMutex{} - aclCacheMap = make(map[string]models.Acl) + aclCacheMap = make(map[string]schema.Acl) aclCacheFullyLoaded atomic.Bool ) @@ -2522,14 +2522,14 @@ func MigrateAclPolicies() { acls := ListAcls() for _, acl := range acls { if acl.Proto.String() == "" { - acl.Proto = models.ALL + acl.Proto = schema.ALL acl.ServiceType = models.Any acl.Port = []string{} UpsertAcl(acl) } if !servercfg.IsPro { - if acl.AllowedDirection == models.TrafficDirectionUni { - acl.AllowedDirection = models.TrafficDirectionBi + if acl.AllowedDirection == schema.TrafficDirectionUni { + acl.AllowedDirection = schema.TrafficDirectionBi UpsertAcl(acl) } } @@ -2539,7 +2539,7 @@ func MigrateAclPolicies() { func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool { // check default policy if all allowed return true - defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if err == nil { if defaultPolicy.Enabled { return true @@ -2548,11 +2548,11 @@ func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool { var nodeId string if node.IsStatic { nodeId = node.StaticNode.ClientID - node = node.StaticNode.ConvertToStaticNode() + node = models.ConvertToStaticNode(&node.StaticNode) } else { nodeId = node.ID.String() } - var nodeTags map[models.TagID]struct{} + var nodeTags map[schema.TagID]struct{} if node.Mutex != nil { node.Mutex.Lock() nodeTags = maps.Clone(node.Tags) @@ -2561,13 +2561,13 @@ func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool { nodeTags = maps.Clone(node.Tags) } if nodeTags == nil { - nodeTags = make(map[models.TagID]struct{}) + nodeTags = make(map[schema.TagID]struct{}) } - nodeTags[models.TagID(node.ID.String())] = struct{}{} + nodeTags[schema.TagID(node.ID.String())] = struct{}{} nodeTags["*"] = struct{}{} - nodeTags[models.TagID(nodeId)] = struct{}{} + nodeTags[schema.TagID(nodeId)] = struct{}{} if !servercfg.IsPro && node.IsGw { - node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{} + node.Tags[schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.GwTagName))] = struct{}{} } // list device policies policies := ListDevicePolicies(schema.NetworkID(node.Network)) @@ -2606,34 +2606,34 @@ func IsNodeAllowedToCommunicateWithAllRsrcs(node models.Node) bool { } // IsNodeAllowedToCommunicate - check node is allowed to communicate with the peer // ADD ALLOWED DIRECTION - 0 => node -> peer, 1 => peer-> node, -func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) (bool, []models.Acl) { +func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) (bool, []schema.Acl) { var nodeId, peerId string // if peer.IsFailOver && node.FailedOverBy != uuid.Nil && node.FailedOverBy == peer.ID { - // return true, []models.Acl{} + // return true, []schema.Acl{} // } // if node.IsFailOver && peer.FailedOverBy != uuid.Nil && peer.FailedOverBy == node.ID { - // return true, []models.Acl{} + // return true, []schema.Acl{} // } // if node.IsGw && peer.IsRelayed && peer.RelayedBy == node.ID.String() { - // return true, []models.Acl{} + // return true, []schema.Acl{} // } // if peer.IsGw && node.IsRelayed && node.RelayedBy == peer.ID.String() { - // return true, []models.Acl{} + // return true, []schema.Acl{} // } if node.IsStatic { nodeId = node.StaticNode.ClientID - node = node.StaticNode.ConvertToStaticNode() + node = models.ConvertToStaticNode(&node.StaticNode) } else { nodeId = node.ID.String() } if peer.IsStatic { peerId = peer.StaticNode.ClientID - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } else { peerId = peer.ID.String() } - var nodeTags, peerTags map[models.TagID]struct{} + var nodeTags, peerTags map[schema.TagID]struct{} if node.Mutex != nil { node.Mutex.Lock() nodeTags = maps.Clone(node.Tags) @@ -2649,23 +2649,23 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) peerTags = peer.Tags } if nodeTags == nil { - nodeTags = make(map[models.TagID]struct{}) + nodeTags = make(map[schema.TagID]struct{}) } if peerTags == nil { - peerTags = make(map[models.TagID]struct{}) + peerTags = make(map[schema.TagID]struct{}) } - nodeTags[models.TagID(nodeId)] = struct{}{} - peerTags[models.TagID(peerId)] = struct{}{} + nodeTags[schema.TagID(nodeId)] = struct{}{} + peerTags[schema.TagID(peerId)] = struct{}{} if checkDefaultPolicy { // check default policy if all allowed return true - defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, err := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if err == nil { if defaultPolicy.Enabled { - return true, []models.Acl{defaultPolicy} + return true, []schema.Acl{defaultPolicy} } } } - allowedPolicies := []models.Acl{} + allowedPolicies := []schema.Acl{} defer func() { allowedPolicies = UniquePolicies(allowedPolicies) }() @@ -2689,7 +2689,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) srcMap = ConvAclTagToValueMap(policy.Src) dstMap = ConvAclTagToValueMap(policy.Dst) for _, dst := range policy.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status { @@ -2701,7 +2701,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) } _, srcAll := srcMap["*"] _, dstAll := dstMap["*"] - if policy.AllowedDirection == models.TrafficDirectionBi { + if policy.AllowedDirection == schema.TrafficDirectionBi { if _, ok := srcMap[nodeId]; ok || srcAll { if _, ok := dstMap[peerId]; ok || dstAll { allowedPolicies = append(allowedPolicies, policy) @@ -2722,7 +2722,7 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) continue } } - if policy.AllowedDirection == models.TrafficDirectionBi { + if policy.AllowedDirection == schema.TrafficDirectionBi { for tagID := range nodeTags { @@ -2792,17 +2792,17 @@ func IsNodeAllowedToCommunicate(node, peer models.Node, checkDefaultPolicy bool) } // GetDefaultPolicy - fetches default policy in the network by ruleType -func GetDefaultPolicy(netID schema.NetworkID, ruleType models.AclPolicyType) (models.Acl, error) { +func GetDefaultPolicy(netID schema.NetworkID, ruleType schema.AclPolicyType) (schema.Acl, error) { aclID := "all-users" - if ruleType == models.DevicePolicy { + if ruleType == schema.DevicePolicy { aclID = "all-nodes" } - if !servercfg.IsPro && ruleType == models.UserPolicy { - return models.Acl{Enabled: true}, nil + if !servercfg.IsPro && ruleType == schema.UserPolicy { + return schema.Acl{Enabled: true}, nil } acl, err := GetAcl(fmt.Sprintf("%s.%s", netID, aclID)) if err != nil { - return models.Acl{}, errors.New("default rule not found") + return schema.Acl{}, errors.New("default rule not found") } if acl.Enabled { return acl, nil @@ -2835,12 +2835,12 @@ func GetDefaultPolicy(netID schema.NetworkID, ruleType models.AclPolicyType) (mo } // ListAcls - lists all acl policies -func ListAclsByNetwork(netID schema.NetworkID) ([]models.Acl, error) { +func ListAclsByNetwork(netID schema.NetworkID) ([]schema.Acl, error) { allAcls := ListAcls() - netAcls := []models.Acl{} + netAcls := []schema.Acl{} for _, acl := range allAcls { - if !servercfg.IsPro && acl.RuleType == models.UserPolicy { + if !servercfg.IsPro && acl.RuleType == schema.UserPolicy { continue } if acl.NetworkID == netID { @@ -2851,15 +2851,15 @@ func ListAclsByNetwork(netID schema.NetworkID) ([]models.Acl, error) { } // ListEgressAcls - list egress acl policies -func ListEgressAcls(eID string) ([]models.Acl, error) { +func ListEgressAcls(eID string) ([]schema.Acl, error) { allAcls := ListAcls() - egressAcls := []models.Acl{} + egressAcls := []schema.Acl{} for _, acl := range allAcls { - if !servercfg.IsPro && acl.RuleType == models.UserPolicy { + if !servercfg.IsPro && acl.RuleType == schema.UserPolicy { continue } for _, dst := range acl.Dst { - if dst.ID == models.EgressID && dst.Value == eID { + if dst.ID == schema.EgressID && dst.Value == eID { egressAcls = append(egressAcls, acl) } } @@ -2868,11 +2868,11 @@ func ListEgressAcls(eID string) ([]models.Acl, error) { } // ListDevicePolicies - lists all device policies in a network -func ListDevicePolicies(netID schema.NetworkID) []models.Acl { +func ListDevicePolicies(netID schema.NetworkID) []schema.Acl { allAcls := ListAcls() - deviceAcls := []models.Acl{} + deviceAcls := []schema.Acl{} for _, acl := range allAcls { - if acl.NetworkID == netID && acl.RuleType == models.DevicePolicy { + if acl.NetworkID == netID && acl.RuleType == schema.DevicePolicy { deviceAcls = append(deviceAcls, acl) } } @@ -2880,18 +2880,18 @@ func ListDevicePolicies(netID schema.NetworkID) []models.Acl { } // ListUserPolicies - lists all user policies in a network -func ListUserPolicies(netID schema.NetworkID) []models.Acl { +func ListUserPolicies(netID schema.NetworkID) []schema.Acl { allAcls := ListAcls() - userAcls := []models.Acl{} + userAcls := []schema.Acl{} for _, acl := range allAcls { - if acl.NetworkID == netID && acl.RuleType == models.UserPolicy { + if acl.NetworkID == netID && acl.RuleType == schema.UserPolicy { userAcls = append(userAcls, acl) } } return userAcls } -func ConvAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} { +func ConvAclTagToValueMap(acltags []schema.AclPolicyTag) map[string]struct{} { aclValueMap := make(map[string]struct{}) for _, aclTagI := range acltags { aclValueMap[aclTagI.Value] = struct{}{} @@ -2899,9 +2899,9 @@ func ConvAclTagToValueMap(acltags []models.AclPolicyTag) map[string]struct{} { return aclValueMap } -func UniqueAclPolicyTags(tags []models.AclPolicyTag) []models.AclPolicyTag { +func UniqueAclPolicyTags(tags []schema.AclPolicyTag) []schema.AclPolicyTag { seen := make(map[string]bool) - var result []models.AclPolicyTag + var result []schema.AclPolicyTag for _, tag := range tags { key := fmt.Sprintf("%v-%s", tag.ID, tag.Value) @@ -2914,7 +2914,7 @@ func UniqueAclPolicyTags(tags []models.AclPolicyTag) []models.AclPolicyTag { } // UpdateAcl - updates allowed fields on acls and commits to DB -func UpdateAcl(newAcl, acl models.Acl) error { +func UpdateAcl(newAcl, acl schema.Acl) error { if !acl.Default { acl.Name = newAcl.Name acl.Src = newAcl.Src @@ -2926,14 +2926,10 @@ func UpdateAcl(newAcl, acl models.Acl) error { } if newAcl.ServiceType == models.Any { acl.Port = []string{} - acl.Proto = models.ALL + acl.Proto = schema.ALL } acl.Enabled = newAcl.Enabled - d, err := json.Marshal(acl) - if err != nil { - return err - } - err = database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME) + err := (&schema.AclEntry{Key: acl.ID, NetworkID: string(acl.NetworkID), Value: datatypes.NewJSONType(acl)}).Save(db.WithContext(context.TODO())) if err == nil && servercfg.CacheEnabled() { storeAclInCache(acl) } @@ -2941,12 +2937,8 @@ func UpdateAcl(newAcl, acl models.Acl) error { } // UpsertAcl - upserts acl -func UpsertAcl(acl models.Acl) error { - d, err := json.Marshal(acl) - if err != nil { - return err - } - err = database.Insert(acl.ID, string(d), database.ACLS_TABLE_NAME) +func UpsertAcl(acl schema.Acl) error { + err := (&schema.AclEntry{Key: acl.ID, NetworkID: string(acl.NetworkID), Value: datatypes.NewJSONType(acl)}).Save(db.WithContext(context.TODO())) if err == nil && servercfg.CacheEnabled() { storeAclInCache(acl) } @@ -2954,39 +2946,38 @@ func UpsertAcl(acl models.Acl) error { } // DeleteAcl - deletes acl policy -func DeleteAcl(a models.Acl) error { - err := database.DeleteRecord(database.ACLS_TABLE_NAME, a.ID) +func DeleteAcl(a schema.Acl) error { + err := (&schema.AclEntry{Key: a.ID}).Delete(db.WithContext(context.TODO())) if err == nil && servercfg.CacheEnabled() { removeAclFromCache(a) } return err } -func ListAcls() (acls []models.Acl) { +func ListAcls() (acls []schema.Acl) { if servercfg.CacheEnabled() && aclCacheFullyLoaded.Load() { return listAclFromCache() } - data, err := database.FetchRecords(database.ACLS_TABLE_NAME) + entries, err := (&schema.AclEntry{}).ListAll(db.WithContext(context.TODO())) if err != nil && !database.IsEmptyRecord(err) { - return []models.Acl{} + return []schema.Acl{} } if servercfg.CacheEnabled() { resetAclCacheLocked() } - for _, dataI := range data { - acl := models.Acl{} - err := json.Unmarshal([]byte(dataI), &acl) - if err != nil { + for _, entry := range entries { + acl := entry.Value.Data() + if acl.ID == "" { continue } if !servercfg.IsPro { - if acl.RuleType == models.UserPolicy { + if acl.RuleType == schema.UserPolicy { continue } skip := false for _, srcI := range acl.Src { - if srcI.ID == models.NodeTagID && (srcI.Value != "*" && srcI.Value != fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName)) { + if srcI.ID == schema.NodeTagID && (srcI.Value != "*" && srcI.Value != fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName)) { skip = true break } @@ -2996,7 +2987,7 @@ func ListAcls() (acls []models.Acl) { } for _, dstI := range acl.Dst { - if dstI.ID == models.NodeTagID && (dstI.Value != "*" && dstI.Value != fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName)) { + if dstI.ID == schema.NodeTagID && (dstI.Value != "*" && dstI.Value != fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName)) { skip = true break } @@ -3016,12 +3007,12 @@ func ListAcls() (acls []models.Acl) { return } -func UniquePolicies(items []models.Acl) []models.Acl { +func UniquePolicies(items []schema.Acl) []schema.Acl { if len(items) == 0 { return items } seen := make(map[string]bool) - var result []models.Acl + var result []schema.Acl for _, item := range items { if !seen[item.ID] { seen[item.ID] = true @@ -3043,21 +3034,21 @@ func DeleteNetworkPolicies(netId schema.NetworkID) { } // SortTagEntrys - Sorts slice of Tag entries by their id -func SortAclEntrys(acls []models.Acl) { +func SortAclEntrys(acls []schema.Acl) { sort.Slice(acls, func(i, j int) bool { return acls[i].Name < acls[j].Name }) } // PopulateAclPolicyTagNames resolves human-readable names for ACL policy tags -func PopulateAclPolicyTagNames(acls []models.Acl) { +func PopulateAclPolicyTagNames(acls []schema.Acl) { for i := range acls { populateTagNames(acls[i].Src) populateTagNames(acls[i].Dst) } } -func populateTagNames(tags []models.AclPolicyTag) { +func populateTagNames(tags []schema.AclPolicyTag) { for i := range tags { tag := &tags[i] if tag.Value == "" || tag.Value == "*" { @@ -3065,18 +3056,18 @@ func populateTagNames(tags []models.AclPolicyTag) { continue } switch tag.ID { - case models.UserAclID: + case schema.UserAclID: tag.Name = tag.Value - case models.UserGroupAclID: + case schema.UserGroupAclID: grp, err := GetUserGroup(schema.UserGroupID(tag.Value)) if err == nil { tag.Name = grp.Name } else { tag.Name = tag.Value } - case models.NodeTagID: + case schema.NodeTagID: tag.Name = tag.Value - case models.NodeID: + case schema.NodeID: node, err := GetNodeByID(tag.Value) if err == nil { host := &schema.Host{ID: node.HostID} @@ -3088,14 +3079,14 @@ func populateTagNames(tags []models.AclPolicyTag) { } else { tag.Name = tag.Value } - case models.EgressID: + case schema.EgressID: egress := schema.Egress{ID: tag.Value} if err := egress.Get(db.WithContext(context.TODO())); err == nil { tag.Name = egress.Name } else { tag.Name = tag.Value } - case models.EgressRange: + case schema.EgressRange: tag.Name = tag.Value default: tag.Name = tag.Value @@ -3104,7 +3095,7 @@ func populateTagNames(tags []models.AclPolicyTag) { } // ValidateCreateAclReq - validates create req for acl -func ValidateCreateAclReq(req models.Acl) error { +func ValidateCreateAclReq(req schema.Acl) error { // check if acl network exists err := (&schema.Network{Name: req.NetworkID.String()}).Get(db.WithContext(context.TODO())) if err != nil { @@ -3115,7 +3106,7 @@ func ValidateCreateAclReq(req models.Acl) error { // return err // } for _, src := range req.Src { - if src.ID == models.UserGroupAclID { + if src.ID == schema.UserGroupAclID { userGroup, err := GetUserGroup(schema.UserGroupID(src.Value)) if err != nil { return err @@ -3135,7 +3126,7 @@ func ValidateCreateAclReq(req models.Acl) error { return nil } -func listAclFromCache() (acls []models.Acl) { +func listAclFromCache() (acls []schema.Acl) { aclCacheMutex.RLock() defer aclCacheMutex.RUnlock() for _, acl := range aclCacheMap { @@ -3147,24 +3138,24 @@ func listAclFromCache() (acls []models.Acl) { func resetAclCacheLocked() { aclCacheMutex.Lock() defer aclCacheMutex.Unlock() - aclCacheMap = make(map[string]models.Acl) + aclCacheMap = make(map[string]schema.Acl) aclCacheFullyLoaded.Store(false) } -func storeAclInCache(a models.Acl) { +func storeAclInCache(a schema.Acl) { aclCacheMutex.Lock() defer aclCacheMutex.Unlock() aclCacheMap[a.ID] = a } -func removeAclFromCache(a models.Acl) { +func removeAclFromCache(a schema.Acl) { aclCacheMutex.Lock() defer aclCacheMutex.Unlock() delete(aclCacheMap, a.ID) } -func getAclFromCache(aID string) (a models.Acl, ok bool) { +func getAclFromCache(aID string) (a schema.Acl, ok bool) { aclCacheMutex.RLock() defer aclCacheMutex.RUnlock() a, ok = aclCacheMap[aID] @@ -3172,12 +3163,8 @@ func getAclFromCache(aID string) (a models.Acl, ok bool) { } // InsertAcl - creates acl policy -func InsertAcl(a models.Acl) error { - d, err := json.Marshal(a) - if err != nil { - return err - } - err = database.Insert(a.ID, string(d), database.ACLS_TABLE_NAME) +func InsertAcl(a schema.Acl) error { + err := (&schema.AclEntry{Key: a.ID, NetworkID: string(a.NetworkID), Value: datatypes.NewJSONType(a)}).Save(db.WithContext(context.TODO())) if err == nil && servercfg.CacheEnabled() { storeAclInCache(a) } @@ -3185,8 +3172,8 @@ func InsertAcl(a models.Acl) error { } // GetAcl - gets acl info by id -func GetAcl(aID string) (models.Acl, error) { - a := models.Acl{} +func GetAcl(aID string) (schema.Acl, error) { + a := schema.Acl{} if servercfg.CacheEnabled() { var ok bool a, ok = getAclFromCache(aID) @@ -3194,14 +3181,11 @@ func GetAcl(aID string) (models.Acl, error) { return a, nil } } - d, err := database.FetchRecord(database.ACLS_TABLE_NAME, aID) - if err != nil { - return a, err - } - err = json.Unmarshal([]byte(d), &a) - if err != nil { + entry := &schema.AclEntry{Key: aID} + if err := entry.Get(db.WithContext(context.TODO())); err != nil { return a, err } + a = entry.Value.Data() if servercfg.CacheEnabled() { storeAclInCache(a) } @@ -3225,9 +3209,9 @@ func RemoveNodeFromAclPolicy(node models.Node) { for _, acl := range acls { delete := false update := false - if acl.RuleType == models.DevicePolicy { + if acl.RuleType == schema.DevicePolicy { for i := len(acl.Src) - 1; i >= 0; i-- { - if acl.Src[i].ID == models.NodeID && acl.Src[i].Value == nodeID { + if acl.Src[i].ID == schema.NodeID && acl.Src[i].Value == nodeID { if len(acl.Src) == 1 { // delete policy delete = true @@ -3243,7 +3227,7 @@ func RemoveNodeFromAclPolicy(node models.Node) { continue } for i := len(acl.Dst) - 1; i >= 0; i-- { - if acl.Dst[i].ID == models.NodeID && acl.Dst[i].Value == nodeID { + if acl.Dst[i].ID == schema.NodeID && acl.Dst[i].Value == nodeID { if len(acl.Dst) == 1 { // delete policy delete = true @@ -3263,9 +3247,9 @@ func RemoveNodeFromAclPolicy(node models.Node) { } } - if acl.RuleType == models.UserPolicy { + if acl.RuleType == schema.UserPolicy { for i := len(acl.Dst) - 1; i >= 0; i-- { - if acl.Dst[i].ID == models.NodeID && acl.Dst[i].Value == nodeID { + if acl.Dst[i].ID == schema.NodeID && acl.Dst[i].Value == nodeID { if len(acl.Dst) == 1 { // delete policy delete = true @@ -3294,27 +3278,27 @@ func CreateDefaultAclNetworkPolicies(netID schema.NetworkID) { } _, _ = ListAclsByNetwork(netID) if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-nodes")) { - defaultDeviceAcl := models.Acl{ + defaultDeviceAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s", netID, "all-nodes"), Name: "All Nodes", MetaData: "This Policy allows all nodes in the network to communicate with each other", Default: true, NetworkID: netID, - Proto: models.ALL, + Proto: schema.ALL, ServiceType: models.Any, Port: []string{}, - RuleType: models.DevicePolicy, - Src: []models.AclPolicyTag{ + RuleType: schema.DevicePolicy, + Src: []schema.AclPolicyTag{ { - ID: models.NodeTagID, + ID: schema.NodeTagID, Value: "*", }}, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, + ID: schema.NodeTagID, Value: "*", }}, - AllowedDirection: models.TrafficDirectionBi, + AllowedDirection: schema.TrafficDirectionBi, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -3323,28 +3307,28 @@ func CreateDefaultAclNetworkPolicies(netID schema.NetworkID) { } if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-gateways")) { - defaultUserAcl := models.Acl{ + defaultUserAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s", netID, "all-gateways"), Default: true, Name: "All Gateways", NetworkID: netID, - Proto: models.ALL, + Proto: schema.ALL, ServiceType: models.Any, Port: []string{}, - RuleType: models.DevicePolicy, - Src: []models.AclPolicyTag{ + RuleType: schema.DevicePolicy, + Src: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", netID, models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", netID, schema.GwTagName), }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, + ID: schema.NodeTagID, Value: "*", }, }, - AllowedDirection: models.TrafficDirectionBi, + AllowedDirection: schema.TrafficDirectionBi, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -3354,12 +3338,12 @@ func CreateDefaultAclNetworkPolicies(netID schema.NetworkID) { CreateDefaultUserPolicies(netID) } -func getTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) { - tagNodesMap = make(map[models.TagID][]models.Node) +func getTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[schema.TagID][]models.Node) { + tagNodesMap = make(map[schema.TagID][]models.Node) nodes, _ := GetNetworkNodes(netID.String()) - netGwTag := models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.GwTagName)) + netGwTag := schema.TagID(fmt.Sprintf("%s.%s", netID.String(), schema.GwTagName)) for _, nodeI := range nodes { - tagNodesMap[models.TagID(nodeI.ID.String())] = append(tagNodesMap[models.TagID(nodeI.ID.String())], nodeI) + tagNodesMap[schema.TagID(nodeI.ID.String())] = append(tagNodesMap[schema.TagID(nodeI.ID.String())], nodeI) if nodeI.IsGw { tagNodesMap[netGwTag] = append(tagNodesMap[netGwTag], nodeI) } @@ -3372,7 +3356,7 @@ func getTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) ( } func addTagMapWithStaticNodes(netID schema.NetworkID, - tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node { + tagNodesMap map[schema.TagID][]models.Node) map[schema.TagID][]models.Node { extclients, err := GetNetworkExtClients(netID.String()) if err != nil { return tagNodesMap @@ -3381,13 +3365,13 @@ func addTagMapWithStaticNodes(netID schema.NetworkID, if extclient.RemoteAccessClientID != "" { continue } - tagNodesMap[models.TagID(extclient.ClientID)] = []models.Node{ + tagNodesMap[schema.TagID(extclient.ClientID)] = []models.Node{ { IsStatic: true, StaticNode: extclient, }, } - tagNodesMap["*"] = append(tagNodesMap["*"], extclient.ConvertToStaticNode()) + tagNodesMap["*"] = append(tagNodesMap["*"], models.ConvertToStaticNode(&extclient)) } return tagNodesMap diff --git a/logic/clients.go b/logic/clients.go index e4adb5fdf..2042f90f7 100644 --- a/logic/clients.go +++ b/logic/clients.go @@ -4,26 +4,26 @@ import ( "errors" "sort" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) // SortExtClient - Sorts slice of ExtClients by their ClientID alphabetically with numbers first -func SortExtClient(unsortedExtClient []models.ExtClient) { +func SortExtClient(unsortedExtClient []schema.ExtClient) { sort.Slice(unsortedExtClient, func(i, j int) bool { return unsortedExtClient[i].ClientID < unsortedExtClient[j].ClientID }) } // GetExtClientByName - gets an ext client by name -func GetExtClientByName(ID string) (models.ExtClient, error) { +func GetExtClientByName(ID string) (schema.ExtClient, error) { clients, err := GetAllExtClients() if err != nil { - return models.ExtClient{}, err + return schema.ExtClient{}, err } for i := range clients { if clients[i].ClientID == ID { return clients[i], nil } } - return models.ExtClient{}, errors.New("client not found") + return schema.ExtClient{}, errors.New("client not found") } diff --git a/logic/dns.go b/logic/dns.go index faf95503d..df95324af 100644 --- a/logic/dns.go +++ b/logic/dns.go @@ -17,7 +17,6 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" - "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" "gorm.io/datatypes" ) @@ -214,7 +213,7 @@ func GetGwDNS(node *models.Node) string { } -func SetDNSOnWgConfig(gwNode *models.Node, extclient *models.ExtClient) { +func SetDNSOnWgConfig(gwNode *models.Node, extclient *schema.ExtClient) { if extclient.DNS != "" { return } @@ -262,7 +261,7 @@ func GetDNSEntryNum(domain string, network string) (int, error) { num := 0 - entries, err := GetDNS(scope.Default(context.TODO()), network) + entries, err := GetDNS(db.WithContext(context.TODO()), network) if err != nil { return 0, err } diff --git a/logic/egress.go b/logic/egress.go index 8b0d7c1bf..420413623 100644 --- a/logic/egress.go +++ b/logic/egress.go @@ -136,7 +136,7 @@ func EgressDomainsEqual(a, b []string) bool { return slices.Equal(aa, bb) } -func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []models.Acl) bool { +func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []schema.Acl) bool { if !e.Status { return false } @@ -150,9 +150,9 @@ func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []mode if _, ok := dstTags[e.ID]; ok || all { // get all src tags for _, srcAcl := range acl.Src { - if srcAcl.ID == models.UserAclID && srcAcl.Value == user.Username { + if srcAcl.ID == schema.UserAclID && srcAcl.Value == user.Username { return true - } else if srcAcl.ID == models.UserGroupAclID { + } else if srcAcl.ID == schema.UserGroupAclID { // fetch all users in the group if _, ok := user.UserGroups.Data()[schema.UserGroupID(srcAcl.Value)]; ok { return true @@ -164,18 +164,18 @@ func DoesUserHaveAccessToEgress(user *schema.User, e *schema.Egress, acls []mode return false } -func DoesNodeHaveAccessToEgress(node *models.Node, e *schema.Egress, acls []models.Acl) bool { +func DoesNodeHaveAccessToEgress(node *models.Node, e *schema.Egress, acls []schema.Acl) bool { nodeTags := maps.Clone(node.Tags) - nodeTags[models.TagID(node.ID.String())] = struct{}{} - nodeTags[models.TagID("*")] = struct{}{} + nodeTags[schema.TagID(node.ID.String())] = struct{}{} + nodeTags[schema.TagID("*")] = struct{}{} for _, acl := range acls { if !acl.Enabled { continue } srcVal := ConvAclTagToValueMap(acl.Src) for _, dstI := range acl.Dst { - if (dstI.ID == models.EgressID && dstI.Value == e.ID) || (dstI.ID == models.NodeTagID && dstI.Value == "*") { - if dstI.ID == models.EgressID { + if (dstI.ID == schema.EgressID && dstI.Value == e.ID) || (dstI.ID == schema.NodeTagID && dstI.Value == "*") { + if dstI.ID == schema.EgressID { e := schema.Egress{ID: dstI.Value} err := e.Get(db.WithContext(context.TODO())) if err != nil { @@ -205,7 +205,7 @@ func DoesNodeHaveAccessToEgress(node *models.Node, e *schema.Egress, acls []mode return false } -func doesNodeHaveAccessToEgressByRoutingPolicy(node, targetNode *models.Node, e *schema.Egress, acls []models.Acl) bool { +func doesNodeHaveAccessToEgressByRoutingPolicy(node, targetNode *models.Node, e *schema.Egress, acls []schema.Acl) bool { if node == nil || targetNode == nil || e == nil { return false } @@ -225,7 +225,7 @@ func doesNodeHaveAccessToEgressByRoutingPolicy(node, targetNode *models.Node, e nodeRoutesDst := targetNodeRoutesAnyEgress(*node, dstEgresses) targetRoutesSrc := targetNodeRoutesAnyEgress(*targetNode, srcEgresses) targetRoutesDst := targetNodeRoutesAnyEgress(*targetNode, dstEgresses) - if acl.AllowedDirection == models.TrafficDirectionUni { + if acl.AllowedDirection == schema.TrafficDirectionUni { if nodeRoutesSrc && targetRoutesDst && egressListContainsID(dstEgresses, e.ID) { return true } @@ -254,7 +254,7 @@ func egressListContainsID(egresses []schema.Egress, id string) bool { // with writers on the same node (shallow copies may share the Tags map). When Mutex is nil, // tags are still read so tag-based egress matching applies; that matches patterns like // maps.Clone(node.Tags) elsewhere for nodes without an initialized mutex. -func snapshotNodeTagIDs(n *models.Node) []models.TagID { +func snapshotNodeTagIDs(n *models.Node) []schema.TagID { if n == nil { return nil } @@ -265,14 +265,14 @@ func snapshotNodeTagIDs(n *models.Node) []models.TagID { if len(n.Tags) == 0 { return nil } - out := make([]models.TagID, 0, len(n.Tags)) + out := make([]schema.TagID, 0, len(n.Tags)) for tid := range n.Tags { out = append(out, tid) } return out } -func AddEgressInfoToPeerByAccess(node, targetNode *models.Node, eli []schema.Egress, acls []models.Acl, isDefaultPolicyActive bool) { +func AddEgressInfoToPeerByAccess(node, targetNode *models.Node, eli []schema.Egress, acls []schema.Acl, isDefaultPolicyActive bool) { req := models.EgressGatewayRequest{ NodeID: targetNode.ID.String(), @@ -411,7 +411,7 @@ func AddEgressInfoToPeerByAccess(node, targetNode *models.Node, eli []schema.Egr func GetEgressDomainsByAccessForUser(user *schema.User, network schema.NetworkID) (domains []string) { acls := ListUserPolicies(network) eli, _ := (&schema.Egress{Network: network.String()}).ListByNetwork(db.WithContext(context.TODO())) - defaultDevicePolicy, _ := GetDefaultPolicy(network, models.UserPolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(network, schema.UserPolicy) isDefaultPolicyActive := defaultDevicePolicy.Enabled seen := make(map[string]struct{}) for _, e := range eli { @@ -444,7 +444,7 @@ func GetEgressDomainsByAccessForUser(user *schema.User, network schema.NetworkID func GetEgressDomainNSForNode(node *models.Node) (returnNsLi []models.Nameserver) { acls := ListDevicePolicies(schema.NetworkID(node.Network)) eli, _ := (&schema.Egress{Network: node.Network}).ListByNetwork(db.WithContext(context.TODO())) - defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) isDefaultPolicyActive := defaultDevicePolicy.Enabled for _, e := range eli { if !e.Status || e.Network != node.Network { @@ -490,7 +490,7 @@ func GetEgressDomainNSForNode(node *models.Node) (returnNsLi []models.Nameserver return } -func GetNodeEgressInfo(targetNode *models.Node, eli []schema.Egress, acls []models.Acl) { +func GetNodeEgressInfo(targetNode *models.Node, eli []schema.Egress, acls []schema.Acl) { req := models.EgressGatewayRequest{ NodeID: targetNode.ID.String(), diff --git a/logic/enrollmentkey.go b/logic/enrollmentkey.go index b4ce9fe7e..53e9589a1 100644 --- a/logic/enrollmentkey.go +++ b/logic/enrollmentkey.go @@ -39,7 +39,7 @@ var EnrollmentErrors = struct { // CreateEnrollmentKey - creates a new enrollment key in db func CreateEnrollmentKey(ctx context.Context, uses int, expiration time.Time, networks, - tags []string, groups []models.TagID, unlimited bool, relay uuid.UUID, + tags []string, groups []schema.TagID, unlimited bool, relay uuid.UUID, defaultKey, autoEgress, autoAssignGw bool) (*schema.EnrollmentKey, error) { newKeyID, err := getUniqueEnrollmentID(ctx) @@ -354,7 +354,7 @@ func DeTokenize(ctx context.Context, b64Token string) (*schema.EnrollmentKey, er return GetEnrollmentKey(ctx, newToken.Value) } -func RemoveTagFromEnrollmentKeys(deletedTagID models.TagID) { +func RemoveTagFromEnrollmentKeys(deletedTagID schema.TagID) { keys, _ := GetAllEnrollmentKeys(db.WithContext(context.TODO())) for _, key := range keys { newTags := datatypes.JSONSlice[string]{} diff --git a/logic/extpeers.go b/logic/extpeers.go index df065378a..bbaf39e34 100644 --- a/logic/extpeers.go +++ b/logic/extpeers.go @@ -2,7 +2,6 @@ package logic import ( "context" - "encoding/json" "errors" "fmt" "net" @@ -13,7 +12,6 @@ import ( "time" "github.com/goombaio/namegenerator" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" @@ -21,14 +19,15 @@ import ( "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" + "gorm.io/datatypes" ) var ( extClientCacheMutex = &sync.RWMutex{} - extClientCacheMap = make(map[string]models.ExtClient) + extClientCacheMap = make(map[string]schema.ExtClient) ) -func getAllExtClientsFromCache() (extClients []models.ExtClient) { +func getAllExtClientsFromCache() (extClients []schema.ExtClient) { extClientCacheMutex.RLock() for _, extclient := range extClientCacheMap { if extclient.Mutex == nil { @@ -46,7 +45,7 @@ func deleteExtClientFromCache(key string) { extClientCacheMutex.Unlock() } -func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) { +func getExtClientFromCache(key string) (extclient schema.ExtClient, ok bool) { extClientCacheMutex.RLock() extclient, ok = extClientCacheMap[key] if extclient.Mutex == nil { @@ -56,7 +55,7 @@ func getExtClientFromCache(key string) (extclient models.ExtClient, ok bool) { return } -func storeExtClientInCache(key string, extclient models.ExtClient) { +func storeExtClientInCache(key string, extclient schema.ExtClient) { extClientCacheMutex.Lock() if extclient.Mutex == nil { extclient.Mutex = &sync.Mutex{} @@ -66,13 +65,13 @@ func storeExtClientInCache(key string, extclient models.ExtClient) { } // ExtClient.GetEgressRangesOnNetwork - returns the egress ranges on network of ext client -func GetEgressRangesOnNetwork(client *models.ExtClient) ([]string, error) { +func GetEgressRangesOnNetwork(client *schema.ExtClient) ([]string, error) { var result []string eli, _ := (&schema.Egress{Network: client.Network}).ListByNetwork(db.WithContext(context.TODO())) - staticNode := client.ConvertToStaticNode() + staticNode := models.ConvertToStaticNode(client) userPolicies := ListUserPolicies(schema.NetworkID(client.Network)) - defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(client.Network), models.UserPolicy) + defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(client.Network), schema.UserPolicy) for _, eI := range eli { if !eI.Status { @@ -157,8 +156,7 @@ func DeleteExtClient(network string, clientid string, isUpdate bool) error { if err != nil { return err } - err = database.DeleteRecord(database.EXT_CLIENT_TABLE_NAME, key) - if err != nil { + if err = (&schema.ExtClientEntry{Key: key}).Delete(db.WithContext(context.TODO())); err != nil { return err } if servercfg.CacheEnabled() { @@ -183,12 +181,12 @@ func DeleteExtClient(network string, clientid string, isUpdate bool) error { Origin: schema.ClientApp, }) } - go RemoveNodeFromAclPolicy(extClient.ConvertToStaticNode()) + go RemoveNodeFromAclPolicy(models.ConvertToStaticNode(&extClient)) return nil } // DeleteExtClientAndCleanup - deletes an existing ext client and update ACLs -func DeleteExtClientAndCleanup(extClient models.ExtClient) error { +func DeleteExtClientAndCleanup(extClient schema.ExtClient) error { //delete extClient record err := DeleteExtClient(extClient.Network, extClient.ClientID, false) @@ -207,8 +205,8 @@ a. check against each user node, if allowed add rule */ // GetNetworkExtClients - gets the ext clients of given network -func GetNetworkExtClients(network string) ([]models.ExtClient, error) { - var extclients []models.ExtClient +func GetNetworkExtClients(network string) ([]schema.ExtClient, error) { + var extclients []schema.ExtClient if servercfg.CacheEnabled() { allextclients := getAllExtClientsFromCache() if len(allextclients) != 0 { @@ -220,35 +218,25 @@ func GetNetworkExtClients(network string) ([]models.ExtClient, error) { return extclients, nil } } - records, err := database.FetchRecords(database.EXT_CLIENT_TABLE_NAME) + entries, err := (&schema.ExtClientEntry{NetworkID: network}).ListByNetwork(db.WithContext(context.TODO())) if err != nil { - if database.IsEmptyRecord(err) { - return extclients, nil - } return extclients, err } - for _, value := range records { - var extclient models.ExtClient - err = json.Unmarshal([]byte(value), &extclient) - if err != nil { - continue - } - key, err := GetRecordKey(extclient.ClientID, extclient.Network) - if err == nil { - if servercfg.CacheEnabled() { + for _, entry := range entries { + extclient := entry.Value.Data() + if servercfg.CacheEnabled() { + if key, kerr := GetRecordKey(extclient.ClientID, extclient.Network); kerr == nil { storeExtClientInCache(key, extclient) } } - if extclient.Network == network { - extclients = append(extclients, extclient) - } + extclients = append(extclients, extclient) } - return extclients, err + return extclients, nil } // GetExtClient - gets a single ext client on a network -func GetExtClient(clientid string, network string) (models.ExtClient, error) { - var extclient models.ExtClient +func GetExtClient(clientid string, network string) (schema.ExtClient, error) { + var extclient schema.ExtClient key, err := GetRecordKey(clientid, network) if err != nil { return extclient, err @@ -258,15 +246,15 @@ func GetExtClient(clientid string, network string) (models.ExtClient, error) { return extclient, nil } } - data, err := database.FetchRecord(database.EXT_CLIENT_TABLE_NAME, key) - if err != nil { + entry := &schema.ExtClientEntry{Key: key} + if err = entry.Get(db.WithContext(context.TODO())); err != nil { return extclient, err } - err = json.Unmarshal([]byte(data), &extclient) + extclient = entry.Value.Data() if servercfg.CacheEnabled() { storeExtClientInCache(key, extclient) } - return extclient, err + return extclient, nil } func GenerateNodeName(network string) (string, error) { @@ -294,16 +282,17 @@ func GenerateNodeName(network string) (string, error) { } // SaveExtClient - saves an ext client to database -func SaveExtClient(extclient *models.ExtClient) error { +func SaveExtClient(extclient *schema.ExtClient) error { key, err := GetRecordKey(extclient.ClientID, extclient.Network) if err != nil { return err } - data, err := json.Marshal(&extclient) - if err != nil { - return err + entry := &schema.ExtClientEntry{ + Key: key, + NetworkID: extclient.Network, + Value: datatypes.NewJSONType(*extclient), } - if err = database.Insert(key, string(data), database.EXT_CLIENT_TABLE_NAME); err != nil { + if err = entry.Save(db.WithContext(context.TODO())); err != nil { return err } if servercfg.CacheEnabled() { @@ -314,7 +303,7 @@ func SaveExtClient(extclient *models.ExtClient) error { } // UpdateExtClient - updates an ext client with new values -func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) models.ExtClient { +func UpdateExtClient(old *schema.ExtClient, update *models.CustomExtClient) schema.ExtClient { new := *old new.ClientID = update.ClientID if update.PublicKey != "" && old.PublicKey != update.PublicKey { @@ -362,8 +351,8 @@ func UpdateExtClient(old *models.ExtClient, update *models.CustomExtClient) mode } // GetExtClientsByID - gets the clients of attached gateway -func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) { - var result []models.ExtClient +func GetExtClientsByID(nodeid, network string) ([]schema.ExtClient, error) { + var result []schema.ExtClient currentClients, err := GetNetworkExtClients(network) if err != nil { return result, err @@ -377,12 +366,10 @@ func GetExtClientsByID(nodeid, network string) ([]models.ExtClient, error) { } // GetAllExtClients - gets all ext clients from DB -func GetAllExtClients() ([]models.ExtClient, error) { - var clients = []models.ExtClient{} +func GetAllExtClients() ([]schema.ExtClient, error) { + var clients = []schema.ExtClient{} currentNetworks, err := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) - if err != nil && database.IsEmptyRecord(err) { - return clients, nil - } else if err != nil { + if err != nil { return clients, err } @@ -400,13 +387,13 @@ func GetAllExtClients() ([]models.ExtClient, error) { // GetAllExtClientsWithStatus - returns all external clients with // given status. -func GetAllExtClientsWithStatus(status schema.NodeStatus) ([]models.ExtClient, error) { +func GetAllExtClientsWithStatus(status schema.NodeStatus) ([]schema.ExtClient, error) { extClients, err := GetAllExtClients() if err != nil { return nil, err } - var validExtClients []models.ExtClient + var validExtClients []schema.ExtClient for _, extClient := range extClients { if extClient.Status == status { validExtClients = append(validExtClients, extClient) @@ -417,7 +404,7 @@ func GetAllExtClientsWithStatus(status schema.NodeStatus) ([]models.ExtClient, e } // ToggleExtClientConnectivity - enables or disables an ext client -func ToggleExtClientConnectivity(client *models.ExtClient, enable bool) (models.ExtClient, error) { +func ToggleExtClientConnectivity(client *schema.ExtClient, enable bool) (schema.ExtClient, error) { update := models.CustomExtClient{ Enabled: enable, ClientID: client.ClientID, @@ -464,7 +451,7 @@ func GetExtPeers(node, peer *models.Node, addressIdentityMap map[string]models.P for _, extPeer := range extPeers { extPeer := extPeer if extPeer.RemoteAccessClientID == "" { - if ok := IsPeerAllowed(extPeer.ConvertToStaticNode(), *peer, true); !ok { + if ok := IsPeerAllowed(models.ConvertToStaticNode(&extPeer), *peer, true); !ok { continue } } else { @@ -572,7 +559,7 @@ func GetExtPeers(node, peer *models.Node, addressIdentityMap map[string]models.P return peers, idsAndAddr, egressRoutes, nil } -func getExtPeerEgressRoute(node models.Node, extPeer models.ExtClient) (egressRoutes []models.EgressNetworkRoutes) { +func getExtPeerEgressRoute(node models.Node, extPeer schema.ExtClient) (egressRoutes []models.EgressNetworkRoutes) { r := models.EgressNetworkRoutes{ PeerKey: extPeer.PublicKey, EgressGwAddr: extPeer.AddressIPNet4(), @@ -601,7 +588,7 @@ func getExtpeerEgressRanges(node models.Node) (ranges, ranges6 []net.IPNet) { if len(extPeer.ExtraAllowedIPs) == 0 { continue } - if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node, true); !ok { + if ok, _ := IsNodeAllowedToCommunicate(models.ConvertToStaticNode(&extPeer), node, true); !ok { continue } for _, allowedRange := range extPeer.ExtraAllowedIPs { @@ -628,7 +615,7 @@ func getExtpeersExtraRoutes(node models.Node) (egressRoutes []models.EgressNetwo if len(extPeer.ExtraAllowedIPs) == 0 || !extPeer.Enabled { continue } - if ok, _ := IsNodeAllowedToCommunicate(extPeer.ConvertToStaticNode(), node, true); !ok { + if ok, _ := IsNodeAllowedToCommunicate(models.ConvertToStaticNode(&extPeer), node, true); !ok { continue } egressRoutes = append(egressRoutes, getExtPeerEgressRoute(node, extPeer)...) @@ -636,7 +623,7 @@ func getExtpeersExtraRoutes(node models.Node) (egressRoutes []models.EgressNetwo return } -func GetExtclientAllowedIPs(client models.ExtClient) (allowedIPs []string) { +func GetExtclientAllowedIPs(client schema.ExtClient) (allowedIPs []string) { gwnode, err := GetNodeByID(client.IngressGatewayID) if err != nil { logger.Log(0, @@ -680,7 +667,7 @@ func GetStaticNodesByNetwork(network schema.NetworkID, onlyWg bool) (staticNode if onlyWg && extI.RemoteAccessClientID != "" { continue } - staticNode = append(staticNode, extI.ConvertToStaticNode()) + staticNode = append(staticNode, models.ConvertToStaticNode(&extI)) } } @@ -688,7 +675,7 @@ func GetStaticNodesByNetwork(network schema.NetworkID, onlyWg bool) (staticNode } // CleanupOtherExtclients cleans up other clients owned by the same use for the same device and network. -func CleanupOtherExtclients(extclient *models.ExtClient) error { +func CleanupOtherExtclients(extclient *schema.ExtClient) error { extclients, err := GetNetworkExtClients(extclient.Network) if err != nil { return err diff --git a/logic/gateway.go b/logic/gateway.go index da6b4822b..49da65b5b 100644 --- a/logic/gateway.go +++ b/logic/gateway.go @@ -154,8 +154,8 @@ func GetIngressGwUsers(node models.Node) (models.IngressGwUsers, error) { } // DeleteIngressGateway - deletes an ingress gateway -func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error) { - removedClients := []models.ExtClient{} +func DeleteIngressGateway(nodeid string) (models.Node, []schema.ExtClient, error) { + removedClients := []schema.ExtClient{} node, err := GetNodeByID(nodeid) if err != nil { return models.Node{}, removedClients, err @@ -174,7 +174,7 @@ func DeleteIngressGateway(nodeid string) (models.Node, []models.ExtClient, error logger.Log(3, "deleting ingress gateway") node.LastModified = time.Now().UTC() node.IsIngressGateway = false - delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))) + delete(node.Tags, schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.GwTagName))) node.IngressGatewayRange = "" node.Metadata = "" err = UpsertNode(&node) @@ -206,7 +206,7 @@ func DeleteGatewayExtClients(gatewayID string, networkName string) error { } // IsUserAllowedAccessToExtClient - checks if user has permission to access extclient -func IsUserAllowedAccessToExtClient(username string, client models.ExtClient) bool { +func IsUserAllowedAccessToExtClient(username string, client schema.ExtClient) bool { if username == MasterUser { return true } diff --git a/logic/hosts.go b/logic/hosts.go index e18d6f7a9..5eeb14e3c 100644 --- a/logic/hosts.go +++ b/logic/hosts.go @@ -35,8 +35,8 @@ var ( ErrInvalidHostID error = errors.New("invalid host id") ) -var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network schema.NetworkID) (v []models.Violation, level schema.Severity) { - return []models.Violation{}, schema.SeverityUnknown +var CheckPostureViolations = func(d models.PostureCheckDeviceInfo, network schema.NetworkID) (v []schema.Violation, level schema.Severity) { + return []schema.Violation{}, schema.SeverityUnknown } var GetPostureCheckDeviceInfoByNode = func(node *models.Node) (d models.PostureCheckDeviceInfo) { diff --git a/logic/metrics.go b/logic/metrics.go index 8c84f40ab..01eb8d174 100644 --- a/logic/metrics.go +++ b/logic/metrics.go @@ -6,7 +6,7 @@ import ( "strconv" "time" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) type MetricsMonitor struct { @@ -73,12 +73,12 @@ var DeleteMetrics = func(string) error { return nil } -var UpdateMetrics = func(string, *models.Metrics) error { +var UpdateMetrics = func(string, *schema.Metrics) error { return nil } -var GetMetrics = func(string) (*models.Metrics, error) { - var metrics models.Metrics +var GetMetrics = func(string) (*schema.Metrics, error) { + var metrics schema.Metrics return &metrics, nil } diff --git a/logic/networks.go b/logic/networks.go index f44aa6410..806d9dede 100644 --- a/logic/networks.go +++ b/logic/networks.go @@ -552,7 +552,7 @@ var NetworkHook models.HookFunc = func(params ...interface{}) error { exists = true break } - if _, ok := node.Tags[models.TagID(tagI)]; ok { + if _, ok := node.Tags[schema.TagID(tagI)]; ok { exists = true break } diff --git a/logic/nodes.go b/logic/nodes.go index 56cbb3ea7..e24b8e466 100644 --- a/logic/nodes.go +++ b/logic/nodes.go @@ -335,7 +335,7 @@ func AddStatusToNodes(nodes []models.Node, statusCall bool) (nodesWithStatus []m for _, node := range nodes { if _, ok := aclDefaultPolicyStatusMap[node.Network]; !ok { // check default policy if all allowed return true - defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) aclDefaultPolicyStatusMap[node.Network] = defaultPolicy.Enabled } if statusCall { @@ -659,11 +659,11 @@ func ConvertSchemaNodeToModelsNode(_node *schema.Node) *models.Node { netAddr6Range = *cidr } - var violations []models.Violation + var violations []schema.Violation _violations, err := _node.ListViolations(db.WithContext(context.TODO())) if err == nil { for _, _violation := range _violations { - violations = append(violations, models.Violation{ + violations = append(violations, schema.Violation{ CheckID: _violation.CheckID, Name: _violation.Name, Attribute: _violation.Attribute, @@ -698,7 +698,7 @@ func ConvertSchemaNodeToModelsNode(_node *schema.Node) *models.Node { IsAutoRelay: _node.IsAutoRelay == "yes", AutoRelayedPeers: _node.AutoRelayedPeers.Data(), IsInternetGateway: _node.IsInternetGateway, - Tags: make(map[models.TagID]struct{}), + Tags: make(map[schema.TagID]struct{}), Status: _node.Status, PostureChecksViolations: violations, PostureCheckViolationSeverityLevel: _node.PostureCheckSeverity, @@ -745,7 +745,7 @@ func ConvertSchemaNodeToModelsNode(_node *schema.Node) *models.Node { } for tagID := range _node.Tags { - node.Tags[models.TagID(tagID)] = struct{}{} + node.Tags[schema.TagID(tagID)] = struct{}{} } return node diff --git a/logic/peers.go b/logic/peers.go index 4da5b46e4..30731a038 100644 --- a/logic/peers.go +++ b/logic/peers.go @@ -156,7 +156,7 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m continue } networkPeersInfo := make(models.PeerMap) - defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) currentPeers := GetNetworkNodesMemory(allNodes, node.Network) for _, peer := range currentPeers { @@ -213,7 +213,7 @@ func computeHostPeerInfo(host *schema.Host, allNodes []models.Node, serverInfo m } // GetPeerUpdateForHost - gets the consolidated peer update for the host from all networks -func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node, deletedHost *schema.Host, deletedNode *models.Node, deletedClients []models.ExtClient) (hostPeerUpdate models.HostPeerUpdate, err error) { +func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.Node, deletedHost *schema.Host, deletedNode *models.Node, deletedClients []schema.ExtClient) (hostPeerUpdate models.HostPeerUpdate, err error) { if host == nil { return models.HostPeerUpdate{}, errors.New("host is nil") } @@ -317,15 +317,15 @@ func GetPeerUpdateForHost(network string, host *schema.Host, allNodes []models.N hostPeerUpdate.IsInternetGw = IsInternetGw(node) } hostPeerUpdate.DnsNameservers = append(hostPeerUpdate.DnsNameservers, GetEgressDomainNSForNode(&node)...) - defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy) - defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultUserPolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.UserPolicy) + defaultDevicePolicy, _ := GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if (defaultDevicePolicy.Enabled && defaultUserPolicy.Enabled) || (!CheckIfAnyPolicyisUniDirectional(node, acls) && !(node.EgressDetails.IsEgressGateway && len(node.EgressDetails.EgressGatewayRanges) > 0)) { aclRule := models.AclRule{ ID: fmt.Sprintf("%s-allowed-network-rules", node.ID.String()), - AllowedProtocol: models.ALL, - Direction: models.TrafficDirectionBi, + AllowedProtocol: schema.ALL, + Direction: schema.TrafficDirectionBi, Allowed: true, IPList: []net.IPNet{node.NetworkRange}, IP6List: []net.IPNet{node.NetworkRange6}, @@ -793,7 +793,7 @@ func filterConflictingEgressRoutesWithMetric(node, peer models.Node) []models.Eg } // GetAllowedIPs - calculates the wireguard allowedip field for a peer of a node based on the peer and node settings -func GetAllowedIPs(node, peer *models.Node, metrics *models.Metrics) []net.IPNet { +func GetAllowedIPs(node, peer *models.Node, metrics *schema.Metrics) []net.IPNet { var allowedips []net.IPNet allowedips = getNodeAllowedIPs(peer, node) if peer.IsInternetGateway && node.InternetGwID == peer.ID.String() { diff --git a/logic/relay.go b/logic/relay.go index 13d96da5c..e0648b8f6 100644 --- a/logic/relay.go +++ b/logic/relay.go @@ -166,7 +166,7 @@ func GetAllowedIpsForRelayed(relayed, relay *models.Node) (allowedIPs []net.IPNe } acls, _ := ListAclsByNetwork(schema.NetworkID(relay.Network)) eli, _ := (&schema.Egress{Network: relay.Network}).ListByNetwork(db.WithContext(context.TODO())) - defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(relay.Network), models.DevicePolicy) + defaultPolicy, _ := GetDefaultPolicy(schema.NetworkID(relay.Network), schema.DevicePolicy) for _, peer := range peers { if peer.ID == relayed.ID || peer.ID == relay.ID { continue diff --git a/logic/settings.go b/logic/settings.go index 2aee99260..cc489dc21 100644 --- a/logic/settings.go +++ b/logic/settings.go @@ -30,7 +30,7 @@ var ( var ServerSettingsDBKey = "server_cfg" var SettingsMutex = &sync.RWMutex{} -var serverSettingsCache atomic.Value +var serverSettingsCache atomic.Pointer[schema.ServerSettingsData] var defaultUserSettings = models.UserSettings{ TextSize: "16", @@ -38,8 +38,8 @@ var defaultUserSettings = models.UserSettings{ ReducedMotion: false, } -func GetServerSettings() (s models.ServerSettings) { - if cached, ok := serverSettingsCache.Load().(*models.ServerSettings); ok && cached != nil { +func GetServerSettings() (s schema.ServerSettingsData) { + if cached := serverSettingsCache.Load(); cached != nil { return *cached } s, err := getServerSettingsFromDB() @@ -52,11 +52,11 @@ func GetServerSettings() (s models.ServerSettings) { // InvalidateServerSettingsCache clears the in-memory settings cache so // the next GetServerSettings call re-reads from the database. func InvalidateServerSettingsCache() { - serverSettingsCache.Store((*models.ServerSettings)(nil)) + serverSettingsCache.Store((*schema.ServerSettingsData)(nil)) } -func getServerSettingsFromDB() (models.ServerSettings, error) { - var s models.ServerSettings +func getServerSettingsFromDB() (schema.ServerSettingsData, error) { + var s schema.ServerSettingsData data, err := database.FetchRecord(database.SERVER_SETTINGS, ServerSettingsDBKey) if err != nil { return s, err @@ -67,7 +67,7 @@ func getServerSettingsFromDB() (models.ServerSettings, error) { return s, nil } -func UpsertServerSettings(s models.ServerSettings) error { +func UpsertServerSettings(s schema.ServerSettingsData) error { // get curr settings from DB directly (not cache) for accurate comparison currSettings, _ := getServerSettingsFromDB() if s.ClientSecret == Mask() { @@ -154,7 +154,7 @@ func UpsertUserSettings(username string, userSettings models.UserSettings) error return user.UpdateUserSettings(db.WithContext(context.TODO())) } -func ValidateNewSettings(req models.ServerSettings) error { +func ValidateNewSettings(req schema.ServerSettingsData) error { // TODO: add checks for different fields if req.JwtValidityDuration > 525600 || req.JwtValidityDuration < 5 { return ErrInvalidJwtValidityDuration @@ -171,9 +171,9 @@ func ValidateNewSettings(req models.ServerSettings) error { return nil } -func GetServerSettingsFromEnv() (s models.ServerSettings) { +func GetServerSettingsFromEnv() (s schema.ServerSettingsData) { - s = models.ServerSettings{ + s = schema.ServerSettingsData{ NetclientAutoUpdate: servercfg.AutoUpdateEnabled(), Verbosity: servercfg.GetVerbosity(), AuthProvider: os.Getenv("AUTH_PROVIDER"), @@ -371,7 +371,7 @@ func AutoUpdateEnabled() bool { } // GetAuthProviderInfo = gets the oauth provider info -func GetAuthProviderInfo(settings models.ServerSettings) (pi []string) { +func GetAuthProviderInfo(settings schema.ServerSettingsData) (pi []string) { var authProvider = "" defer func() { diff --git a/logic/telemetry.go b/logic/telemetry.go index 47f77778a..b90c09f95 100644 --- a/logic/telemetry.go +++ b/logic/telemetry.go @@ -87,7 +87,9 @@ func FetchTelemetryData() telemetryData { var data telemetryData data.IsPro = servercfg.IsPro - data.ExtClients = getDBLength(database.EXT_CLIENT_TABLE_NAME) + if cnt, err := (&schema.ExtClientEntry{}).Count(db.WithContext(context.TODO())); err == nil { + data.ExtClients = int(cnt) + } data.Users, _ = (&schema.User{}).Count(db.WithContext(context.TODO())) data.Networks, _ = (&schema.Network{}).Count(db.WithContext(context.TODO())) data.Hosts, _ = (&schema.Host{}).Count(db.WithContext(context.TODO())) diff --git a/logic/usage.go b/logic/usage.go index 4f5099f4c..cf8166119 100644 --- a/logic/usage.go +++ b/logic/usage.go @@ -25,7 +25,7 @@ func GetCurrentServerUsage() (limits models.Usage) { nodes, _ := GetAllNodes() for _, client := range clients { - nodes = append(nodes, client.ConvertToStaticNode()) + nodes = append(nodes, models.ConvertToStaticNode(&client)) } limits.NetworkUsage = make(map[string]models.NetworkUsage) diff --git a/logic/user_mgmt.go b/logic/user_mgmt.go index 1288332ea..7adf9cff8 100644 --- a/logic/user_mgmt.go +++ b/logic/user_mgmt.go @@ -50,27 +50,27 @@ var CreateDefaultUserPolicies = func(netID schema.NetworkID) { return } if !IsAclExists(fmt.Sprintf("%s.%s", netID, "all-users")) { - defaultUserAcl := models.Acl{ + defaultUserAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s", netID, "all-users"), Default: true, Name: "All Users", MetaData: "This policy gives access to everything in the network for an user", NetworkID: netID, - Proto: models.ALL, + Proto: schema.ALL, ServiceType: models.Any, Port: []string{}, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserAclID, + ID: schema.UserAclID, Value: "*", }, }, - Dst: []models.AclPolicyTag{{ - ID: models.NodeTagID, + Dst: []schema.AclPolicyTag{{ + ID: schema.NodeTagID, Value: "*", }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), diff --git a/migrate/migrate.go b/migrate/migrate.go index c38d353fa..61e3c3a16 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -250,7 +250,7 @@ func updateNodes() { extclients, _ := logic.GetAllExtClients() for _, extclient := range extclients { if extclient.Tags == nil { - extclient.Tags = make(map[models.TagID]struct{}) + extclient.Tags = make(map[schema.TagID]struct{}) logic.SaveExtClient(&extclient) } } @@ -286,9 +286,9 @@ func updateNewAcls() { enableSeparateACL := true adminAcl, err := logic.GetAcl(fmt.Sprintf("%s.%s-grp", networkID, schema.NetworkAdmin)) if err == nil { - var newAclSrc []models.AclPolicyTag + var newAclSrc []schema.AclPolicyTag for _, src := range adminAcl.Src { - if src.ID == models.UserGroupAclID && src.Value == group.ID.String() { + if src.ID == schema.UserGroupAclID && src.Value == group.ID.String() { createSeparateACL = true enableSeparateACL = adminAcl.Enabled } else { @@ -302,9 +302,9 @@ func updateNewAcls() { userAcl, err := logic.GetAcl(fmt.Sprintf("%s.%s-grp", networkID, schema.NetworkUser)) if err == nil { - var newAclSrc []models.AclPolicyTag + var newAclSrc []schema.AclPolicyTag for _, src := range userAcl.Src { - if src.ID == models.UserGroupAclID && src.Value == group.ID.String() { + if src.ID == schema.UserGroupAclID && src.Value == group.ID.String() { if !createSeparateACL { // if group src not found in adminACL, then create. createSeparateACL = true @@ -325,27 +325,27 @@ func updateNewAcls() { _ = logic.UpsertAcl(userAcl) } - expectedAcl := models.Acl{ + expectedAcl := schema.Acl{ ID: uuid.New().String(), Name: fmt.Sprintf("%s group", group.Name), MetaData: "This Policy allows user group to communicate with all gateways", Default: true, ServiceType: models.Any, NetworkID: schema.NetworkID(network.Name), - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: group.ID.String(), }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", schema.NetworkID(network.Name), models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", schema.NetworkID(network.Name), schema.GwTagName), }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -439,15 +439,15 @@ func createDefaultTagsAndPolicies() { logic.CreateDefaultTags(schema.NetworkID(network.Name)) logic.CreateDefaultAclNetworkPolicies(schema.NetworkID(network.Name)) // delete old remote access gws policy - logic.DeleteAcl(models.Acl{ID: fmt.Sprintf("%s.%s", network.Name, "all-remote-access-gws")}) + logic.DeleteAcl(schema.Acl{ID: fmt.Sprintf("%s.%s", network.Name, "all-remote-access-gws")}) } logic.MigrateAclPolicies() if !servercfg.IsPro { nodes, _ := logic.GetAllNodes() for _, node := range nodes { if node.IsGw { - node.Tags = make(map[models.TagID]struct{}) - node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{} + node.Tags = make(map[schema.TagID]struct{}) + node.Tags[schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.GwTagName))] = struct{}{} logic.UpsertNode(&node) } } @@ -527,7 +527,7 @@ func migrateSettings() { func deleteOldExtclients() { extclients, _ := logic.GetAllExtClients() - userExtclientMap := make(map[string][]models.ExtClient) + userExtclientMap := make(map[string][]schema.ExtClient) for _, extclient := range extclients { if extclient.RemoteAccessClientID == "" { continue @@ -538,7 +538,7 @@ func deleteOldExtclients() { } if _, ok := userExtclientMap[extclient.OwnerID]; !ok { - userExtclientMap[extclient.OwnerID] = make([]models.ExtClient, 0) + userExtclientMap[extclient.OwnerID] = make([]schema.ExtClient, 0) } userExtclientMap[extclient.OwnerID] = append(userExtclientMap[extclient.OwnerID], extclient) @@ -583,9 +583,9 @@ func cleanupDeletedUserGroupRefs() { } for _, acl := range logic.ListAcls() { - var newSrc []models.AclPolicyTag + var newSrc []schema.AclPolicyTag for _, src := range acl.Src { - if src.ID == models.UserGroupAclID { + if src.ID == schema.UserGroupAclID { if group, ok := existingGroups[schema.UserGroupID(src.Value)]; ok { var hasAccess bool if _, ok := group.NetworkRoles.Data()[schema.AllNetworks]; ok { @@ -600,7 +600,7 @@ func cleanupDeletedUserGroupRefs() { newSrc = append(newSrc, src) } } - } else if src.ID == models.UserAclID && src.Value != "*" { + } else if src.ID == schema.UserAclID && src.Value != "*" { if _, ok := existingUsers[src.Value]; ok { newSrc = append(newSrc, src) } diff --git a/migrate/migrate_multitenancy.go b/migrate/migrate_multitenancy.go index 4e920222d..f5fd224a4 100644 --- a/migrate/migrate_multitenancy.go +++ b/migrate/migrate_multitenancy.go @@ -8,7 +8,6 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) @@ -158,7 +157,7 @@ func backfillExtClientNetworkID(ctx context.Context) error { return fmt.Errorf("multitenancy migration: list extclients records: %w", err) } for key, value := range records { - var client models.ExtClient + var client schema.ExtClient if err := json.Unmarshal([]byte(value), &client); err != nil { return fmt.Errorf("multitenancy migration: parse extclient record %s: %w", key, err) } @@ -182,7 +181,7 @@ func backfillAclNetworkID(ctx context.Context) error { return fmt.Errorf("multitenancy migration: list acls records: %w", err) } for key, value := range records { - var acl models.Acl + var acl schema.Acl if err := json.Unmarshal([]byte(value), &acl); err != nil { return fmt.Errorf("multitenancy migration: parse acl record %s: %w", key, err) } @@ -206,7 +205,7 @@ func backfillMetricsNetworkID(ctx context.Context) error { return fmt.Errorf("multitenancy migration: list metrics records: %w", err) } for key, value := range records { - var m models.Metrics + var m schema.Metrics if err := json.Unmarshal([]byte(value), &m); err != nil { return fmt.Errorf("multitenancy migration: parse metrics record %s: %w", key, err) } @@ -230,7 +229,7 @@ func backfillTagNetworkID(ctx context.Context) error { return fmt.Errorf("multitenancy migration: list tags records: %w", err) } for key, value := range records { - var tag models.Tag + var tag schema.Tag if err := json.Unmarshal([]byte(value), &tag); err != nil { return fmt.Errorf("multitenancy migration: parse tag record %s: %w", key, err) } diff --git a/migrate/migrate_v1_6_0.go b/migrate/migrate_v1_6_0.go index 19c502fba..fb303976d 100644 --- a/migrate/migrate_v1_6_0.go +++ b/migrate/migrate_v1_6_0.go @@ -179,10 +179,10 @@ func migrateNodes(ctx context.Context) error { node.IsAutoRelay = true } if node.Tags == nil { - node.Tags = make(map[models.TagID]struct{}) + node.Tags = make(map[schema.TagID]struct{}) } - node.Tags[models.TagID(fmt.Sprintf("%s.%s", node.Network, models.GwTagName))] = struct{}{} - delete(node.Tags, models.TagID(fmt.Sprintf("%s.%s", node.Network, models.OldRemoteAccessTagName))) + node.Tags[schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.GwTagName))] = struct{}{} + delete(node.Tags, schema.TagID(fmt.Sprintf("%s.%s", node.Network, schema.OldRemoteAccessTagName))) } } @@ -364,30 +364,30 @@ func migrateNodes_Egress(ctx context.Context, node *models.Node) error { return err } - acl := models.Acl{ + acl := schema.Acl{ ID: uuid.New().String(), Name: "egress node policy", MetaData: "", Default: false, ServiceType: models.Any, NetworkID: schema.NetworkID(node.Network), - Proto: models.ALL, - RuleType: models.DevicePolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.DevicePolicy, + Src: []schema.AclPolicyTag{ { - ID: models.NodeTagID, + ID: schema.NodeTagID, Value: "*", }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.EgressID, + ID: schema.EgressID, Value: egress.ID, }, }, - AllowedDirection: models.TrafficDirectionBi, + AllowedDirection: schema.TrafficDirectionBi, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -397,30 +397,30 @@ func migrateNodes_Egress(ctx context.Context, node *models.Node) error { return err } - acl = models.Acl{ + acl = schema.Acl{ ID: uuid.New().String(), Name: "egress node policy", MetaData: "", Default: false, ServiceType: models.Any, NetworkID: schema.NetworkID(node.Network), - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserAclID, + ID: schema.UserAclID, Value: "*", }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.EgressID, + ID: schema.EgressID, Value: egress.ID, }, }, - AllowedDirection: models.TrafficDirectionBi, + AllowedDirection: schema.TrafficDirectionBi, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), diff --git a/models/acl.go b/models/acl.go index da950dcc2..4a917a291 100644 --- a/models/acl.go +++ b/models/acl.go @@ -2,31 +2,11 @@ package models import ( "net" - "time" "github.com/gravitl/netmaker/schema" ) -// AllowedTrafficDirection - allowed direction of traffic -type AllowedTrafficDirection int - -const ( - // TrafficDirectionUni implies traffic is only allowed in one direction (src --> dst) - TrafficDirectionUni AllowedTrafficDirection = iota - // TrafficDirectionBi implies traffic is allowed both direction (src <--> dst ) - TrafficDirectionBi -) - -// Protocol - allowed protocol -type Protocol string - -const ( - ALL Protocol = "all" - UDP Protocol = "udp" - TCP Protocol = "tcp" - ICMP Protocol = "icmp" -) - +// Non-KV string constants that stay in models. const ( Http = "HTTP" Https = "HTTPS" @@ -38,90 +18,33 @@ const ( Any = "Any" ) -func (p Protocol) String() string { - return string(p) -} - -type AclPolicyType string - -const ( - UserPolicy AclPolicyType = "user-policy" - DevicePolicy AclPolicyType = "device-policy" -) - -type AclPolicyTag struct { - ID AclGroupType `json:"id"` - Name string `json:"name"` - Value string `json:"value"` -} - -type AclGroupType string - -const ( - UserAclID AclGroupType = "user" - UserGroupAclID AclGroupType = "user-group" - NodeTagID AclGroupType = "tag" - NodeID AclGroupType = "device" - EgressRange AclGroupType = "egress-range" - EgressID AclGroupType = "egress-id" - NetmakerIPAclID AclGroupType = "ip" - NetmakerSubNetRangeAClID AclGroupType = "ipset" -) - -func (g AclGroupType) String() string { - return string(g) -} - type UpdateAclRequest struct { - Acl + schema.Acl NewName string `json:"new_name"` } -type AclPolicy struct { - TypeID AclPolicyType - PrefixTagUser AclGroupType -} - -type Acl struct { - ID string `json:"id"` - Default bool `json:"default"` - MetaData string `json:"meta_data"` - Name string `json:"name"` - NetworkID schema.NetworkID `json:"network_id"` - RuleType AclPolicyType `json:"policy_type"` - Src []AclPolicyTag `json:"src_type"` - Dst []AclPolicyTag `json:"dst_type"` - Proto Protocol `json:"protocol"` // tcp, udp, etc. - ServiceType string `json:"type"` - Port []string `json:"ports"` - AllowedDirection AllowedTrafficDirection `json:"allowed_traffic_direction"` - Enabled bool `json:"enabled"` - CreatedBy string `json:"created_by"` - CreatedAt time.Time `json:"created_at"` -} - type AclPolicyTypes struct { ProtocolTypes []ProtocolType - RuleTypes []AclPolicyType `json:"policy_types"` - SrcGroupTypes []AclGroupType `json:"src_grp_types"` - DstGroupTypes []AclGroupType `json:"dst_grp_types"` + RuleTypes []schema.AclPolicyType `json:"policy_types"` + SrcGroupTypes []schema.AclGroupType `json:"src_grp_types"` + DstGroupTypes []schema.AclGroupType `json:"dst_grp_types"` } type ProtocolType struct { - Name string `json:"name"` - AllowedProtocols []Protocol `json:"allowed_protocols"` - PortRange string `json:"port_range"` - AllowPortSetting bool `json:"allow_port_setting"` + Name string `json:"name"` + AllowedProtocols []schema.Protocol `json:"allowed_protocols"` + PortRange string `json:"port_range"` + AllowPortSetting bool `json:"allow_port_setting"` } type AclRule struct { - ID string `json:"id"` - IPList []net.IPNet `json:"ip_list"` - IP6List []net.IPNet `json:"ip6_list"` - AllowedProtocol Protocol `json:"allowed_protocols"` // tcp, udp, etc. - AllowedPorts []string `json:"allowed_ports"` - Direction AllowedTrafficDirection `json:"direction"` // single or two-way - Dst []net.IPNet `json:"dst"` - Dst6 []net.IPNet `json:"dst6"` + ID string `json:"id"` + IPList []net.IPNet `json:"ip_list"` + IP6List []net.IPNet `json:"ip6_list"` + AllowedProtocol schema.Protocol `json:"allowed_protocols"` // tcp, udp, etc. + AllowedPorts []string `json:"allowed_ports"` + Direction schema.AllowedTrafficDirection `json:"direction"` // single or two-way + Dst []net.IPNet `json:"dst"` + Dst6 []net.IPNet `json:"dst6"` Allowed bool } diff --git a/models/api_node.go b/models/api_node.go index 856138eb0..11322b10d 100644 --- a/models/api_node.go +++ b/models/api_node.go @@ -56,24 +56,24 @@ type ApiNode struct { PendingDelete bool `json:"pendingdelete"` Metadata string `json:"metadata"` // == PRO == - DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"` - IsFailOver bool `json:"is_fail_over"` - FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` - FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"` - IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` - InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` - InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` - AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"` - Tags map[TagID]struct{} `json:"tags" yaml:"tags"` - IsStatic bool `json:"is_static"` - IsUserNode bool `json:"is_user_node"` - StaticNode ExtClient `json:"static_node"` - Status schema.NodeStatus `json:"status"` - Location string `json:"location"` - Country string `json:"country"` - PostureChecksViolations []Violation `json:"posture_check_violations"` - PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"` - LastEvaluatedAt time.Time `json:"last_evaluated_at"` + DefaultACL string `json:"defaultacl,omitempty" validate:"checkyesornoorunset"` + IsFailOver bool `json:"is_fail_over"` + FailOverPeers map[string]struct{} `json:"fail_over_peers" yaml:"fail_over_peers"` + FailedOverBy uuid.UUID `json:"failed_over_by" yaml:"failed_over_by"` + IsInternetGateway bool `json:"isinternetgateway" yaml:"isinternetgateway"` + InetNodeReq InetNodeReq `json:"inet_node_req" yaml:"inet_node_req"` + InternetGwID string `json:"internetgw_node_id" yaml:"internetgw_node_id"` + AdditionalRagIps []string `json:"additional_rag_ips" yaml:"additional_rag_ips"` + Tags map[schema.TagID]struct{} `json:"tags" yaml:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode schema.ExtClient `json:"static_node"` + Status schema.NodeStatus `json:"status"` + Location string `json:"location"` + Country string `json:"country"` + PostureChecksViolations []schema.Violation `json:"posture_check_violations"` + PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` } // ApiNode.ConvertToServerNode - converts an api node to a server node diff --git a/models/enrollment_key.go b/models/enrollment_key.go index 1ac99e459..d885fc30f 100644 --- a/models/enrollment_key.go +++ b/models/enrollment_key.go @@ -44,34 +44,34 @@ const EnrollmentKeyLength = 32 // EnrollmentKey - the key used to register hosts and join them to specific networks type EnrollmentKey struct { - Expiration time.Time `json:"expiration"` - UsesRemaining int `json:"uses_remaining"` - Value string `json:"value"` - Networks []string `json:"networks"` - Unlimited bool `json:"unlimited"` - Tags []string `json:"tags"` - Token string `json:"token,omitempty"` // B64 value of EnrollmentToken - Type KeyType `json:"type"` - Relay uuid.UUID `json:"relay"` - Groups []TagID `json:"groups"` - Default bool `json:"default"` - AutoEgress bool `json:"auto_egress"` - AutoAssignGateway bool `json:"auto_assign_gw"` + Expiration time.Time `json:"expiration"` + UsesRemaining int `json:"uses_remaining"` + Value string `json:"value"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags"` + Token string `json:"token,omitempty"` // B64 value of EnrollmentToken + Type KeyType `json:"type"` + Relay uuid.UUID `json:"relay"` + Groups []schema.TagID `json:"groups"` + Default bool `json:"default"` + AutoEgress bool `json:"auto_egress"` + AutoAssignGateway bool `json:"auto_assign_gw"` } // APIEnrollmentKey - used to create enrollment keys via API type APIEnrollmentKey struct { - Expiration int64 `json:"expiration" swaggertype:"primitive,integer" format:"int64"` - UsesRemaining int `json:"uses_remaining"` - Networks []string `json:"networks"` - Unlimited bool `json:"unlimited"` - Tags []string `json:"tags" validate:"required,dive,min=3,max=32"` - Type KeyType `json:"type"` - Relay string `json:"relay"` - Groups []TagID `json:"groups"` - Default bool `json:"default"` - AutoEgress bool `json:"auto_egress"` - AutoAssignGateway bool `json:"auto_assign_gw"` + Expiration int64 `json:"expiration" swaggertype:"primitive,integer" format:"int64"` + UsesRemaining int `json:"uses_remaining"` + Networks []string `json:"networks"` + Unlimited bool `json:"unlimited"` + Tags []string `json:"tags" validate:"required,dive,min=3,max=32"` + Type KeyType `json:"type"` + Relay string `json:"relay"` + Groups []schema.TagID `json:"groups"` + Default bool `json:"default"` + AutoEgress bool `json:"auto_egress"` + AutoAssignGateway bool `json:"auto_assign_gw"` } // RegisterResponse - the response to a successful enrollment register diff --git a/models/extclient.go b/models/extclient.go index f5c3b4ebe..373ad3711 100644 --- a/models/extclient.go +++ b/models/extclient.go @@ -1,79 +1,38 @@ package models import ( - "sync" - "time" - "github.com/gravitl/netmaker/schema" ) -// ExtClient - struct for external clients -type ExtClient struct { - ClientID string `json:"clientid" bson:"clientid"` - PrivateKey string `json:"privatekey" bson:"privatekey"` - PublicKey string `json:"publickey" bson:"publickey"` - Network string `json:"network" bson:"network"` - DNS string `json:"dns" bson:"dns"` - Address string `json:"address" bson:"address"` - Address6 string `json:"address6" bson:"address6"` - ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"` - AllowedIPs []string `json:"allowed_ips"` - IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"` - IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` - LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"` - Enabled bool `json:"enabled" bson:"enabled"` - OwnerID string `json:"ownerid" bson:"ownerid"` - DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` - RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine - PostUp string `json:"postup" bson:"postup"` - PostDown string `json:"postdown" bson:"postdown"` - Tags map[TagID]struct{} `json:"tags"` - OS string `json:"os"` - OSFamily string `json:"os_family" yaml:"os_family"` - OSVersion string `json:"os_version" yaml:"os_version"` - KernelVersion string `json:"kernel_version" yaml:"kernel_version"` - ClientVersion string `json:"client_version"` - DeviceID string `json:"device_id"` - DeviceName string `json:"device_name"` - PublicEndpoint string `json:"public_endpoint"` - Country string `json:"country"` - Location string `json:"location"` //format: lat,long - PostureChecksViolations []Violation `json:"posture_check_violations"` - PostureCheckVolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"` - LastEvaluatedAt time.Time `json:"last_evaluated_at"` - JITExpiresAt *time.Time `json:"jit_expires_at,omitempty" bson:"jit_expires_at,omitempty"` // JIT grant expiry time (nil if JIT not enabled or user is admin) - Status schema.NodeStatus `json:"status" bson:"status"` - Mutex *sync.Mutex `json:"-"` -} - // CustomExtClient - struct for CustomExtClient params type CustomExtClient struct { - ClientID string `json:"clientid,omitempty"` - PublicKey string `json:"publickey,omitempty"` - DNS string `json:"dns,omitempty"` - ExtraAllowedIPs []string `json:"extraallowedips,omitempty"` - Enabled bool `json:"enabled,omitempty"` - DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` - RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine - PostUp string `json:"postup" bson:"postup" validate:"max=1024"` - PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"` - Tags map[TagID]struct{} `json:"tags"` - DeviceID string `json:"device_id"` - DeviceName string `json:"device_name"` - IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"` - PublicEndpoint string `json:"public_endpoint"` - OS string `json:"os"` - OSFamily string `json:"os_family" yaml:"os_family"` - OSVersion string `json:"os_version" yaml:"os_version"` - KernelVersion string `json:"kernel_version" yaml:"kernel_version"` - ClientVersion string `json:"client_version"` - Country string `json:"country"` - Location string `json:"location"` //format: lat,long + ClientID string `json:"clientid,omitempty"` + PublicKey string `json:"publickey,omitempty"` + DNS string `json:"dns,omitempty"` + ExtraAllowedIPs []string `json:"extraallowedips,omitempty"` + Enabled bool `json:"enabled,omitempty"` + DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` + RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine + PostUp string `json:"postup" bson:"postup" validate:"max=1024"` + PostDown string `json:"postdown" bson:"postdown" validate:"max=1024"` + Tags map[schema.TagID]struct{} `json:"tags"` + DeviceID string `json:"device_id"` + DeviceName string `json:"device_name"` + IsAlreadyConnectedToInetGw bool `json:"is_already_connected_to_inet_gw"` + PublicEndpoint string `json:"public_endpoint"` + OS string `json:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` + ClientVersion string `json:"client_version"` + Country string `json:"country"` + Location string `json:"location"` //format: lat,long } -func (ext *ExtClient) ConvertToStaticNode() Node { +// ConvertToStaticNode converts an ExtClient to a Node suitable for static node operations. +func ConvertToStaticNode(ext *schema.ExtClient) Node { if ext.Tags == nil { - ext.Tags = make(map[TagID]struct{}) + ext.Tags = make(map[schema.TagID]struct{}) } return Node{ CommonNode: CommonNode{ diff --git a/models/host.go b/models/host.go index 08b44327a..98f6c6161 100644 --- a/models/host.go +++ b/models/host.go @@ -157,7 +157,7 @@ type HostUpdate struct { Node Node Signal Signal EgressDomain EgressDomain - NewMetrics Metrics + NewMetrics schema.Metrics } // HostTurnRegister - struct for host turn registration diff --git a/models/metrics.go b/models/metrics.go index f57d17f6d..432e5288e 100644 --- a/models/metrics.go +++ b/models/metrics.go @@ -1,35 +1,9 @@ package models import ( - "time" - "github.com/gravitl/netmaker/schema" ) -// Metrics - metrics struct -type Metrics struct { - Network string `json:"network" bson:"network" yaml:"network"` - NodeID string `json:"node_id" bson:"node_id" yaml:"node_id"` - NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` - Connectivity map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"` - UpdatedAt time.Time `json:"updated_at" bson:"updated_at" yaml:"updated_at"` -} - -// Metric - holds a metric for data between nodes -type Metric struct { - NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` - Uptime int64 `json:"uptime" bson:"uptime" yaml:"uptime" swaggertype:"primitive,integer" format:"int64"` - TotalTime int64 `json:"totaltime" bson:"totaltime" yaml:"totaltime" swaggertype:"primitive,integer" format:"int64"` - Latency int64 `json:"latency" bson:"latency" yaml:"latency" swaggertype:"primitive,integer" format:"int64"` - TotalReceived int64 `json:"totalreceived" bson:"totalreceived" yaml:"totalreceived" swaggertype:"primitive,integer" format:"int64"` - LastTotalReceived int64 `json:"lasttotalreceived" bson:"lasttotalreceived" yaml:"lasttotalreceived" swaggertype:"primitive,integer" format:"int64"` - TotalSent int64 `json:"totalsent" bson:"totalsent" yaml:"totalsent" swaggertype:"primitive,integer" format:"int64"` - LastTotalSent int64 `json:"lasttotalsent" bson:"lasttotalsent" yaml:"lasttotalsent" swaggertype:"primitive,integer" format:"int64"` - ActualUptime time.Duration `json:"actualuptime" swaggertype:"primitive,integer" format:"int64" bson:"actualuptime" yaml:"actualuptime"` - PercentUp float64 `json:"percentup" bson:"percentup" yaml:"percentup"` - Connected bool `json:"connected" bson:"connected" yaml:"connected"` -} - // IDandAddr - struct to hold ID and primary Address type IDandAddr struct { ID string `json:"id" bson:"id" yaml:"id"` @@ -61,7 +35,7 @@ type HostNetworkInfo struct { type PeerMap map[string]IDandAddr // MetricsMap - map for holding multiple metrics in memory -type MetricsMap map[string]Metrics +type MetricsMap map[string]schema.Metrics // NetworkMetrics - metrics model for all nodes in a network type NetworkMetrics struct { diff --git a/models/mqtt.go b/models/mqtt.go index 100ee9a29..c959c97a1 100644 --- a/models/mqtt.go +++ b/models/mqtt.go @@ -76,11 +76,11 @@ type OldPeerUpdateFields struct { } type FwRule struct { - SrcIP net.IPNet `json:"src_ip"` - DstIP net.IPNet `json:"dst_ip"` - AllowedProtocol Protocol `json:"allowed_protocols"` // tcp, udp, etc. - AllowedPorts []string `json:"allowed_ports"` - Allow bool `json:"allow"` + SrcIP net.IPNet `json:"src_ip"` + DstIP net.IPNet `json:"dst_ip"` + AllowedProtocol schema.Protocol `json:"allowed_protocols"` // tcp, udp, etc. + AllowedPorts []string `json:"allowed_ports"` + Allow bool `json:"allow"` } // IngressInfo - struct for ingress info diff --git a/models/node.go b/models/node.go index 950f49ce6..a56d8813f 100644 --- a/models/node.go +++ b/models/node.go @@ -68,25 +68,25 @@ type Node struct { //AutoRelayedPeers map[string]struct{} `json:"auto_relayed_peers"` AutoRelayedPeers map[string]string `json:"auto_relayed_peers_v1"` //AutoRelayedBy uuid.UUID `json:"auto_relayed_by"` - FailOverPeers map[string]struct{} `json:"fail_over_peers"` - FailedOverBy uuid.UUID `json:"failed_over_by"` - IsInternetGateway bool `json:"isinternetgateway"` - InetNodeReq InetNodeReq `json:"inet_node_req"` - InternetGwID string `json:"internetgw_node_id"` - AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"` - Tags map[TagID]struct{} `json:"tags"` - IsStatic bool `json:"is_static"` - IsUserNode bool `json:"is_user_node"` - StaticNode ExtClient `json:"static_node"` - Status schema.NodeStatus `json:"node_status"` - Mutex *sync.Mutex `json:"-"` - EgressDetails EgressDetails `json:"-"` - PostureChecksViolations []Violation `json:"posture_check_violations"` - PostureCheckViolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"` - LastEvaluationCycleID string `json:"last_evaluation_cycle_id"` - LastEvaluatedAt time.Time `json:"last_evaluated_at"` - Location string `json:"location"` // Format: "lat,lon" - CountryCode string `json:"country_code"` + FailOverPeers map[string]struct{} `json:"fail_over_peers"` + FailedOverBy uuid.UUID `json:"failed_over_by"` + IsInternetGateway bool `json:"isinternetgateway"` + InetNodeReq InetNodeReq `json:"inet_node_req"` + InternetGwID string `json:"internetgw_node_id"` + AdditionalRagIps []net.IP `json:"additional_rag_ips" swaggertype:"array,number"` + Tags map[schema.TagID]struct{} `json:"tags"` + IsStatic bool `json:"is_static"` + IsUserNode bool `json:"is_user_node"` + StaticNode schema.ExtClient `json:"static_node"` + Status schema.NodeStatus `json:"node_status"` + Mutex *sync.Mutex `json:"-"` + EgressDetails EgressDetails `json:"-"` + PostureChecksViolations []schema.Violation `json:"posture_check_violations"` + PostureCheckViolationSeverityLevel schema.Severity `json:"posture_check_violation_severity_level"` + LastEvaluationCycleID string `json:"last_evaluation_cycle_id"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` + Location string `json:"location"` // Format: "lat,lon" + CountryCode string `json:"country_code"` } type EgressDetails struct { EgressGatewayNatEnabled bool @@ -147,22 +147,6 @@ func (node *Node) AddressIPNet6() net.IPNet { } } -// ExtClient.PrimaryAddress - returns ipv4 IPNet format -func (extPeer *ExtClient) AddressIPNet4() net.IPNet { - return net.IPNet{ - IP: net.ParseIP(extPeer.Address), - Mask: net.CIDRMask(32, 32), - } -} - -// ExtClient.AddressIPNet6 - return ipv6 IPNet format -func (extPeer *ExtClient) AddressIPNet6() net.IPNet { - return net.IPNet{ - IP: net.ParseIP(extPeer.Address6), - Mask: net.CIDRMask(128, 128), - } -} - // Node.PrimaryNetworkRange - returns node's parent network, returns ipv4 address if present, else return ipv6 func (node *Node) PrimaryNetworkRange() net.IPNet { if node.NetworkRange.IP != nil { @@ -259,7 +243,7 @@ func (newNode *Node) Fill( newNode.FailOverPeers = currentNode.FailOverPeers if newNode.Tags == nil { if currentNode.Tags == nil { - currentNode.Tags = make(map[TagID]struct{}) + currentNode.Tags = make(map[schema.TagID]struct{}) } newNode.Tags = currentNode.Tags } diff --git a/models/settings.go b/models/settings.go index e60ccf463..06095b7ed 100644 --- a/models/settings.go +++ b/models/settings.go @@ -2,54 +2,6 @@ package models import "github.com/gravitl/netmaker/schema" -type ServerSettings struct { - NetclientAutoUpdate bool `json:"netclientautoupdate"` - Verbosity int32 `json:"verbosity"` - AuthProvider string `json:"authprovider"` - OIDCIssuer string `json:"oidcissuer"` - ClientID string `json:"client_id"` - ClientSecret string `json:"client_secret"` - SyncEnabled bool `json:"sync_enabled"` - GoogleAdminEmail string `json:"google_admin_email"` - GoogleSACredsJson string `json:"google_sa_creds_json"` - AzureTenant string `json:"azure_tenant"` - OktaOrgURL string `json:"okta_org_url"` - OktaAPIToken string `json:"okta_api_token"` - UserFilters []string `json:"user_filters"` - GroupFilters []string `json:"group_filters"` - IDPSyncInterval string `json:"idp_sync_interval"` - Telemetry string `json:"telemetry"` - BasicAuth bool `json:"basic_auth"` - // JwtValidityDuration is the validity duration of auth tokens for users - // on the dashboard (NMUI). - JwtValidityDuration int `json:"jwt_validity_duration"` - // JwtValidityDurationClients is the validity duration of auth tokens for - // users on the clients (NetDesk). - JwtValidityDurationClients int `json:"jwt_validity_duration_clients"` - MFAEnforced bool `json:"mfa_enforced"` - RacRestrictToSingleNetwork bool `json:"rac_restrict_to_single_network"` - EndpointDetection bool `json:"endpoint_detection"` - AllowedEmailDomains string `json:"allowed_email_domains"` - EmailSenderAddr string `json:"email_sender_addr"` - EmailSenderUser string `json:"email_sender_user"` - EmailSenderPassword string `json:"email_sender_password"` - SmtpHost string `json:"smtp_host"` - SmtpPort int `json:"smtp_port"` - MetricInterval string `json:"metric_interval"` - MetricsPort int `json:"metrics_port"` - // IPDetectionInterval is the interval (in seconds) at which devices check for changes in public ip. - IPDetectionInterval int `json:"ip_detection_interval"` - ManageDNS bool `json:"manage_dns"` - DefaultDomain string `json:"default_domain"` - Stun bool `json:"stun"` - StunServers string `json:"stun_servers"` - AuditLogsRetentionPeriodInDays int `json:"audit_logs_retention_period"` - PeerConnectionCheckInterval string `json:"peer_connection_check_interval"` - PostureCheckInterval string `json:"posture_check_interval"` // in minutes - CleanUpInterval int `json:"clean_up_interval_in_mins"` - EnableFlowLogs bool `json:"enable_flow_logs"` -} - type UserSettings struct { Theme schema.Theme `json:"theme"` TextSize string `json:"text_size"` diff --git a/models/structs.go b/models/structs.go index 7026a2156..e2ee1ebc5 100644 --- a/models/structs.go +++ b/models/structs.go @@ -46,7 +46,7 @@ type UserRemoteGws struct { Network string `json:"network"` Connected bool `json:"connected"` IsInternetGateway bool `json:"is_internet_gateway"` - GwClient ExtClient `json:"gw_client"` + GwClient schema.ExtClient `json:"gw_client"` GwPeerPublicKey string `json:"gw_peer_public_key"` GwListenPort int `json:"gw_listen_port"` Metadata string `json:"metadata"` @@ -465,19 +465,11 @@ type PostureCheckDeviceInfo struct { KernelVersion string AutoUpdate bool SkipAutoUpdate bool - Tags map[TagID]struct{} + Tags map[schema.TagID]struct{} IsUser bool UserGroups map[schema.UserGroupID]struct{} } -type Violation struct { - CheckID string `json:"check_id"` - Name string `json:"name"` - Attribute string `json:"attribute"` - Message string `json:"message"` - Severity schema.Severity `json:"severity"` -} - type BulkDeleteRequest struct { IDs []string `json:"ids"` } diff --git a/models/tags.go b/models/tags.go index 99c96446e..9808a8f14 100644 --- a/models/tags.go +++ b/models/tags.go @@ -1,36 +1,9 @@ package models import ( - "fmt" - "time" - "github.com/gravitl/netmaker/schema" ) -type TagID string - -const ( - OldRemoteAccessTagName = "remote-access-gws" - GwTagName = "gateways" -) - -func (id TagID) String() string { - return string(id) -} - -func (t Tag) GetIDFromName() string { - return fmt.Sprintf("%s.%s", t.Network, t.TagName) -} - -type Tag struct { - ID TagID `json:"id"` - TagName string `json:"tag_name"` - Network schema.NetworkID `json:"network"` - ColorCode string `json:"color_code"` - CreatedBy string `json:"created_by"` - CreatedAt time.Time `json:"created_at"` -} - type CreateTagReq struct { TagName string `json:"tag_name"` Network schema.NetworkID `json:"network"` @@ -39,19 +12,19 @@ type CreateTagReq struct { } type TagListResp struct { - Tag + schema.Tag UsedByCnt int `json:"used_by_count"` TaggedNodes []ApiNode `json:"tagged_nodes"` } type TagListRespNodes struct { - Tag + schema.Tag UsedByCnt int `json:"used_by_count"` TaggedNodes []ApiNode `json:"tagged_nodes"` } type UpdateTagReq struct { - Tag + schema.Tag NewName string `json:"new_name"` ColorCode string `json:"color_code"` TaggedNodes []ApiNode `json:"tagged_nodes"` diff --git a/mq/handlers.go b/mq/handlers.go index 5e2b562e5..5d048f6ef 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -24,7 +24,7 @@ import ( var UpdateMetrics = func(client mqtt.Client, msg mqtt.Message) { } -var UpdateMetricsFallBack = func(nodeid string, newMetrics models.Metrics) {} +var UpdateMetricsFallBack = func(nodeid string, newMetrics schema.Metrics) {} // DefaultHandler default message queue handler -- NOT USED func DefaultHandler(client mqtt.Client, msg mqtt.Message) { diff --git a/mq/publishers.go b/mq/publishers.go index 4b31d3e3d..53c0bbeb4 100644 --- a/mq/publishers.go +++ b/mq/publishers.go @@ -13,7 +13,6 @@ import ( "time" "github.com/google/uuid" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" @@ -188,7 +187,7 @@ func PublishDeletedNodePeerUpdate(delHost *schema.Host, delNode *models.Node) er // PublishDeletedClientPeerUpdate --- determines and publishes a peer update // to all the hosts with a deleted ext client to account for -func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error { +func PublishDeletedClientPeerUpdate(delClient *schema.ExtClient) error { if !servercfg.IsMessageQueueBackend() { return nil } @@ -205,7 +204,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error { for _, host := range hosts { host := host if host.OS != models.OS_Types.IoT { - if err = PublishSingleHostPeerUpdate(&host, nodes, nil, nil, []models.ExtClient{*delClient}, false, nil); err != nil { + if err = PublishSingleHostPeerUpdate(&host, nodes, nil, nil, []schema.ExtClient{*delClient}, false, nil); err != nil { logger.Log(1, "failed to publish peer update to host", host.ID.String(), ": ", err.Error()) } } @@ -214,7 +213,7 @@ func PublishDeletedClientPeerUpdate(delClient *models.ExtClient) error { } // PublishSingleHostPeerUpdate --- determines and publishes a peer update to one host -func PublishSingleHostPeerUpdate(host *schema.Host, allNodes []models.Node, deletedHost *schema.Host, deletedNode *models.Node, deletedClients []models.ExtClient, replacePeers bool, wg *sync.WaitGroup) error { +func PublishSingleHostPeerUpdate(host *schema.Host, allNodes []models.Node, deletedHost *schema.Host, deletedNode *models.Node, deletedClients []schema.ExtClient, replacePeers bool, wg *sync.WaitGroup) error { if wg != nil { defer wg.Done() } @@ -342,18 +341,14 @@ func PushAllMetricsToExporter() { slog.Warn("metrics export: exporter unhealthy, skipping", "status", healthResp.StatusCode) return } - records, err := database.FetchRecords(database.METRICS_TABLE_NAME) + entries, err := (&schema.MetricsEntry{}).ListAll(db.WithContext(context.TODO())) if err != nil { slog.Error("metrics export: failed to fetch records", "error", err) return } - batch := make([]models.Metrics, 0, len(records)) - for _, data := range records { - var m models.Metrics - if err := json.Unmarshal([]byte(data), &m); err != nil { - continue - } - batch = append(batch, m) + batch := make([]schema.Metrics, 0, len(entries)) + for _, entry := range entries { + batch = append(batch, entry.Value.Data()) } if len(batch) == 0 { return diff --git a/orchestrator/extensions/node.go b/orchestrator/extensions/node.go index 951a0bf08..6cb8850ab 100644 --- a/orchestrator/extensions/node.go +++ b/orchestrator/extensions/node.go @@ -1,14 +1,13 @@ package extensions import ( - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) type NodeExtensions interface { ConfigureAutoRelay(node *schema.Node) ConfigureAutoAssignGateway(node *schema.Node, key *schema.EnrollmentKey) - ConfigureTag(node *schema.Node, tagID models.TagID) + ConfigureTag(node *schema.Node, tagID schema.TagID) } type CENodeExtensions struct{} @@ -21,4 +20,4 @@ func (c *CENodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, _ *sche node.AutoAssignGateway = false } -func (c *CENodeExtensions) ConfigureTag(_ *schema.Node, _ models.TagID) {} +func (c *CENodeExtensions) ConfigureTag(_ *schema.Node, _ schema.TagID) {} diff --git a/orchestrator/node.go b/orchestrator/node.go index bdb814100..b9138792a 100644 --- a/orchestrator/node.go +++ b/orchestrator/node.go @@ -48,7 +48,7 @@ func (n *NodeOrchestrator) CreateNode(ctx context.Context, host *schema.Host, ne n.nodeExt.ConfigureAutoAssignGateway(node, ops.key) for _, tag := range ops.key.Tags { - n.nodeExt.ConfigureTag(node, models.TagID(tag)) + n.nodeExt.ConfigureTag(node, schema.TagID(tag)) } } @@ -111,7 +111,7 @@ func (n *NodeOrchestrator) CreateNode(ctx context.Context, host *schema.Host, ne go logic.CheckZombies(node) go func() { - err := logic.UpdateMetrics(node.ID, &models.Metrics{Connectivity: make(map[string]models.Metric)}) + err := logic.UpdateMetrics(node.ID, &schema.Metrics{Connectivity: make(map[string]schema.Metric)}) if err != nil { logger.Log(1, fmt.Sprintf("failed to initialize metrics for node (%s): %v", node.ID, err)) } @@ -224,7 +224,7 @@ func (n *NodeOrchestrator) CreateGateway(ctx context.Context, node *schema.Node, n.nodeExt.ConfigureAutoRelay(node) - node.Tags[fmt.Sprintf("%s.%s", node.Network.Name, models.GwTagName)] = struct{}{} + node.Tags[fmt.Sprintf("%s.%s", node.Network.Name, schema.GwTagName)] = struct{}{} err := node.Update(ctx) if err != nil { diff --git a/pro/controllers/auto_relay.go b/pro/controllers/auto_relay.go index c38ba5fc5..62485ec82 100644 --- a/pro/controllers/auto_relay.go +++ b/pro/controllers/auto_relay.go @@ -66,7 +66,7 @@ func getAutoRelayGws(w http.ResponseWriter, r *http.Request) { ) return } - defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return diff --git a/pro/controllers/metrics.go b/pro/controllers/metrics.go index 8fbff781e..087bac2a8 100644 --- a/pro/controllers/metrics.go +++ b/pro/controllers/metrics.go @@ -13,6 +13,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/schema" ) // MetricHandlers - How we handle Pro Metrics @@ -30,7 +31,7 @@ func MetricHandlers(r *mux.Router) { // @Produce json // @Param network path string true "Network ID" // @Param nodeid path string true "Node ID" -// @Success 200 {object} models.Metrics +// @Success 200 {object} schema.Metrics // @Failure 500 {object} models.ErrorResponse func getNodeMetrics(w http.ResponseWriter, r *http.Request) { // set header. @@ -99,7 +100,7 @@ func getNetworkNodesMetrics(w http.ResponseWriter, r *http.Request) { // @Security oauth // @Produce json // @Param network path string true "Network ID" -// @Success 200 {object} models.Metrics +// @Success 200 {object} schema.Metrics // @Failure 500 {object} models.ErrorResponse func getNetworkExtMetrics(w http.ResponseWriter, r *http.Request) { // set header. @@ -129,8 +130,8 @@ func getNetworkExtMetrics(w http.ResponseWriter, r *http.Request) { return } - networkMetrics := models.Metrics{} - networkMetrics.Connectivity = make(map[string]models.Metric) + networkMetrics := schema.Metrics{} + networkMetrics.Connectivity = make(map[string]schema.Metric) for i := range ingresses { id := ingresses[i].ID diff --git a/pro/controllers/posture_check.go b/pro/controllers/posture_check.go index e40c57c82..2f505f821 100644 --- a/pro/controllers/posture_check.go +++ b/pro/controllers/posture_check.go @@ -340,7 +340,7 @@ func listPostureCheckViolatedNodes(w http.ResponseWriter, r *http.Request) { for _, extclient := range extclients { if extclient.DeviceID != "" { if len(extclient.PostureChecksViolations) > 0 { - violatedNodes = append(violatedNodes, extclient.ConvertToStaticNode()) + violatedNodes = append(violatedNodes, models.ConvertToStaticNode(&extclient)) } } } diff --git a/pro/controllers/tags.go b/pro/controllers/tags.go index e4f4e4fd4..9d0f53f05 100644 --- a/pro/controllers/tags.go +++ b/pro/controllers/tags.go @@ -98,8 +98,8 @@ func createTag(w http.ResponseWriter, r *http.Request) { return } // check if tag exists - tag := models.Tag{ - ID: models.TagID(fmt.Sprintf("%s.%s", req.Network, req.TagName)), + tag := schema.Tag{ + ID: schema.TagID(fmt.Sprintf("%s.%s", req.Network, req.TagName)), TagName: req.TagName, Network: req.Network, CreatedBy: user.Username, @@ -129,7 +129,7 @@ func createTag(w http.ResponseWriter, r *http.Request) { extclient, err := logic.GetExtClient(node.StaticNode.ClientID, node.StaticNode.Network) if err == nil && extclient.RemoteAccessClientID == "" { if extclient.Tags == nil { - extclient.Tags = make(map[models.TagID]struct{}) + extclient.Tags = make(map[schema.TagID]struct{}) } extclient.Tags[tag.ID] = struct{}{} logic.SaveExtClient(&extclient) @@ -226,7 +226,7 @@ func updateTag(w http.ResponseWriter, r *http.Request) { Origin: schema.Dashboard, } updateTag.NewName = strings.TrimSpace(updateTag.NewName) - var newID models.TagID + var newID schema.TagID if updateTag.NewName != "" { // validate name err = proLogic.CheckIDSyntax(updateTag.NewName) @@ -234,7 +234,7 @@ func updateTag(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return } - newID = models.TagID(fmt.Sprintf("%s.%s", tag.Network, updateTag.NewName)) + newID = schema.TagID(fmt.Sprintf("%s.%s", tag.Network, updateTag.NewName)) tag.ID = newID tag.TagName = updateTag.NewName err = proLogic.InsertTag(tag) @@ -285,7 +285,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("role is required"), "badrequest")) return } - tag, err := proLogic.GetTag(models.TagID(tagID)) + tag, err := proLogic.GetTag(schema.TagID(tagID)) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest")) return @@ -295,7 +295,7 @@ func deleteTag(w http.ResponseWriter, r *http.Request) { logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("tag is currently in use by an active policy"), "badrequest")) return } - err = proLogic.DeleteTag(models.TagID(tagID), true) + err = proLogic.DeleteTag(schema.TagID(tagID), true) if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) return diff --git a/pro/controllers/users.go b/pro/controllers/users.go index 3215f2ee2..e5bd954b6 100644 --- a/pro/controllers/users.go +++ b/pro/controllers/users.go @@ -1450,7 +1450,7 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) { if err != nil { slog.Error("failed to get node network", "error", err) } - var userConf models.ExtClient + var userConf schema.ExtClient allextClients, err := logic.GetAllExtClients() if err != nil { logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal")) @@ -1480,8 +1480,8 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) { userConf.IngressGatewayEndpoint = fmt.Sprintf("%s:%d", host.EndpointIP.String(), listenPort) } userConf.Enabled = true - userConf.Tags = make(map[models.TagID]struct{}) - // userConf.Tags[models.TagID(fmt.Sprintf("%s.%s", userConf.Network, + userConf.Tags = make(map[schema.TagID]struct{}) + // userConf.Tags[schema.TagID(fmt.Sprintf("%s.%s", userConf.Network, // models.RemoteAccessTagName))] = struct{}{} if len(userConf.PublicKey) == 0 { privateKey, err := wgtypes.GeneratePrivateKey() @@ -1671,7 +1671,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { } userGwNodes := proLogic.GetUserRAGNodes(user) - userExtClients := make(map[string][]models.ExtClient) + userExtClients := make(map[string][]schema.ExtClient) // group all extclients of the requesting user by ingress // gateway. @@ -1687,7 +1687,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { _, ok := userExtClients[extClient.IngressGatewayID] if !ok { - userExtClients[extClient.IngressGatewayID] = []models.ExtClient{} + userExtClients[extClient.IngressGatewayID] = []schema.ExtClient{} } userExtClients[extClient.IngressGatewayID] = append(userExtClients[extClient.IngressGatewayID], extClient) @@ -1701,7 +1701,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) { continue } - var gwClient models.ExtClient + var gwClient schema.ExtClient var found bool if deviceID != "" { for _, extClient := range extClients { diff --git a/pro/logic/acls.go b/pro/logic/acls.go index ba747ec52..98f3cf68a 100644 --- a/pro/logic/acls.go +++ b/pro/logic/acls.go @@ -20,7 +20,7 @@ func getStaticUserNodesByNetwork(network schema.NetworkID) (staticNode []models. for _, extI := range extClients { if extI.Network == network.String() { if extI.RemoteAccessClientID != "" { - n := extI.ConvertToStaticNode() + n := models.ConvertToStaticNode(&extI) staticNode = append(staticNode, n) } } @@ -29,7 +29,7 @@ func getStaticUserNodesByNetwork(network schema.NetworkID) (staticNode []models. } func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules []models.FwRule) { - defaultUserPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(node.Network), models.UserPolicy) + defaultUserPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(node.Network), schema.UserPolicy) userNodes := getStaticUserNodesByNetwork(schema.NetworkID(node.Network)) for _, userNodeI := range userNodes { if !userNodeI.StaticNode.Enabled { @@ -40,7 +40,7 @@ func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] rules = append(rules, models.FwRule{ SrcIP: userNodeI.StaticNode.AddressIPNet4(), DstIP: net.IPNet{}, - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, }) @@ -49,7 +49,7 @@ func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] rules = append(rules, models.FwRule{ SrcIP: userNodeI.StaticNode.AddressIPNet6(), DstIP: net.IPNet{}, - AllowedProtocol: models.ALL, + AllowedProtocol: schema.ALL, AllowedPorts: []string{}, Allow: true, }) @@ -63,7 +63,7 @@ func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] if ok, allowedPolicies := IsUserAllowedToCommunicate(userNodeI.StaticNode.OwnerID, peer); ok { if peer.IsStatic { - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } for _, policy := range allowedPolicies { selectedIP4, selectedIP6 := getSelectedUserEgressIPNets(policy.Dst) @@ -104,7 +104,7 @@ func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] }) break } - if dstI.ID == models.EgressID { + if dstI.ID == schema.EgressID { e := schema.Egress{ID: dstI.Value} err := e.Get(db.WithContext(context.TODO())) @@ -206,7 +206,7 @@ func GetFwRulesForUserNodesOnGw(node models.Node, nodes []models.Node) (rules [] return } -func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []models.Acl) (rules []models.FwRule) { +func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []schema.Acl) (rules []models.FwRule) { for _, policy := range allowedPolicies { selectedIP4, selectedIP6 := getSelectedUserEgressIPNets(policy.Dst) @@ -242,7 +242,7 @@ func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode Allow: true, }) } - if policy.AllowedDirection == models.TrafficDirectionBi { + if policy.AllowedDirection == schema.TrafficDirectionBi { if node.Address.IP != nil { rules = append(rules, models.FwRule{ SrcIP: net.IPNet{ @@ -336,7 +336,7 @@ func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode // add egress range rules for _, dstI := range policy.Dst { - if dstI.ID == models.EgressID { + if dstI.ID == schema.EgressID { e := schema.Egress{ID: dstI.Value} err := e.Get(db.WithContext(context.TODO())) @@ -493,19 +493,19 @@ func GetFwRulesForNodeAndPeerOnGw(node, peer models.Node, allowedPolicies []mode return } -func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err error) { +func checkIfAclTagisValid(a schema.Acl, t schema.AclPolicyTag, isSrc bool) (err error) { switch t.ID { - case models.NodeTagID: - if a.RuleType == models.UserPolicy && isSrc { + case schema.NodeTagID: + if a.RuleType == schema.UserPolicy && isSrc { return errors.New("user policy source mismatch") } // check if tag is valid - _, err := GetTag(models.TagID(t.Value)) + _, err := GetTag(schema.TagID(t.Value)) if err != nil { return errors.New("invalid tag " + t.Value) } - case models.NodeID: - if a.RuleType == models.UserPolicy && isSrc { + case schema.NodeID: + if a.RuleType == schema.UserPolicy && isSrc { return errors.New("user policy source mismatch") } _, nodeErr := logic.GetNodeByID(t.Value) @@ -515,7 +515,7 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err return errors.New("invalid node " + t.Value) } } - case models.EgressID, models.EgressRange: + case schema.EgressID, schema.EgressRange: e := schema.Egress{ ID: t.Value, } @@ -523,14 +523,14 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err if err != nil { return errors.New("invalid egress") } - case models.NetmakerIPAclID: + case schema.NetmakerIPAclID: _, err := logic.NormalizeIPOrCIDR(t.Value) if err != nil { return err } - case models.UserAclID: - if a.RuleType == models.DevicePolicy { + case schema.UserAclID: + if a.RuleType == schema.DevicePolicy { return errors.New("device policy source mismatch") } if !isSrc { @@ -541,8 +541,8 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err if err != nil { return errors.New("invalid user " + t.Value) } - case models.UserGroupAclID: - if a.RuleType == models.DevicePolicy { + case schema.UserGroupAclID: + if a.RuleType == schema.DevicePolicy { return errors.New("device policy source mismatch") } if !isSrc { @@ -564,14 +564,14 @@ func checkIfAclTagisValid(a models.Acl, t models.AclPolicyTag, isSrc bool) (err } // IsAclPolicyValid - validates if acl policy is valid -func IsAclPolicyValid(acl models.Acl) (err error) { +func IsAclPolicyValid(acl schema.Acl) (err error) { //check if src and dst are valid - if acl.AllowedDirection != models.TrafficDirectionBi && - acl.AllowedDirection != models.TrafficDirectionUni { + if acl.AllowedDirection != schema.TrafficDirectionBi && + acl.AllowedDirection != schema.TrafficDirectionUni { return errors.New("invalid traffic direction") } switch acl.RuleType { - case models.UserPolicy: + case schema.UserPolicy: // src list should only contain users for _, srcI := range acl.Src { @@ -594,7 +594,7 @@ func IsAclPolicyValid(acl models.Acl) (err error) { return } } - case models.DevicePolicy: + case schema.DevicePolicy: for _, srcI := range acl.Src { if srcI.Value == "*" { continue @@ -622,9 +622,9 @@ func IsAclPolicyValid(acl models.Acl) (err error) { } // listPoliciesOfUser - lists all user acl policies applied to user in an network -func listPoliciesOfUser(user *schema.User, netID schema.NetworkID) []models.Acl { +func listPoliciesOfUser(user *schema.User, netID schema.NetworkID) []schema.Acl { allAcls := logic.ListAcls() - var userAcls []models.Acl + var userAcls []schema.Acl if _, ok := user.UserGroups.Data()[globalNetworksAdminGroupID]; ok { user.UserGroups.Data()[GetDefaultNetworkAdminGroupID(netID)] = struct{}{} } @@ -632,7 +632,7 @@ func listPoliciesOfUser(user *schema.User, netID schema.NetworkID) []models.Acl user.UserGroups.Data()[GetDefaultNetworkUserGroupID(netID)] = struct{}{} } for _, acl := range allAcls { - if acl.NetworkID == netID && acl.RuleType == models.UserPolicy { + if acl.NetworkID == netID && acl.RuleType == schema.UserPolicy { srcMap := logic.ConvAclTagToValueMap(acl.Src) if _, ok := srcMap[user.Username]; ok { userAcls = append(userAcls, acl) @@ -651,20 +651,20 @@ func listPoliciesOfUser(user *schema.User, netID schema.NetworkID) []models.Acl } // listUserPolicies - lists all user policies in a network -func listUserPolicies(netID schema.NetworkID) []models.Acl { +func listUserPolicies(netID schema.NetworkID) []schema.Acl { allAcls := logic.ListAcls() - deviceAcls := []models.Acl{} + deviceAcls := []schema.Acl{} for _, acl := range allAcls { - if acl.NetworkID == netID && acl.RuleType == models.UserPolicy { + if acl.NetworkID == netID && acl.RuleType == schema.UserPolicy { deviceAcls = append(deviceAcls, acl) } } return deviceAcls } -func getSelectedUserEgressIPNets(dstTags []models.AclPolicyTag) (dst4, dst6 []net.IPNet) { +func getSelectedUserEgressIPNets(dstTags []schema.AclPolicyTag) (dst4, dst6 []net.IPNet) { for _, dst := range dstTags { - if dst.ID != models.NetmakerIPAclID { + if dst.ID != schema.NetmakerIPAclID { continue } normalized, err := logic.NormalizeIPOrCIDR(dst.Value) @@ -685,16 +685,16 @@ func getSelectedUserEgressIPNets(dstTags []models.AclPolicyTag) (dst4, dst6 []ne } // IsUserAllowedToCommunicate - check if user is allowed to communicate with peer -func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []models.Acl) { +func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []schema.Acl) { var peerId string if peer.IsStatic { peerId = peer.StaticNode.ClientID - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } else { peerId = peer.ID.String() } - var peerTags map[models.TagID]struct{} + var peerTags map[schema.TagID]struct{} if peer.Mutex != nil { peer.Mutex.Lock() peerTags = maps.Clone(peer.Tags) @@ -703,20 +703,20 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode peerTags = peer.Tags } if peerTags == nil { - peerTags = make(map[models.TagID]struct{}) + peerTags = make(map[schema.TagID]struct{}) } - peerTags[models.TagID(peerId)] = struct{}{} - peerTags[models.TagID("*")] = struct{}{} - acl, _ := logic.GetDefaultPolicy(schema.NetworkID(peer.Network), models.UserPolicy) + peerTags[schema.TagID(peerId)] = struct{}{} + peerTags[schema.TagID("*")] = struct{}{} + acl, _ := logic.GetDefaultPolicy(schema.NetworkID(peer.Network), schema.UserPolicy) if acl.Enabled { - return true, []models.Acl{acl} + return true, []schema.Acl{acl} } user := &schema.User{Username: userName} err := user.Get(db.WithContext(context.TODO())) if err != nil { - return false, []models.Acl{} + return false, []schema.Acl{} } - allowedPolicies := []models.Acl{} + allowedPolicies := []schema.Acl{} policies := listPoliciesOfUser(user, schema.NetworkID(peer.Network)) for _, policy := range policies { if !policy.Enabled { @@ -724,7 +724,7 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode } dstMap := logic.ConvAclTagToValueMap(policy.Dst) for _, dst := range policy.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status { @@ -753,7 +753,7 @@ func IsUserAllowedToCommunicate(userName string, peer models.Node) (bool, []mode if len(allowedPolicies) > 0 { return true, allowedPolicies } - return false, []models.Acl{} + return false, []schema.Acl{} } // IsPeerAllowed - checks if peer needs to be added to the interface @@ -773,18 +773,18 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool { // } if node.IsStatic { nodeId = node.StaticNode.ClientID - node = node.StaticNode.ConvertToStaticNode() + node = models.ConvertToStaticNode(&node.StaticNode) } else { nodeId = node.ID.String() } if peer.IsStatic { peerId = peer.StaticNode.ClientID - peer = peer.StaticNode.ConvertToStaticNode() + peer = models.ConvertToStaticNode(&peer.StaticNode) } else { peerId = peer.ID.String() } - var nodeTags, peerTags map[models.TagID]struct{} + var nodeTags, peerTags map[schema.TagID]struct{} if node.Mutex != nil { node.Mutex.Lock() nodeTags = maps.Clone(node.Tags) @@ -800,16 +800,16 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool { peerTags = peer.Tags } if nodeTags == nil { - nodeTags = make(map[models.TagID]struct{}) + nodeTags = make(map[schema.TagID]struct{}) } if peerTags == nil { - peerTags = make(map[models.TagID]struct{}) + peerTags = make(map[schema.TagID]struct{}) } - nodeTags[models.TagID(nodeId)] = struct{}{} - peerTags[models.TagID(peerId)] = struct{}{} + nodeTags[schema.TagID(nodeId)] = struct{}{} + peerTags[schema.TagID(peerId)] = struct{}{} if checkDefaultPolicy { // check default policy if all allowed return true - defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), models.DevicePolicy) + defaultPolicy, err := logic.GetDefaultPolicy(schema.NetworkID(node.Network), schema.DevicePolicy) if err == nil { if defaultPolicy.Enabled { return true @@ -836,7 +836,7 @@ func IsPeerAllowed(node, peer models.Node, checkDefaultPolicy bool) bool { srcMap = logic.ConvAclTagToValueMap(policy.Src) dstMap = logic.ConvAclTagToValueMap(policy.Dst) for _, dst := range policy.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status { @@ -859,9 +859,9 @@ func RemoveUserFromAclPolicy(userName string) { for _, acl := range acls { delete := false update := false - if acl.RuleType == models.UserPolicy { + if acl.RuleType == schema.UserPolicy { for i := len(acl.Src) - 1; i >= 0; i-- { - if acl.Src[i].ID == models.UserAclID && acl.Src[i].Value == userName { + if acl.Src[i].ID == schema.UserAclID && acl.Src[i].Value == userName { if len(acl.Src) == 1 { // delete policy delete = true @@ -884,12 +884,12 @@ func RemoveUserFromAclPolicy(userName string) { } // UpdateDeviceTag - updates device tag on acl policies -func UpdateDeviceTag(OldID, newID models.TagID, netID schema.NetworkID) { +func UpdateDeviceTag(OldID, newID schema.TagID, netID schema.NetworkID) { acls := logic.ListDevicePolicies(netID) update := false for _, acl := range acls { for i, srcTagI := range acl.Src { - if srcTagI.ID == models.NodeTagID { + if srcTagI.ID == schema.NodeTagID { if OldID.String() == srcTagI.Value { acl.Src[i].Value = newID.String() update = true @@ -897,7 +897,7 @@ func UpdateDeviceTag(OldID, newID models.TagID, netID schema.NetworkID) { } } for i, dstTagI := range acl.Dst { - if dstTagI.ID == models.NodeTagID { + if dstTagI.ID == schema.NodeTagID { if OldID.String() == dstTagI.Value { acl.Dst[i].Value = newID.String() update = true @@ -910,18 +910,18 @@ func UpdateDeviceTag(OldID, newID models.TagID, netID schema.NetworkID) { } } -func CheckIfTagAsActivePolicy(tagID models.TagID, netID schema.NetworkID) bool { +func CheckIfTagAsActivePolicy(tagID schema.TagID, netID schema.NetworkID) bool { acls := logic.ListDevicePolicies(netID) for _, acl := range acls { for _, srcTagI := range acl.Src { - if srcTagI.ID == models.NodeTagID { + if srcTagI.ID == schema.NodeTagID { if tagID.String() == srcTagI.Value { return true } } } for _, dstTagI := range acl.Dst { - if dstTagI.ID == models.NodeTagID { + if dstTagI.ID == schema.NodeTagID { if tagID.String() == dstTagI.Value { return true } @@ -932,12 +932,12 @@ func CheckIfTagAsActivePolicy(tagID models.TagID, netID schema.NetworkID) bool { } // RemoveDeviceTagFromAclPolicies - remove device tag from acl policies -func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID schema.NetworkID) error { +func RemoveDeviceTagFromAclPolicies(tagID schema.TagID, netID schema.NetworkID) error { acls := logic.ListDevicePolicies(netID) update := false for _, acl := range acls { for i := len(acl.Src) - 1; i >= 0; i-- { - if acl.Src[i].ID == models.NodeTagID { + if acl.Src[i].ID == schema.NodeTagID { if tagID.String() == acl.Src[i].Value { acl.Src = append(acl.Src[:i], acl.Src[i+1:]...) update = true @@ -945,7 +945,7 @@ func RemoveDeviceTagFromAclPolicies(tagID models.TagID, netID schema.NetworkID) } } for i := len(acl.Dst) - 1; i >= 0; i-- { - if acl.Dst[i].ID == models.NodeTagID { + if acl.Dst[i].ID == schema.NodeTagID { if tagID.String() == acl.Dst[i].Value { acl.Dst = append(acl.Dst[:i], acl.Dst[i+1:]...) update = true @@ -963,15 +963,15 @@ func GetEgressUserRulesForNode(targetnode *models.Node, rules map[string]models.AclRule) map[string]models.AclRule { userNodes := getStaticUserNodesByNetwork(schema.NetworkID(targetnode.Network)) userGrpMap := GetUserGrpMap() - allowedUsers := make(map[string][]models.Acl) + allowedUsers := make(map[string][]schema.Acl) acls := listUserPolicies(schema.NetworkID(targetnode.Network)) - var targetNodeTags = make(map[models.TagID]struct{}) + var targetNodeTags = make(map[schema.TagID]struct{}) targetNodeTags["*"] = struct{}{} egs, _ := (&schema.Egress{Network: targetnode.Network}).ListByNetwork(db.WithContext(context.TODO())) if len(egs) == 0 { return rules } - defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(targetnode.Network), models.UserPolicy) + defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(targetnode.Network), schema.UserPolicy) for _, egI := range egs { if !egI.Status { @@ -979,14 +979,14 @@ func GetEgressUserRulesForNode(targetnode *models.Node, } if _, ok := egI.Nodes[targetnode.ID.String()]; ok { if egI.Range != "" { - targetNodeTags[models.TagID(egI.Range)] = struct{}{} + targetNodeTags[schema.TagID(egI.Range)] = struct{}{} } else if logic.HasEgressDomainAns(egI) { for _, domainAnsI := range logic.AllDomainAnsFromEgress(egI) { - targetNodeTags[models.TagID(domainAnsI)] = struct{}{} + targetNodeTags[schema.TagID(domainAnsI)] = struct{}{} } } - targetNodeTags[models.TagID(egI.ID)] = struct{}{} + targetNodeTags[schema.TagID(egI.ID)] = struct{}{} } } if !defaultPolicy.Enabled { @@ -996,7 +996,7 @@ func GetEgressUserRulesForNode(targetnode *models.Node, } dstTags := logic.ConvAclTagToValueMap(acl.Dst) for _, dst := range acl.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status { @@ -1030,9 +1030,9 @@ func GetEgressUserRulesForNode(targetnode *models.Node, if addUsers { // get all src tags for _, srcAcl := range acl.Src { - if srcAcl.ID == models.UserAclID { + if srcAcl.ID == schema.UserAclID { allowedUsers[srcAcl.Value] = append(allowedUsers[srcAcl.Value], acl) - } else if srcAcl.ID == models.UserGroupAclID { + } else if srcAcl.ID == schema.UserGroupAclID { // fetch all users in the group if usersMap, ok := userGrpMap[schema.UserGroupID(srcAcl.Value)]; ok { for userName := range usersMap { @@ -1099,7 +1099,7 @@ func GetEgressUserRulesForNode(targetnode *models.Node, r.IP6List = append(r.IP6List, userNode.StaticNode.AddressIPNet6()) } for _, dstI := range acl.Dst { - if dstI.ID == models.EgressID { + if dstI.ID == schema.EgressID { e := schema.Egress{ID: dstI.Value} err := e.Get(db.WithContext(context.TODO())) if err != nil { @@ -1190,7 +1190,7 @@ func GetEgressUserRulesForNode(targetnode *models.Node, // rules keyed by acl.ID, and a "-reverse" companion is added for Bi policies. func appendUserExtClientRemoteEgressFwdRules( targetnode *models.Node, - acls []models.Acl, + acls []schema.Acl, egs []schema.Egress, userNodes []models.Node, userGrpMap map[schema.UserGroupID]map[string]struct{}, @@ -1263,9 +1263,9 @@ func appendUserExtClientRemoteEgressFwdRules( allowedOwners := make(map[string]struct{}) for _, srcAcl := range acl.Src { - if srcAcl.ID == models.UserAclID { + if srcAcl.ID == schema.UserAclID { allowedOwners[srcAcl.Value] = struct{}{} - } else if srcAcl.ID == models.UserGroupAclID { + } else if srcAcl.ID == schema.UserGroupAclID { if usersMap, ok := userGrpMap[schema.UserGroupID(srcAcl.Value)]; ok { for userName := range usersMap { allowedOwners[userName] = struct{}{} @@ -1313,7 +1313,7 @@ func appendUserExtClientRemoteEgressFwdRules( } rules[ruleID] = aclRule - if acl.AllowedDirection == models.TrafficDirectionBi && + if acl.AllowedDirection == schema.TrafficDirectionBi && (len(aclRule.Dst) > 0 || len(aclRule.Dst6) > 0) { revID := ruleID + "-reverse" rules[revID] = models.AclRule{ @@ -1335,9 +1335,9 @@ func GetUserAclRulesForNode(targetnode *models.Node, rules map[string]models.AclRule) map[string]models.AclRule { userNodes := getStaticUserNodesByNetwork(schema.NetworkID(targetnode.Network)) userGrpMap := GetUserGrpMap() - allowedUsers := make(map[string][]models.Acl) + allowedUsers := make(map[string][]schema.Acl) acls := listUserPolicies(schema.NetworkID(targetnode.Network)) - var targetNodeTags = make(map[models.TagID]struct{}) + var targetNodeTags = make(map[schema.TagID]struct{}) if targetnode.Mutex != nil { targetnode.Mutex.Lock() targetNodeTags = maps.Clone(targetnode.Tags) @@ -1346,10 +1346,10 @@ func GetUserAclRulesForNode(targetnode *models.Node, targetNodeTags = maps.Clone(targetnode.Tags) } if targetNodeTags == nil { - targetNodeTags = make(map[models.TagID]struct{}) + targetNodeTags = make(map[schema.TagID]struct{}) } - defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(targetnode.Network), models.UserPolicy) - targetNodeTags[models.TagID(targetnode.ID.String())] = struct{}{} + defaultPolicy, _ := logic.GetDefaultPolicy(schema.NetworkID(targetnode.Network), schema.UserPolicy) + targetNodeTags[schema.TagID(targetnode.ID.String())] = struct{}{} if !defaultPolicy.Enabled { for _, acl := range acls { if !acl.Enabled { @@ -1360,7 +1360,7 @@ func GetUserAclRulesForNode(targetnode *models.Node, addUsers := false if !all { for _, dst := range acl.Dst { - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status && len(e.Nodes) > 0 { @@ -1383,9 +1383,9 @@ func GetUserAclRulesForNode(targetnode *models.Node, if addUsers { // get all src tags for _, srcAcl := range acl.Src { - if srcAcl.ID == models.UserAclID { + if srcAcl.ID == schema.UserAclID { allowedUsers[srcAcl.Value] = append(allowedUsers[srcAcl.Value], acl) - } else if srcAcl.ID == models.UserGroupAclID { + } else if srcAcl.ID == schema.UserGroupAclID { // fetch all users in the group if usersMap, ok := userGrpMap[schema.UserGroupID(srcAcl.Value)]; ok { for userName := range usersMap { @@ -1489,7 +1489,7 @@ func GetUserAclRulesForNode(targetnode *models.Node, } break } - if dst.ID == models.EgressID { + if dst.ID == schema.EgressID { e := schema.Egress{ID: dst.Value} err := e.Get(db.WithContext(context.TODO())) if err == nil && e.Status && len(e.Nodes) > 0 { @@ -1589,8 +1589,8 @@ func GetUserAclRulesForNode(targetnode *models.Node, return rules } -func CheckIfAnyPolicyisUniDirectional(targetNode models.Node, acls []models.Acl) bool { - var targetNodeTags = make(map[models.TagID]struct{}) +func CheckIfAnyPolicyisUniDirectional(targetNode models.Node, acls []schema.Acl) bool { + var targetNodeTags = make(map[schema.TagID]struct{}) if targetNode.Mutex != nil { targetNode.Mutex.Lock() targetNodeTags = maps.Clone(targetNode.Tags) @@ -1599,24 +1599,24 @@ func CheckIfAnyPolicyisUniDirectional(targetNode models.Node, acls []models.Acl) targetNodeTags = maps.Clone(targetNode.Tags) } if targetNodeTags == nil { - targetNodeTags = make(map[models.TagID]struct{}) + targetNodeTags = make(map[schema.TagID]struct{}) } - targetNodeTags[models.TagID(targetNode.ID.String())] = struct{}{} + targetNodeTags[schema.TagID(targetNode.ID.String())] = struct{}{} targetNodeTags["*"] = struct{}{} for _, acl := range acls { if !acl.Enabled { continue } - if acl.AllowedDirection == models.TrafficDirectionBi && acl.Proto == models.ALL && acl.ServiceType == models.Any { + if acl.AllowedDirection == schema.TrafficDirectionBi && acl.Proto == schema.ALL && acl.ServiceType == models.Any { continue } - if acl.Proto != models.ALL || acl.ServiceType != models.Any { + if acl.Proto != schema.ALL || acl.ServiceType != models.Any { return true } srcTags := logic.ConvAclTagToValueMap(acl.Src) dstTags := logic.ConvAclTagToValueMap(acl.Dst) for nodeTag := range targetNodeTags { - if acl.RuleType == models.DevicePolicy { + if acl.RuleType == schema.DevicePolicy { if _, ok := srcTags[nodeTag.String()]; ok { return true } @@ -1636,11 +1636,11 @@ func CheckIfAnyPolicyisUniDirectional(targetNode models.Node, acls []models.Acl) return false } -func GetTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[models.TagID][]models.Node) { - tagNodesMap = make(map[models.TagID][]models.Node) +func GetTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) (tagNodesMap map[schema.TagID][]models.Node) { + tagNodesMap = make(map[schema.TagID][]models.Node) nodes, _ := logic.GetNetworkNodes(netID.String()) for _, nodeI := range nodes { - tagNodesMap[models.TagID(nodeI.ID.String())] = []models.Node{ + tagNodesMap[schema.TagID(nodeI.ID.String())] = []models.Node{ nodeI, } if nodeI.Tags == nil { @@ -1650,7 +1650,7 @@ func GetTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) ( nodeI.Mutex.Lock() } for nodeTagID := range nodeI.Tags { - if nodeTagID == models.TagID(nodeI.ID.String()) { + if nodeTagID == schema.TagID(nodeI.ID.String()) { continue } tagNodesMap[nodeTagID] = append(tagNodesMap[nodeTagID], nodeI) @@ -1667,7 +1667,7 @@ func GetTagMapWithNodesByNetwork(netID schema.NetworkID, withStaticNodes bool) ( } func AddTagMapWithStaticNodes(netID schema.NetworkID, - tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node { + tagNodesMap map[schema.TagID][]models.Node) map[schema.TagID][]models.Node { extclients, err := logic.GetNetworkExtClients(netID.String()) if err != nil { return tagNodesMap @@ -1676,7 +1676,7 @@ func AddTagMapWithStaticNodes(netID schema.NetworkID, if extclient.RemoteAccessClientID != "" { continue } - tagNodesMap[models.TagID(extclient.ClientID)] = []models.Node{ + tagNodesMap[schema.TagID(extclient.ClientID)] = []models.Node{ { IsStatic: true, StaticNode: extclient, @@ -1690,11 +1690,11 @@ func AddTagMapWithStaticNodes(netID schema.NetworkID, extclient.Mutex.Lock() } for tagID := range extclient.Tags { - if tagID == models.TagID(extclient.ClientID) { + if tagID == schema.TagID(extclient.ClientID) { continue } - tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode()) - tagNodesMap["*"] = append(tagNodesMap["*"], extclient.ConvertToStaticNode()) + tagNodesMap[tagID] = append(tagNodesMap[tagID], models.ConvertToStaticNode(&extclient)) + tagNodesMap["*"] = append(tagNodesMap["*"], models.ConvertToStaticNode(&extclient)) } if extclient.Mutex != nil { extclient.Mutex.Unlock() diff --git a/pro/logic/dns.go b/pro/logic/dns.go index 16b4e69ea..cf1f51310 100644 --- a/pro/logic/dns.go +++ b/pro/logic/dns.go @@ -25,7 +25,7 @@ func ValidateNameserverReq(ns *schema.Nameserver) error { if tagI == "*" { continue } - _, err := GetTag(models.TagID(tagI)) + _, err := GetTag(schema.TagID(tagI)) if err != nil { return errors.New("invalid tag") } @@ -243,7 +243,7 @@ func GetNameserversForHost(h *schema.Host) (returnNsLi []models.Nameserver) { return } -func RemoveTagFromNameservers(tagID models.TagID, netID schema.NetworkID) error { +func RemoveTagFromNameservers(tagID schema.TagID, netID schema.NetworkID) error { nameservers, err := (&schema.Nameserver{ NetworkID: netID.String(), }).ListByNetwork(db.WithContext(context.TODO())) diff --git a/pro/logic/egress.go b/pro/logic/egress.go index d3f0e9c54..c3990d46c 100644 --- a/pro/logic/egress.go +++ b/pro/logic/egress.go @@ -13,7 +13,6 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "gorm.io/datatypes" @@ -60,7 +59,7 @@ func ValidateEgressReq(e *schema.Egress) error { if len(e.Tags) > 0 { e.Nodes = make(datatypes.JSONMap) for tagID := range e.Tags { - _, err := GetTag(models.TagID(tagID)) + _, err := GetTag(schema.TagID(tagID)) if err != nil { return errors.New("invalid tag " + tagID) } @@ -69,7 +68,7 @@ func ValidateEgressReq(e *schema.Egress) error { return nil } -func RemoveTagFromEgress(net schema.NetworkID, tagID models.TagID) { +func RemoveTagFromEgress(net schema.NetworkID, tagID schema.TagID) { eli, _ := (&schema.Egress{Network: net.String()}).ListByNetwork(db.WithContext(context.TODO())) for _, eI := range eli { if _, ok := eI.Tags[tagID.String()]; ok { diff --git a/pro/logic/metrics.go b/pro/logic/metrics.go index 9cd9e28aa..99075321f 100644 --- a/pro/logic/metrics.go +++ b/pro/logic/metrics.go @@ -18,21 +18,22 @@ import ( "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/exp/slog" + "gorm.io/datatypes" ) var ( metricsCacheMutex = &sync.RWMutex{} - metricsCacheMap = make(map[string]models.Metrics) + metricsCacheMap = make(map[string]schema.Metrics) ) -func getMetricsFromCache(key string) (metrics models.Metrics, ok bool) { +func getMetricsFromCache(key string) (metrics schema.Metrics, ok bool) { metricsCacheMutex.RLock() metrics, ok = metricsCacheMap[key] metricsCacheMutex.RUnlock() return } -func storeMetricsInCache(key string, metrics models.Metrics) { +func storeMetricsInCache(key string, metrics schema.Metrics) { metricsCacheMutex.Lock() metricsCacheMap[key] = metrics metricsCacheMutex.Unlock() @@ -40,7 +41,7 @@ func storeMetricsInCache(key string, metrics models.Metrics) { // metricsReadCopy returns a shallow copy of m with Connectivity deep-cloned so callers never // share the cached map with MQTT / UpdateMetrics / metrics monitor goroutines that mutate it. -func metricsReadCopy(m *models.Metrics) *models.Metrics { +func metricsReadCopy(m *schema.Metrics) *schema.Metrics { if m == nil { return nil } @@ -60,22 +61,18 @@ func deleteMetricsFromCache(key string) { func LoadNodeMetricsToCache() error { slog.Info("loading metrics to cache") if metricsCacheMap == nil { - metricsCacheMap = map[string]models.Metrics{} + metricsCacheMap = map[string]schema.Metrics{} } - collection, err := database.FetchRecords(database.METRICS_TABLE_NAME) + entries, err := (&schema.MetricsEntry{}).ListAll(db.WithContext(context.TODO())) if err != nil { return err } - for key, value := range collection { - var metrics models.Metrics - if err := json.Unmarshal([]byte(value), &metrics); err != nil { - slog.Error("parse metric record error", "error", err.Error()) - continue - } + for _, entry := range entries { + metrics := entry.Value.Data() if servercfg.CacheEnabled() { - storeMetricsInCache(key, metrics) + storeMetricsInCache(entry.Key, metrics) } } @@ -84,24 +81,21 @@ func LoadNodeMetricsToCache() error { } // GetMetrics - gets the metrics -func GetMetrics(nodeid string) (*models.Metrics, error) { - var metrics models.Metrics +func GetMetrics(nodeid string) (*schema.Metrics, error) { + var metrics schema.Metrics if servercfg.CacheEnabled() { if m, ok := getMetricsFromCache(nodeid); ok { return metricsReadCopy(&m), nil } } - record, err := database.FetchRecord(database.METRICS_TABLE_NAME, nodeid) - if err != nil { + entry := &schema.MetricsEntry{Key: nodeid} + if err := entry.Get(db.WithContext(context.TODO())); err != nil { if database.IsEmptyRecord(err) { return &metrics, nil } return &metrics, err } - err = json.Unmarshal([]byte(record), &metrics) - if err != nil { - return &metrics, err - } + metrics = entry.Value.Data() if servercfg.CacheEnabled() { storeMetricsInCache(nodeid, metrics) return metricsReadCopy(&metrics), nil @@ -110,13 +104,9 @@ func GetMetrics(nodeid string) (*models.Metrics, error) { } // UpdateMetrics - updates the metrics of a given client -func UpdateMetrics(nodeid string, metrics *models.Metrics) error { +func UpdateMetrics(nodeid string, metrics *schema.Metrics) error { metrics.UpdatedAt = time.Now() - data, err := json.Marshal(metrics) - if err != nil { - return err - } - err = database.Insert(nodeid, string(data), database.METRICS_TABLE_NAME) + err := (&schema.MetricsEntry{Key: nodeid, NetworkID: metrics.Network, Value: datatypes.NewJSONType(*metrics)}).Save(db.WithContext(context.TODO())) if err != nil { return err } @@ -128,7 +118,7 @@ func UpdateMetrics(nodeid string, metrics *models.Metrics) error { // DeleteMetrics - deletes metrics of a given node func DeleteMetrics(nodeid string) error { - err := database.DeleteRecord(database.METRICS_TABLE_NAME, nodeid) + err := (&schema.MetricsEntry{Key: nodeid}).Delete(db.WithContext(context.TODO())) if err != nil { return err } @@ -208,7 +198,7 @@ func SetPeerMetricsDisconnected(nodeID string) { } // MQUpdateMetricsFallBack - called when mq fallback thread is triggered on client -func MQUpdateMetricsFallBack(nodeid string, newMetrics models.Metrics) { +func MQUpdateMetricsFallBack(nodeid string, newMetrics schema.Metrics) { currentNode, err := logic.GetNodeByID(nodeid) if err != nil { @@ -243,7 +233,7 @@ func MQUpdateMetrics(client mqtt.Client, msg mqtt.Message) { return } - var newMetrics models.Metrics + var newMetrics schema.Metrics if err := json.Unmarshal(decrypted, &newMetrics); err != nil { slog.Error("error unmarshaling payload", "error", err) return @@ -256,14 +246,14 @@ func MQUpdateMetrics(client mqtt.Client, msg mqtt.Message) { slog.Debug("updated node metrics", "id", id) } -func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) { +func updateNodeMetrics(currentNode *models.Node, newMetrics *schema.Metrics) { oldMetrics, err := logic.GetMetrics(currentNode.ID.String()) if err != nil { slog.Error("error finding old metrics for node", "id", currentNode.ID, "error", err) return } - var attachedClients []models.ExtClient + var attachedClients []schema.ExtClient if currentNode.IsIngressGateway { clients, err := logic.GetExtClientsByID(currentNode.ID.String(), currentNode.Network) if err == nil { @@ -271,7 +261,7 @@ func updateNodeMetrics(currentNode *models.Node, newMetrics *models.Metrics) { } } if newMetrics.Connectivity == nil { - newMetrics.Connectivity = make(map[string]models.Metric) + newMetrics.Connectivity = make(map[string]schema.Metric) } for i := range attachedClients { slog.Debug("[metrics] processing attached client", "client", attachedClients[i].ClientID, "public key", attachedClients[i].PublicKey) diff --git a/pro/logic/migrate.go b/pro/logic/migrate.go index 0841dd371..b9b0f771e 100644 --- a/pro/logic/migrate.go +++ b/pro/logic/migrate.go @@ -7,7 +7,6 @@ import ( "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/logic" - "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" ) @@ -16,15 +15,15 @@ func CleanupGwsMigration() { for _, acl := range acls { upsert := false for i, srcI := range acl.Src { - if srcI.ID == models.NodeTagID && srcI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.OldRemoteAccessTagName) { - srcI.Value = fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName) + if srcI.ID == schema.NodeTagID && srcI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.OldRemoteAccessTagName) { + srcI.Value = fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName) acl.Src[i] = srcI upsert = true } } for i, dstI := range acl.Dst { - if dstI.ID == models.NodeTagID && dstI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.OldRemoteAccessTagName) { - dstI.Value = fmt.Sprintf("%s.%s", acl.NetworkID.String(), models.GwTagName) + if dstI.ID == schema.NodeTagID && dstI.Value == fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.OldRemoteAccessTagName) { + dstI.Value = fmt.Sprintf("%s.%s", acl.NetworkID.String(), schema.GwTagName) acl.Dst[i] = dstI upsert = true } @@ -35,6 +34,6 @@ func CleanupGwsMigration() { } nets, _ := (&schema.Network{}).ListAll(db.WithContext(context.TODO())) for _, netI := range nets { - DeleteTag(models.TagID(fmt.Sprintf("%s.%s", netI.Name, models.OldRemoteAccessTagName)), true) + DeleteTag(schema.TagID(fmt.Sprintf("%s.%s", netI.Name, schema.OldRemoteAccessTagName)), true) } } diff --git a/pro/logic/nodes.go b/pro/logic/nodes.go index 9e71f259f..75f74746c 100644 --- a/pro/logic/nodes.go +++ b/pro/logic/nodes.go @@ -21,8 +21,8 @@ func GetNetworkIngresses(network string) ([]models.Node, error) { return ingresses, nil } -func GetTagMapWithNodes() (tagNodesMap map[models.TagID][]models.Node) { - tagNodesMap = make(map[models.TagID][]models.Node) +func GetTagMapWithNodes() (tagNodesMap map[schema.TagID][]models.Node) { + tagNodesMap = make(map[schema.TagID][]models.Node) nodes, _ := logic.GetAllNodes() for _, nodeI := range nodes { if nodeI.Tags == nil { @@ -43,13 +43,13 @@ func GetTagMapWithNodes() (tagNodesMap map[models.TagID][]models.Node) { } func AddTagMapWithStaticNodesWithUsers(netID schema.NetworkID, - tagNodesMap map[models.TagID][]models.Node) map[models.TagID][]models.Node { + tagNodesMap map[schema.TagID][]models.Node) map[schema.TagID][]models.Node { extclients, err := logic.GetNetworkExtClients(netID.String()) if err != nil { return tagNodesMap } for _, extclient := range extclients { - tagNodesMap[models.TagID(extclient.ClientID)] = []models.Node{ + tagNodesMap[schema.TagID(extclient.ClientID)] = []models.Node{ { IsStatic: true, StaticNode: extclient, @@ -62,7 +62,7 @@ func AddTagMapWithStaticNodesWithUsers(netID schema.NetworkID, extclient.Mutex.Lock() } for tagID := range extclient.Tags { - tagNodesMap[tagID] = append(tagNodesMap[tagID], extclient.ConvertToStaticNode()) + tagNodesMap[tagID] = append(tagNodesMap[tagID], models.ConvertToStaticNode(&extclient)) } if extclient.Mutex != nil { extclient.Mutex.Unlock() @@ -72,7 +72,7 @@ func AddTagMapWithStaticNodesWithUsers(netID schema.NetworkID, return tagNodesMap } -func GetNodeIDsWithTag(tagID models.TagID) (ids []string) { +func GetNodeIDsWithTag(tagID schema.TagID) (ids []string) { tag, err := GetTag(tagID) if err != nil { @@ -96,7 +96,7 @@ func GetNodeIDsWithTag(tagID models.TagID) (ids []string) { return } -func GetNodesWithTag(tagID models.TagID) map[string]models.Node { +func GetNodesWithTag(tagID schema.TagID) map[string]models.Node { nMap := make(map[string]models.Node) tag, err := GetTag(tagID) if err != nil { @@ -120,7 +120,7 @@ func GetNodesWithTag(tagID models.TagID) map[string]models.Node { return AddStaticNodesWithTag(tag, nMap) } -func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[string]models.Node { +func AddStaticNodesWithTag(tag schema.Tag, nMap map[string]models.Node) map[string]models.Node { extclients, err := logic.GetNetworkExtClients(tag.Network.String()) if err != nil { return nMap @@ -133,7 +133,7 @@ func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[stri extclient.Mutex.Lock() } if _, ok := extclient.Tags[tag.ID]; ok { - nMap[extclient.ClientID] = extclient.ConvertToStaticNode() + nMap[extclient.ClientID] = models.ConvertToStaticNode(&extclient) } if extclient.Mutex != nil { extclient.Mutex.Unlock() @@ -142,7 +142,7 @@ func AddStaticNodesWithTag(tag models.Tag, nMap map[string]models.Node) map[stri return nMap } -func GetStaticNodeWithTag(tagID models.TagID) map[string]models.Node { +func GetStaticNodeWithTag(tagID schema.TagID) map[string]models.Node { nMap := make(map[string]models.Node) tag, err := GetTag(tagID) if err != nil { @@ -153,7 +153,7 @@ func GetStaticNodeWithTag(tagID models.TagID) map[string]models.Node { return nMap } for _, extclient := range extclients { - nMap[extclient.ClientID] = extclient.ConvertToStaticNode() + nMap[extclient.ClientID] = models.ConvertToStaticNode(&extclient) } return nMap } diff --git a/pro/logic/posture_check.go b/pro/logic/posture_check.go index 1111f8356..8eee24527 100644 --- a/pro/logic/posture_check.go +++ b/pro/logic/posture_check.go @@ -33,7 +33,7 @@ func AddPostureCheckHook() { Interval: interval, } } -func RemoveTagFromPostureChecks(tagID models.TagID, netID schema.NetworkID) { +func RemoveTagFromPostureChecks(tagID schema.TagID, netID schema.NetworkID) { pcLi, err := (&schema.PostureCheck{NetworkID: netID}).ListByNetwork(db.WithContext(context.TODO())) if err != nil || len(pcLi) == 0 { return @@ -87,7 +87,7 @@ func RunPostureChecks() error { if nodeI.IsStatic && !nodeI.IsUserNode { continue } - var postureChecksViolations []models.Violation + var postureChecksViolations []schema.Violation var postureCheckVolationSeverityLevel schema.Severity if noChecks { postureCheckVolationSeverityLevel = schema.SeverityUnknown @@ -140,22 +140,22 @@ func RunPostureChecks() error { return nil } -func CheckPostureViolations(d models.PostureCheckDeviceInfo, network schema.NetworkID) ([]models.Violation, schema.Severity) { +func CheckPostureViolations(d models.PostureCheckDeviceInfo, network schema.NetworkID) ([]schema.Violation, schema.Severity) { if !GetFeatureFlags().EnablePostureChecks { - return []models.Violation{}, schema.SeverityUnknown + return []schema.Violation{}, schema.SeverityUnknown } pcLi, err := (&schema.PostureCheck{NetworkID: network}).ListByNetwork(db.WithContext(context.TODO())) if err != nil || len(pcLi) == 0 { - return []models.Violation{}, schema.SeverityUnknown + return []schema.Violation{}, schema.SeverityUnknown } violations, level := GetPostureCheckViolations(pcLi, d) return violations, level } -func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]models.Violation, schema.Severity) { +func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureCheckDeviceInfo) ([]schema.Violation, schema.Severity) { if !GetFeatureFlags().EnablePostureChecks { - return []models.Violation{}, schema.SeverityUnknown + return []schema.Violation{}, schema.SeverityUnknown } - var violations []models.Violation + var violations []schema.Violation highest := schema.SeverityUnknown // Group checks by attribute @@ -181,7 +181,7 @@ func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureChe } exists := false for tagID := range c.Tags { - if _, ok := d.Tags[models.TagID(tagID)]; ok { + if _, ok := d.Tags[schema.TagID(tagID)]; ok { exists = true break } @@ -241,7 +241,7 @@ func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureChe if sev > highest { highest = sev } - v := models.Violation{ + v := schema.Violation{ CheckID: denied.check.ID, Name: denied.check.Name, Attribute: string(denied.check.Attribute), @@ -255,7 +255,7 @@ func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureChe if sev > highest { highest = sev } - v := models.Violation{ + v := schema.Violation{ CheckID: denied.check.ID, Name: denied.check.Name, Attribute: string(denied.check.Attribute), @@ -303,7 +303,7 @@ func GetPostureCheckViolations(checks []schema.PostureCheck, d models.PostureChe highest = sev } - v := models.Violation{ + v := schema.Violation{ CheckID: denied.check.ID, Name: denied.check.Name, Attribute: string(denied.check.Attribute), @@ -348,7 +348,7 @@ func GetPostureCheckDeviceInfoByNode(node *models.Node) models.PostureCheckDevic OSVersion: node.StaticNode.OSVersion, OSFamily: node.StaticNode.OSFamily, KernelVersion: node.StaticNode.KernelVersion, - Tags: make(map[models.TagID]struct{}), + Tags: make(map[schema.TagID]struct{}), IsUser: true, UserGroups: make(map[schema.UserGroupID]struct{}), } @@ -587,7 +587,7 @@ func ValidatePostureCheck(pc *schema.PostureCheck) error { if tagID == "*" { continue } - _, err := GetTag(models.TagID(tagID)) + _, err := GetTag(schema.TagID(tagID)) if err != nil { return errors.New("unknown tag") } diff --git a/pro/logic/status.go b/pro/logic/status.go index 26127a93a..98dcf5b33 100644 --- a/pro/logic/status.go +++ b/pro/logic/status.go @@ -165,7 +165,7 @@ func GetNodeStatus(node *models.Node, defaultEnabledPolicy bool) { // This collapses the per-peer GetNodeByID storm that previously dominated // status computation: with P peers the old path issued O(P^2) preloaded // First() queries; this path issues exactly one IN-query. -func buildPeerCache(node *models.Node, metrics *models.Metrics) map[string]models.Node { +func buildPeerCache(node *models.Node, metrics *schema.Metrics) map[string]models.Node { if metrics == nil || len(metrics.Connectivity) == 0 { return map[string]models.Node{} } @@ -250,7 +250,7 @@ func CheckPeerStatus(node *models.Node, defaultAclPolicy bool, peers map[string] node.Status = schema.WarningSt } -func checkPeerConnectivity(node *models.Node, metrics *models.Metrics, defaultAclPolicy bool, peers map[string]models.Node) { +func checkPeerConnectivity(node *models.Node, metrics *schema.Metrics, defaultAclPolicy bool, peers map[string]models.Node) { peerNotConnectedCnt := 0 for peerID, metric := range metrics.Connectivity { peer, ok := peers[peerID] diff --git a/pro/logic/tags.go b/pro/logic/tags.go index 01ab56bfd..cd55c1ceb 100644 --- a/pro/logic/tags.go +++ b/pro/logic/tags.go @@ -2,7 +2,6 @@ package logic import ( "context" - "encoding/json" "errors" "fmt" "regexp" @@ -17,49 +16,37 @@ import ( "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" "golang.org/x/exp/slog" + "gorm.io/datatypes" ) var tagMutex = &sync.RWMutex{} // GetTag - fetches tag info -func GetTag(tagID models.TagID) (models.Tag, error) { - data, err := database.FetchRecord(database.TAG_TABLE_NAME, tagID.String()) - if err != nil { - return models.Tag{}, err - } - tag := models.Tag{} - err = json.Unmarshal([]byte(data), &tag) - if err != nil { - return tag, err +func GetTag(tagID schema.TagID) (schema.Tag, error) { + entry := &schema.TagEntry{Key: tagID.String()} + if err := entry.Get(db.WithContext(context.TODO())); err != nil { + return schema.Tag{}, err } - return tag, nil + return entry.Value.Data(), nil } -func UpsertTag(tag models.Tag) error { - d, err := json.Marshal(tag) - if err != nil { - return err - } - return database.Insert(tag.ID.String(), string(d), database.TAG_TABLE_NAME) +func UpsertTag(tag schema.Tag) error { + return (&schema.TagEntry{Key: tag.ID.String(), NetworkID: string(tag.Network), Value: datatypes.NewJSONType(tag)}).Save(db.WithContext(context.TODO())) } // InsertTag - creates new tag -func InsertTag(tag models.Tag) error { +func InsertTag(tag schema.Tag) error { tagMutex.Lock() defer tagMutex.Unlock() - _, err := database.FetchRecord(database.TAG_TABLE_NAME, tag.ID.String()) - if err == nil { + existing := &schema.TagEntry{Key: tag.ID.String()} + if err := existing.Get(db.WithContext(context.TODO())); err == nil { return fmt.Errorf("tag `%s` exists already", tag.ID) } - d, err := json.Marshal(tag) - if err != nil { - return err - } - return database.Insert(tag.ID.String(), string(d), database.TAG_TABLE_NAME) + return (&schema.TagEntry{Key: tag.ID.String(), NetworkID: string(tag.Network), Value: datatypes.NewJSONType(tag)}).Save(db.WithContext(context.TODO())) } // DeleteTag - delete tag, will also untag hosts -func DeleteTag(tagID models.TagID, removeFromPolicy bool) error { +func DeleteTag(tagID schema.TagID, removeFromPolicy bool) error { tagMutex.Lock() defer tagMutex.Unlock() // cleanUp tags on hosts @@ -93,7 +80,7 @@ func DeleteTag(tagID models.TagID, removeFromPolicy bool) error { logic.SaveExtClient(&extclient) } } - return database.DeleteRecord(database.TAG_TABLE_NAME, tagID.String()) + return (&schema.TagEntry{Key: tagID.String()}).Delete(db.WithContext(context.TODO())) } // ListTagsWithHosts - lists all tags with tagged hosts @@ -122,30 +109,25 @@ func DeleteAllNetworkTags(networkID schema.NetworkID) { } // ListNetworkTags - lists all tags in network -func ListNetworkTags(netID schema.NetworkID) ([]models.Tag, error) { +func ListNetworkTags(netID schema.NetworkID) ([]schema.Tag, error) { tagMutex.RLock() defer tagMutex.RUnlock() - data, err := database.FetchRecords(database.TAG_TABLE_NAME) + entries, err := (&schema.TagEntry{}).ListAll(db.WithContext(context.TODO())) if err != nil && !database.IsEmptyRecord(err) { - return []models.Tag{}, err + return []schema.Tag{}, err } - tags := []models.Tag{} - for _, dataI := range data { - tag := models.Tag{} - err := json.Unmarshal([]byte(dataI), &tag) - if err != nil { - continue - } + tags := []schema.Tag{} + for _, entry := range entries { + tag := entry.Value.Data() if tag.Network == netID { tags = append(tags, tag) } - } return tags, nil } // UpdateTag - updates and syncs hosts with tag update -func UpdateTag(req models.UpdateTagReq, newID models.TagID) { +func UpdateTag(req models.UpdateTagReq, newID schema.TagID) { tagMutex.Lock() defer tagMutex.Unlock() network := &schema.Network{ @@ -196,7 +178,7 @@ func UpdateTag(req models.UpdateTagReq, newID models.TagID) { extclients, _ := logic.GetNetworkExtClients(req.Network.String()) for _, extclient := range extclients { if extclient.Tags == nil { - extclient.Tags = make(map[models.TagID]struct{}) + extclient.Tags = make(map[schema.TagID]struct{}) } // unassign old tag @@ -240,9 +222,9 @@ func CheckIDSyntax(id string) error { func CreateDefaultTags(netID schema.NetworkID) { // create tag for gws in the network - tag := models.Tag{ - ID: models.TagID(fmt.Sprintf("%s.%s", netID.String(), models.GwTagName)), - TagName: models.GwTagName, + tag := schema.Tag{ + ID: schema.TagID(fmt.Sprintf("%s.%s", netID.String(), schema.GwTagName)), + TagName: schema.GwTagName, Network: netID, CreatedBy: "auto", CreatedAt: time.Now().UTC(), diff --git a/pro/logic/user_mgmt.go b/pro/logic/user_mgmt.go index bc3e97322..39c816059 100644 --- a/pro/logic/user_mgmt.go +++ b/pro/logic/user_mgmt.go @@ -1062,27 +1062,27 @@ func EnsureDefaultUserGroupNetworkPolicies(old, new *schema.UserGroup) error { } if !exists { - _ = logic.InsertAcl(models.Acl{ + _ = logic.InsertAcl(schema.Acl{ ID: uuid.New().String(), Name: defaultAclName, MetaData: "This Policy allows user group to communicate with all gateways", Default: true, ServiceType: models.Any, NetworkID: schema.NetworkID(network.Name), - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: groupID, }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", schema.NetworkID(network.Name), models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", schema.NetworkID(network.Name), schema.GwTagName), }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -1118,10 +1118,10 @@ func EnsureDefaultUserGroupNetworkPolicies(old, new *schema.UserGroup) error { // _ = logic.DeleteAcl(acl) //} - var newAclSrc []models.AclPolicyTag + var newAclSrc []schema.AclPolicyTag var groupSrcExists bool for _, src := range acl.Src { - if src.ID == models.UserGroupAclID && src.Value == groupID { + if src.ID == schema.UserGroupAclID && src.Value == groupID { groupSrcExists = true } else { newAclSrc = append(newAclSrc, src) @@ -1188,27 +1188,27 @@ func CreateDefaultUserPolicies(netID schema.NetworkID) { } if !logic.IsAclExists(fmt.Sprintf("%s.%s", netID, "all-users")) { - defaultUserAcl := models.Acl{ + defaultUserAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s", netID, "all-users"), Default: true, Name: "All Users", MetaData: "This policy gives access to everything in the network for an user", NetworkID: netID, - Proto: models.ALL, + Proto: schema.ALL, ServiceType: models.Any, Port: []string{}, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserAclID, + ID: schema.UserAclID, Value: "*", }, }, - Dst: []models.AclPolicyTag{{ - ID: models.NodeTagID, + Dst: []schema.AclPolicyTag{{ + ID: schema.NodeTagID, Value: "*", }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -1219,31 +1219,31 @@ func CreateDefaultUserPolicies(netID schema.NetworkID) { if !logic.IsAclExists(fmt.Sprintf("%s.%s-grp", netID, schema.NetworkAdmin)) { networkAdminGroupID := GetDefaultNetworkAdminGroupID(netID) - defaultUserAcl := models.Acl{ + defaultUserAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s-grp", netID, schema.NetworkAdmin), Name: "Network Admin", MetaData: "This Policy allows all network admins to communicate with all gateways", Default: true, ServiceType: models.Any, NetworkID: netID, - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: globalNetworksAdminGroupID.String(), }, { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: networkAdminGroupID.String(), }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", netID, models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", netID, schema.GwTagName), }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -1254,31 +1254,31 @@ func CreateDefaultUserPolicies(netID schema.NetworkID) { if !logic.IsAclExists(fmt.Sprintf("%s.%s-grp", netID, schema.NetworkUser)) { networkUserGroupID := GetDefaultNetworkUserGroupID(netID) - defaultUserAcl := models.Acl{ + defaultUserAcl := schema.Acl{ ID: fmt.Sprintf("%s.%s-grp", netID, schema.NetworkUser), Name: "Network User", MetaData: "This Policy allows all network users to communicate with all gateways", Default: true, ServiceType: models.Any, NetworkID: netID, - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: globalNetworksUserGroupID.String(), }, { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: networkUserGroupID.String(), }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", netID, models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", netID, schema.GwTagName), }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), @@ -1317,27 +1317,27 @@ func CreateDefaultUserPolicies(netID schema.NetworkID) { } if !exists { - _ = logic.InsertAcl(models.Acl{ + _ = logic.InsertAcl(schema.Acl{ ID: uuid.New().String(), Name: defaultAclName, MetaData: "This Policy allows user group to communicate with all gateways", Default: true, ServiceType: models.Any, NetworkID: netID, - Proto: models.ALL, - RuleType: models.UserPolicy, - Src: []models.AclPolicyTag{ + Proto: schema.ALL, + RuleType: schema.UserPolicy, + Src: []schema.AclPolicyTag{ { - ID: models.UserGroupAclID, + ID: schema.UserGroupAclID, Value: group.ID.String(), }, }, - Dst: []models.AclPolicyTag{ + Dst: []schema.AclPolicyTag{ { - ID: models.NodeTagID, - Value: fmt.Sprintf("%s.%s", netID, models.GwTagName), + ID: schema.NodeTagID, + Value: fmt.Sprintf("%s.%s", netID, schema.GwTagName), }}, - AllowedDirection: models.TrafficDirectionUni, + AllowedDirection: schema.TrafficDirectionUni, Enabled: true, CreatedBy: "auto", CreatedAt: time.Now().UTC(), diff --git a/pro/orchestrator/extensions/node.go b/pro/orchestrator/extensions/node.go index a2d315c7d..b785e8773 100644 --- a/pro/orchestrator/extensions/node.go +++ b/pro/orchestrator/extensions/node.go @@ -1,7 +1,6 @@ package extensions import ( - "github.com/gravitl/netmaker/models" proLogic "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" ) @@ -16,7 +15,7 @@ func (p *ProNodeExtensions) ConfigureAutoAssignGateway(node *schema.Node, key *s node.AutoAssignGateway = key.AutoAssignGateway } -func (p *ProNodeExtensions) ConfigureTag(node *schema.Node, tagID models.TagID) { +func (p *ProNodeExtensions) ConfigureTag(node *schema.Node, tagID schema.TagID) { tag, err := proLogic.GetTag(tagID) if err != nil { return diff --git a/pro/remote_access_client.go b/pro/remote_access_client.go index 3e9cd874c..0d1fde65d 100644 --- a/pro/remote_access_client.go +++ b/pro/remote_access_client.go @@ -72,7 +72,7 @@ func unauthorisedUserNodeHook() error { return nil } -func disableExtClient(client *models.ExtClient) error { +func disableExtClient(client *schema.ExtClient) error { if newClient, err := logic.ToggleExtClientConnectivity(client, false); err != nil { return err } else { @@ -92,7 +92,7 @@ func disableExtClient(client *models.ExtClient) error { if err != nil { return err } - go mq.PublishSingleHostPeerUpdate(ingressHost, nodes, nil, nil, []models.ExtClient{*client}, false, nil) + go mq.PublishSingleHostPeerUpdate(ingressHost, nodes, nil, nil, []schema.ExtClient{*client}, false, nil) } else { return err } diff --git a/schema/acl.go b/schema/acl.go new file mode 100644 index 000000000..135227271 --- /dev/null +++ b/schema/acl.go @@ -0,0 +1,136 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +// AllowedTrafficDirection - allowed direction of traffic +type AllowedTrafficDirection int + +const ( + // TrafficDirectionUni implies traffic is only allowed in one direction (src --> dst) + TrafficDirectionUni AllowedTrafficDirection = iota + // TrafficDirectionBi implies traffic is allowed both direction (src <--> dst ) + TrafficDirectionBi +) + +// Protocol - allowed protocol +type Protocol string + +const ( + ALL Protocol = "all" + UDP Protocol = "udp" + TCP Protocol = "tcp" + ICMP Protocol = "icmp" +) + +func (p Protocol) String() string { + return string(p) +} + +type AclPolicyType string + +const ( + UserPolicy AclPolicyType = "user-policy" + DevicePolicy AclPolicyType = "device-policy" +) + +type AclPolicyTag struct { + ID AclGroupType `json:"id"` + Name string `json:"name"` + Value string `json:"value"` +} + +type AclGroupType string + +const ( + UserAclID AclGroupType = "user" + UserGroupAclID AclGroupType = "user-group" + NodeTagID AclGroupType = "tag" + NodeID AclGroupType = "device" + EgressRange AclGroupType = "egress-range" + EgressID AclGroupType = "egress-id" + NetmakerIPAclID AclGroupType = "ip" + NetmakerSubNetRangeAClID AclGroupType = "ipset" +) + +func (g AclGroupType) String() string { + return string(g) +} + +type AclPolicy struct { + TypeID AclPolicyType + PrefixTagUser AclGroupType +} + +type Acl struct { + ID string `json:"id"` + Default bool `json:"default"` + MetaData string `json:"meta_data"` + Name string `json:"name"` + NetworkID NetworkID `json:"network_id"` + RuleType AclPolicyType `json:"policy_type"` + Src []AclPolicyTag `json:"src_type"` + Dst []AclPolicyTag `json:"dst_type"` + Proto Protocol `json:"protocol"` // tcp, udp, etc. + ServiceType string `json:"type"` + Port []string `json:"ports"` + AllowedDirection AllowedTrafficDirection `json:"allowed_traffic_direction"` + Enabled bool `json:"enabled"` + CreatedBy string `json:"created_by"` + CreatedAt time.Time `json:"created_at"` +} + +// AclEntry is the GORM model for the legacy "acls" key-value table, +// extended with tenant_id and network_id columns for multi-tenancy. +type AclEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + NetworkID string `gorm:"column:network_id"` + Value datatypes.JSONType[Acl] `gorm:"column:value"` +} + +func (*AclEntry) TableName() string { return "acls" } + +func (e *AclEntry) Create(ctx context.Context) error { + return db.FromContext(ctx).Create(e).Error +} + +// Save does an upsert — insert or replace on primary key conflict. +func (e *AclEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *AclEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *AclEntry) ListAll(ctx context.Context) ([]AclEntry, error) { + var entries []AclEntry + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (e *AclEntry) ListByNetwork(ctx context.Context) ([]AclEntry, error) { + var entries []AclEntry + err := db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Find(&entries).Error + return entries, err +} + +func (e *AclEntry) Count(ctx context.Context) (int64, error) { + var count int64 + err := db.FromContext(ctx).Model(&AclEntry{}).Count(&count).Error + return count, err +} + +func (e *AclEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&AclEntry{}).Error +} + +func (e *AclEntry) DeleteByNetwork(ctx context.Context) error { + return db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Delete(&AclEntry{}).Error +} diff --git a/schema/extclient.go b/schema/extclient.go new file mode 100644 index 000000000..5124847af --- /dev/null +++ b/schema/extclient.go @@ -0,0 +1,116 @@ +package schema + +import ( + "context" + "net" + "sync" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +// ExtClient - struct for external clients +type ExtClient struct { + ClientID string `json:"clientid" bson:"clientid"` + PrivateKey string `json:"privatekey" bson:"privatekey"` + PublicKey string `json:"publickey" bson:"publickey"` + Network string `json:"network" bson:"network"` + DNS string `json:"dns" bson:"dns"` + Address string `json:"address" bson:"address"` + Address6 string `json:"address6" bson:"address6"` + ExtraAllowedIPs []string `json:"extraallowedips" bson:"extraallowedips"` + AllowedIPs []string `json:"allowed_ips"` + IngressGatewayID string `json:"ingressgatewayid" bson:"ingressgatewayid"` + IngressGatewayEndpoint string `json:"ingressgatewayendpoint" bson:"ingressgatewayendpoint"` + LastModified int64 `json:"lastmodified" bson:"lastmodified" swaggertype:"primitive,integer" format:"int64"` + Enabled bool `json:"enabled" bson:"enabled"` + OwnerID string `json:"ownerid" bson:"ownerid"` + DeniedACLs map[string]struct{} `json:"deniednodeacls" bson:"acls,omitempty"` + RemoteAccessClientID string `json:"remote_access_client_id"` // unique ID (MAC address) of RAC machine + PostUp string `json:"postup" bson:"postup"` + PostDown string `json:"postdown" bson:"postdown"` + Tags map[TagID]struct{} `json:"tags"` + OS string `json:"os"` + OSFamily string `json:"os_family" yaml:"os_family"` + OSVersion string `json:"os_version" yaml:"os_version"` + KernelVersion string `json:"kernel_version" yaml:"kernel_version"` + ClientVersion string `json:"client_version"` + DeviceID string `json:"device_id"` + DeviceName string `json:"device_name"` + PublicEndpoint string `json:"public_endpoint"` + Country string `json:"country"` + Location string `json:"location"` //format: lat,long + PostureChecksViolations []Violation `json:"posture_check_violations"` + PostureCheckVolationSeverityLevel Severity `json:"posture_check_violation_severity_level"` + LastEvaluatedAt time.Time `json:"last_evaluated_at"` + JITExpiresAt *time.Time `json:"jit_expires_at,omitempty" bson:"jit_expires_at,omitempty"` // JIT grant expiry time (nil if JIT not enabled or user is admin) + Status NodeStatus `json:"status" bson:"status"` + Mutex *sync.Mutex `json:"-"` +} + +// AddressIPNet4 returns the IPv4 address of the ExtClient in IPNet format. +func (extPeer *ExtClient) AddressIPNet4() net.IPNet { + return net.IPNet{ + IP: net.ParseIP(extPeer.Address), + Mask: net.CIDRMask(32, 32), + } +} + +// AddressIPNet6 returns the IPv6 address of the ExtClient in IPNet format. +func (extPeer *ExtClient) AddressIPNet6() net.IPNet { + return net.IPNet{ + IP: net.ParseIP(extPeer.Address6), + Mask: net.CIDRMask(128, 128), + } +} + +// ExtClientEntry is the GORM model for the legacy "extclients" key-value table, +// extended with tenant_id and network_id columns for multi-tenancy. +type ExtClientEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + NetworkID string `gorm:"column:network_id"` + Value datatypes.JSONType[ExtClient] `gorm:"column:value"` +} + +func (*ExtClientEntry) TableName() string { return "extclients" } + +func (e *ExtClientEntry) Create(ctx context.Context) error { + return db.FromContext(ctx).Create(e).Error +} + +// Save does an upsert — insert or replace on primary key conflict. +func (e *ExtClientEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *ExtClientEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *ExtClientEntry) ListAll(ctx context.Context) ([]ExtClientEntry, error) { + var entries []ExtClientEntry + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (e *ExtClientEntry) ListByNetwork(ctx context.Context) ([]ExtClientEntry, error) { + var entries []ExtClientEntry + err := db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Find(&entries).Error + return entries, err +} + +func (e *ExtClientEntry) Count(ctx context.Context) (int64, error) { + var count int64 + err := db.FromContext(ctx).Model(&ExtClientEntry{}).Count(&count).Error + return count, err +} + +func (e *ExtClientEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&ExtClientEntry{}).Error +} + +func (e *ExtClientEntry) DeleteByNetwork(ctx context.Context) error { + return db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Delete(&ExtClientEntry{}).Error +} diff --git a/schema/metrics.go b/schema/metrics.go new file mode 100644 index 000000000..102ad9249 --- /dev/null +++ b/schema/metrics.go @@ -0,0 +1,83 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +// Metrics - metrics struct +type Metrics struct { + Network string `json:"network" bson:"network" yaml:"network"` + NodeID string `json:"node_id" bson:"node_id" yaml:"node_id"` + NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` + Connectivity map[string]Metric `json:"connectivity" bson:"connectivity" yaml:"connectivity"` + UpdatedAt time.Time `json:"updated_at" bson:"updated_at" yaml:"updated_at"` +} + +// Metric - holds a metric for data between nodes +type Metric struct { + NodeName string `json:"node_name" bson:"node_name" yaml:"node_name"` + Uptime int64 `json:"uptime" bson:"uptime" yaml:"uptime" swaggertype:"primitive,integer" format:"int64"` + TotalTime int64 `json:"totaltime" bson:"totaltime" yaml:"totaltime" swaggertype:"primitive,integer" format:"int64"` + Latency int64 `json:"latency" bson:"latency" yaml:"latency" swaggertype:"primitive,integer" format:"int64"` + TotalReceived int64 `json:"totalreceived" bson:"totalreceived" yaml:"totalreceived" swaggertype:"primitive,integer" format:"int64"` + LastTotalReceived int64 `json:"lasttotalreceived" bson:"lasttotalreceived" yaml:"lasttotalreceived" swaggertype:"primitive,integer" format:"int64"` + TotalSent int64 `json:"totalsent" bson:"totalsent" yaml:"totalsent" swaggertype:"primitive,integer" format:"int64"` + LastTotalSent int64 `json:"lasttotalsent" bson:"lasttotalsent" yaml:"lasttotalsent" swaggertype:"primitive,integer" format:"int64"` + ActualUptime time.Duration `json:"actualuptime" swaggertype:"primitive,integer" format:"int64" bson:"actualuptime" yaml:"actualuptime"` + PercentUp float64 `json:"percentup" bson:"percentup" yaml:"percentup"` + Connected bool `json:"connected" bson:"connected" yaml:"connected"` +} + +// MetricsEntry is the GORM model for the legacy "metrics" key-value table, +// extended with tenant_id and network_id columns for multi-tenancy. +type MetricsEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + NetworkID string `gorm:"column:network_id"` + Value datatypes.JSONType[Metrics] `gorm:"column:value"` +} + +func (*MetricsEntry) TableName() string { return "metrics" } + +func (e *MetricsEntry) Create(ctx context.Context) error { + return db.FromContext(ctx).Create(e).Error +} + +// Save does an upsert — insert or replace on primary key conflict. +func (e *MetricsEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *MetricsEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *MetricsEntry) ListAll(ctx context.Context) ([]MetricsEntry, error) { + var entries []MetricsEntry + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (e *MetricsEntry) ListByNetwork(ctx context.Context) ([]MetricsEntry, error) { + var entries []MetricsEntry + err := db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Find(&entries).Error + return entries, err +} + +func (e *MetricsEntry) Count(ctx context.Context) (int64, error) { + var count int64 + err := db.FromContext(ctx).Model(&MetricsEntry{}).Count(&count).Error + return count, err +} + +func (e *MetricsEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&MetricsEntry{}).Error +} + +func (e *MetricsEntry) DeleteByNetwork(ctx context.Context) error { + return db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Delete(&MetricsEntry{}).Error +} diff --git a/schema/models.go b/schema/models.go index d1de16e63..78c2f16da 100644 --- a/schema/models.go +++ b/schema/models.go @@ -26,5 +26,9 @@ func ListModels() []interface{} { &Integration{}, &EnrollmentKey{}, &DNS{}, + &ExtClientEntry{}, + &AclEntry{}, + &MetricsEntry{}, + &TagEntry{}, } } diff --git a/schema/nodes.go b/schema/nodes.go index f408ba665..a753a3f4a 100644 --- a/schema/nodes.go +++ b/schema/nodes.go @@ -35,6 +35,13 @@ const ( Disconnected NodeStatus = "disconnected" ) +// TagID is a string identifier for a node tag. +type TagID string + +func (id TagID) String() string { + return string(id) +} + type Node struct { ID string `gorm:"primaryKey" json:"id"` TenantID string `gorm:"default:'';index" json:"tenant_id"` diff --git a/schema/org_settings.go b/schema/org_settings.go index 83e92abd4..5e62faa1b 100644 --- a/schema/org_settings.go +++ b/schema/org_settings.go @@ -29,3 +29,7 @@ func (o *OrganizationSettings) Get(ctx context.Context) error { First(&o). Error } + +func (o *OrganizationSettings) Delete(ctx context.Context) error { + return db.FromContext(ctx).Delete(&o).Error +} diff --git a/schema/posture_check.go b/schema/posture_check.go index 838c82260..dbfbc5526 100644 --- a/schema/posture_check.go +++ b/schema/posture_check.go @@ -89,6 +89,15 @@ var PostureCheckAttrValues = map[Attribute][]string{ AutoUpdate: {"true", "false"}, } +// Violation represents a posture check violation for a node or ext client. +type Violation struct { + CheckID string `json:"check_id"` + Name string `json:"name"` + Attribute string `json:"attribute"` + Message string `json:"message"` + Severity Severity `json:"severity"` +} + type PostureCheck struct { ID string `gorm:"primaryKey" json:"id"` TenantID string `gorm:"default:'';index" json:"tenant_id"` diff --git a/schema/tag.go b/schema/tag.go new file mode 100644 index 000000000..7276551b4 --- /dev/null +++ b/schema/tag.go @@ -0,0 +1,78 @@ +package schema + +import ( + "context" + "fmt" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +const ( + OldRemoteAccessTagName = "remote-access-gws" + GwTagName = "gateways" +) + +type Tag struct { + ID TagID `json:"id"` + TagName string `json:"tag_name"` + Network NetworkID `json:"network"` + ColorCode string `json:"color_code"` + CreatedBy string `json:"created_by"` + CreatedAt time.Time `json:"created_at"` +} + +func (t Tag) GetIDFromName() string { + return fmt.Sprintf("%s.%s", t.Network, t.TagName) +} + +// TagEntry is the GORM model for the legacy "tags" key-value table, +// extended with tenant_id and network_id columns for multi-tenancy. +type TagEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + NetworkID string `gorm:"column:network_id"` + Value datatypes.JSONType[Tag] `gorm:"column:value"` +} + +func (*TagEntry) TableName() string { return "tags" } + +func (e *TagEntry) Create(ctx context.Context) error { + return db.FromContext(ctx).Create(e).Error +} + +// Save does an upsert — insert or replace on primary key conflict. +func (e *TagEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *TagEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *TagEntry) ListAll(ctx context.Context) ([]TagEntry, error) { + var entries []TagEntry + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (e *TagEntry) ListByNetwork(ctx context.Context) ([]TagEntry, error) { + var entries []TagEntry + err := db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Find(&entries).Error + return entries, err +} + +func (e *TagEntry) Count(ctx context.Context) (int64, error) { + var count int64 + err := db.FromContext(ctx).Model(&TagEntry{}).Count(&count).Error + return count, err +} + +func (e *TagEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&TagEntry{}).Error +} + +func (e *TagEntry) DeleteByNetwork(ctx context.Context) error { + return db.FromContext(ctx).Where("network_id = ?", e.NetworkID).Delete(&TagEntry{}).Error +} diff --git a/test/utils/tag.go b/test/utils/tag.go index 27441e70c..056a50610 100644 --- a/test/utils/tag.go +++ b/test/utils/tag.go @@ -1,19 +1,19 @@ package utils import ( + "context" "testing" "time" - "github.com/gravitl/netmaker/database" - "github.com/gravitl/netmaker/models" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/pro/logic" "github.com/gravitl/netmaker/schema" "github.com/stretchr/testify/require" ) -func CreateTag(t *testing.T, tagID, network string) *models.Tag { - tag := models.Tag{ - ID: models.TagID(tagID), +func CreateTag(t *testing.T, tagID, network string) *schema.Tag { + tag := schema.Tag{ + ID: schema.TagID(tagID), TagName: tagID, Network: schema.NetworkID(network), CreatedAt: time.Now(), @@ -24,7 +24,7 @@ func CreateTag(t *testing.T, tagID, network string) *models.Tag { return &tag } -func DeleteTag(t *testing.T, tag *models.Tag) { - err := database.DeleteRecord(database.TAG_TABLE_NAME, tag.ID.String()) +func DeleteTag(t *testing.T, tag *schema.Tag) { + err := (&schema.TagEntry{Key: tag.ID.String()}).Delete(db.WithContext(context.TODO())) require.NoError(t, err) } From 1750b0f4ca8b6b706eb1c6ad46c99df83217fad1 Mon Sep 17 00:00:00 2001 From: VishalDalwadi Date: Mon, 22 Jun 2026 17:50:20 +0530 Subject: [PATCH 21/21] wip(go): move cache db management to schema pkg; --- auth/host_session.go | 4 +-- database/database.go | 6 ---- logic/auth.go | 49 ++++++++++------------------ logic/pro/netcache/netcache.go | 45 +++++++++----------------- logic/settings.go | 58 ++++++++++++---------------------- models/ssocache.go | 17 ---------- mq/serversync.go | 1 - pro/auth/auth.go | 3 +- pro/auth/register_callback.go | 6 ++-- schema/cache.go | 39 +++++++++++++++++++++++ schema/models.go | 2 ++ schema/sso_state.go | 47 +++++++++++++++++++++++++++ 12 files changed, 147 insertions(+), 130 deletions(-) create mode 100644 schema/cache.go create mode 100644 schema/sso_state.go diff --git a/auth/host_session.go b/auth/host_session.go index 8bd5d81bf..c7ee217d1 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -44,7 +44,7 @@ func SessionHandler(conn *websocket.Conn) { return } - req := new(netcache.CValue) + req := new(schema.CacheValue) req.Value = string(registerMessage.RegisterHost.ID.String()) req.Network = registerMessage.Network req.Host = registerMessage.RegisterHost @@ -64,7 +64,7 @@ func SessionHandler(conn *websocket.Conn) { defer netcache.Del(stateStr) // Wait for the user to finish his auth flow... timeout := make(chan bool, 2) - answer := make(chan netcache.CValue, 1) + answer := make(chan schema.CacheValue, 1) defer close(answer) defer close(timeout) if len(registerMessage.User) > 0 { // handle basic auth diff --git a/database/database.go b/database/database.go index 28b0a3194..7f3b53d81 100644 --- a/database/database.go +++ b/database/database.go @@ -12,10 +12,6 @@ const ( // == Table Names == // ACLS_TABLE_NAME - table for acls v2 ACLS_TABLE_NAME = "acls" - // SSO_STATE_CACHE - holds sso session information for OAuth2 sign-ins - SSO_STATE_CACHE = "ssostatecache" - // CACHE_TABLE_NAME - caching table - CACHE_TABLE_NAME = "cache" // SERVER_SETTINGS - table for server settings SERVER_SETTINGS = "server_settings" // == ERROR CONSTS == @@ -46,8 +42,6 @@ const ( ) var Tables = []string{ - SSO_STATE_CACHE, - CACHE_TABLE_NAME, ACLS_TABLE_NAME, SERVER_SETTINGS, } diff --git a/logic/auth.go b/logic/auth.go index 8a094a37e..7d3f30479 100644 --- a/logic/auth.go +++ b/logic/auth.go @@ -3,7 +3,6 @@ package logic import ( "context" "encoding/base64" - "encoding/json" "errors" "fmt" "net/mail" @@ -18,7 +17,6 @@ import ( "golang.org/x/crypto/bcrypt" "golang.org/x/exp/slog" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/models" ) @@ -492,38 +490,30 @@ func FetchOAuthSecret() (string, error) { } // GetState - gets an SsoState from DB, if expired returns error -func GetState(state string) (*models.SsoState, error) { - var s models.SsoState - record, err := database.FetchRecord(database.SSO_STATE_CACHE, state) - if err != nil { - return &s, err - } - - if err = json.Unmarshal([]byte(record), &s); err != nil { - return &s, err +func GetState(state string) (*schema.SsoState, error) { + entry := &schema.SsoStateEntry{Key: state} + if err := entry.Get(db.WithContext(context.TODO())); err != nil { + return nil, err } - + s := entry.Value.Data() if s.IsExpired() { return &s, fmt.Errorf("state expired") } - return &s, nil } // SetState - sets a state with new expiration func SetState(appName, state string) error { - s := models.SsoState{ + s := schema.SsoState{ AppName: appName, Value: state, - Expiration: time.Now().Add(models.DefaultExpDuration), + Expiration: time.Now().Add(schema.DefaultSsoStateDuration), } - - data, err := json.Marshal(&s) - if err != nil { - return err + entry := &schema.SsoStateEntry{ + Key: state, + Value: datatypes.NewJSONType(s), } - - return database.Insert(state, string(data), database.SSO_STATE_CACHE) + return entry.Save(db.WithContext(context.TODO())) } // IsStateValid - checks if given state is valid or not @@ -545,27 +535,20 @@ func IsStateValid(state string) (string, bool) { // delState - removes a state from cache/db func delState(state string) error { - return database.DeleteRecord(database.SSO_STATE_CACHE, state) + return (&schema.SsoStateEntry{Key: state}).Delete(db.WithContext(context.TODO())) } // CleanExpiredSSOStates removes expired SSO state entries from the database // to prevent unbounded table growth that degrades FetchRecord performance. func CleanExpiredSSOStates() error { - records, err := database.FetchRecords(database.SSO_STATE_CACHE) + entries, err := (&schema.SsoStateEntry{}).ListAll(db.WithContext(context.TODO())) if err != nil { - if database.IsEmptyRecord(err) { - return nil - } return err } - for key, value := range records { - var s models.SsoState - if err := json.Unmarshal([]byte(value), &s); err != nil { - _ = database.DeleteRecord(database.SSO_STATE_CACHE, key) - continue - } + for _, entry := range entries { + s := entry.Value.Data() if s.IsExpired() { - _ = database.DeleteRecord(database.SSO_STATE_CACHE, key) + _ = (&schema.SsoStateEntry{Key: entry.Key}).Delete(db.WithContext(context.TODO())) } } return nil diff --git a/logic/pro/netcache/netcache.go b/logic/pro/netcache/netcache.go index 2e7f5e963..de588e167 100644 --- a/logic/pro/netcache/netcache.go +++ b/logic/pro/netcache/netcache.go @@ -1,60 +1,45 @@ package netcache import ( - "encoding/json" + "context" "fmt" "time" - "github.com/gravitl/netmaker/database" + "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/schema" + "gorm.io/datatypes" ) const ( expirationTime = time.Minute * 5 ) -// CValue - the cache object for a network -type CValue struct { - Network string `json:"network,omitempty"` - Value string `json:"value"` - Host schema.Host `json:"host"` - Pass string `json:"pass,omitempty"` - User string `json:"user,omitempty"` - ALL bool `json:"all,omitempty"` - Expiration time.Time `json:"expiration"` -} - var ErrExpired = fmt.Errorf("expired") // Set - sets a value to a key in db -func Set(k string, newValue *CValue) error { +func Set(k string, newValue *schema.CacheValue) error { newValue.Expiration = time.Now().Add(expirationTime) - newData, err := json.Marshal(newValue) - if err != nil { - return err + entry := &schema.CacheEntry{ + Key: k, + Value: datatypes.NewJSONType(*newValue), } - - return database.Insert(k, string(newData), database.CACHE_TABLE_NAME) + return entry.Save(db.WithContext(context.TODO())) } // Get - gets a value from db, if expired, return err -func Get(k string) (*CValue, error) { - record, err := database.FetchRecord(database.CACHE_TABLE_NAME, k) - if err != nil { +func Get(k string) (*schema.CacheValue, error) { + entry := &schema.CacheEntry{Key: k} + if err := entry.Get(db.WithContext(context.TODO())); err != nil { return nil, err } - var entry CValue - if err := json.Unmarshal([]byte(record), &entry); err != nil { - return nil, err - } - if time.Now().After(entry.Expiration) { + v := entry.Value.Data() + if time.Now().After(v.Expiration) { return nil, ErrExpired } - - return &entry, nil + return &v, nil } // Del - deletes a value from db func Del(k string) error { - return database.DeleteRecord(database.CACHE_TABLE_NAME, k) + return (&schema.CacheEntry{Key: k}).Delete(db.WithContext(context.TODO())) } diff --git a/logic/settings.go b/logic/settings.go index cc489dc21..6da349e6a 100644 --- a/logic/settings.go +++ b/logic/settings.go @@ -2,7 +2,6 @@ package logic import ( "context" - "encoding/json" "errors" "fmt" "os" @@ -10,15 +9,15 @@ import ( "strconv" "strings" "sync" - "sync/atomic" "time" "github.com/gravitl/netmaker/config" - "github.com/gravitl/netmaker/database" "github.com/gravitl/netmaker/db" "github.com/gravitl/netmaker/models" "github.com/gravitl/netmaker/schema" + "github.com/gravitl/netmaker/scope" "github.com/gravitl/netmaker/servercfg" + "gorm.io/datatypes" ) var ( @@ -27,11 +26,6 @@ var ( ErrInvalidIPDetectionInterval = errors.New("invalid ip detection interval (must be greater than or equal to 15s)") ) -var ServerSettingsDBKey = "server_cfg" -var SettingsMutex = &sync.RWMutex{} - -var serverSettingsCache atomic.Pointer[schema.ServerSettingsData] - var defaultUserSettings = models.UserSettings{ TextSize: "16", Theme: schema.Dark, @@ -39,37 +33,27 @@ var defaultUserSettings = models.UserSettings{ } func GetServerSettings() (s schema.ServerSettingsData) { - if cached := serverSettingsCache.Load(); cached != nil { - return *cached - } - s, err := getServerSettingsFromDB() - if err == nil { - serverSettingsCache.Store(&s) - } + s, _ = getServerSettingsFromDB() return } -// InvalidateServerSettingsCache clears the in-memory settings cache so -// the next GetServerSettings call re-reads from the database. -func InvalidateServerSettingsCache() { - serverSettingsCache.Store((*schema.ServerSettingsData)(nil)) -} - -func getServerSettingsFromDB() (schema.ServerSettingsData, error) { - var s schema.ServerSettingsData - data, err := database.FetchRecord(database.SERVER_SETTINGS, ServerSettingsDBKey) - if err != nil { - return s, err +func getServerSettingsFromDB(ctx context.Context) (schema.ServerSettingsData, error) { + settings := &schema.ServerSettings{ + Key: scope.ID(ctx), } - if err := json.Unmarshal([]byte(data), &s); err != nil { - return s, err + err := settings.Get(ctx) + if err != nil { + return schema.ServerSettingsData{}, err } - return s, nil + + return settings.Value.Data(), nil } -func UpsertServerSettings(s schema.ServerSettingsData) error { - // get curr settings from DB directly (not cache) for accurate comparison - currSettings, _ := getServerSettingsFromDB() +func UpsertServerSettings(ctx context.Context, s schema.ServerSettingsData) error { + currSettings, err := getServerSettingsFromDB(ctx) + if err != nil { + return err + } if s.ClientSecret == Mask() { s.ClientSecret = currSettings.ClientSecret } @@ -101,15 +85,15 @@ func UpsertServerSettings(s schema.ServerSettingsData) error { } } s.GroupFilters = groupFilters - data, err := json.Marshal(s) - if err != nil { - return err + settings := &schema.ServerSettings{ + Key: scope.ID(ctx), + Value: datatypes.NewJSONType(s), } - err = database.Insert(ServerSettingsDBKey, string(data), database.SERVER_SETTINGS) + err = settings.Upsert(ctx) if err != nil { return err } - serverSettingsCache.Store(&s) + if PublishServerSync != nil { PublishServerSync(SyncTypeSettings) } diff --git a/models/ssocache.go b/models/ssocache.go index e71f1be90..2640e7f93 100644 --- a/models/ssocache.go +++ b/models/ssocache.go @@ -1,18 +1 @@ package models - -import "time" - -// DefaultExpDuration - the default expiration time of SsoState -const DefaultExpDuration = time.Minute * 5 - -// SsoState - holds SSO sign-in session data -type SsoState struct { - AppName string `json:"app_name"` - Value string `json:"value"` - Expiration time.Time `json:"expiration"` -} - -// SsoState.IsExpired - tells if an SsoState is expired or not -func (s *SsoState) IsExpired() bool { - return time.Now().After(s.Expiration) -} diff --git a/mq/serversync.go b/mq/serversync.go index 9d2bd443c..afc096b5b 100644 --- a/mq/serversync.go +++ b/mq/serversync.go @@ -59,7 +59,6 @@ func handleServerSync(_ mqtt.Client, msg mqtt.Message) { switch syncMsg.SyncType { case logic.SyncTypeSettings: oldInterval := logic.GetMetricInterval() - logic.InvalidateServerSettingsCache() if logic.GetMetricInterval() != oldInterval { logic.NotifyMetricExportIntervalChanged() } diff --git a/pro/auth/auth.go b/pro/auth/auth.go index 26d4f4999..89e5e1ce4 100644 --- a/pro/auth/auth.go +++ b/pro/auth/auth.go @@ -14,6 +14,7 @@ import ( "github.com/gravitl/netmaker/logger" "github.com/gravitl/netmaker/logic" "github.com/gravitl/netmaker/logic/pro/netcache" + "github.com/gravitl/netmaker/schema" "github.com/gravitl/netmaker/servercfg" "golang.org/x/oauth2" ) @@ -211,7 +212,7 @@ func HandleHeadlessSSO(w http.ResponseWriter, r *http.Request) { } defer conn.Close() - req := &netcache.CValue{User: "", Pass: ""} + req := &schema.CacheValue{User: "", Pass: ""} stateStr := logic.RandomString(headless_signin_length) if err = netcache.Set(stateStr, req); err != nil { logger.Log(0, "Failed to process sso request -", err.Error()) diff --git a/pro/auth/register_callback.go b/pro/auth/register_callback.go index 3b991705a..acfeebd4c 100644 --- a/pro/auth/register_callback.go +++ b/pro/auth/register_callback.go @@ -55,7 +55,7 @@ func HandleHostSSOCallback(w http.ResponseWriter, r *http.Request) { reqKeyIf, machineKeyFoundErr := netcache.Get(state) if machineKeyFoundErr != nil { logger.Log(0, "requested machine state key expired before authorisation completed -", machineKeyFoundErr.Error()) - reqKeyIf = &netcache.CValue{ + reqKeyIf = &schema.CacheValue{ Network: "invalid", Value: state, Pass: "", @@ -104,7 +104,7 @@ func HandleHostSSOCallback(w http.ResponseWriter, r *http.Request) { } } -func setNetcache(ncache *netcache.CValue, state string) error { +func setNetcache(ncache *schema.CacheValue, state string) error { if ncache == nil { return fmt.Errorf("cache miss") } @@ -115,7 +115,7 @@ func setNetcache(ncache *netcache.CValue, state string) error { return err } -func returnErrTemplate(uname, message, state string, ncache *netcache.CValue) []byte { +func returnErrTemplate(uname, message, state string, ncache *schema.CacheValue) []byte { var response bytes.Buffer if ncache != nil { ncache.Pass = message diff --git a/schema/cache.go b/schema/cache.go new file mode 100644 index 000000000..bc23ab016 --- /dev/null +++ b/schema/cache.go @@ -0,0 +1,39 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +type CacheValue struct { + Network string `json:"network,omitempty"` + Value string `json:"value"` + Host Host `json:"host"` + Pass string `json:"pass,omitempty"` + User string `json:"user,omitempty"` + ALL bool `json:"all,omitempty"` + Expiration time.Time `json:"expiration"` +} + +type CacheEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + Value datatypes.JSONType[CacheValue] `gorm:"column:value"` +} + +func (*CacheEntry) TableName() string { return "cache" } + +func (e *CacheEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *CacheEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *CacheEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&CacheEntry{}).Error +} diff --git a/schema/models.go b/schema/models.go index 78c2f16da..df1089c28 100644 --- a/schema/models.go +++ b/schema/models.go @@ -30,5 +30,7 @@ func ListModels() []interface{} { &AclEntry{}, &MetricsEntry{}, &TagEntry{}, + &SsoStateEntry{}, + &CacheEntry{}, } } diff --git a/schema/sso_state.go b/schema/sso_state.go new file mode 100644 index 000000000..744a90539 --- /dev/null +++ b/schema/sso_state.go @@ -0,0 +1,47 @@ +package schema + +import ( + "context" + "time" + + "github.com/gravitl/netmaker/db" + "gorm.io/datatypes" +) + +const DefaultSsoStateDuration = time.Minute * 5 + +type SsoState struct { + AppName string `json:"app_name"` + Value string `json:"value"` + Expiration time.Time `json:"expiration"` +} + +func (s *SsoState) IsExpired() bool { + return time.Now().After(s.Expiration) +} + +type SsoStateEntry struct { + Key string `gorm:"primaryKey;column:key"` + TenantID string `gorm:"column:tenant_id;default:''"` + Value datatypes.JSONType[SsoState] `gorm:"column:value"` +} + +func (*SsoStateEntry) TableName() string { return "ssostatecache" } + +func (e *SsoStateEntry) Save(ctx context.Context) error { + return db.FromContext(ctx).Save(e).Error +} + +func (e *SsoStateEntry) Get(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).First(e).Error +} + +func (e *SsoStateEntry) ListAll(ctx context.Context) ([]SsoStateEntry, error) { + var entries []SsoStateEntry + err := db.FromContext(ctx).Find(&entries).Error + return entries, err +} + +func (e *SsoStateEntry) Delete(ctx context.Context) error { + return db.FromContext(ctx).Where("key = ?", e.Key).Delete(&SsoStateEntry{}).Error +}