diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go deleted file mode 100644 index 1ab7329..0000000 --- a/internal/auth/auth_test.go +++ /dev/null @@ -1,146 +0,0 @@ -package auth_test - -import ( - "testing" - "time" - "tinyauth/internal/auth" - "tinyauth/internal/types" -) - -var config = types.AuthConfig{ - Users: types.Users{}, - OauthWhitelist: "", - SessionExpiry: 3600, -} - -func TestLoginRateLimiting(t *testing.T) { - // Initialize a new auth service with 3 max retries and 5 seconds timeout - config.LoginMaxRetries = 3 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test identifier - identifier := "test_user" - - // Test successful login - should not lock account - t.Log("Testing successful login") - - authService.RecordLoginAttempt(identifier, true) - locked, _ := authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after successful login") - } - - // Test 2 failed attempts - should not lock account yet - t.Log("Testing 2 failed login attempts") - - authService.RecordLoginAttempt(identifier, false) - authService.RecordLoginAttempt(identifier, false) - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should not be locked after only 2 failed attempts") - } - - // Add one more failed attempt (total 3) - should lock account with maxRetries=3 - t.Log("Testing 3 failed login attempts") - authService.RecordLoginAttempt(identifier, false) - locked, remainingTime := authService.IsAccountLocked(identifier) - - if !locked { - t.Fatalf("Account should be locked after reaching max retries") - } - if remainingTime <= 0 || remainingTime > 5 { - t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime) - } - - // Test reset after waiting for timeout - use 1 second timeout for fast testing - t.Log("Testing unlocking after timeout") - - // Reinitialize auth service with a shorter timeout for testing - config.LoginTimeout = 1 - config.LoginMaxRetries = 3 - authService = auth.NewAuth(config, nil, nil) - - // Add enough failed attempts to lock the account - for i := 0; i < 3; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - // Verify it's locked - locked, _ = authService.IsAccountLocked(identifier) - if !locked { - t.Fatalf("Account should be locked initially") - } - - // Wait a bit and verify it gets unlocked after timeout - time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout - locked, _ = authService.IsAccountLocked(identifier) - - if locked { - t.Fatalf("Account should be unlocked after timeout period") - } - - // Test disabled rate limiting - t.Log("Testing disabled rate limiting") - config.LoginMaxRetries = 0 - config.LoginTimeout = 0 - authService = auth.NewAuth(config, nil, nil) - - for i := 0; i < 10; i++ { - authService.RecordLoginAttempt(identifier, false) - } - - locked, _ = authService.IsAccountLocked(identifier) - if locked { - t.Fatalf("Account should not be locked when rate limiting is disabled") - } -} - -func TestConcurrentLoginAttempts(t *testing.T) { - // Initialize a new auth service with 2 max retries and 5 seconds timeout - config.LoginMaxRetries = 2 - config.LoginTimeout = 5 - authService := auth.NewAuth(config, nil, nil) - - // Test multiple identifiers - identifiers := []string{"user1", "user2", "user3"} - - // Test that locking one identifier doesn't affect others - t.Log("Testing multiple identifiers") - - // Add enough failed attempts to lock first user (2 attempts with maxRetries=2) - authService.RecordLoginAttempt(identifiers[0], false) - authService.RecordLoginAttempt(identifiers[0], false) - - // Check if first user is locked - locked, _ := authService.IsAccountLocked(identifiers[0]) - if !locked { - t.Fatalf("User1 should be locked after reaching max retries") - } - - // Check that other users are not affected - for i := 1; i < len(identifiers); i++ { - locked, _ := authService.IsAccountLocked(identifiers[i]) - if locked { - t.Fatalf("User%d should not be locked", i+1) - } - } - - // Test successful login after failed attempts (but before lock) - t.Log("Testing successful login after failed attempts but before lock") - - // One failed attempt for user2 - authService.RecordLoginAttempt(identifiers[1], false) - - // Successful login should reset the counter - authService.RecordLoginAttempt(identifiers[1], true) - - // Now try a failed login again - should not be locked as counter was reset - authService.RecordLoginAttempt(identifiers[1], false) - locked, _ = authService.IsAccountLocked(identifiers[1]) - if locked { - t.Fatalf("User2 should not be locked after successful login reset") - } -} diff --git a/internal/types/config.go b/internal/config/config.go similarity index 67% rename from internal/types/config.go rename to internal/config/config.go index dfb9e98..48dc47f 100644 --- a/internal/types/config.go +++ b/internal/config/config.go @@ -1,6 +1,22 @@ -package types +package config + +import "time" + +type Claims struct { + Name string `json:"name"` + Email string `json:"email"` + PreferredUsername string `json:"preferred_username"` + Groups any `json:"groups"` +} + +var Version = "development" +var CommitHash = "n/a" +var BuildTimestamp = "n/a" + +var SessionCookieName = "tinyauth-session" +var CsrfCookieName = "tinyauth-csrf" +var RedirectCookieName = "tinyauth-redirect" -// Config is the configuration for the tinyauth server type Config struct { Port int `mapstructure:"port" validate:"required"` Address string `validate:"required,ip4_addr" mapstructure:"address"` @@ -44,62 +60,27 @@ type Config struct { LdapSearchFilter string `mapstructure:"ldap-search-filter"` } -// OAuthConfig is the configuration for the providers -type OAuthConfig struct { - GithubClientId string - GithubClientSecret string - GoogleClientId string - GoogleClientSecret string - GenericClientId string - GenericClientSecret string - GenericScopes []string - GenericAuthURL string - GenericTokenURL string - GenericUserURL string - GenericSkipSSL bool - AppURL string -} - -// AuthConfig is the configuration for the auth service -type AuthConfig struct { - Users Users - OauthWhitelist string - SessionExpiry int - CookieSecure bool - Domain string - LoginTimeout int - LoginMaxRetries int - SessionCookieName string - HMACSecret string - EncryptionSecret string -} - -// OAuthLabels is a list of labels that can be used in a tinyauth protected container type OAuthLabels struct { Whitelist string Groups string } -// Basic auth labels for a tinyauth protected container type BasicLabels struct { Username string Password PassowrdLabels } -// PassowrdLabels is a struct that contains the password labels for a tinyauth protected container type PassowrdLabels struct { Plain string File string } -// IP labels for a tinyauth protected container type IPLabels struct { Allow []string Block []string Bypass []string } -// Labels is a struct that contains the labels for a tinyauth protected container type Labels struct { Users string Allowed string @@ -110,12 +91,65 @@ type Labels struct { IP IPLabels } -// Ldap config is a struct that contains the configuration for the LDAP service -type LdapConfig struct { - Address string - BindDN string - BindPassword string - BaseDN string - Insecure bool - SearchFilter string +type OAuthServiceConfig struct { + ClientID string + ClientSecret string + Scopes []string + RedirectURL string + AuthURL string + TokenURL string + UserinfoURL string + InsecureSkipVerify bool + Name string +} + +type User struct { + Username string + Password string + TotpSecret string +} + +type UserSearch struct { + Username string + Type string // local, ldap or unknown +} + +type Users []User + +type SessionCookie struct { + Username string + Name string + Email string + Provider string + TotpPending bool + OAuthGroups string +} + +type UserContext struct { + Username string + Name string + Email string + IsLoggedIn bool + OAuth bool + Provider string + TotpPending bool + OAuthGroups string + TotpEnabled bool +} + +type LoginAttempt struct { + FailedAttempts int + LastAttempt time.Time + LockedUntil time.Time +} + +type UnauthorizedQuery struct { + Username string `url:"username"` + Resource string `url:"resource"` + GroupErr bool `url:"groupErr"` + IP string `url:"ip"` +} + +type RedirectQuery struct { + RedirectURI string `url:"redirect_uri"` } diff --git a/internal/constants/constants.go b/internal/constants/constants.go deleted file mode 100644 index d6f64fa..0000000 --- a/internal/constants/constants.go +++ /dev/null @@ -1,19 +0,0 @@ -package constants - -// Claims are the OIDC supported claims (prefered username is included for convinience) -type Claims struct { - Name string `json:"name"` - Email string `json:"email"` - PreferredUsername string `json:"preferred_username"` - Groups any `json:"groups"` -} - -// Version information -var Version = "development" -var CommitHash = "n/a" -var BuildTimestamp = "n/a" - -// Base cookie names -var SessionCookieName = "tinyauth-session" -var CsrfCookieName = "tinyauth-csrf" -var RedirectCookieName = "tinyauth-redirect" diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go deleted file mode 100644 index 9529fce..0000000 --- a/internal/oauth/oauth.go +++ /dev/null @@ -1,71 +0,0 @@ -package oauth - -import ( - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "net/http" - - "golang.org/x/oauth2" -) - -type OAuth struct { - Config oauth2.Config - Context context.Context - Token *oauth2.Token - Verifier string -} - -func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: insecureSkipVerify, - MinVersion: tls.VersionTLS12, - }, - } - - httpClient := &http.Client{ - Transport: transport, - } - - ctx := context.Background() - - // Set the HTTP client in the context - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - verifier := oauth2.GenerateVerifier() - - return &OAuth{ - Config: config, - Context: ctx, - Verifier: verifier, - } -} - -func (oauth *OAuth) GetAuthURL(state string) string { - return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) -} - -func (oauth *OAuth) ExchangeToken(code string) (string, error) { - token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) - - if err != nil { - return "", err - } - - // Set and return the token - oauth.Token = token - return oauth.Token.AccessToken, nil -} - -func (oauth *OAuth) GetClient() *http.Client { - return oauth.Config.Client(oauth.Context, oauth.Token) -} - -func (oauth *OAuth) GenerateState() string { - b := make([]byte, 128) - rand.Read(b) - state := base64.URLEncoding.EncodeToString(b) - return state -} diff --git a/internal/providers/generic.go b/internal/providers/generic.go deleted file mode 100644 index 200f7c4..0000000 --- a/internal/providers/generic.go +++ /dev/null @@ -1,37 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get(url) - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from generic provider") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from generic provider") - - err = json.Unmarshal(body, &user) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from generic provider") - return user, nil -} diff --git a/internal/providers/github.go b/internal/providers/github.go deleted file mode 100644 index 67f8510..0000000 --- a/internal/providers/github.go +++ /dev/null @@ -1,102 +0,0 @@ -package providers - -import ( - "encoding/json" - "errors" - "io" - "net/http" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the github email endpoint -type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` -} - -// Response for the github user endpoint -type GithubUserInfoResponse struct { - Login string `json:"login"` - Name string `json:"name"` -} - -// The scopes required for the github provider -func GithubScopes() []string { - return []string{"user:email", "read:user"} -} - -func GetGithubUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://api.github.com/user") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got user response from github") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read user body from github") - - var userInfo GithubUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - res, err = client.Get("https://api.github.com/user/emails") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got email response from github") - - body, err = io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read email body from github") - - var emails GithubEmailResponse - - err = json.Unmarshal(body, &emails) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed emails from github") - - // Find and return the primary email - for _, email := range emails { - if email.Primary { - log.Debug().Str("email", email.Email).Msg("Found primary email") - user.Email = email.Email - break - } - } - - if len(emails) == 0 { - return user, errors.New("no emails found") - } - - // Use first available email if no primary email was found - if user.Email == "" { - log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") - user.Email = emails[0].Email - } - - user.PreferredUsername = userInfo.Login - user.Name = userInfo.Name - - return user, nil -} diff --git a/internal/providers/google.go b/internal/providers/google.go deleted file mode 100644 index e794bee..0000000 --- a/internal/providers/google.go +++ /dev/null @@ -1,56 +0,0 @@ -package providers - -import ( - "encoding/json" - "io" - "net/http" - "strings" - "tinyauth/internal/constants" - - "github.com/rs/zerolog/log" -) - -// Response for the google user endpoint -type GoogleUserInfoResponse struct { - Email string `json:"email"` - Name string `json:"name"` -} - -// The scopes required for the google provider -func GoogleScopes() []string { - return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} -} - -func GetGoogleUser(client *http.Client) (constants.Claims, error) { - var user constants.Claims - - res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") - if err != nil { - return user, err - } - defer res.Body.Close() - - log.Debug().Msg("Got response from google") - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - log.Debug().Msg("Read body from google") - - var userInfo GoogleUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - log.Debug().Msg("Parsed user from google") - - user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] - user.Name = userInfo.Name - user.Email = userInfo.Email - - return user, nil -} diff --git a/internal/providers/providers.go b/internal/providers/providers.go deleted file mode 100644 index 7af127e..0000000 --- a/internal/providers/providers.go +++ /dev/null @@ -1,154 +0,0 @@ -package providers - -import ( - "fmt" - "tinyauth/internal/constants" - "tinyauth/internal/oauth" - "tinyauth/internal/types" - - "github.com/rs/zerolog/log" - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -type Providers struct { - Config types.OAuthConfig - Github *oauth.OAuth - Google *oauth.OAuth - Generic *oauth.OAuth -} - -func NewProviders(config types.OAuthConfig) *Providers { - providers := &Providers{ - Config: config, - } - - if config.GithubClientId != "" && config.GithubClientSecret != "" { - log.Info().Msg("Initializing Github OAuth") - providers.Github = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GithubClientId, - ClientSecret: config.GithubClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL), - Scopes: GithubScopes(), - Endpoint: endpoints.GitHub, - }, false) - } - - if config.GoogleClientId != "" && config.GoogleClientSecret != "" { - log.Info().Msg("Initializing Google OAuth") - providers.Google = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GoogleClientId, - ClientSecret: config.GoogleClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL), - Scopes: GoogleScopes(), - Endpoint: endpoints.Google, - }, false) - } - - if config.GenericClientId != "" && config.GenericClientSecret != "" { - log.Info().Msg("Initializing Generic OAuth") - providers.Generic = oauth.NewOAuth(oauth2.Config{ - ClientID: config.GenericClientId, - ClientSecret: config.GenericClientSecret, - RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL), - Scopes: config.GenericScopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.GenericAuthURL, - TokenURL: config.GenericTokenURL, - }, - }, config.GenericSkipSSL) - } - - return providers -} - -func (providers *Providers) GetProvider(provider string) *oauth.OAuth { - switch provider { - case "github": - return providers.Github - case "google": - return providers.Google - case "generic": - return providers.Generic - default: - return nil - } -} - -func (providers *Providers) GetUser(provider string) (constants.Claims, error) { - var user constants.Claims - - // Get the user from the provider - switch provider { - case "github": - if providers.Github == nil { - log.Debug().Msg("Github provider not configured") - return user, nil - } - - client := providers.Github.GetClient() - - log.Debug().Msg("Got client from github") - - user, err := GetGithubUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from github") - - return user, nil - case "google": - if providers.Google == nil { - log.Debug().Msg("Google provider not configured") - return user, nil - } - - client := providers.Google.GetClient() - - log.Debug().Msg("Got client from google") - - user, err := GetGoogleUser(client) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from google") - - return user, nil - case "generic": - if providers.Generic == nil { - log.Debug().Msg("Generic provider not configured") - return user, nil - } - - client := providers.Generic.GetClient() - - log.Debug().Msg("Got client from generic") - - user, err := GetGenericUser(client, providers.Config.GenericUserURL) - if err != nil { - return user, err - } - - log.Debug().Msg("Got user from generic") - - return user, nil - default: - return user, nil - } -} - -func (provider *Providers) GetConfiguredProviders() []string { - providers := []string{} - if provider.Github != nil { - providers = append(providers, "github") - } - if provider.Google != nil { - providers = append(providers, "google") - } - if provider.Generic != nil { - providers = append(providers, "generic") - } - return providers -} diff --git a/internal/auth/auth.go b/internal/service/auth_service.go similarity index 83% rename from internal/auth/auth.go rename to internal/service/auth_service.go index 3f18419..ebbd1ad 100644 --- a/internal/auth/auth.go +++ b/internal/service/auth_service.go @@ -1,4 +1,4 @@ -package auth +package service import ( "fmt" @@ -6,8 +6,6 @@ import ( "strings" "sync" "time" - "tinyauth/internal/docker" - "tinyauth/internal/ldap" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -17,35 +15,50 @@ import ( "golang.org/x/crypto/bcrypt" ) -type Auth struct { - Config types.AuthConfig - Docker *docker.Docker +type AuthServiceConfig struct { + Users types.Users + OauthWhitelist string + SessionExpiry int + CookieSecure bool + Domain string + LoginTimeout int + LoginMaxRetries int + SessionCookieName string + HMACSecret string + EncryptionSecret string +} + +type AuthService struct { + Config AuthServiceConfig + Docker *DockerService LoginAttempts map[string]*types.LoginAttempt LoginMutex sync.RWMutex Store *sessions.CookieStore - LDAP *ldap.LDAP + LDAP *LdapService } -func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { - // Setup cookie store and create the auth service - store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) - store.Options = &sessions.Options{ - Path: "/", - MaxAge: config.SessionExpiry, - Secure: config.CookieSecure, - HttpOnly: true, - Domain: fmt.Sprintf(".%s", config.Domain), - } - return &Auth{ +func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService { + return &AuthService{ Config: config, Docker: docker, LoginAttempts: make(map[string]*types.LoginAttempt), - Store: store, LDAP: ldap, } } -func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { +func (auth *AuthService) Init() error { + store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) + store.Options = &sessions.Options{ + Path: "/", + MaxAge: auth.Config.SessionExpiry, + Secure: auth.Config.CookieSecure, + HttpOnly: true, + Domain: fmt.Sprintf(".%s", auth.Config.Domain), + } + return nil +} + +func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) // If there was an error getting the session, it might be invalid so let's clear it and retry @@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { return session, nil } -func (auth *Auth) SearchUser(username string) types.UserSearch { +func (auth *AuthService) SearchUser(username string) types.UserSearch { log.Debug().Str("username", username).Msg("Searching for user") // Check local users first @@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { } } -func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { +func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { // Authenticate the user based on the type switch search.Type { case "local": @@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { return false } -func (auth *Auth) GetLocalUser(username string) types.User { +func (auth *AuthService) GetLocalUser(username string) types.User { // Loop through users and return the user if the username matches log.Debug().Str("username", username).Msg("Searching for local user") @@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User { return types.User{} } -func (auth *Auth) CheckPassword(user types.User, password string) bool { +func (auth *AuthService) CheckPassword(user types.User, password string) bool { return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil } -func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { +func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { auth.LoginMutex.RLock() defer auth.LoginMutex.RUnlock() @@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { return false, 0 } -func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { +func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { // Skip if rate limiting is not configured if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { return @@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { } } -func (auth *Auth) EmailWhitelisted(email string) bool { +func (auth *AuthService) EmailWhitelisted(email string) bool { return utils.CheckFilter(auth.Config.OauthWhitelist, email) } -func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { +func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { log.Debug().Msg("Creating session cookie") session, err := auth.GetSession(c) @@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) return nil } -func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { +func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { log.Debug().Msg("Deleting session cookie") session, err := auth.GetSession(c) @@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { return nil } -func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { +func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { log.Debug().Msg("Getting session cookie") session, err := auth.GetSession(c) @@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) }, nil } -func (auth *Auth) UserAuthConfigured() bool { +func (auth *AuthService) UserAuthConfigured() bool { // If there are users or LDAP is configured, return true return len(auth.Config.Users) > 0 || auth.LDAP != nil } -func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { if context.OAuth { log.Debug().Msg("Checking OAuth whitelist") return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) @@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab return utils.CheckFilter(labels.Users, context.Username) } -func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { +func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { if labels.OAuth.Groups == "" { return true } @@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t return false } -func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { +func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { // If the label is empty, auth is enabled if labels.Allowed == "" { return true, nil @@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { return true, nil } -func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { +func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { username, password, ok := c.Request.BasicAuth() if !ok { return nil @@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { } } -func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { +func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { // Check if the IP is in block list for _, blocked := range labels.IP.Block { res, err := utils.FilterIP(blocked, ip) @@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { return true } -func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { +func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { // For every IP in the bypass list, check if the IP matches for _, bypassed := range labels.IP.Bypass { res, err := utils.FilterIP(bypassed, ip) diff --git a/internal/docker/docker.go b/internal/service/docker_service.go similarity index 76% rename from internal/docker/docker.go rename to internal/service/docker_service.go index f5a0468..f067d7f 100644 --- a/internal/docker/docker.go +++ b/internal/service/docker_service.go @@ -1,9 +1,9 @@ -package docker +package service import ( "context" "strings" - "tinyauth/internal/types" + "tinyauth/internal/config" "tinyauth/internal/utils" container "github.com/docker/docker/api/types/container" @@ -11,27 +11,27 @@ import ( "github.com/rs/zerolog/log" ) -type Docker struct { +type DockerService struct { Client *client.Client Context context.Context } -func NewDocker() (*Docker, error) { +func NewDockerService() *DockerService { + return &DockerService{} +} + +func (docker *DockerService) Init() error { client, err := client.NewClientWithOpts(client.FromEnv) if err != nil { - return nil, err + return err } ctx := context.Background() client.NegotiateAPIVersion(ctx) - - return &Docker{ - Client: client, - Context: ctx, - }, nil + return nil } -func (docker *Docker) GetContainers() ([]container.Summary, error) { +func (docker *DockerService) GetContainers() ([]container.Summary, error) { containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) if err != nil { return nil, err @@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) { return containers, nil } -func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { +func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) if err != nil { return container.InspectResponse{}, err @@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes return inspect, nil } -func (docker *Docker) DockerConnected() bool { +func (docker *DockerService) DockerConnected() bool { _, err := docker.Client.Ping(docker.Context) return err == nil } -func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { +func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { isConnected := docker.DockerConnected() if !isConnected { log.Debug().Msg("Docker not connected, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } log.Debug().Msg("Getting containers") @@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) containers, err := docker.GetContainers() if err != nil { log.Error().Err(err).Msg("Error getting containers") - return types.Labels{}, err + return config.Labels{}, err } for _, container := range containers { @@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) } log.Debug().Msg("No matching container found, returning empty labels") - return types.Labels{}, nil + return config.Labels{}, nil } diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go new file mode 100644 index 0000000..9bd6a8e --- /dev/null +++ b/internal/service/generic_oauth_service.go @@ -0,0 +1,114 @@ +package service + +import ( + "context" + "crypto/rand" + "crypto/tls" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +type GenericOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string + InsecureSkipVerify bool + ServiceName string + UserinfoURL string +} + +func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { + return &GenericOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + InsecureSkipVerify: config.InsecureSkipVerify, + ServiceName: config.Name, + UserinfoURL: config.UserinfoURL, + } +} + +func (generic *GenericOAuthService) Init() error { + transport := &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: generic.InsecureSkipVerify, + MinVersion: tls.VersionTLS12, + }, + } + + httpClient := &http.Client{ + Transport: transport, + } + + ctx := context.Background() + + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + generic.Context = ctx + generic.Verifier = verifier + return nil +} + +func (generic *GenericOAuthService) Name() string { + return generic.ServiceName +} + +func (generic *GenericOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (generic *GenericOAuthService) GetAuthURL(state string) string { + return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) +} + +func (generic *GenericOAuthService) VerifyCode(code string) error { + token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) + + if err != nil { + return nil + } + + generic.Token = token + return nil +} + +func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := generic.Config.Client(generic.Context, generic.Token) + + res, err := client.Get(generic.UserinfoURL) + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + err = json.Unmarshal(body, &user) + if err != nil { + return user, err + } + + return user, nil +} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go new file mode 100644 index 0000000..57d8391 --- /dev/null +++ b/internal/service/github_oauth_service.go @@ -0,0 +1,144 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "errors" + "io" + "net/http" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +var GithubOAuthScopes = []string{"user:email", "read:user"} + +type GithubEmailResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +type GithubUserInfoResponse struct { + Login string `json:"login"` + Name string `json:"name"` +} + +type GithubOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { + return &GithubOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GithubOAuthScopes, + }, + } +} + +func (github *GithubOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + github.Context = ctx + github.Verifier = verifier + return nil +} + +func (github *GithubOAuthService) Name() string { + return "github" +} + +func (github *GithubOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (github *GithubOAuthService) GetAuthURL(state string) string { + return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) +} + +func (github *GithubOAuthService) VerifyCode(code string) error { + token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) + + if err != nil { + return nil + } + + github.Token = token + return nil +} + +func (github *GithubOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := github.Config.Client(github.Context, github.Token) + + res, err := client.Get("https://api.github.com/user") + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var userInfo GithubUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return user, err + } + + res, err = client.Get("https://api.github.com/user/emails") + if err != nil { + return user, err + } + defer res.Body.Close() + + body, err = io.ReadAll(res.Body) + if err != nil { + return user, err + } + + var emails GithubEmailResponse + + err = json.Unmarshal(body, &emails) + if err != nil { + return user, err + } + + for _, email := range emails { + if email.Primary { + user.Email = email.Email + break + } + } + + if len(emails) == 0 { + return user, errors.New("no emails found") + } + + // Use first available email if no primary email was found + if user.Email == "" { + user.Email = emails[0].Email + } + + user.PreferredUsername = userInfo.Login + user.Name = userInfo.Name + + return user, nil +} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go new file mode 100644 index 0000000..2d86a56 --- /dev/null +++ b/internal/service/google_oauth_service.go @@ -0,0 +1,106 @@ +package service + +import ( + "context" + "crypto/rand" + "encoding/base64" + "encoding/json" + "io" + "net/http" + "strings" + "tinyauth/internal/config" + + "golang.org/x/oauth2" +) + +var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} + +type GoogleUserInfoResponse struct { + Email string `json:"email"` + Name string `json:"name"` +} + +type GoogleOAuthService struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { + return &GoogleOAuthService{ + Config: oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: GoogleOAuthScopes, + }, + } +} + +func (google *GoogleOAuthService) Init() error { + httpClient := &http.Client{} + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + verifier := oauth2.GenerateVerifier() + + google.Context = ctx + google.Verifier = verifier + return nil +} + +func (google *GoogleOAuthService) Name() string { + return "google" +} + +func (oauth *GoogleOAuthService) GenerateState() string { + b := make([]byte, 128) + rand.Read(b) + state := base64.URLEncoding.EncodeToString(b) + return state +} + +func (google *GoogleOAuthService) GetAuthURL(state string) string { + return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) +} + +func (google *GoogleOAuthService) VerifyCode(code string) error { + token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) + + if err != nil { + return nil + } + + google.Token = token + return nil +} + +func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { + var user config.Claims + + client := google.Config.Client(google.Context, google.Token) + + res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") + if err != nil { + return config.Claims{}, err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + return config.Claims{}, err + } + + var userInfo GoogleUserInfoResponse + + err = json.Unmarshal(body, &userInfo) + if err != nil { + return config.Claims{}, err + } + + user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] + user.Name = userInfo.Name + user.Email = userInfo.Email + + return user, nil +} diff --git a/internal/ldap/ldap.go b/internal/service/ldap_service.go similarity index 64% rename from internal/ldap/ldap.go rename to internal/service/ldap_service.go index 61578d7..805e2f7 100644 --- a/internal/ldap/ldap.go +++ b/internal/service/ldap_service.go @@ -1,30 +1,40 @@ -package ldap +package service import ( "context" "crypto/tls" "fmt" "time" - "tinyauth/internal/types" "github.com/cenkalti/backoff/v5" ldapgo "github.com/go-ldap/ldap/v3" "github.com/rs/zerolog/log" ) -type LDAP struct { - Config types.LdapConfig +type LdapServiceConfig struct { + Address string + BindDN string + BindPassword string + BaseDN string + Insecure bool + SearchFilter string +} + +type LdapService struct { + Config LdapServiceConfig Conn *ldapgo.Conn } -func NewLDAP(config types.LdapConfig) (*LDAP, error) { - ldap := &LDAP{ +func NewLdapService(config LdapServiceConfig) *LdapService { + return &LdapService{ Config: config, } +} +func (ldap *LdapService) Init() error { _, err := ldap.connect() if err != nil { - return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) + return fmt.Errorf("failed to connect to LDAP server: %w", err) } go func() { @@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { } }() - return ldap, nil + return nil } -func (l *LDAP) connect() (*ldapgo.Conn, error) { +func (ldap *LdapService) connect() (*ldapgo.Conn, error) { log.Debug().Msg("Connecting to LDAP server") - conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ - InsecureSkipVerify: l.Config.Insecure, + conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ + InsecureSkipVerify: ldap.Config.Insecure, MinVersion: tls.VersionTLS12, })) if err != nil { @@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) { } log.Debug().Msg("Binding to LDAP server") - err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) + err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) if err != nil { return nil, err } // Set and return the connection - l.Conn = conn + ldap.Conn = conn return conn, nil } -func (l *LDAP) Search(username string) (string, error) { +func (ldap *LdapService) Search(username string) (string, error) { // Escape the username to prevent LDAP injection escapedUsername := ldapgo.EscapeFilter(username) - filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) + filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername) searchRequest := ldapgo.NewSearchRequest( - l.Config.BaseDN, + ldap.Config.BaseDN, ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, filter, []string{"dn"}, nil, ) - searchResult, err := l.Conn.Search(searchRequest) + searchResult, err := ldap.Conn.Search(searchRequest) if err != nil { return "", err } @@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) { return userDN, nil } -func (l *LDAP) Bind(userDN string, password string) error { - err := l.Conn.Bind(userDN, password) +func (ldap *LdapService) Bind(userDN string, password string) error { + err := ldap.Conn.Bind(userDN, password) if err != nil { return err } return nil } -func (l *LDAP) heartbeat() error { +func (ldap *LdapService) heartbeat() error { log.Debug().Msg("Performing LDAP connection heartbeat") searchRequest := ldapgo.NewSearchRequest( @@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error { nil, ) - _, err := l.Conn.Search(searchRequest) + _, err := ldap.Conn.Search(searchRequest) if err != nil { return err } @@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error { return nil } -func (l *LDAP) reconnect() error { +func (ldap *LdapService) reconnect() error { log.Info().Msg("Reconnecting to LDAP server") exp := backoff.NewExponentialBackOff() @@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error { exp.Reset() operation := func() (*ldapgo.Conn, error) { - l.Conn.Close() - conn, err := l.connect() + ldap.Conn.Close() + conn, err := ldap.connect() if err != nil { return nil, nil } diff --git a/internal/types/types.go b/internal/types/types.go deleted file mode 100644 index 1cb6bed..0000000 --- a/internal/types/types.go +++ /dev/null @@ -1,70 +0,0 @@ -package types - -import ( - "time" - "tinyauth/internal/oauth" -) - -// User is the struct for a user -type User struct { - Username string - Password string - TotpSecret string -} - -// UserSearch is the response of the get user -type UserSearch struct { - Username string - Type string // "local", "ldap" or empty -} - -// Users is a list of users -type Users []User - -// OAuthProviders is the struct for the OAuth providers -type OAuthProviders struct { - Github *oauth.OAuth - Google *oauth.OAuth - Microsoft *oauth.OAuth -} - -// SessionCookie is the cookie for the session (exculding the expiry) -type SessionCookie struct { - Username string - Name string - Email string - Provider string - TotpPending bool - OAuthGroups string -} - -// UserContext is the context for the user -type UserContext struct { - Username string - Name string - Email string - IsLoggedIn bool - OAuth bool - Provider string - TotpPending bool - OAuthGroups string - TotpEnabled bool -} - -// LoginAttempt tracks information about login attempts for rate limiting -type LoginAttempt struct { - FailedAttempts int - LastAttempt time.Time - LockedUntil time.Time -} - -type UnauthorizedQuery struct { - Username string `url:"username"` - Resource string `url:"resource"` - GroupErr bool `url:"groupErr"` - IP string `url:"ip"` -} - -type RedirectQuery struct { - RedirectURI string `url:"redirect_uri"` -}