refactor: move oauth providers into services (non-working)

This commit is contained in:
Stavros
2025-08-25 17:50:34 +03:00
parent dfdc656145
commit 44f35af3bf
15 changed files with 544 additions and 778 deletions

View File

@@ -1,146 +0,0 @@
package auth_test
import (
"testing"
"time"
"tinyauth/internal/auth"
"tinyauth/internal/types"
)
var config = types.AuthConfig{
Users: types.Users{},
OauthWhitelist: "",
SessionExpiry: 3600,
}
func TestLoginRateLimiting(t *testing.T) {
// Initialize a new auth service with 3 max retries and 5 seconds timeout
config.LoginMaxRetries = 3
config.LoginTimeout = 5
authService := auth.NewAuth(config, nil, nil)
// Test identifier
identifier := "test_user"
// Test successful login - should not lock account
t.Log("Testing successful login")
authService.RecordLoginAttempt(identifier, true)
locked, _ := authService.IsAccountLocked(identifier)
if locked {
t.Fatalf("Account should not be locked after successful login")
}
// Test 2 failed attempts - should not lock account yet
t.Log("Testing 2 failed login attempts")
authService.RecordLoginAttempt(identifier, false)
authService.RecordLoginAttempt(identifier, false)
locked, _ = authService.IsAccountLocked(identifier)
if locked {
t.Fatalf("Account should not be locked after only 2 failed attempts")
}
// Add one more failed attempt (total 3) - should lock account with maxRetries=3
t.Log("Testing 3 failed login attempts")
authService.RecordLoginAttempt(identifier, false)
locked, remainingTime := authService.IsAccountLocked(identifier)
if !locked {
t.Fatalf("Account should be locked after reaching max retries")
}
if remainingTime <= 0 || remainingTime > 5 {
t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime)
}
// Test reset after waiting for timeout - use 1 second timeout for fast testing
t.Log("Testing unlocking after timeout")
// Reinitialize auth service with a shorter timeout for testing
config.LoginTimeout = 1
config.LoginMaxRetries = 3
authService = auth.NewAuth(config, nil, nil)
// Add enough failed attempts to lock the account
for i := 0; i < 3; i++ {
authService.RecordLoginAttempt(identifier, false)
}
// Verify it's locked
locked, _ = authService.IsAccountLocked(identifier)
if !locked {
t.Fatalf("Account should be locked initially")
}
// Wait a bit and verify it gets unlocked after timeout
time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout
locked, _ = authService.IsAccountLocked(identifier)
if locked {
t.Fatalf("Account should be unlocked after timeout period")
}
// Test disabled rate limiting
t.Log("Testing disabled rate limiting")
config.LoginMaxRetries = 0
config.LoginTimeout = 0
authService = auth.NewAuth(config, nil, nil)
for i := 0; i < 10; i++ {
authService.RecordLoginAttempt(identifier, false)
}
locked, _ = authService.IsAccountLocked(identifier)
if locked {
t.Fatalf("Account should not be locked when rate limiting is disabled")
}
}
func TestConcurrentLoginAttempts(t *testing.T) {
// Initialize a new auth service with 2 max retries and 5 seconds timeout
config.LoginMaxRetries = 2
config.LoginTimeout = 5
authService := auth.NewAuth(config, nil, nil)
// Test multiple identifiers
identifiers := []string{"user1", "user2", "user3"}
// Test that locking one identifier doesn't affect others
t.Log("Testing multiple identifiers")
// Add enough failed attempts to lock first user (2 attempts with maxRetries=2)
authService.RecordLoginAttempt(identifiers[0], false)
authService.RecordLoginAttempt(identifiers[0], false)
// Check if first user is locked
locked, _ := authService.IsAccountLocked(identifiers[0])
if !locked {
t.Fatalf("User1 should be locked after reaching max retries")
}
// Check that other users are not affected
for i := 1; i < len(identifiers); i++ {
locked, _ := authService.IsAccountLocked(identifiers[i])
if locked {
t.Fatalf("User%d should not be locked", i+1)
}
}
// Test successful login after failed attempts (but before lock)
t.Log("Testing successful login after failed attempts but before lock")
// One failed attempt for user2
authService.RecordLoginAttempt(identifiers[1], false)
// Successful login should reset the counter
authService.RecordLoginAttempt(identifiers[1], true)
// Now try a failed login again - should not be locked as counter was reset
authService.RecordLoginAttempt(identifiers[1], false)
locked, _ = authService.IsAccountLocked(identifiers[1])
if locked {
t.Fatalf("User2 should not be locked after successful login reset")
}
}

View File

@@ -1,6 +1,22 @@
package types
package config
import "time"
type Claims struct {
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
var Version = "development"
var CommitHash = "n/a"
var BuildTimestamp = "n/a"
var SessionCookieName = "tinyauth-session"
var CsrfCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"
// Config is the configuration for the tinyauth server
type Config struct {
Port int `mapstructure:"port" validate:"required"`
Address string `validate:"required,ip4_addr" mapstructure:"address"`
@@ -44,62 +60,27 @@ type Config struct {
LdapSearchFilter string `mapstructure:"ldap-search-filter"`
}
// OAuthConfig is the configuration for the providers
type OAuthConfig struct {
GithubClientId string
GithubClientSecret string
GoogleClientId string
GoogleClientSecret string
GenericClientId string
GenericClientSecret string
GenericScopes []string
GenericAuthURL string
GenericTokenURL string
GenericUserURL string
GenericSkipSSL bool
AppURL string
}
// AuthConfig is the configuration for the auth service
type AuthConfig struct {
Users Users
OauthWhitelist string
SessionExpiry int
CookieSecure bool
Domain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
HMACSecret string
EncryptionSecret string
}
// OAuthLabels is a list of labels that can be used in a tinyauth protected container
type OAuthLabels struct {
Whitelist string
Groups string
}
// Basic auth labels for a tinyauth protected container
type BasicLabels struct {
Username string
Password PassowrdLabels
}
// PassowrdLabels is a struct that contains the password labels for a tinyauth protected container
type PassowrdLabels struct {
Plain string
File string
}
// IP labels for a tinyauth protected container
type IPLabels struct {
Allow []string
Block []string
Bypass []string
}
// Labels is a struct that contains the labels for a tinyauth protected container
type Labels struct {
Users string
Allowed string
@@ -110,12 +91,65 @@ type Labels struct {
IP IPLabels
}
// Ldap config is a struct that contains the configuration for the LDAP service
type LdapConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
type OAuthServiceConfig struct {
ClientID string
ClientSecret string
Scopes []string
RedirectURL string
AuthURL string
TokenURL string
UserinfoURL string
InsecureSkipVerify bool
Name string
}
type User struct {
Username string
Password string
TotpSecret string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type Users []User
type SessionCookie struct {
Username string
Name string
Email string
Provider string
TotpPending bool
OAuthGroups string
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
TotpEnabled bool
}
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}

View File

@@ -1,19 +0,0 @@
package constants
// Claims are the OIDC supported claims (prefered username is included for convinience)
type Claims struct {
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
// Version information
var Version = "development"
var CommitHash = "n/a"
var BuildTimestamp = "n/a"
// Base cookie names
var SessionCookieName = "tinyauth-session"
var CsrfCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"

View File

@@ -1,71 +0,0 @@
package oauth
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"net/http"
"golang.org/x/oauth2"
)
type OAuth struct {
Config oauth2.Config
Context context.Context
Token *oauth2.Token
Verifier string
}
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: insecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
httpClient := &http.Client{
Transport: transport,
}
ctx := context.Background()
// Set the HTTP client in the context
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
verifier := oauth2.GenerateVerifier()
return &OAuth{
Config: config,
Context: ctx,
Verifier: verifier,
}
}
func (oauth *OAuth) GetAuthURL(state string) string {
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
}
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
if err != nil {
return "", err
}
// Set and return the token
oauth.Token = token
return oauth.Token.AccessToken, nil
}
func (oauth *OAuth) GetClient() *http.Client {
return oauth.Config.Client(oauth.Context, oauth.Token)
}
func (oauth *OAuth) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
return state
}

View File

@@ -1,37 +0,0 @@
package providers
import (
"encoding/json"
"io"
"net/http"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
func GetGenericUser(client *http.Client, url string) (constants.Claims, error) {
var user constants.Claims
res, err := client.Get(url)
if err != nil {
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got response from generic provider")
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
log.Debug().Msg("Read body from generic provider")
err = json.Unmarshal(body, &user)
if err != nil {
return user, err
}
log.Debug().Msg("Parsed user from generic provider")
return user, nil
}

View File

@@ -1,102 +0,0 @@
package providers
import (
"encoding/json"
"errors"
"io"
"net/http"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
// Response for the github email endpoint
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
// Response for the github user endpoint
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
}
// The scopes required for the github provider
func GithubScopes() []string {
return []string{"user:email", "read:user"}
}
func GetGithubUser(client *http.Client) (constants.Claims, error) {
var user constants.Claims
res, err := client.Get("https://api.github.com/user")
if err != nil {
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got user response from github")
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
log.Debug().Msg("Read user body from github")
var userInfo GithubUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
res, err = client.Get("https://api.github.com/user/emails")
if err != nil {
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got email response from github")
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err
}
log.Debug().Msg("Read email body from github")
var emails GithubEmailResponse
err = json.Unmarshal(body, &emails)
if err != nil {
return user, err
}
log.Debug().Msg("Parsed emails from github")
// Find and return the primary email
for _, email := range emails {
if email.Primary {
log.Debug().Str("email", email.Email).Msg("Found primary email")
user.Email = email.Email
break
}
}
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Use first available email if no primary email was found
if user.Email == "" {
log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email")
user.Email = emails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
return user, nil
}

View File

@@ -1,56 +0,0 @@
package providers
import (
"encoding/json"
"io"
"net/http"
"strings"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
// Response for the google user endpoint
type GoogleUserInfoResponse struct {
Email string `json:"email"`
Name string `json:"name"`
}
// The scopes required for the google provider
func GoogleScopes() []string {
return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
}
func GetGoogleUser(client *http.Client) (constants.Claims, error) {
var user constants.Claims
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
if err != nil {
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got response from google")
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
log.Debug().Msg("Read body from google")
var userInfo GoogleUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
log.Debug().Msg("Parsed user from google")
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
user.Name = userInfo.Name
user.Email = userInfo.Email
return user, nil
}

View File

@@ -1,154 +0,0 @@
package providers
import (
"fmt"
"tinyauth/internal/constants"
"tinyauth/internal/oauth"
"tinyauth/internal/types"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
type Providers struct {
Config types.OAuthConfig
Github *oauth.OAuth
Google *oauth.OAuth
Generic *oauth.OAuth
}
func NewProviders(config types.OAuthConfig) *Providers {
providers := &Providers{
Config: config,
}
if config.GithubClientId != "" && config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth")
providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: config.GithubClientId,
ClientSecret: config.GithubClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL),
Scopes: GithubScopes(),
Endpoint: endpoints.GitHub,
}, false)
}
if config.GoogleClientId != "" && config.GoogleClientSecret != "" {
log.Info().Msg("Initializing Google OAuth")
providers.Google = oauth.NewOAuth(oauth2.Config{
ClientID: config.GoogleClientId,
ClientSecret: config.GoogleClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL),
Scopes: GoogleScopes(),
Endpoint: endpoints.Google,
}, false)
}
if config.GenericClientId != "" && config.GenericClientSecret != "" {
log.Info().Msg("Initializing Generic OAuth")
providers.Generic = oauth.NewOAuth(oauth2.Config{
ClientID: config.GenericClientId,
ClientSecret: config.GenericClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL),
Scopes: config.GenericScopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.GenericAuthURL,
TokenURL: config.GenericTokenURL,
},
}, config.GenericSkipSSL)
}
return providers
}
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
switch provider {
case "github":
return providers.Github
case "google":
return providers.Google
case "generic":
return providers.Generic
default:
return nil
}
}
func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
var user constants.Claims
// Get the user from the provider
switch provider {
case "github":
if providers.Github == nil {
log.Debug().Msg("Github provider not configured")
return user, nil
}
client := providers.Github.GetClient()
log.Debug().Msg("Got client from github")
user, err := GetGithubUser(client)
if err != nil {
return user, err
}
log.Debug().Msg("Got user from github")
return user, nil
case "google":
if providers.Google == nil {
log.Debug().Msg("Google provider not configured")
return user, nil
}
client := providers.Google.GetClient()
log.Debug().Msg("Got client from google")
user, err := GetGoogleUser(client)
if err != nil {
return user, err
}
log.Debug().Msg("Got user from google")
return user, nil
case "generic":
if providers.Generic == nil {
log.Debug().Msg("Generic provider not configured")
return user, nil
}
client := providers.Generic.GetClient()
log.Debug().Msg("Got client from generic")
user, err := GetGenericUser(client, providers.Config.GenericUserURL)
if err != nil {
return user, err
}
log.Debug().Msg("Got user from generic")
return user, nil
default:
return user, nil
}
}
func (provider *Providers) GetConfiguredProviders() []string {
providers := []string{}
if provider.Github != nil {
providers = append(providers, "github")
}
if provider.Google != nil {
providers = append(providers, "google")
}
if provider.Generic != nil {
providers = append(providers, "generic")
}
return providers
}

View File

@@ -1,4 +1,4 @@
package auth
package service
import (
"fmt"
@@ -6,8 +6,6 @@ import (
"strings"
"sync"
"time"
"tinyauth/internal/docker"
"tinyauth/internal/ldap"
"tinyauth/internal/types"
"tinyauth/internal/utils"
@@ -17,35 +15,50 @@ import (
"golang.org/x/crypto/bcrypt"
)
type Auth struct {
Config types.AuthConfig
Docker *docker.Docker
type AuthServiceConfig struct {
Users types.Users
OauthWhitelist string
SessionExpiry int
CookieSecure bool
Domain string
LoginTimeout int
LoginMaxRetries int
SessionCookieName string
HMACSecret string
EncryptionSecret string
}
type AuthService struct {
Config AuthServiceConfig
Docker *DockerService
LoginAttempts map[string]*types.LoginAttempt
LoginMutex sync.RWMutex
Store *sessions.CookieStore
LDAP *ldap.LDAP
LDAP *LdapService
}
func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth {
// Setup cookie store and create the auth service
store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret))
store.Options = &sessions.Options{
Path: "/",
MaxAge: config.SessionExpiry,
Secure: config.CookieSecure,
HttpOnly: true,
Domain: fmt.Sprintf(".%s", config.Domain),
}
return &Auth{
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService {
return &AuthService{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
Store: store,
LDAP: ldap,
}
}
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
func (auth *AuthService) Init() error {
store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret))
store.Options = &sessions.Options{
Path: "/",
MaxAge: auth.Config.SessionExpiry,
Secure: auth.Config.CookieSecure,
HttpOnly: true,
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
}
return nil
}
func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName)
// If there was an error getting the session, it might be invalid so let's clear it and retry
@@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
return session, nil
}
func (auth *Auth) SearchUser(username string) types.UserSearch {
func (auth *AuthService) SearchUser(username string) types.UserSearch {
log.Debug().Str("username", username).Msg("Searching for user")
// Check local users first
@@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
}
}
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool {
// Authenticate the user based on the type
switch search.Type {
case "local":
@@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
return false
}
func (auth *Auth) GetLocalUser(username string) types.User {
func (auth *AuthService) GetLocalUser(username string) types.User {
// Loop through users and return the user if the username matches
log.Debug().Str("username", username).Msg("Searching for local user")
@@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User {
return types.User{}
}
func (auth *Auth) CheckPassword(user types.User, password string) bool {
func (auth *AuthService) CheckPassword(user types.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.LoginMutex.RLock()
defer auth.LoginMutex.RUnlock()
@@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
return false, 0
}
func (auth *Auth) RecordLoginAttempt(identifier string, success bool) {
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
// Skip if rate limiting is not configured
if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 {
return
@@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) {
}
}
func (auth *Auth) EmailWhitelisted(email string) bool {
func (auth *AuthService) EmailWhitelisted(email string) bool {
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
}
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
log.Debug().Msg("Creating session cookie")
session, err := auth.GetSession(c)
@@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
return nil
}
func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
log.Debug().Msg("Deleting session cookie")
session, err := auth.GetSession(c)
@@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
return nil
}
func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
log.Debug().Msg("Getting session cookie")
session, err := auth.GetSession(c)
@@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
}, nil
}
func (auth *Auth) UserAuthConfigured() bool {
func (auth *AuthService) UserAuthConfigured() bool {
// If there are users or LDAP is configured, return true
return len(auth.Config.Users) > 0 || auth.LDAP != nil
}
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
@@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab
return utils.CheckFilter(labels.Users, context.Username)
}
func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
if labels.OAuth.Groups == "" {
return true
}
@@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t
return false
}
func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) {
// If the label is empty, auth is enabled
if labels.Allowed == "" {
return true, nil
@@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
return true, nil
}
func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User {
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User {
username, password, ok := c.Request.BasicAuth()
if !ok {
return nil
@@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User {
}
}
func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
// Check if the IP is in block list
for _, blocked := range labels.IP.Block {
res, err := utils.FilterIP(blocked, ip)
@@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
return true
}
func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool {
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool {
// For every IP in the bypass list, check if the IP matches
for _, bypassed := range labels.IP.Bypass {
res, err := utils.FilterIP(bypassed, ip)

View File

@@ -1,9 +1,9 @@
package docker
package service
import (
"context"
"strings"
"tinyauth/internal/types"
"tinyauth/internal/config"
"tinyauth/internal/utils"
container "github.com/docker/docker/api/types/container"
@@ -11,27 +11,27 @@ import (
"github.com/rs/zerolog/log"
)
type Docker struct {
type DockerService struct {
Client *client.Client
Context context.Context
}
func NewDocker() (*Docker, error) {
func NewDockerService() *DockerService {
return &DockerService{}
}
func (docker *DockerService) Init() error {
client, err := client.NewClientWithOpts(client.FromEnv)
if err != nil {
return nil, err
return err
}
ctx := context.Background()
client.NegotiateAPIVersion(ctx)
return &Docker{
Client: client,
Context: ctx,
}, nil
return nil
}
func (docker *Docker) GetContainers() ([]container.Summary, error) {
func (docker *DockerService) GetContainers() ([]container.Summary, error) {
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
if err != nil {
return nil, err
@@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) {
return containers, nil
}
func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) {
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
if err != nil {
return container.InspectResponse{}, err
@@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes
return inspect, nil
}
func (docker *Docker) DockerConnected() bool {
func (docker *DockerService) DockerConnected() bool {
_, err := docker.Client.Ping(docker.Context)
return err == nil
}
func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) {
func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) {
isConnected := docker.DockerConnected()
if !isConnected {
log.Debug().Msg("Docker not connected, returning empty labels")
return types.Labels{}, nil
return config.Labels{}, nil
}
log.Debug().Msg("Getting containers")
@@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error)
containers, err := docker.GetContainers()
if err != nil {
log.Error().Err(err).Msg("Error getting containers")
return types.Labels{}, err
return config.Labels{}, err
}
for _, container := range containers {
@@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error)
}
log.Debug().Msg("No matching container found, returning empty labels")
return types.Labels{}, nil
return config.Labels{}, nil
}

View File

@@ -0,0 +1,114 @@
package service
import (
"context"
"crypto/rand"
"crypto/tls"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"tinyauth/internal/config"
"golang.org/x/oauth2"
)
type GenericOAuthService struct {
Config oauth2.Config
Context context.Context
Token *oauth2.Token
Verifier string
InsecureSkipVerify bool
ServiceName string
UserinfoURL string
}
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
return &GenericOAuthService{
Config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: config.Scopes,
Endpoint: oauth2.Endpoint{
AuthURL: config.AuthURL,
TokenURL: config.TokenURL,
},
},
InsecureSkipVerify: config.InsecureSkipVerify,
ServiceName: config.Name,
UserinfoURL: config.UserinfoURL,
}
}
func (generic *GenericOAuthService) Init() error {
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: generic.InsecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
httpClient := &http.Client{
Transport: transport,
}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
verifier := oauth2.GenerateVerifier()
generic.Context = ctx
generic.Verifier = verifier
return nil
}
func (generic *GenericOAuthService) Name() string {
return generic.ServiceName
}
func (generic *GenericOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
return state
}
func (generic *GenericOAuthService) GetAuthURL(state string) string {
return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier))
}
func (generic *GenericOAuthService) VerifyCode(code string) error {
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
if err != nil {
return nil
}
generic.Token = token
return nil
}
func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := generic.Config.Client(generic.Context, generic.Token)
res, err := client.Get(generic.UserinfoURL)
if err != nil {
return user, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
err = json.Unmarshal(body, &user)
if err != nil {
return user, err
}
return user, nil
}

View File

@@ -0,0 +1,144 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"errors"
"io"
"net/http"
"tinyauth/internal/config"
"golang.org/x/oauth2"
)
var GithubOAuthScopes = []string{"user:email", "read:user"}
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
}
type GithubOAuthService struct {
Config oauth2.Config
Context context.Context
Token *oauth2.Token
Verifier string
}
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
return &GithubOAuthService{
Config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GithubOAuthScopes,
},
}
}
func (github *GithubOAuthService) Init() error {
httpClient := &http.Client{}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
verifier := oauth2.GenerateVerifier()
github.Context = ctx
github.Verifier = verifier
return nil
}
func (github *GithubOAuthService) Name() string {
return "github"
}
func (github *GithubOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
return state
}
func (github *GithubOAuthService) GetAuthURL(state string) string {
return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier))
}
func (github *GithubOAuthService) VerifyCode(code string) error {
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
if err != nil {
return nil
}
github.Token = token
return nil
}
func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := github.Config.Client(github.Context, github.Token)
res, err := client.Get("https://api.github.com/user")
if err != nil {
return user, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return user, err
}
var userInfo GithubUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return user, err
}
res, err = client.Get("https://api.github.com/user/emails")
if err != nil {
return user, err
}
defer res.Body.Close()
body, err = io.ReadAll(res.Body)
if err != nil {
return user, err
}
var emails GithubEmailResponse
err = json.Unmarshal(body, &emails)
if err != nil {
return user, err
}
for _, email := range emails {
if email.Primary {
user.Email = email.Email
break
}
}
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Use first available email if no primary email was found
if user.Email == "" {
user.Email = emails[0].Email
}
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
return user, nil
}

View File

@@ -0,0 +1,106 @@
package service
import (
"context"
"crypto/rand"
"encoding/base64"
"encoding/json"
"io"
"net/http"
"strings"
"tinyauth/internal/config"
"golang.org/x/oauth2"
)
var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
type GoogleUserInfoResponse struct {
Email string `json:"email"`
Name string `json:"name"`
}
type GoogleOAuthService struct {
Config oauth2.Config
Context context.Context
Token *oauth2.Token
Verifier string
}
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
return &GoogleOAuthService{
Config: oauth2.Config{
ClientID: config.ClientID,
ClientSecret: config.ClientSecret,
RedirectURL: config.RedirectURL,
Scopes: GoogleOAuthScopes,
},
}
}
func (google *GoogleOAuthService) Init() error {
httpClient := &http.Client{}
ctx := context.Background()
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
verifier := oauth2.GenerateVerifier()
google.Context = ctx
google.Verifier = verifier
return nil
}
func (google *GoogleOAuthService) Name() string {
return "google"
}
func (oauth *GoogleOAuthService) GenerateState() string {
b := make([]byte, 128)
rand.Read(b)
state := base64.URLEncoding.EncodeToString(b)
return state
}
func (google *GoogleOAuthService) GetAuthURL(state string) string {
return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier))
}
func (google *GoogleOAuthService) VerifyCode(code string) error {
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
if err != nil {
return nil
}
google.Token = token
return nil
}
func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
var user config.Claims
client := google.Config.Client(google.Context, google.Token)
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
if err != nil {
return config.Claims{}, err
}
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
if err != nil {
return config.Claims{}, err
}
var userInfo GoogleUserInfoResponse
err = json.Unmarshal(body, &userInfo)
if err != nil {
return config.Claims{}, err
}
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
user.Name = userInfo.Name
user.Email = userInfo.Email
return user, nil
}

View File

@@ -1,30 +1,40 @@
package ldap
package service
import (
"context"
"crypto/tls"
"fmt"
"time"
"tinyauth/internal/types"
"github.com/cenkalti/backoff/v5"
ldapgo "github.com/go-ldap/ldap/v3"
"github.com/rs/zerolog/log"
)
type LDAP struct {
Config types.LdapConfig
type LdapServiceConfig struct {
Address string
BindDN string
BindPassword string
BaseDN string
Insecure bool
SearchFilter string
}
type LdapService struct {
Config LdapServiceConfig
Conn *ldapgo.Conn
}
func NewLDAP(config types.LdapConfig) (*LDAP, error) {
ldap := &LDAP{
func NewLdapService(config LdapServiceConfig) *LdapService {
return &LdapService{
Config: config,
}
}
func (ldap *LdapService) Init() error {
_, err := ldap.connect()
if err != nil {
return nil, fmt.Errorf("failed to connect to LDAP server: %w", err)
return fmt.Errorf("failed to connect to LDAP server: %w", err)
}
go func() {
@@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
}
}()
return ldap, nil
return nil
}
func (l *LDAP) connect() (*ldapgo.Conn, error) {
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
log.Debug().Msg("Connecting to LDAP server")
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: l.Config.Insecure,
conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
InsecureSkipVerify: ldap.Config.Insecure,
MinVersion: tls.VersionTLS12,
}))
if err != nil {
@@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) {
}
log.Debug().Msg("Binding to LDAP server")
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword)
if err != nil {
return nil, err
}
// Set and return the connection
l.Conn = conn
ldap.Conn = conn
return conn, nil
}
func (l *LDAP) Search(username string) (string, error) {
func (ldap *LdapService) Search(username string) (string, error) {
// Escape the username to prevent LDAP injection
escapedUsername := ldapgo.EscapeFilter(username)
filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername)
filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername)
searchRequest := ldapgo.NewSearchRequest(
l.Config.BaseDN,
ldap.Config.BaseDN,
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
filter,
[]string{"dn"},
nil,
)
searchResult, err := l.Conn.Search(searchRequest)
searchResult, err := ldap.Conn.Search(searchRequest)
if err != nil {
return "", err
}
@@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) {
return userDN, nil
}
func (l *LDAP) Bind(userDN string, password string) error {
err := l.Conn.Bind(userDN, password)
func (ldap *LdapService) Bind(userDN string, password string) error {
err := ldap.Conn.Bind(userDN, password)
if err != nil {
return err
}
return nil
}
func (l *LDAP) heartbeat() error {
func (ldap *LdapService) heartbeat() error {
log.Debug().Msg("Performing LDAP connection heartbeat")
searchRequest := ldapgo.NewSearchRequest(
@@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error {
nil,
)
_, err := l.Conn.Search(searchRequest)
_, err := ldap.Conn.Search(searchRequest)
if err != nil {
return err
}
@@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error {
return nil
}
func (l *LDAP) reconnect() error {
func (ldap *LdapService) reconnect() error {
log.Info().Msg("Reconnecting to LDAP server")
exp := backoff.NewExponentialBackOff()
@@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error {
exp.Reset()
operation := func() (*ldapgo.Conn, error) {
l.Conn.Close()
conn, err := l.connect()
ldap.Conn.Close()
conn, err := ldap.connect()
if err != nil {
return nil, nil
}

View File

@@ -1,70 +0,0 @@
package types
import (
"time"
"tinyauth/internal/oauth"
)
// User is the struct for a user
type User struct {
Username string
Password string
TotpSecret string
}
// UserSearch is the response of the get user
type UserSearch struct {
Username string
Type string // "local", "ldap" or empty
}
// Users is a list of users
type Users []User
// OAuthProviders is the struct for the OAuth providers
type OAuthProviders struct {
Github *oauth.OAuth
Google *oauth.OAuth
Microsoft *oauth.OAuth
}
// SessionCookie is the cookie for the session (exculding the expiry)
type SessionCookie struct {
Username string
Name string
Email string
Provider string
TotpPending bool
OAuthGroups string
}
// UserContext is the context for the user
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
TotpEnabled bool
}
// LoginAttempt tracks information about login attempts for rate limiting
type LoginAttempt struct {
FailedAttempts int
LastAttempt time.Time
LockedUntil time.Time
}
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}