mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-04-30 01:18:12 +00:00
wip
This commit is contained in:
@@ -1,10 +1,13 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -33,7 +36,8 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type ContextMiddlewareConfig struct {
|
type ContextMiddlewareConfig struct {
|
||||||
CookieDomain string
|
CookieDomain string
|
||||||
|
SessionCookieName string
|
||||||
}
|
}
|
||||||
|
|
||||||
type ContextMiddleware struct {
|
type ContextMiddleware struct {
|
||||||
@@ -61,194 +65,42 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
cookie, err := m.auth.GetSessionCookie(c)
|
uuid, err := c.Cookie(m.config.SessionCookieName)
|
||||||
|
|
||||||
if err != nil {
|
if err == nil {
|
||||||
tlog.App.Debug().Err(err).Msg("No valid session cookie found")
|
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
|
||||||
goto basic
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie.TotpPending {
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: "local",
|
|
||||||
TotpPending: true,
|
|
||||||
TotpEnabled: true,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
switch cookie.Provider {
|
|
||||||
case "local", "ldap":
|
|
||||||
userSearch := m.auth.SearchUser(cookie.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "unknown" {
|
|
||||||
tlog.App.Debug().Msg("User from session cookie not found")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
goto basic
|
|
||||||
}
|
|
||||||
|
|
||||||
if userSearch.Type != cookie.Provider {
|
|
||||||
tlog.App.Warn().Msg("User type from session cookie does not match user search type")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var ldapGroups []string
|
|
||||||
var localAttributes config.UserAttributes
|
|
||||||
|
|
||||||
if cookie.Provider == "ldap" {
|
|
||||||
ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ldapGroups = ldapUser.Groups
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie.Provider == "local" {
|
|
||||||
localUser := m.auth.GetLocalUser(cookie.Username)
|
|
||||||
localAttributes = localUser.Attributes
|
|
||||||
}
|
|
||||||
|
|
||||||
m.auth.RefreshSessionCookie(c)
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
LdapGroups: strings.Join(ldapGroups, ","),
|
|
||||||
Attributes: localAttributes,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
_, exists := m.broker.GetService(cookie.Provider)
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
tlog.App.Debug().Msg("OAuth provider from session cookie not found")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
goto basic
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.auth.IsEmailWhitelisted(cookie.Email) {
|
|
||||||
tlog.App.Debug().Msg("Email from session cookie not whitelisted")
|
|
||||||
m.auth.DeleteSessionCookie(c)
|
|
||||||
goto basic
|
|
||||||
}
|
|
||||||
|
|
||||||
m.auth.RefreshSessionCookie(c)
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
OAuthGroups: cookie.OAuthGroups,
|
|
||||||
OAuthName: cookie.OAuthName,
|
|
||||||
OAuthSub: cookie.OAuthSub,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
OAuth: true,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
basic:
|
|
||||||
basic := m.auth.GetBasicAuth(c)
|
|
||||||
|
|
||||||
if basic == nil {
|
|
||||||
tlog.App.Debug().Msg("No basic auth provided")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
|
||||||
|
|
||||||
if locked {
|
|
||||||
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
|
|
||||||
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
userSearch := m.auth.SearchUser(basic.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "unknown" || userSearch.Type == "error" {
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
|
||||||
tlog.App.Debug().Msg("User from basic auth not found")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if !m.auth.VerifyUser(userSearch, basic.Password) {
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, false)
|
|
||||||
tlog.App.Debug().Msg("Invalid password for basic auth user")
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
m.auth.RecordLoginAttempt(basic.Username, true)
|
|
||||||
|
|
||||||
switch userSearch.Type {
|
|
||||||
case "local":
|
|
||||||
tlog.App.Debug().Msg("Basic auth user is local")
|
|
||||||
|
|
||||||
user := m.auth.GetLocalUser(basic.Username)
|
|
||||||
|
|
||||||
if user.TotpSecret != "" {
|
|
||||||
tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
name := utils.Capitalize(user.Username)
|
|
||||||
if user.Attributes.Name != "" {
|
|
||||||
name = user.Attributes.Name
|
|
||||||
}
|
|
||||||
email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
|
||||||
if user.Attributes.Email != "" {
|
|
||||||
email = user.Attributes.Email
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
|
||||||
Username: user.Username,
|
|
||||||
Name: name,
|
|
||||||
Email: email,
|
|
||||||
Provider: "local",
|
|
||||||
IsLoggedIn: true,
|
|
||||||
IsBasicAuth: true,
|
|
||||||
Attributes: user.Attributes,
|
|
||||||
})
|
|
||||||
c.Next()
|
|
||||||
return
|
|
||||||
case "ldap":
|
|
||||||
tlog.App.Debug().Msg("Basic auth user is LDAP")
|
|
||||||
|
|
||||||
ldapUser, err := m.auth.GetLdapUser(basic.Username)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
|
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Set("context", &config.UserContext{
|
if cookie != nil {
|
||||||
Username: basic.Username,
|
http.SetCookie(c.Writer, cookie)
|
||||||
Name: utils.Capitalize(basic.Username),
|
}
|
||||||
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
|
||||||
Provider: "ldap",
|
c.Set("context", userContext)
|
||||||
IsLoggedIn: true,
|
c.Next()
|
||||||
LdapGroups: strings.Join(ldapUser.Groups, ","),
|
return
|
||||||
IsBasicAuth: true,
|
}
|
||||||
})
|
|
||||||
|
basic, err := m.auth.GetBasicAuth(c.Request)
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
userContext, headers, err := m.basicAuth(c.Request.Context(), basic)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err)
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range headers {
|
||||||
|
c.Header(k, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set("context", userContext)
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -257,6 +109,150 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
|
||||||
|
session, err := m.auth.GetSession(ctx, uuid)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, err := new(model.UserContext).NewFromSession(session)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.Provider == model.ProviderLocal &&
|
||||||
|
userContext.Local.TOTPPending {
|
||||||
|
userContext.Local.TOTPEnabled = true
|
||||||
|
return userContext, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch userContext.Provider {
|
||||||
|
case model.ProviderLocal:
|
||||||
|
user := m.auth.GetLocalUser(userContext.Local.Username)
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
return nil, nil, fmt.Errorf("local user not found")
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.Local.Attributes = user.Attributes
|
||||||
|
|
||||||
|
if userContext.Local.Attributes.Name == "" {
|
||||||
|
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.Local.Attributes.Email == "" {
|
||||||
|
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
|
||||||
|
}
|
||||||
|
case model.ProviderLDAP:
|
||||||
|
search, err := m.auth.SearchUser(userContext.LDAP.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if search.Type != model.UserLDAP {
|
||||||
|
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
|
||||||
|
}
|
||||||
|
|
||||||
|
user, err := m.auth.GetLDAPUser(search.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.LDAP.Groups = user.Groups
|
||||||
|
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
|
||||||
|
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
|
||||||
|
case model.ProviderOAuth:
|
||||||
|
_, exists := m.broker.GetService(userContext.OAuth.ID)
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) {
|
||||||
|
m.auth.DeleteSession(ctx, uuid)
|
||||||
|
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie, err := m.auth.RefreshSession(ctx, uuid)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error refreshing session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return userContext, cookie, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) {
|
||||||
|
headers := make(map[string]string)
|
||||||
|
userContext := new(model.UserContext)
|
||||||
|
locked, remaining := m.auth.IsAccountLocked(basic.Username)
|
||||||
|
|
||||||
|
if locked {
|
||||||
|
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
|
||||||
|
headers["x-tinyauth-lock-locked"] = "true"
|
||||||
|
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339)
|
||||||
|
return nil, headers, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
search, err := m.auth.SearchUser(basic.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error searching for user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = m.auth.CheckUserPassword(*search, basic.Password)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
m.auth.RecordLoginAttempt(basic.Username, false)
|
||||||
|
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.auth.RecordLoginAttempt(basic.Username, true)
|
||||||
|
|
||||||
|
switch search.Type {
|
||||||
|
case model.UserLocal:
|
||||||
|
user := m.auth.GetLocalUser(basic.Username)
|
||||||
|
|
||||||
|
if user.TOTPSecret != "" {
|
||||||
|
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.Local = &model.LocalContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: user.Username,
|
||||||
|
Name: utils.Capitalize(user.Username),
|
||||||
|
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
|
||||||
|
},
|
||||||
|
Attributes: user.Attributes,
|
||||||
|
}
|
||||||
|
userContext.Provider = model.ProviderLocal
|
||||||
|
case model.UserLDAP:
|
||||||
|
user, err := m.auth.GetLDAPUser(basic.Username)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.LDAP = &model.LDAPContext{
|
||||||
|
BaseContext: model.BaseContext{
|
||||||
|
Username: basic.Username,
|
||||||
|
Name: utils.Capitalize(basic.Username),
|
||||||
|
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
|
||||||
|
},
|
||||||
|
Groups: user.Groups,
|
||||||
|
}
|
||||||
|
userContext.Provider = model.ProviderLDAP
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext.Authenticated = true
|
||||||
|
return userContext, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *ContextMiddleware) isIgnorePath(path string) bool {
|
func (m *ContextMiddleware) isIgnorePath(path string) bool {
|
||||||
for _, prefix := range contextSkipPathsPrefix {
|
for _, prefix := range contextSkipPathsPrefix {
|
||||||
if strings.HasPrefix(path, prefix) {
|
if strings.HasPrefix(path, prefix) {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package config
|
package model
|
||||||
|
|
||||||
// Default configuration
|
// Default configuration
|
||||||
func NewDefaultConfiguration() *Config {
|
func NewDefaultConfiguration() *Config {
|
||||||
@@ -29,7 +29,7 @@ func NewDefaultConfiguration() *Config {
|
|||||||
BackgroundImage: "/background.jpg",
|
BackgroundImage: "/background.jpg",
|
||||||
WarningsEnabled: true,
|
WarningsEnabled: true,
|
||||||
},
|
},
|
||||||
Ldap: LdapConfig{
|
LDAP: LDAPConfig{
|
||||||
Insecure: false,
|
Insecure: false,
|
||||||
SearchFilter: "(uid=%s)",
|
SearchFilter: "(uid=%s)",
|
||||||
GroupCacheTTL: 900, // 15 minutes
|
GroupCacheTTL: 900, // 15 minutes
|
||||||
@@ -63,20 +63,6 @@ func NewDefaultConfiguration() *Config {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version information, set at build time
|
|
||||||
|
|
||||||
var Version = "development"
|
|
||||||
var CommitHash = "development"
|
|
||||||
var BuildTimestamp = "0000-00-00T00:00:00Z"
|
|
||||||
|
|
||||||
// Cookie name templates
|
|
||||||
|
|
||||||
var SessionCookieName = "tinyauth-session"
|
|
||||||
var CSRFCookieName = "tinyauth-csrf"
|
|
||||||
var RedirectCookieName = "tinyauth-redirect"
|
|
||||||
var OAuthSessionCookieName = "tinyauth-oauth"
|
|
||||||
|
|
||||||
// Main app config
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
|
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
|
||||||
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
|
Database DatabaseConfig `description:"Database configuration." yaml:"database"`
|
||||||
@@ -88,7 +74,7 @@ type Config struct {
|
|||||||
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
|
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
|
||||||
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
|
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
|
||||||
UI UIConfig `description:"UI customization." yaml:"ui"`
|
UI UIConfig `description:"UI customization." yaml:"ui"`
|
||||||
Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
|
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"`
|
||||||
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
|
||||||
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
|
||||||
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
Log LogConfig `description:"Logging configuration." yaml:"log"`
|
||||||
@@ -177,7 +163,7 @@ type UIConfig struct {
|
|||||||
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
|
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type LdapConfig struct {
|
type LDAPConfig struct {
|
||||||
Address string `description:"LDAP server address." yaml:"address"`
|
Address string `description:"LDAP server address." yaml:"address"`
|
||||||
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
|
||||||
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
|
||||||
@@ -210,20 +196,6 @@ type ExperimentalConfig struct {
|
|||||||
ConfigFile string `description:"Path to config file." yaml:"-"`
|
ConfigFile string `description:"Path to config file." yaml:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config loader options
|
|
||||||
|
|
||||||
const DefaultNamePrefix = "TINYAUTH_"
|
|
||||||
|
|
||||||
// OAuth/OIDC config
|
|
||||||
|
|
||||||
type Claims struct {
|
|
||||||
Sub string `json:"sub"`
|
|
||||||
Name string `json:"name"`
|
|
||||||
Email string `json:"email"`
|
|
||||||
PreferredUsername string `json:"preferred_username"`
|
|
||||||
Groups any `json:"groups"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type OAuthServiceConfig struct {
|
type OAuthServiceConfig struct {
|
||||||
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
ClientID string `description:"OAuth client ID." yaml:"clientId"`
|
||||||
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
|
||||||
@@ -246,47 +218,6 @@ type OIDCClientConfig struct {
|
|||||||
Name string `description:"Client name in UI." yaml:"name"`
|
Name string `description:"Client name in UI." yaml:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
var OverrideProviders = map[string]string{
|
|
||||||
"google": "Google",
|
|
||||||
"github": "GitHub",
|
|
||||||
}
|
|
||||||
|
|
||||||
// User/session related stuff
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
Username string
|
|
||||||
Password string
|
|
||||||
TotpSecret string
|
|
||||||
Attributes UserAttributes
|
|
||||||
}
|
|
||||||
|
|
||||||
type LdapUser struct {
|
|
||||||
DN string
|
|
||||||
Groups []string
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserSearch struct {
|
|
||||||
Username string
|
|
||||||
Type string // local, ldap or unknown
|
|
||||||
}
|
|
||||||
|
|
||||||
type UserContext struct {
|
|
||||||
Username string
|
|
||||||
Name string
|
|
||||||
Email string
|
|
||||||
IsLoggedIn bool
|
|
||||||
IsBasicAuth bool
|
|
||||||
OAuth bool
|
|
||||||
Provider string
|
|
||||||
TotpPending bool
|
|
||||||
OAuthGroups string
|
|
||||||
TotpEnabled bool
|
|
||||||
OAuthName string
|
|
||||||
OAuthSub string
|
|
||||||
LdapGroups string
|
|
||||||
Attributes UserAttributes
|
|
||||||
}
|
|
||||||
|
|
||||||
// API responses and queries
|
// API responses and queries
|
||||||
|
|
||||||
type UnauthorizedQuery struct {
|
type UnauthorizedQuery struct {
|
||||||
@@ -355,7 +286,3 @@ type AppPath struct {
|
|||||||
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
|
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
|
||||||
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
|
Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// API server
|
|
||||||
|
|
||||||
var ApiServer = "https://api.tinyauth.app"
|
|
||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
const DefaultNamePrefix = "TINYAUTH_"
|
||||||
|
|
||||||
|
const APIServer = "https://api.tinyauth.app"
|
||||||
|
|
||||||
|
type Claims struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PreferredUsername string `json:"preferred_username"`
|
||||||
|
Groups any `json:"groups"`
|
||||||
|
}
|
||||||
|
|
||||||
|
var OverrideProviders = map[string]string{
|
||||||
|
"google": "Google",
|
||||||
|
"github": "GitHub",
|
||||||
|
}
|
||||||
|
|
||||||
|
const SessionCookieName = "tinyauth-session"
|
||||||
|
const CSRFCookieName = "tinyauth-csrf"
|
||||||
|
const RedirectCookieName = "tinyauth-redirect"
|
||||||
|
const OAuthSessionCookieName = "tinyauth-oauth"
|
||||||
@@ -0,0 +1,179 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ProviderType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
ProviderLocal ProviderType = iota
|
||||||
|
ProviderBasicAuth
|
||||||
|
ProviderOAuth
|
||||||
|
ProviderLDAP
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserContext struct {
|
||||||
|
Authenticated bool
|
||||||
|
Provider ProviderType
|
||||||
|
Local *LocalContext
|
||||||
|
OAuth *OAuthContext
|
||||||
|
LDAP *LDAPContext
|
||||||
|
}
|
||||||
|
|
||||||
|
type BaseContext struct {
|
||||||
|
Username string
|
||||||
|
Name string
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalContext struct {
|
||||||
|
BaseContext
|
||||||
|
TOTPPending bool
|
||||||
|
TOTPEnabled bool
|
||||||
|
Attributes UserAttributes
|
||||||
|
}
|
||||||
|
|
||||||
|
type OAuthContext struct {
|
||||||
|
BaseContext
|
||||||
|
Groups []string
|
||||||
|
Sub string
|
||||||
|
DisplayName string
|
||||||
|
ID string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LDAPContext struct {
|
||||||
|
BaseContext
|
||||||
|
Groups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsAuthenticated() bool {
|
||||||
|
return c.Authenticated
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsLocal() bool {
|
||||||
|
return c.Provider == ProviderLocal
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsOAuth() bool {
|
||||||
|
return c.Provider == ProviderOAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsLDAP() bool {
|
||||||
|
return c.Provider == ProviderLDAP
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) IsBasicAuth() bool {
|
||||||
|
return c.Provider == ProviderBasicAuth
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||||
|
userContextValue, exists := ginctx.Get("context")
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return nil, errors.New("failed to get user context")
|
||||||
|
}
|
||||||
|
|
||||||
|
userContext, ok := userContextValue.(*UserContext)
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("invalid user context type")
|
||||||
|
}
|
||||||
|
|
||||||
|
*c = *userContext
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compatability layer until we get an excuse to drop in database migrations
|
||||||
|
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
|
||||||
|
switch session.Provider {
|
||||||
|
case "local":
|
||||||
|
c.Provider = ProviderLocal
|
||||||
|
c.Local = &LocalContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
TOTPPending: session.TotpPending,
|
||||||
|
}
|
||||||
|
case "ldap":
|
||||||
|
c.Provider = ProviderLDAP
|
||||||
|
c.LDAP = &LDAPContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
// By default we assume an unkown name which is oauth
|
||||||
|
default:
|
||||||
|
c.Provider = ProviderOAuth
|
||||||
|
c.OAuth = &OAuthContext{
|
||||||
|
BaseContext: BaseContext{
|
||||||
|
Username: session.Username,
|
||||||
|
Name: session.Name,
|
||||||
|
Email: session.Email,
|
||||||
|
},
|
||||||
|
Groups: strings.Split(session.OAuthGroups, ","),
|
||||||
|
Sub: session.OAuthSub,
|
||||||
|
DisplayName: session.OAuthName,
|
||||||
|
ID: session.Provider,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !session.TotpPending {
|
||||||
|
c.Authenticated = true
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetUsername() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
return c.Local.Username
|
||||||
|
case ProviderLDAP:
|
||||||
|
return c.LDAP.Username
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
return c.Local.Username
|
||||||
|
case ProviderOAuth:
|
||||||
|
return c.OAuth.Username
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetEmail() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
return c.Local.Email
|
||||||
|
case ProviderLDAP:
|
||||||
|
return c.LDAP.Email
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
return c.Local.Email
|
||||||
|
case ProviderOAuth:
|
||||||
|
return c.OAuth.Email
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *UserContext) GetName() string {
|
||||||
|
switch c.Provider {
|
||||||
|
case ProviderLocal:
|
||||||
|
return c.Local.Name
|
||||||
|
case ProviderLDAP:
|
||||||
|
return c.LDAP.Name
|
||||||
|
case ProviderBasicAuth:
|
||||||
|
return c.Local.Name
|
||||||
|
case ProviderOAuth:
|
||||||
|
return c.OAuth.Name
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
type UserSearchType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
UserLocal UserSearchType = iota
|
||||||
|
UserLDAP
|
||||||
|
)
|
||||||
|
|
||||||
|
type LDAPUser struct {
|
||||||
|
DN string
|
||||||
|
Groups []string
|
||||||
|
}
|
||||||
|
|
||||||
|
type LocalUser struct {
|
||||||
|
Username string
|
||||||
|
Password string
|
||||||
|
TOTPSecret string
|
||||||
|
Attributes UserAttributes
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserSearch struct {
|
||||||
|
Username string
|
||||||
|
Type UserSearchType
|
||||||
|
}
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
var Version = "development"
|
||||||
|
var CommitHash = "development"
|
||||||
|
var BuildTimestamp = "0000-00-00T00:00:00Z"
|
||||||
@@ -4,20 +4,20 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
type LabelProvider interface {
|
type LabelProvider interface {
|
||||||
GetLabels(appDomain string) (config.App, error)
|
GetLabels(appDomain string) (*model.App, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type AccessControlsService struct {
|
type AccessControlsService struct {
|
||||||
labelProvider LabelProvider
|
labelProvider LabelProvider
|
||||||
static map[string]config.App
|
static map[string]model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService {
|
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService {
|
||||||
return &AccessControlsService{
|
return &AccessControlsService{
|
||||||
labelProvider: labelProvider,
|
labelProvider: labelProvider,
|
||||||
static: static,
|
static: static,
|
||||||
@@ -28,22 +28,22 @@ func (acls *AccessControlsService) Init() error {
|
|||||||
return nil // No initialization needed
|
return nil // No initialization needed
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) {
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
return config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
return config, nil
|
return &config, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, errors.New("no results")
|
return nil, errors.New("no results")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
|
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||||
// First check in the static config
|
// First check in the static config
|
||||||
app, err := acls.lookupStaticACLs(domain)
|
app, err := acls.lookupStaticACLs(domain)
|
||||||
|
|
||||||
|
|||||||
+130
-138
@@ -5,12 +5,13 @@ import (
|
|||||||
"database/sql"
|
"database/sql"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -68,7 +69,7 @@ type Lockdown struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
type AuthServiceConfig struct {
|
||||||
Users []config.User
|
LocalUsers []model.LocalUser
|
||||||
OauthWhitelist []string
|
OauthWhitelist []string
|
||||||
SessionExpiry int
|
SessionExpiry int
|
||||||
SessionMaxLifetime int
|
SessionMaxLifetime int
|
||||||
@@ -77,7 +78,7 @@ type AuthServiceConfig struct {
|
|||||||
LoginTimeout int
|
LoginTimeout int
|
||||||
LoginMaxRetries int
|
LoginMaxRetries int
|
||||||
SessionCookieName string
|
SessionCookieName string
|
||||||
IP config.IPConfig
|
IP model.IPConfig
|
||||||
LDAPGroupsCacheTTL int
|
LDAPGroupsCacheTTL int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -106,7 +107,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
|
|||||||
ldap: ldap,
|
ldap: ldap,
|
||||||
queries: queries,
|
queries: queries,
|
||||||
oauthBroker: oauthBroker,
|
oauthBroker: oauthBroker,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) Init() error {
|
func (auth *AuthService) Init() error {
|
||||||
@@ -114,79 +115,67 @@ func (auth *AuthService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) SearchUser(username string) config.UserSearch {
|
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) {
|
||||||
if auth.GetLocalUser(username).Username != "" {
|
if auth.GetLocalUser(username).Username != "" {
|
||||||
return config.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: username,
|
Username: username,
|
||||||
Type: "local",
|
Type: model.UserLocal,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if auth.ldap.IsConfigured() {
|
if auth.ldap.IsConfigured() {
|
||||||
userDN, err := auth.ldap.GetUserDN(username)
|
userDN, err := auth.ldap.GetUserDN(username)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
|
return nil, fmt.Errorf("failed to get ldap user: %w", err)
|
||||||
return config.UserSearch{
|
|
||||||
Type: "unknown",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return config.UserSearch{
|
return &model.UserSearch{
|
||||||
Username: userDN,
|
Username: userDN,
|
||||||
Type: "ldap",
|
Type: model.UserLDAP,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return config.UserSearch{
|
return nil, fmt.Errorf("user not found")
|
||||||
Type: "unknown",
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
|
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error {
|
||||||
switch search.Type {
|
switch search.Type {
|
||||||
case "local":
|
case model.UserLocal:
|
||||||
user := auth.GetLocalUser(search.Username)
|
user := auth.GetLocalUser(search.Username)
|
||||||
return auth.CheckPassword(user, password)
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
case "ldap":
|
case model.UserLDAP:
|
||||||
if auth.ldap.IsConfigured() {
|
if auth.ldap.IsConfigured() {
|
||||||
err := auth.ldap.Bind(search.Username, password)
|
err := auth.ldap.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
return fmt.Errorf("failed to bind to ldap user: %w", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = auth.ldap.BindService(true)
|
err = auth.ldap.BindService(true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
|
return fmt.Errorf("failed to bind to ldap service account: %w", err)
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return true
|
return nil
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
|
return errors.New("unknown user search type")
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
return errors.New("user authentication failed")
|
||||||
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) config.User {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
for _, user := range auth.config.Users {
|
for _, user := range auth.config.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return user
|
return &user
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
tlog.App.Warn().Str("username", username).Msg("Local user not found")
|
|
||||||
return config.User{}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
|
||||||
if !auth.ldap.IsConfigured() {
|
if !auth.ldap.IsConfigured() {
|
||||||
return config.LdapUser{}, errors.New("LDAP service not initialized")
|
return nil, errors.New("ldap service not configured")
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.RLock()
|
auth.ldapGroupsMutex.RLock()
|
||||||
@@ -194,7 +183,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
auth.ldapGroupsMutex.RUnlock()
|
auth.ldapGroupsMutex.RUnlock()
|
||||||
|
|
||||||
if exists && time.Now().Before(entry.Expires) {
|
if exists && time.Now().Before(entry.Expires) {
|
||||||
return config.LdapUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: entry.Groups,
|
Groups: entry.Groups,
|
||||||
}, nil
|
}, nil
|
||||||
@@ -203,7 +192,7 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
groups, err := auth.ldap.GetUserGroups(userDN)
|
groups, err := auth.ldap.GetUserGroups(userDN)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.LdapUser{}, err
|
return nil, fmt.Errorf("failed to get ldap groups: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
auth.ldapGroupsMutex.Lock()
|
auth.ldapGroupsMutex.Lock()
|
||||||
@@ -213,16 +202,12 @@ func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
|
|||||||
}
|
}
|
||||||
auth.ldapGroupsMutex.Unlock()
|
auth.ldapGroupsMutex.Unlock()
|
||||||
|
|
||||||
return config.LdapUser{
|
return &model.LDAPUser{
|
||||||
DN: userDN,
|
DN: userDN,
|
||||||
Groups: groups,
|
Groups: groups,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
|
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||||
auth.loginMutex.RLock()
|
auth.loginMutex.RLock()
|
||||||
defer auth.loginMutex.RUnlock()
|
defer auth.loginMutex.RUnlock()
|
||||||
@@ -291,11 +276,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
|
|||||||
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
|
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) {
|
||||||
uuid, err := uuid.NewRandom()
|
uuid, err := uuid.NewRandom()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to generate session uuid: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var expiry int
|
var expiry int
|
||||||
@@ -320,28 +305,30 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Se
|
|||||||
OAuthSub: data.OAuthSub,
|
OAuthSub: data.OAuthSub,
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = auth.queries.CreateSession(c, session)
|
_, err = auth.queries.CreateSession(ctx, session)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to create session entry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
return &http.Cookie{
|
||||||
|
Name: auth.config.SessionCookieName,
|
||||||
return nil
|
Value: session.UUID,
|
||||||
|
Path: "/",
|
||||||
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
|
||||||
|
MaxAge: expiry,
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to retrieve session: %w", err)
|
||||||
}
|
|
||||||
|
|
||||||
session, err := auth.queries.GetSession(c, cookie)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
@@ -355,12 +342,12 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if session.Expiry-currentTime > refreshThreshold {
|
if session.Expiry-currentTime > refreshThreshold {
|
||||||
return nil
|
return nil, fmt.Errorf("session not eligible for refresh yet")
|
||||||
}
|
}
|
||||||
|
|
||||||
newExpiry := session.Expiry + refreshThreshold
|
newExpiry := session.Expiry + refreshThreshold
|
||||||
|
|
||||||
_, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
|
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{
|
||||||
Username: session.Username,
|
Username: session.Username,
|
||||||
Email: session.Email,
|
Email: session.Email,
|
||||||
Name: session.Name,
|
Name: session.Name,
|
||||||
@@ -374,120 +361,117 @@ func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, fmt.Errorf("failed to update session expiry: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
return &http.Cookie{
|
||||||
tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
|
Name: auth.config.SessionCookieName,
|
||||||
|
Value: session.UUID,
|
||||||
|
Path: "/",
|
||||||
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
|
||||||
|
MaxAge: auth.config.SessionExpiry,
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
err := auth.queries.DeleteSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway")
|
||||||
}
|
}
|
||||||
|
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
return &http.Cookie{
|
||||||
|
Name: auth.config.SessionCookieName,
|
||||||
if err != nil {
|
Value: "",
|
||||||
return err
|
Path: "/",
|
||||||
}
|
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
|
||||||
|
Expires: time.Now(),
|
||||||
c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
|
MaxAge: -1,
|
||||||
|
Secure: auth.config.SecureCookie,
|
||||||
return nil
|
HttpOnly: true,
|
||||||
|
SameSite: http.SameSiteLaxMode,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
|
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) {
|
||||||
cookie, err := c.Cookie(auth.config.SessionCookieName)
|
session, err := auth.queries.GetSession(ctx, uuid)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return repository.Session{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
session, err := auth.queries.GetSession(c, cookie)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return repository.Session{}, fmt.Errorf("session not found")
|
return nil, errors.New("session not found")
|
||||||
}
|
}
|
||||||
return repository.Session{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
currentTime := time.Now().Unix()
|
currentTime := time.Now().Unix()
|
||||||
|
|
||||||
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
|
||||||
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
|
return nil, fmt.Errorf("session max lifetime exceeded")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if currentTime > session.Expiry {
|
if currentTime > session.Expiry {
|
||||||
err = auth.queries.DeleteSession(c, cookie)
|
err = auth.queries.DeleteSession(ctx, uuid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete expired session")
|
return nil, fmt.Errorf("failed to delete expired session: %w", err)
|
||||||
}
|
}
|
||||||
return repository.Session{}, fmt.Errorf("session expired")
|
return nil, fmt.Errorf("session expired")
|
||||||
}
|
}
|
||||||
|
|
||||||
return repository.Session{
|
return &session, nil
|
||||||
UUID: session.UUID,
|
|
||||||
Username: session.Username,
|
|
||||||
Email: session.Email,
|
|
||||||
Name: session.Name,
|
|
||||||
Provider: session.Provider,
|
|
||||||
TotpPending: session.TotpPending,
|
|
||||||
OAuthGroups: session.OAuthGroups,
|
|
||||||
OAuthName: session.OAuthName,
|
|
||||||
OAuthSub: session.OAuthSub,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return len(auth.config.Users) > 0
|
return len(auth.config.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LdapAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
return auth.ldap.IsConfigured()
|
return auth.ldap.IsConfigured()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
|
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls model.App) bool {
|
||||||
if context.OAuth {
|
if context.Provider == model.ProviderOAuth {
|
||||||
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
tlog.App.Debug().Msg("Checking OAuth whitelist")
|
||||||
return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
|
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email)
|
||||||
}
|
}
|
||||||
|
|
||||||
if acls.Users.Block != "" {
|
if acls.Users.Block != "" {
|
||||||
tlog.App.Debug().Msg("Checking blocked users")
|
tlog.App.Debug().Msg("Checking blocked users")
|
||||||
if utils.CheckFilter(acls.Users.Block, context.Username) {
|
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("Checking users")
|
tlog.App.Debug().Msg("Checking users")
|
||||||
return utils.CheckFilter(acls.Users.Allow, context.Username)
|
return utils.CheckFilter(acls.Users.Allow, context.GetUsername())
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
||||||
if requiredGroups == "" {
|
if requiredGroups == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for id := range config.OverrideProviders {
|
if !context.IsOAuth() {
|
||||||
if context.Provider == id {
|
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check")
|
||||||
tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
|
return false
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
|
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
|
||||||
|
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userGroup := range context.OAuth.Groups {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
@@ -498,12 +482,17 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserConte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
|
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool {
|
||||||
if requiredGroups == "" {
|
if requiredGroups == "" {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
|
if !context.IsLDAP() {
|
||||||
|
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, userGroup := range context.LDAP.Groups {
|
||||||
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
|
||||||
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
|
||||||
return true
|
return true
|
||||||
@@ -514,7 +503,7 @@ func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContex
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
|
func (auth *AuthService) IsAuthEnabled(uri string, path model.AppPath) (bool, error) {
|
||||||
// Check for block list
|
// Check for block list
|
||||||
if path.Block != "" {
|
if path.Block != "" {
|
||||||
regex, err := regexp.Compile(path.Block)
|
regex, err := regexp.Compile(path.Block)
|
||||||
@@ -544,19 +533,22 @@ func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, e
|
|||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
|
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too
|
||||||
username, password, ok := c.Request.BasicAuth()
|
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) {
|
||||||
if !ok {
|
if req == nil {
|
||||||
tlog.App.Debug().Msg("No basic auth provided")
|
return nil, errors.New("request is nil")
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
return &config.User{
|
username, password, ok := req.BasicAuth()
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("no basic auth credentials provided")
|
||||||
|
}
|
||||||
|
return &model.LocalUser{
|
||||||
Username: username,
|
Username: username,
|
||||||
Password: password,
|
Password: password,
|
||||||
}
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
func (auth *AuthService) CheckIP(acls model.AppIP, ip string) bool {
|
||||||
// Merge the global and app IP filter
|
// Merge the global and app IP filter
|
||||||
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
blockedIps := append(auth.config.IP.Block, acls.Block...)
|
||||||
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
|
||||||
@@ -594,7 +586,7 @@ func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
|
func (auth *AuthService) IsBypassedIP(acls model.AppIP, ip string) bool {
|
||||||
for _, bypassed := range acls.Bypass {
|
for _, bypassed := range acls.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -674,21 +666,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
|
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) {
|
||||||
session, err := auth.GetOAuthPendingSession(sessionId)
|
session, err := auth.GetOAuthPendingSession(sessionId)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if session.Token == nil {
|
if session.Token == nil {
|
||||||
return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId)
|
||||||
}
|
}
|
||||||
|
|
||||||
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
userinfo, err := (*session.Service).GetUserinfo(session.Token)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
|
return nil, fmt.Errorf("failed to get userinfo: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return userinfo, nil
|
return userinfo, nil
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
@@ -66,41 +66,41 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
|
|||||||
return inspect, nil
|
return inspect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
|
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !docker.isConnected {
|
if !docker.isConnected {
|
||||||
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
tlog.App.Debug().Msg("Docker not connected, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
containers, err := docker.getContainers()
|
containers, err := docker.getContainers()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, ctr := range containers {
|
for _, ctr := range containers {
|
||||||
inspect, err := docker.inspectContainer(ctr.ID)
|
inspect, err := docker.inspectContainer(ctr.ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.App{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for appName, appLabels := range labels.Apps {
|
for appName, appLabels := range labels.Apps {
|
||||||
if appLabels.Config.Domain == appDomain {
|
if appLabels.Config.Domain == appDomain {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
|
||||||
return appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
if strings.SplitN(appDomain, ".", 2)[0] == appName {
|
||||||
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
|
||||||
return appLabels, nil
|
return &appLabels, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
tlog.App.Debug().Msg("No matching container found, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
"github.com/tinyauthapp/tinyauth/internal/utils/decoders"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
@@ -32,7 +32,7 @@ type ingressAppKey struct {
|
|||||||
type ingressApp struct {
|
type ingressApp struct {
|
||||||
domain string
|
domain string
|
||||||
appName string
|
appName string
|
||||||
app config.App
|
app model.App
|
||||||
}
|
}
|
||||||
|
|
||||||
type KubernetesService struct {
|
type KubernetesService struct {
|
||||||
@@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
|
func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
@@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
|
|||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for _, app := range apps {
|
||||||
if app.domain == domain && app.appName == appKey.appName {
|
if app.domain == domain && app.appName == appKey.appName {
|
||||||
return app.app, true
|
return &app.app, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
|
func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
|
||||||
k.mu.RLock()
|
k.mu.RLock()
|
||||||
defer k.mu.RUnlock()
|
defer k.mu.RUnlock()
|
||||||
|
|
||||||
@@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
|
|||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for _, app := range apps {
|
||||||
if app.appName == appName {
|
if app.appName == appName {
|
||||||
return app.app, true
|
return &app.app, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return config.App{}, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
||||||
@@ -129,7 +129,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
|
|||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps")
|
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
|
||||||
k.removeIngress(namespace, name)
|
k.removeIngress(namespace, name)
|
||||||
@@ -280,10 +280,10 @@ func (k *KubernetesService) Init() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
|
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
|
||||||
if !k.started {
|
if !k.started {
|
||||||
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// First check cache
|
// First check cache
|
||||||
@@ -298,6 +298,5 @@ func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
|
||||||
return config.App{}, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"slices"
|
"slices"
|
||||||
@@ -15,20 +15,20 @@ type OAuthServiceImpl interface {
|
|||||||
NewRandom() string
|
NewRandom() string
|
||||||
GetAuthURL(state string, verifier string) string
|
GetAuthURL(state string, verifier string) string
|
||||||
GetToken(code string, verifier string) (*oauth2.Token, error)
|
GetToken(code string, verifier string) (*oauth2.Token, error)
|
||||||
GetUserinfo(token *oauth2.Token) (config.Claims, error)
|
GetUserinfo(token *oauth2.Token) (*model.Claims, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type OAuthBrokerService struct {
|
type OAuthBrokerService struct {
|
||||||
services map[string]OAuthServiceImpl
|
services map[string]OAuthServiceImpl
|
||||||
configs map[string]config.OAuthServiceConfig
|
configs map[string]model.OAuthServiceConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
|
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{
|
||||||
"github": newGitHubOAuthService,
|
"github": newGitHubOAuthService,
|
||||||
"google": newGoogleOAuthService,
|
"google": newGoogleOAuthService,
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
|
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService {
|
||||||
return &OAuthBrokerService{
|
return &OAuthBrokerService{
|
||||||
services: make(map[string]OAuthServiceImpl),
|
services: make(map[string]OAuthServiceImpl),
|
||||||
configs: configs,
|
configs: configs,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GithubEmailResponse []struct {
|
type GithubEmailResponse []struct {
|
||||||
@@ -22,32 +22,32 @@ type GithubUserInfoResponse struct {
|
|||||||
ID int `json:"id"`
|
ID int `json:"id"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
|
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||||
return simpleReq[config.Claims](client, url, nil)
|
return simpleReq[model.Claims](client, url, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
|
||||||
var user config.Claims
|
var user model.Claims
|
||||||
|
|
||||||
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
|
||||||
"accept": "application/vnd.github+json",
|
"accept": "application/vnd.github+json",
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return config.Claims{}, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(userEmails) == 0 {
|
if len(*userEmails) == 0 {
|
||||||
return user, errors.New("no emails found")
|
return nil, errors.New("no emails found")
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, email := range userEmails {
|
for _, email := range *userEmails {
|
||||||
if email.Primary {
|
if email.Primary {
|
||||||
user.Email = email.Email
|
user.Email = email.Email
|
||||||
break
|
break
|
||||||
@@ -56,22 +56,22 @@ func githubExtractor(client *http.Client, url string) (config.Claims, error) {
|
|||||||
|
|
||||||
// Use first available email if no primary email was found
|
// Use first available email if no primary email was found
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
user.Email = userEmails[0].Email
|
user.Email = (*userEmails)[0].Email
|
||||||
}
|
}
|
||||||
|
|
||||||
user.PreferredUsername = userInfo.Login
|
user.PreferredUsername = userInfo.Login
|
||||||
user.Name = userInfo.Name
|
user.Name = userInfo.Name
|
||||||
user.Sub = strconv.Itoa(userInfo.ID)
|
user.Sub = strconv.Itoa(userInfo.ID)
|
||||||
|
|
||||||
return user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
|
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) {
|
||||||
var decodedRes T
|
var decodedRes T
|
||||||
|
|
||||||
req, err := http.NewRequest("GET", url, nil)
|
req, err := http.NewRequest("GET", url, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for key, value := range headers {
|
for key, value := range headers {
|
||||||
@@ -80,23 +80,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
|
|||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
if res.StatusCode < 200 || res.StatusCode >= 300 {
|
||||||
return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
|
return nil, fmt.Errorf("request failed with status: %s", res.Status)
|
||||||
}
|
}
|
||||||
|
|
||||||
body, err := io.ReadAll(res.Body)
|
body, err := io.ReadAll(res.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = json.Unmarshal(body, &decodedRes)
|
err = json.Unmarshal(body, &decodedRes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return decodedRes, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return decodedRes, nil
|
return &decodedRes, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
package service
|
package service
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2/endpoints"
|
"golang.org/x/oauth2/endpoints"
|
||||||
)
|
)
|
||||||
|
|
||||||
func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"openid", "email", "profile"}
|
scopes := []string{"openid", "email", "profile"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.Google.AuthURL
|
config.AuthURL = endpoints.Google.AuthURL
|
||||||
@@ -14,7 +14,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
|||||||
return NewOAuthService(config, "google")
|
return NewOAuthService(config, "google")
|
||||||
}
|
}
|
||||||
|
|
||||||
func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
|
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService {
|
||||||
scopes := []string{"read:user", "user:email"}
|
scopes := []string{"read:user", "user:email"}
|
||||||
config.Scopes = scopes
|
config.Scopes = scopes
|
||||||
config.AuthURL = endpoints.GitHub.AuthURL
|
config.AuthURL = endpoints.GitHub.AuthURL
|
||||||
|
|||||||
@@ -6,21 +6,21 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
|
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error)
|
||||||
|
|
||||||
type OAuthService struct {
|
type OAuthService struct {
|
||||||
serviceCfg config.OAuthServiceConfig
|
serviceCfg model.OAuthServiceConfig
|
||||||
config *oauth2.Config
|
config *oauth2.Config
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
userinfoExtractor UserinfoExtractor
|
userinfoExtractor UserinfoExtractor
|
||||||
id string
|
id string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
|
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService {
|
||||||
httpClient := &http.Client{
|
httpClient := &http.Client{
|
||||||
Timeout: 30 * time.Second,
|
Timeout: 30 * time.Second,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
@@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
|
|||||||
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
|
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) {
|
||||||
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
|
||||||
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/go-jose/go-jose/v4"
|
"github.com/go-jose/go-jose/v4"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
@@ -68,27 +68,27 @@ type ClaimSet struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type UserinfoResponse struct {
|
type UserinfoResponse struct {
|
||||||
Sub string `json:"sub"`
|
Sub string `json:"sub"`
|
||||||
Name string `json:"name,omitempty"`
|
Name string `json:"name,omitempty"`
|
||||||
GivenName string `json:"given_name,omitempty"`
|
GivenName string `json:"given_name,omitempty"`
|
||||||
FamilyName string `json:"family_name,omitempty"`
|
FamilyName string `json:"family_name,omitempty"`
|
||||||
MiddleName string `json:"middle_name,omitempty"`
|
MiddleName string `json:"middle_name,omitempty"`
|
||||||
Nickname string `json:"nickname,omitempty"`
|
Nickname string `json:"nickname,omitempty"`
|
||||||
Profile string `json:"profile,omitempty"`
|
Profile string `json:"profile,omitempty"`
|
||||||
Picture string `json:"picture,omitempty"`
|
Picture string `json:"picture,omitempty"`
|
||||||
Website string `json:"website,omitempty"`
|
Website string `json:"website,omitempty"`
|
||||||
Gender string `json:"gender,omitempty"`
|
Gender string `json:"gender,omitempty"`
|
||||||
Birthdate string `json:"birthdate,omitempty"`
|
Birthdate string `json:"birthdate,omitempty"`
|
||||||
Zoneinfo string `json:"zoneinfo,omitempty"`
|
Zoneinfo string `json:"zoneinfo,omitempty"`
|
||||||
Locale string `json:"locale,omitempty"`
|
Locale string `json:"locale,omitempty"`
|
||||||
Email string `json:"email,omitempty"`
|
Email string `json:"email,omitempty"`
|
||||||
PreferredUsername string `json:"preferred_username,omitempty"`
|
PreferredUsername string `json:"preferred_username,omitempty"`
|
||||||
Groups []string `json:"groups,omitempty"`
|
Groups []string `json:"groups,omitempty"`
|
||||||
EmailVerified bool `json:"email_verified,omitempty"`
|
EmailVerified bool `json:"email_verified,omitempty"`
|
||||||
PhoneNumber string `json:"phone_number,omitempty"`
|
PhoneNumber string `json:"phone_number,omitempty"`
|
||||||
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
|
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
|
||||||
Address *config.AddressClaim `json:"address,omitempty"`
|
Address *model.AddressClaim `json:"address,omitempty"`
|
||||||
UpdatedAt int64 `json:"updated_at"`
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenResponse struct {
|
type TokenResponse struct {
|
||||||
@@ -112,7 +112,7 @@ type AuthorizeRequest struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OIDCServiceConfig struct {
|
type OIDCServiceConfig struct {
|
||||||
Clients map[string]config.OIDCClientConfig
|
Clients map[string]model.OIDCClientConfig
|
||||||
PrivateKeyPath string
|
PrivateKeyPath string
|
||||||
PublicKeyPath string
|
PublicKeyPath string
|
||||||
Issuer string
|
Issuer string
|
||||||
@@ -122,7 +122,7 @@ type OIDCServiceConfig struct {
|
|||||||
type OIDCService struct {
|
type OIDCService struct {
|
||||||
config OIDCServiceConfig
|
config OIDCServiceConfig
|
||||||
queries *repository.Queries
|
queries *repository.Queries
|
||||||
clients map[string]config.OIDCClientConfig
|
clients map[string]model.OIDCClientConfig
|
||||||
privateKey *rsa.PrivateKey
|
privateKey *rsa.PrivateKey
|
||||||
publicKey crypto.PublicKey
|
publicKey crypto.PublicKey
|
||||||
issuer string
|
issuer string
|
||||||
@@ -255,7 +255,7 @@ func (service *OIDCService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// We will reorganize the client into a map with the client ID as the key
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
service.clients = make(map[string]config.OIDCClientConfig)
|
service.clients = make(map[string]model.OIDCClientConfig)
|
||||||
|
|
||||||
for id, client := range service.config.Clients {
|
for id, client := range service.config.Clients {
|
||||||
client.ID = id
|
client.ID = id
|
||||||
@@ -283,7 +283,7 @@ func (service *OIDCService) GetIssuer() string {
|
|||||||
return service.issuer
|
return service.issuer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) {
|
||||||
client, ok := service.clients[id]
|
client, ok := service.clients[id]
|
||||||
return client, ok
|
return client, ok
|
||||||
}
|
}
|
||||||
@@ -367,43 +367,45 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
|
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error {
|
||||||
addressJSON, err := json.Marshal(userContext.Attributes.Address)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
userInfoParams := repository.CreateOidcUserInfoParams{
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
||||||
Sub: sub,
|
Sub: sub,
|
||||||
Name: userContext.Name,
|
Name: userContext.GetName(),
|
||||||
Email: userContext.Email,
|
Email: userContext.GetEmail(),
|
||||||
PreferredUsername: userContext.Username,
|
PreferredUsername: userContext.GetUsername(),
|
||||||
UpdatedAt: time.Now().Unix(),
|
UpdatedAt: time.Now().Unix(),
|
||||||
GivenName: userContext.Attributes.GivenName,
|
}
|
||||||
FamilyName: userContext.Attributes.FamilyName,
|
|
||||||
MiddleName: userContext.Attributes.MiddleName,
|
if userContext.IsLocal() {
|
||||||
Nickname: userContext.Attributes.Nickname,
|
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
|
||||||
Profile: userContext.Attributes.Profile,
|
if err != nil {
|
||||||
Picture: userContext.Attributes.Picture,
|
return err
|
||||||
Website: userContext.Attributes.Website,
|
}
|
||||||
Gender: userContext.Attributes.Gender,
|
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
|
||||||
Birthdate: userContext.Attributes.Birthdate,
|
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName
|
||||||
Zoneinfo: userContext.Attributes.Zoneinfo,
|
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName
|
||||||
Locale: userContext.Attributes.Locale,
|
userInfoParams.Nickname = userContext.Local.Attributes.Nickname
|
||||||
PhoneNumber: userContext.Attributes.PhoneNumber,
|
userInfoParams.Profile = userContext.Local.Attributes.Profile
|
||||||
Address: string(addressJSON),
|
userInfoParams.Picture = userContext.Local.Attributes.Picture
|
||||||
|
userInfoParams.Website = userContext.Local.Attributes.Website
|
||||||
|
userInfoParams.Gender = userContext.Local.Attributes.Gender
|
||||||
|
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate
|
||||||
|
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo
|
||||||
|
userInfoParams.Locale = userContext.Local.Attributes.Locale
|
||||||
|
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber
|
||||||
|
userInfoParams.Address = string(addressJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
||||||
if userContext.Provider == "ldap" {
|
if userContext.IsLDAP() {
|
||||||
userInfoParams.Groups = userContext.LdapGroups
|
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
if userContext.IsOAuth() {
|
||||||
userInfoParams.Groups = userContext.OAuthGroups
|
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
|
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -445,7 +447,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
|||||||
return oidcCode, nil
|
return oidcCode, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
|
||||||
createdAt := time.Now().Unix()
|
createdAt := time.Now().Unix()
|
||||||
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
|
||||||
|
|
||||||
@@ -511,7 +513,7 @@ func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user
|
|||||||
return token, nil
|
return token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
|
||||||
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
user, err := service.GetUserinfo(c, codeEntry.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -585,7 +587,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
|||||||
return TokenResponse{}, err
|
return TokenResponse{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
idToken, err := service.generateIDToken(config.OIDCClientConfig{
|
idToken, err := service.generateIDToken(model.OIDCClientConfig{
|
||||||
ClientID: entry.ClientID,
|
ClientID: entry.ClientID,
|
||||||
}, user, entry.Scope, entry.Nonce)
|
}, user, entry.Scope, entry.Nonce)
|
||||||
|
|
||||||
@@ -714,7 +716,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
|
|||||||
}
|
}
|
||||||
|
|
||||||
if slices.Contains(scopes, "address") {
|
if slices.Contains(scopes, "address") {
|
||||||
var addr config.AddressClaim
|
var addr model.AddressClaim
|
||||||
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
|
||||||
userInfo.Address = &addr
|
userInfo.Address = &addr
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,10 +7,8 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"github.com/weppos/publicsuffix-go/publicsuffix"
|
"github.com/weppos/publicsuffix-go/publicsuffix"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -73,22 +71,6 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
|
|||||||
return res
|
return res
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetContext(c *gin.Context) (config.UserContext, error) {
|
|
||||||
userContextValue, exists := c.Get("context")
|
|
||||||
|
|
||||||
if !exists {
|
|
||||||
return config.UserContext{}, errors.New("no user context in request")
|
|
||||||
}
|
|
||||||
|
|
||||||
userContext, ok := userContextValue.(*config.UserContext)
|
|
||||||
|
|
||||||
if !ok {
|
|
||||||
return config.UserContext{}, errors.New("invalid user context in request")
|
|
||||||
}
|
|
||||||
|
|
||||||
return *userContext, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func IsRedirectSafe(redirectURL string, domain string) bool {
|
func IsRedirectSafe(redirectURL string, domain string) bool {
|
||||||
if redirectURL == "" {
|
if redirectURL == "" {
|
||||||
return false
|
return false
|
||||||
|
|||||||
@@ -3,10 +3,8 @@ package utils_test
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
|
||||||
"gotest.tools/v3/assert"
|
"gotest.tools/v3/assert"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -129,28 +127,6 @@ func TestFilter(t *testing.T) {
|
|||||||
assert.DeepEqual(t, expectedStr, resultStr)
|
assert.DeepEqual(t, expectedStr, resultStr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetContext(t *testing.T) {
|
|
||||||
// Setup
|
|
||||||
gin.SetMode(gin.TestMode)
|
|
||||||
c, _ := gin.CreateTestContext(nil)
|
|
||||||
|
|
||||||
// Normal case
|
|
||||||
c.Set("context", &config.UserContext{Username: "testuser"})
|
|
||||||
result, err := utils.GetContext(c)
|
|
||||||
assert.NilError(t, err)
|
|
||||||
assert.Equal(t, "testuser", result.Username)
|
|
||||||
|
|
||||||
// Case with no context
|
|
||||||
c.Set("context", nil)
|
|
||||||
_, err = utils.GetContext(c)
|
|
||||||
assert.Error(t, err, "invalid user context in request")
|
|
||||||
|
|
||||||
// Case with invalid context type
|
|
||||||
c.Set("context", "invalid type")
|
|
||||||
_, err = utils.GetContext(c)
|
|
||||||
assert.Error(t, err, "invalid user context in request")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestIsRedirectSafe(t *testing.T) {
|
func TestIsRedirectSafe(t *testing.T) {
|
||||||
// Setup
|
// Setup
|
||||||
domain := "example.com"
|
domain := "example.com"
|
||||||
|
|||||||
Reference in New Issue
Block a user