mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	refactor: move oauth providers into services (non-working)
This commit is contained in:
		| @@ -1,146 +0,0 @@ | |||||||
| package auth_test |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"testing" |  | ||||||
| 	"time" |  | ||||||
| 	"tinyauth/internal/auth" |  | ||||||
| 	"tinyauth/internal/types" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var config = types.AuthConfig{ |  | ||||||
| 	Users:          types.Users{}, |  | ||||||
| 	OauthWhitelist: "", |  | ||||||
| 	SessionExpiry:  3600, |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestLoginRateLimiting(t *testing.T) { |  | ||||||
| 	// Initialize a new auth service with 3 max retries and 5 seconds timeout |  | ||||||
| 	config.LoginMaxRetries = 3 |  | ||||||
| 	config.LoginTimeout = 5 |  | ||||||
| 	authService := auth.NewAuth(config, nil, nil) |  | ||||||
|  |  | ||||||
| 	// Test identifier |  | ||||||
| 	identifier := "test_user" |  | ||||||
|  |  | ||||||
| 	// Test successful login - should not lock account |  | ||||||
| 	t.Log("Testing successful login") |  | ||||||
|  |  | ||||||
| 	authService.RecordLoginAttempt(identifier, true) |  | ||||||
| 	locked, _ := authService.IsAccountLocked(identifier) |  | ||||||
|  |  | ||||||
| 	if locked { |  | ||||||
| 		t.Fatalf("Account should not be locked after successful login") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Test 2 failed attempts - should not lock account yet |  | ||||||
| 	t.Log("Testing 2 failed login attempts") |  | ||||||
|  |  | ||||||
| 	authService.RecordLoginAttempt(identifier, false) |  | ||||||
| 	authService.RecordLoginAttempt(identifier, false) |  | ||||||
| 	locked, _ = authService.IsAccountLocked(identifier) |  | ||||||
|  |  | ||||||
| 	if locked { |  | ||||||
| 		t.Fatalf("Account should not be locked after only 2 failed attempts") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Add one more failed attempt (total 3) - should lock account with maxRetries=3 |  | ||||||
| 	t.Log("Testing 3 failed login attempts") |  | ||||||
| 	authService.RecordLoginAttempt(identifier, false) |  | ||||||
| 	locked, remainingTime := authService.IsAccountLocked(identifier) |  | ||||||
|  |  | ||||||
| 	if !locked { |  | ||||||
| 		t.Fatalf("Account should be locked after reaching max retries") |  | ||||||
| 	} |  | ||||||
| 	if remainingTime <= 0 || remainingTime > 5 { |  | ||||||
| 		t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Test reset after waiting for timeout - use 1 second timeout for fast testing |  | ||||||
| 	t.Log("Testing unlocking after timeout") |  | ||||||
|  |  | ||||||
| 	// Reinitialize auth service with a shorter timeout for testing |  | ||||||
| 	config.LoginTimeout = 1 |  | ||||||
| 	config.LoginMaxRetries = 3 |  | ||||||
| 	authService = auth.NewAuth(config, nil, nil) |  | ||||||
|  |  | ||||||
| 	// Add enough failed attempts to lock the account |  | ||||||
| 	for i := 0; i < 3; i++ { |  | ||||||
| 		authService.RecordLoginAttempt(identifier, false) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Verify it's locked |  | ||||||
| 	locked, _ = authService.IsAccountLocked(identifier) |  | ||||||
| 	if !locked { |  | ||||||
| 		t.Fatalf("Account should be locked initially") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Wait a bit and verify it gets unlocked after timeout |  | ||||||
| 	time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout |  | ||||||
| 	locked, _ = authService.IsAccountLocked(identifier) |  | ||||||
|  |  | ||||||
| 	if locked { |  | ||||||
| 		t.Fatalf("Account should be unlocked after timeout period") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Test disabled rate limiting |  | ||||||
| 	t.Log("Testing disabled rate limiting") |  | ||||||
| 	config.LoginMaxRetries = 0 |  | ||||||
| 	config.LoginTimeout = 0 |  | ||||||
| 	authService = auth.NewAuth(config, nil, nil) |  | ||||||
|  |  | ||||||
| 	for i := 0; i < 10; i++ { |  | ||||||
| 		authService.RecordLoginAttempt(identifier, false) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	locked, _ = authService.IsAccountLocked(identifier) |  | ||||||
| 	if locked { |  | ||||||
| 		t.Fatalf("Account should not be locked when rate limiting is disabled") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func TestConcurrentLoginAttempts(t *testing.T) { |  | ||||||
| 	// Initialize a new auth service with 2 max retries and 5 seconds timeout |  | ||||||
| 	config.LoginMaxRetries = 2 |  | ||||||
| 	config.LoginTimeout = 5 |  | ||||||
| 	authService := auth.NewAuth(config, nil, nil) |  | ||||||
|  |  | ||||||
| 	// Test multiple identifiers |  | ||||||
| 	identifiers := []string{"user1", "user2", "user3"} |  | ||||||
|  |  | ||||||
| 	// Test that locking one identifier doesn't affect others |  | ||||||
| 	t.Log("Testing multiple identifiers") |  | ||||||
|  |  | ||||||
| 	// Add enough failed attempts to lock first user (2 attempts with maxRetries=2) |  | ||||||
| 	authService.RecordLoginAttempt(identifiers[0], false) |  | ||||||
| 	authService.RecordLoginAttempt(identifiers[0], false) |  | ||||||
|  |  | ||||||
| 	// Check if first user is locked |  | ||||||
| 	locked, _ := authService.IsAccountLocked(identifiers[0]) |  | ||||||
| 	if !locked { |  | ||||||
| 		t.Fatalf("User1 should be locked after reaching max retries") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Check that other users are not affected |  | ||||||
| 	for i := 1; i < len(identifiers); i++ { |  | ||||||
| 		locked, _ := authService.IsAccountLocked(identifiers[i]) |  | ||||||
| 		if locked { |  | ||||||
| 			t.Fatalf("User%d should not be locked", i+1) |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Test successful login after failed attempts (but before lock) |  | ||||||
| 	t.Log("Testing successful login after failed attempts but before lock") |  | ||||||
|  |  | ||||||
| 	// One failed attempt for user2 |  | ||||||
| 	authService.RecordLoginAttempt(identifiers[1], false) |  | ||||||
|  |  | ||||||
| 	// Successful login should reset the counter |  | ||||||
| 	authService.RecordLoginAttempt(identifiers[1], true) |  | ||||||
|  |  | ||||||
| 	// Now try a failed login again - should not be locked as counter was reset |  | ||||||
| 	authService.RecordLoginAttempt(identifiers[1], false) |  | ||||||
| 	locked, _ = authService.IsAccountLocked(identifiers[1]) |  | ||||||
| 	if locked { |  | ||||||
| 		t.Fatalf("User2 should not be locked after successful login reset") |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
| @@ -1,6 +1,22 @@ | |||||||
| package types | package config | ||||||
|  | 
 | ||||||
|  | import "time" | ||||||
|  | 
 | ||||||
|  | type Claims struct { | ||||||
|  | 	Name              string `json:"name"` | ||||||
|  | 	Email             string `json:"email"` | ||||||
|  | 	PreferredUsername string `json:"preferred_username"` | ||||||
|  | 	Groups            any    `json:"groups"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | var Version = "development" | ||||||
|  | var CommitHash = "n/a" | ||||||
|  | var BuildTimestamp = "n/a" | ||||||
|  | 
 | ||||||
|  | var SessionCookieName = "tinyauth-session" | ||||||
|  | var CsrfCookieName = "tinyauth-csrf" | ||||||
|  | var RedirectCookieName = "tinyauth-redirect" | ||||||
| 
 | 
 | ||||||
| // Config is the configuration for the tinyauth server |  | ||||||
| type Config struct { | type Config struct { | ||||||
| 	Port                    int    `mapstructure:"port" validate:"required"` | 	Port                    int    `mapstructure:"port" validate:"required"` | ||||||
| 	Address                 string `validate:"required,ip4_addr" mapstructure:"address"` | 	Address                 string `validate:"required,ip4_addr" mapstructure:"address"` | ||||||
| @@ -44,62 +60,27 @@ type Config struct { | |||||||
| 	LdapSearchFilter        string `mapstructure:"ldap-search-filter"` | 	LdapSearchFilter        string `mapstructure:"ldap-search-filter"` | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // OAuthConfig is the configuration for the providers |  | ||||||
| type OAuthConfig struct { |  | ||||||
| 	GithubClientId      string |  | ||||||
| 	GithubClientSecret  string |  | ||||||
| 	GoogleClientId      string |  | ||||||
| 	GoogleClientSecret  string |  | ||||||
| 	GenericClientId     string |  | ||||||
| 	GenericClientSecret string |  | ||||||
| 	GenericScopes       []string |  | ||||||
| 	GenericAuthURL      string |  | ||||||
| 	GenericTokenURL     string |  | ||||||
| 	GenericUserURL      string |  | ||||||
| 	GenericSkipSSL      bool |  | ||||||
| 	AppURL              string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // AuthConfig is the configuration for the auth service |  | ||||||
| type AuthConfig struct { |  | ||||||
| 	Users             Users |  | ||||||
| 	OauthWhitelist    string |  | ||||||
| 	SessionExpiry     int |  | ||||||
| 	CookieSecure      bool |  | ||||||
| 	Domain            string |  | ||||||
| 	LoginTimeout      int |  | ||||||
| 	LoginMaxRetries   int |  | ||||||
| 	SessionCookieName string |  | ||||||
| 	HMACSecret        string |  | ||||||
| 	EncryptionSecret  string |  | ||||||
| } |  | ||||||
| 
 |  | ||||||
| // OAuthLabels is a list of labels that can be used in a tinyauth protected container |  | ||||||
| type OAuthLabels struct { | type OAuthLabels struct { | ||||||
| 	Whitelist string | 	Whitelist string | ||||||
| 	Groups    string | 	Groups    string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Basic auth labels for a tinyauth protected container |  | ||||||
| type BasicLabels struct { | type BasicLabels struct { | ||||||
| 	Username string | 	Username string | ||||||
| 	Password PassowrdLabels | 	Password PassowrdLabels | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // PassowrdLabels is a struct that contains the password labels for a tinyauth protected container |  | ||||||
| type PassowrdLabels struct { | type PassowrdLabels struct { | ||||||
| 	Plain string | 	Plain string | ||||||
| 	File  string | 	File  string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // IP labels for a tinyauth protected container |  | ||||||
| type IPLabels struct { | type IPLabels struct { | ||||||
| 	Allow  []string | 	Allow  []string | ||||||
| 	Block  []string | 	Block  []string | ||||||
| 	Bypass []string | 	Bypass []string | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Labels is a struct that contains the labels for a tinyauth protected container |  | ||||||
| type Labels struct { | type Labels struct { | ||||||
| 	Users   string | 	Users   string | ||||||
| 	Allowed string | 	Allowed string | ||||||
| @@ -110,12 +91,65 @@ type Labels struct { | |||||||
| 	IP      IPLabels | 	IP      IPLabels | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| // Ldap config is a struct that contains the configuration for the LDAP service | type OAuthServiceConfig struct { | ||||||
| type LdapConfig struct { | 	ClientID           string | ||||||
| 	Address      string | 	ClientSecret       string | ||||||
| 	BindDN       string | 	Scopes             []string | ||||||
| 	BindPassword string | 	RedirectURL        string | ||||||
| 	BaseDN       string | 	AuthURL            string | ||||||
| 	Insecure     bool | 	TokenURL           string | ||||||
| 	SearchFilter string | 	UserinfoURL        string | ||||||
|  | 	InsecureSkipVerify bool | ||||||
|  | 	Name               string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type User struct { | ||||||
|  | 	Username   string | ||||||
|  | 	Password   string | ||||||
|  | 	TotpSecret string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type UserSearch struct { | ||||||
|  | 	Username string | ||||||
|  | 	Type     string // local, ldap or unknown | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type Users []User | ||||||
|  | 
 | ||||||
|  | type SessionCookie struct { | ||||||
|  | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
|  | 	Provider    string | ||||||
|  | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type UserContext struct { | ||||||
|  | 	Username    string | ||||||
|  | 	Name        string | ||||||
|  | 	Email       string | ||||||
|  | 	IsLoggedIn  bool | ||||||
|  | 	OAuth       bool | ||||||
|  | 	Provider    string | ||||||
|  | 	TotpPending bool | ||||||
|  | 	OAuthGroups string | ||||||
|  | 	TotpEnabled bool | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type LoginAttempt struct { | ||||||
|  | 	FailedAttempts int | ||||||
|  | 	LastAttempt    time.Time | ||||||
|  | 	LockedUntil    time.Time | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type UnauthorizedQuery struct { | ||||||
|  | 	Username string `url:"username"` | ||||||
|  | 	Resource string `url:"resource"` | ||||||
|  | 	GroupErr bool   `url:"groupErr"` | ||||||
|  | 	IP       string `url:"ip"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type RedirectQuery struct { | ||||||
|  | 	RedirectURI string `url:"redirect_uri"` | ||||||
| } | } | ||||||
| @@ -1,19 +0,0 @@ | |||||||
| package constants |  | ||||||
|  |  | ||||||
| // Claims are the OIDC supported claims (prefered username is included for convinience) |  | ||||||
| type Claims struct { |  | ||||||
| 	Name              string `json:"name"` |  | ||||||
| 	Email             string `json:"email"` |  | ||||||
| 	PreferredUsername string `json:"preferred_username"` |  | ||||||
| 	Groups            any    `json:"groups"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Version information |  | ||||||
| var Version = "development" |  | ||||||
| var CommitHash = "n/a" |  | ||||||
| var BuildTimestamp = "n/a" |  | ||||||
|  |  | ||||||
| // Base cookie names |  | ||||||
| var SessionCookieName = "tinyauth-session" |  | ||||||
| var CsrfCookieName = "tinyauth-csrf" |  | ||||||
| var RedirectCookieName = "tinyauth-redirect" |  | ||||||
| @@ -1,71 +0,0 @@ | |||||||
| package oauth |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"context" |  | ||||||
| 	"crypto/rand" |  | ||||||
| 	"crypto/tls" |  | ||||||
| 	"encoding/base64" |  | ||||||
| 	"net/http" |  | ||||||
|  |  | ||||||
| 	"golang.org/x/oauth2" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type OAuth struct { |  | ||||||
| 	Config   oauth2.Config |  | ||||||
| 	Context  context.Context |  | ||||||
| 	Token    *oauth2.Token |  | ||||||
| 	Verifier string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { |  | ||||||
| 	transport := &http.Transport{ |  | ||||||
| 		TLSClientConfig: &tls.Config{ |  | ||||||
| 			InsecureSkipVerify: insecureSkipVerify, |  | ||||||
| 			MinVersion:         tls.VersionTLS12, |  | ||||||
| 		}, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	httpClient := &http.Client{ |  | ||||||
| 		Transport: transport, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	ctx := context.Background() |  | ||||||
|  |  | ||||||
| 	// Set the HTTP client in the context |  | ||||||
| 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) |  | ||||||
|  |  | ||||||
| 	verifier := oauth2.GenerateVerifier() |  | ||||||
|  |  | ||||||
| 	return &OAuth{ |  | ||||||
| 		Config:   config, |  | ||||||
| 		Context:  ctx, |  | ||||||
| 		Verifier: verifier, |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (oauth *OAuth) GetAuthURL(state string) string { |  | ||||||
| 	return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (oauth *OAuth) ExchangeToken(code string) (string, error) { |  | ||||||
| 	token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) |  | ||||||
|  |  | ||||||
| 	if err != nil { |  | ||||||
| 		return "", err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Set and return the token |  | ||||||
| 	oauth.Token = token |  | ||||||
| 	return oauth.Token.AccessToken, nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (oauth *OAuth) GetClient() *http.Client { |  | ||||||
| 	return oauth.Config.Client(oauth.Context, oauth.Token) |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (oauth *OAuth) GenerateState() string { |  | ||||||
| 	b := make([]byte, 128) |  | ||||||
| 	rand.Read(b) |  | ||||||
| 	state := base64.URLEncoding.EncodeToString(b) |  | ||||||
| 	return state |  | ||||||
| } |  | ||||||
| @@ -1,37 +0,0 @@ | |||||||
| package providers |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"tinyauth/internal/constants" |  | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { |  | ||||||
| 	var user constants.Claims |  | ||||||
|  |  | ||||||
| 	res, err := client.Get(url) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	defer res.Body.Close() |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from generic provider") |  | ||||||
|  |  | ||||||
| 	body, err := io.ReadAll(res.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from generic provider") |  | ||||||
|  |  | ||||||
| 	err = json.Unmarshal(body, &user) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from generic provider") |  | ||||||
| 	return user, nil |  | ||||||
| } |  | ||||||
| @@ -1,102 +0,0 @@ | |||||||
| package providers |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"errors" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"tinyauth/internal/constants" |  | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Response for the github email endpoint |  | ||||||
| type GithubEmailResponse []struct { |  | ||||||
| 	Email   string `json:"email"` |  | ||||||
| 	Primary bool   `json:"primary"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Response for the github user endpoint |  | ||||||
| type GithubUserInfoResponse struct { |  | ||||||
| 	Login string `json:"login"` |  | ||||||
| 	Name  string `json:"name"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // The scopes required for the github provider |  | ||||||
| func GithubScopes() []string { |  | ||||||
| 	return []string{"user:email", "read:user"} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetGithubUser(client *http.Client) (constants.Claims, error) { |  | ||||||
| 	var user constants.Claims |  | ||||||
|  |  | ||||||
| 	res, err := client.Get("https://api.github.com/user") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	defer res.Body.Close() |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got user response from github") |  | ||||||
|  |  | ||||||
| 	body, err := io.ReadAll(res.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read user body from github") |  | ||||||
|  |  | ||||||
| 	var userInfo GithubUserInfoResponse |  | ||||||
|  |  | ||||||
| 	err = json.Unmarshal(body, &userInfo) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	res, err = client.Get("https://api.github.com/user/emails") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	defer res.Body.Close() |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got email response from github") |  | ||||||
|  |  | ||||||
| 	body, err = io.ReadAll(res.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read email body from github") |  | ||||||
|  |  | ||||||
| 	var emails GithubEmailResponse |  | ||||||
|  |  | ||||||
| 	err = json.Unmarshal(body, &emails) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed emails from github") |  | ||||||
|  |  | ||||||
| 	// Find and return the primary email |  | ||||||
| 	for _, email := range emails { |  | ||||||
| 		if email.Primary { |  | ||||||
| 			log.Debug().Str("email", email.Email).Msg("Found primary email") |  | ||||||
| 			user.Email = email.Email |  | ||||||
| 			break |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if len(emails) == 0 { |  | ||||||
| 		return user, errors.New("no emails found") |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// Use first available email if no primary email was found |  | ||||||
| 	if user.Email == "" { |  | ||||||
| 		log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") |  | ||||||
| 		user.Email = emails[0].Email |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	user.PreferredUsername = userInfo.Login |  | ||||||
| 	user.Name = userInfo.Name |  | ||||||
|  |  | ||||||
| 	return user, nil |  | ||||||
| } |  | ||||||
| @@ -1,56 +0,0 @@ | |||||||
| package providers |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"encoding/json" |  | ||||||
| 	"io" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strings" |  | ||||||
| 	"tinyauth/internal/constants" |  | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // Response for the google user endpoint |  | ||||||
| type GoogleUserInfoResponse struct { |  | ||||||
| 	Email string `json:"email"` |  | ||||||
| 	Name  string `json:"name"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // The scopes required for the google provider |  | ||||||
| func GoogleScopes() []string { |  | ||||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func GetGoogleUser(client *http.Client) (constants.Claims, error) { |  | ||||||
| 	var user constants.Claims |  | ||||||
|  |  | ||||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
| 	defer res.Body.Close() |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Got response from google") |  | ||||||
|  |  | ||||||
| 	body, err := io.ReadAll(res.Body) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Read body from google") |  | ||||||
|  |  | ||||||
| 	var userInfo GoogleUserInfoResponse |  | ||||||
|  |  | ||||||
| 	err = json.Unmarshal(body, &userInfo) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return user, err |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	log.Debug().Msg("Parsed user from google") |  | ||||||
|  |  | ||||||
| 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] |  | ||||||
| 	user.Name = userInfo.Name |  | ||||||
| 	user.Email = userInfo.Email |  | ||||||
|  |  | ||||||
| 	return user, nil |  | ||||||
| } |  | ||||||
| @@ -1,154 +0,0 @@ | |||||||
| package providers |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"tinyauth/internal/constants" |  | ||||||
| 	"tinyauth/internal/oauth" |  | ||||||
| 	"tinyauth/internal/types" |  | ||||||
|  |  | ||||||
| 	"github.com/rs/zerolog/log" |  | ||||||
| 	"golang.org/x/oauth2" |  | ||||||
| 	"golang.org/x/oauth2/endpoints" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| type Providers struct { |  | ||||||
| 	Config  types.OAuthConfig |  | ||||||
| 	Github  *oauth.OAuth |  | ||||||
| 	Google  *oauth.OAuth |  | ||||||
| 	Generic *oauth.OAuth |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func NewProviders(config types.OAuthConfig) *Providers { |  | ||||||
| 	providers := &Providers{ |  | ||||||
| 		Config: config, |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if config.GithubClientId != "" && config.GithubClientSecret != "" { |  | ||||||
| 		log.Info().Msg("Initializing Github OAuth") |  | ||||||
| 		providers.Github = oauth.NewOAuth(oauth2.Config{ |  | ||||||
| 			ClientID:     config.GithubClientId, |  | ||||||
| 			ClientSecret: config.GithubClientSecret, |  | ||||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL), |  | ||||||
| 			Scopes:       GithubScopes(), |  | ||||||
| 			Endpoint:     endpoints.GitHub, |  | ||||||
| 		}, false) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if config.GoogleClientId != "" && config.GoogleClientSecret != "" { |  | ||||||
| 		log.Info().Msg("Initializing Google OAuth") |  | ||||||
| 		providers.Google = oauth.NewOAuth(oauth2.Config{ |  | ||||||
| 			ClientID:     config.GoogleClientId, |  | ||||||
| 			ClientSecret: config.GoogleClientSecret, |  | ||||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL), |  | ||||||
| 			Scopes:       GoogleScopes(), |  | ||||||
| 			Endpoint:     endpoints.Google, |  | ||||||
| 		}, false) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	if config.GenericClientId != "" && config.GenericClientSecret != "" { |  | ||||||
| 		log.Info().Msg("Initializing Generic OAuth") |  | ||||||
| 		providers.Generic = oauth.NewOAuth(oauth2.Config{ |  | ||||||
| 			ClientID:     config.GenericClientId, |  | ||||||
| 			ClientSecret: config.GenericClientSecret, |  | ||||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL), |  | ||||||
| 			Scopes:       config.GenericScopes, |  | ||||||
| 			Endpoint: oauth2.Endpoint{ |  | ||||||
| 				AuthURL:  config.GenericAuthURL, |  | ||||||
| 				TokenURL: config.GenericTokenURL, |  | ||||||
| 			}, |  | ||||||
| 		}, config.GenericSkipSSL) |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return providers |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (providers *Providers) GetProvider(provider string) *oauth.OAuth { |  | ||||||
| 	switch provider { |  | ||||||
| 	case "github": |  | ||||||
| 		return providers.Github |  | ||||||
| 	case "google": |  | ||||||
| 		return providers.Google |  | ||||||
| 	case "generic": |  | ||||||
| 		return providers.Generic |  | ||||||
| 	default: |  | ||||||
| 		return nil |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (providers *Providers) GetUser(provider string) (constants.Claims, error) { |  | ||||||
| 	var user constants.Claims |  | ||||||
|  |  | ||||||
| 	// Get the user from the provider |  | ||||||
| 	switch provider { |  | ||||||
| 	case "github": |  | ||||||
| 		if providers.Github == nil { |  | ||||||
| 			log.Debug().Msg("Github provider not configured") |  | ||||||
| 			return user, nil |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		client := providers.Github.GetClient() |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from github") |  | ||||||
|  |  | ||||||
| 		user, err := GetGithubUser(client) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return user, err |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got user from github") |  | ||||||
|  |  | ||||||
| 		return user, nil |  | ||||||
| 	case "google": |  | ||||||
| 		if providers.Google == nil { |  | ||||||
| 			log.Debug().Msg("Google provider not configured") |  | ||||||
| 			return user, nil |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		client := providers.Google.GetClient() |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from google") |  | ||||||
|  |  | ||||||
| 		user, err := GetGoogleUser(client) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return user, err |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got user from google") |  | ||||||
|  |  | ||||||
| 		return user, nil |  | ||||||
| 	case "generic": |  | ||||||
| 		if providers.Generic == nil { |  | ||||||
| 			log.Debug().Msg("Generic provider not configured") |  | ||||||
| 			return user, nil |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		client := providers.Generic.GetClient() |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got client from generic") |  | ||||||
|  |  | ||||||
| 		user, err := GetGenericUser(client, providers.Config.GenericUserURL) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return user, err |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		log.Debug().Msg("Got user from generic") |  | ||||||
|  |  | ||||||
| 		return user, nil |  | ||||||
| 	default: |  | ||||||
| 		return user, nil |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func (provider *Providers) GetConfiguredProviders() []string { |  | ||||||
| 	providers := []string{} |  | ||||||
| 	if provider.Github != nil { |  | ||||||
| 		providers = append(providers, "github") |  | ||||||
| 	} |  | ||||||
| 	if provider.Google != nil { |  | ||||||
| 		providers = append(providers, "google") |  | ||||||
| 	} |  | ||||||
| 	if provider.Generic != nil { |  | ||||||
| 		providers = append(providers, "generic") |  | ||||||
| 	} |  | ||||||
| 	return providers |  | ||||||
| } |  | ||||||
| @@ -1,4 +1,4 @@ | |||||||
| package auth | package service | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
| @@ -6,8 +6,6 @@ import ( | |||||||
| 	"strings" | 	"strings" | ||||||
| 	"sync" | 	"sync" | ||||||
| 	"time" | 	"time" | ||||||
| 	"tinyauth/internal/docker" |  | ||||||
| 	"tinyauth/internal/ldap" |  | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/types" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
| 
 | 
 | ||||||
| @@ -17,35 +15,50 @@ import ( | |||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Auth struct { | type AuthServiceConfig struct { | ||||||
| 	Config        types.AuthConfig | 	Users             types.Users | ||||||
| 	Docker        *docker.Docker | 	OauthWhitelist    string | ||||||
|  | 	SessionExpiry     int | ||||||
|  | 	CookieSecure      bool | ||||||
|  | 	Domain            string | ||||||
|  | 	LoginTimeout      int | ||||||
|  | 	LoginMaxRetries   int | ||||||
|  | 	SessionCookieName string | ||||||
|  | 	HMACSecret        string | ||||||
|  | 	EncryptionSecret  string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type AuthService struct { | ||||||
|  | 	Config        AuthServiceConfig | ||||||
|  | 	Docker        *DockerService | ||||||
| 	LoginAttempts map[string]*types.LoginAttempt | 	LoginAttempts map[string]*types.LoginAttempt | ||||||
| 	LoginMutex    sync.RWMutex | 	LoginMutex    sync.RWMutex | ||||||
| 	Store         *sessions.CookieStore | 	Store         *sessions.CookieStore | ||||||
| 	LDAP          *ldap.LDAP | 	LDAP          *LdapService | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { | func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService { | ||||||
| 	// Setup cookie store and create the auth service | 	return &AuthService{ | ||||||
| 	store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) |  | ||||||
| 	store.Options = &sessions.Options{ |  | ||||||
| 		Path:     "/", |  | ||||||
| 		MaxAge:   config.SessionExpiry, |  | ||||||
| 		Secure:   config.CookieSecure, |  | ||||||
| 		HttpOnly: true, |  | ||||||
| 		Domain:   fmt.Sprintf(".%s", config.Domain), |  | ||||||
| 	} |  | ||||||
| 	return &Auth{ |  | ||||||
| 		Config:        config, | 		Config:        config, | ||||||
| 		Docker:        docker, | 		Docker:        docker, | ||||||
| 		LoginAttempts: make(map[string]*types.LoginAttempt), | 		LoginAttempts: make(map[string]*types.LoginAttempt), | ||||||
| 		Store:         store, |  | ||||||
| 		LDAP:          ldap, | 		LDAP:          ldap, | ||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { | func (auth *AuthService) Init() error { | ||||||
|  | 	store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) | ||||||
|  | 	store.Options = &sessions.Options{ | ||||||
|  | 		Path:     "/", | ||||||
|  | 		MaxAge:   auth.Config.SessionExpiry, | ||||||
|  | 		Secure:   auth.Config.CookieSecure, | ||||||
|  | 		HttpOnly: true, | ||||||
|  | 		Domain:   fmt.Sprintf(".%s", auth.Config.Domain), | ||||||
|  | 	} | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { | ||||||
| 	session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) | 	session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) | ||||||
| 
 | 
 | ||||||
| 	// If there was an error getting the session, it might be invalid so let's clear it and retry | 	// If there was an error getting the session, it might be invalid so let's clear it and retry | ||||||
| @@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { | |||||||
| 	return session, nil | 	return session, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) SearchUser(username string) types.UserSearch { | func (auth *AuthService) SearchUser(username string) types.UserSearch { | ||||||
| 	log.Debug().Str("username", username).Msg("Searching for user") | 	log.Debug().Str("username", username).Msg("Searching for user") | ||||||
| 
 | 
 | ||||||
| 	// Check local users first | 	// Check local users first | ||||||
| @@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { | ||||||
| 	// Authenticate the user based on the type | 	// Authenticate the user based on the type | ||||||
| 	switch search.Type { | 	switch search.Type { | ||||||
| 	case "local": | 	case "local": | ||||||
| @@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) GetLocalUser(username string) types.User { | func (auth *AuthService) GetLocalUser(username string) types.User { | ||||||
| 	// Loop through users and return the user if the username matches | 	// Loop through users and return the user if the username matches | ||||||
| 	log.Debug().Str("username", username).Msg("Searching for local user") | 	log.Debug().Str("username", username).Msg("Searching for local user") | ||||||
| 
 | 
 | ||||||
| @@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User { | |||||||
| 	return types.User{} | 	return types.User{} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) CheckPassword(user types.User, password string) bool { | func (auth *AuthService) CheckPassword(user types.User, password string) bool { | ||||||
| 	return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil | 	return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { | func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { | ||||||
| 	auth.LoginMutex.RLock() | 	auth.LoginMutex.RLock() | ||||||
| 	defer auth.LoginMutex.RUnlock() | 	defer auth.LoginMutex.RUnlock() | ||||||
| 
 | 
 | ||||||
| @@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { | |||||||
| 	return false, 0 | 	return false, 0 | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { | func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { | ||||||
| 	// Skip if rate limiting is not configured | 	// Skip if rate limiting is not configured | ||||||
| 	if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { | 	if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { | ||||||
| 		return | 		return | ||||||
| @@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) EmailWhitelisted(email string) bool { | func (auth *AuthService) EmailWhitelisted(email string) bool { | ||||||
| 	return utils.CheckFilter(auth.Config.OauthWhitelist, email) | 	return utils.CheckFilter(auth.Config.OauthWhitelist, email) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { | func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { | ||||||
| 	log.Debug().Msg("Creating session cookie") | 	log.Debug().Msg("Creating session cookie") | ||||||
| 
 | 
 | ||||||
| 	session, err := auth.GetSession(c) | 	session, err := auth.GetSession(c) | ||||||
| @@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { | func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { | ||||||
| 	log.Debug().Msg("Deleting session cookie") | 	log.Debug().Msg("Deleting session cookie") | ||||||
| 
 | 
 | ||||||
| 	session, err := auth.GetSession(c) | 	session, err := auth.GetSession(c) | ||||||
| @@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { | func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { | ||||||
| 	log.Debug().Msg("Getting session cookie") | 	log.Debug().Msg("Getting session cookie") | ||||||
| 
 | 
 | ||||||
| 	session, err := auth.GetSession(c) | 	session, err := auth.GetSession(c) | ||||||
| @@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | |||||||
| 	}, nil | 	}, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) UserAuthConfigured() bool { | func (auth *AuthService) UserAuthConfigured() bool { | ||||||
| 	// If there are users or LDAP is configured, return true | 	// If there are users or LDAP is configured, return true | ||||||
| 	return len(auth.Config.Users) > 0 || auth.LDAP != nil | 	return len(auth.Config.Users) > 0 || auth.LDAP != nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { | func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||||
| 	if context.OAuth { | 	if context.OAuth { | ||||||
| 		log.Debug().Msg("Checking OAuth whitelist") | 		log.Debug().Msg("Checking OAuth whitelist") | ||||||
| 		return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) | 		return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) | ||||||
| @@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab | |||||||
| 	return utils.CheckFilter(labels.Users, context.Username) | 	return utils.CheckFilter(labels.Users, context.Username) | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { | func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||||
| 	if labels.OAuth.Groups == "" { | 	if labels.OAuth.Groups == "" { | ||||||
| 		return true | 		return true | ||||||
| 	} | 	} | ||||||
| @@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t | |||||||
| 	return false | 	return false | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { | func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { | ||||||
| 	// If the label is empty, auth is enabled | 	// If the label is empty, auth is enabled | ||||||
| 	if labels.Allowed == "" { | 	if labels.Allowed == "" { | ||||||
| 		return true, nil | 		return true, nil | ||||||
| @@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { | |||||||
| 	return true, nil | 	return true, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { | func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { | ||||||
| 	username, password, ok := c.Request.BasicAuth() | 	username, password, ok := c.Request.BasicAuth() | ||||||
| 	if !ok { | 	if !ok { | ||||||
| 		return nil | 		return nil | ||||||
| @@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { | ||||||
| 	// Check if the IP is in block list | 	// Check if the IP is in block list | ||||||
| 	for _, blocked := range labels.IP.Block { | 	for _, blocked := range labels.IP.Block { | ||||||
| 		res, err := utils.FilterIP(blocked, ip) | 		res, err := utils.FilterIP(blocked, ip) | ||||||
| @@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | |||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { | func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { | ||||||
| 	// For every IP in the bypass list, check if the IP matches | 	// For every IP in the bypass list, check if the IP matches | ||||||
| 	for _, bypassed := range labels.IP.Bypass { | 	for _, bypassed := range labels.IP.Bypass { | ||||||
| 		res, err := utils.FilterIP(bypassed, ip) | 		res, err := utils.FilterIP(bypassed, ip) | ||||||
| @@ -1,9 +1,9 @@ | |||||||
| package docker | package service | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"strings" | 	"strings" | ||||||
| 	"tinyauth/internal/types" | 	"tinyauth/internal/config" | ||||||
| 	"tinyauth/internal/utils" | 	"tinyauth/internal/utils" | ||||||
| 
 | 
 | ||||||
| 	container "github.com/docker/docker/api/types/container" | 	container "github.com/docker/docker/api/types/container" | ||||||
| @@ -11,27 +11,27 @@ import ( | |||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type Docker struct { | type DockerService struct { | ||||||
| 	Client  *client.Client | 	Client  *client.Client | ||||||
| 	Context context.Context | 	Context context.Context | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewDocker() (*Docker, error) { | func NewDockerService() *DockerService { | ||||||
|  | 	return &DockerService{} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (docker *DockerService) Init() error { | ||||||
| 	client, err := client.NewClientWithOpts(client.FromEnv) | 	client, err := client.NewClientWithOpts(client.FromEnv) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	ctx := context.Background() | 	ctx := context.Background() | ||||||
| 	client.NegotiateAPIVersion(ctx) | 	client.NegotiateAPIVersion(ctx) | ||||||
| 
 | 	return nil | ||||||
| 	return &Docker{ |  | ||||||
| 		Client:  client, |  | ||||||
| 		Context: ctx, |  | ||||||
| 	}, nil |  | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (docker *Docker) GetContainers() ([]container.Summary, error) { | func (docker *DockerService) GetContainers() ([]container.Summary, error) { | ||||||
| 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| @@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) { | |||||||
| 	return containers, nil | 	return containers, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { | func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return container.InspectResponse{}, err | 		return container.InspectResponse{}, err | ||||||
| @@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes | |||||||
| 	return inspect, nil | 	return inspect, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (docker *Docker) DockerConnected() bool { | func (docker *DockerService) DockerConnected() bool { | ||||||
| 	_, err := docker.Client.Ping(docker.Context) | 	_, err := docker.Client.Ping(docker.Context) | ||||||
| 	return err == nil | 	return err == nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { | func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { | ||||||
| 	isConnected := docker.DockerConnected() | 	isConnected := docker.DockerConnected() | ||||||
| 
 | 
 | ||||||
| 	if !isConnected { | 	if !isConnected { | ||||||
| 		log.Debug().Msg("Docker not connected, returning empty labels") | 		log.Debug().Msg("Docker not connected, returning empty labels") | ||||||
| 		return types.Labels{}, nil | 		return config.Labels{}, nil | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Debug().Msg("Getting containers") | 	log.Debug().Msg("Getting containers") | ||||||
| @@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) | |||||||
| 	containers, err := docker.GetContainers() | 	containers, err := docker.GetContainers() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		log.Error().Err(err).Msg("Error getting containers") | 		log.Error().Err(err).Msg("Error getting containers") | ||||||
| 		return types.Labels{}, err | 		return config.Labels{}, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	for _, container := range containers { | 	for _, container := range containers { | ||||||
| @@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Debug().Msg("No matching container found, returning empty labels") | 	log.Debug().Msg("No matching container found, returning empty labels") | ||||||
| 	return types.Labels{}, nil | 	return config.Labels{}, nil | ||||||
| } | } | ||||||
							
								
								
									
										114
									
								
								internal/service/generic_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								internal/service/generic_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | |||||||
|  | package service | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"crypto/tls" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type GenericOAuthService struct { | ||||||
|  | 	Config             oauth2.Config | ||||||
|  | 	Context            context.Context | ||||||
|  | 	Token              *oauth2.Token | ||||||
|  | 	Verifier           string | ||||||
|  | 	InsecureSkipVerify bool | ||||||
|  | 	ServiceName        string | ||||||
|  | 	UserinfoURL        string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { | ||||||
|  | 	return &GenericOAuthService{ | ||||||
|  | 		Config: oauth2.Config{ | ||||||
|  | 			ClientID:     config.ClientID, | ||||||
|  | 			ClientSecret: config.ClientSecret, | ||||||
|  | 			RedirectURL:  config.RedirectURL, | ||||||
|  | 			Scopes:       config.Scopes, | ||||||
|  | 			Endpoint: oauth2.Endpoint{ | ||||||
|  | 				AuthURL:  config.AuthURL, | ||||||
|  | 				TokenURL: config.TokenURL, | ||||||
|  | 			}, | ||||||
|  | 		}, | ||||||
|  | 		InsecureSkipVerify: config.InsecureSkipVerify, | ||||||
|  | 		ServiceName:        config.Name, | ||||||
|  | 		UserinfoURL:        config.UserinfoURL, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) Init() error { | ||||||
|  | 	transport := &http.Transport{ | ||||||
|  | 		TLSClientConfig: &tls.Config{ | ||||||
|  | 			InsecureSkipVerify: generic.InsecureSkipVerify, | ||||||
|  | 			MinVersion:         tls.VersionTLS12, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	httpClient := &http.Client{ | ||||||
|  | 		Transport: transport, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	ctx := context.Background() | ||||||
|  |  | ||||||
|  | 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||||
|  | 	verifier := oauth2.GenerateVerifier() | ||||||
|  |  | ||||||
|  | 	generic.Context = ctx | ||||||
|  | 	generic.Verifier = verifier | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) Name() string { | ||||||
|  | 	return generic.ServiceName | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) GenerateState() string { | ||||||
|  | 	b := make([]byte, 128) | ||||||
|  | 	rand.Read(b) | ||||||
|  | 	state := base64.URLEncoding.EncodeToString(b) | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) GetAuthURL(state string) string { | ||||||
|  | 	return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) VerifyCode(code string) error { | ||||||
|  | 	token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	generic.Token = token | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { | ||||||
|  | 	var user config.Claims | ||||||
|  |  | ||||||
|  | 	client := generic.Config.Client(generic.Context, generic.Token) | ||||||
|  |  | ||||||
|  | 	res, err := client.Get(generic.UserinfoURL) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	body, err := io.ReadAll(res.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	err = json.Unmarshal(body, &user) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return user, nil | ||||||
|  | } | ||||||
							
								
								
									
										144
									
								
								internal/service/github_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								internal/service/github_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | |||||||
|  | package service | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"errors" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var GithubOAuthScopes = []string{"user:email", "read:user"} | ||||||
|  |  | ||||||
|  | type GithubEmailResponse []struct { | ||||||
|  | 	Email   string `json:"email"` | ||||||
|  | 	Primary bool   `json:"primary"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GithubUserInfoResponse struct { | ||||||
|  | 	Login string `json:"login"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GithubOAuthService struct { | ||||||
|  | 	Config   oauth2.Config | ||||||
|  | 	Context  context.Context | ||||||
|  | 	Token    *oauth2.Token | ||||||
|  | 	Verifier string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { | ||||||
|  | 	return &GithubOAuthService{ | ||||||
|  | 		Config: oauth2.Config{ | ||||||
|  | 			ClientID:     config.ClientID, | ||||||
|  | 			ClientSecret: config.ClientSecret, | ||||||
|  | 			RedirectURL:  config.RedirectURL, | ||||||
|  | 			Scopes:       GithubOAuthScopes, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) Init() error { | ||||||
|  | 	httpClient := &http.Client{} | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||||
|  | 	verifier := oauth2.GenerateVerifier() | ||||||
|  |  | ||||||
|  | 	github.Context = ctx | ||||||
|  | 	github.Verifier = verifier | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) Name() string { | ||||||
|  | 	return "github" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) GenerateState() string { | ||||||
|  | 	b := make([]byte, 128) | ||||||
|  | 	rand.Read(b) | ||||||
|  | 	state := base64.URLEncoding.EncodeToString(b) | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) GetAuthURL(state string) string { | ||||||
|  | 	return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) VerifyCode(code string) error { | ||||||
|  | 	token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	github.Token = token | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (github *GithubOAuthService) Userinfo() (config.Claims, error) { | ||||||
|  | 	var user config.Claims | ||||||
|  |  | ||||||
|  | 	client := github.Config.Client(github.Context, github.Token) | ||||||
|  |  | ||||||
|  | 	res, err := client.Get("https://api.github.com/user") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	body, err := io.ReadAll(res.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var userInfo GithubUserInfoResponse | ||||||
|  |  | ||||||
|  | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	res, err = client.Get("https://api.github.com/user/emails") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	body, err = io.ReadAll(res.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var emails GithubEmailResponse | ||||||
|  |  | ||||||
|  | 	err = json.Unmarshal(body, &emails) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return user, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	for _, email := range emails { | ||||||
|  | 		if email.Primary { | ||||||
|  | 			user.Email = email.Email | ||||||
|  | 			break | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	if len(emails) == 0 { | ||||||
|  | 		return user, errors.New("no emails found") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Use first available email if no primary email was found | ||||||
|  | 	if user.Email == "" { | ||||||
|  | 		user.Email = emails[0].Email | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	user.PreferredUsername = userInfo.Login | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  |  | ||||||
|  | 	return user, nil | ||||||
|  | } | ||||||
							
								
								
									
										106
									
								
								internal/service/google_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								internal/service/google_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | |||||||
|  | package service | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"context" | ||||||
|  | 	"crypto/rand" | ||||||
|  | 	"encoding/base64" | ||||||
|  | 	"encoding/json" | ||||||
|  | 	"io" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strings" | ||||||
|  | 	"tinyauth/internal/config" | ||||||
|  |  | ||||||
|  | 	"golang.org/x/oauth2" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||||
|  |  | ||||||
|  | type GoogleUserInfoResponse struct { | ||||||
|  | 	Email string `json:"email"` | ||||||
|  | 	Name  string `json:"name"` | ||||||
|  | } | ||||||
|  |  | ||||||
|  | type GoogleOAuthService struct { | ||||||
|  | 	Config   oauth2.Config | ||||||
|  | 	Context  context.Context | ||||||
|  | 	Token    *oauth2.Token | ||||||
|  | 	Verifier string | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { | ||||||
|  | 	return &GoogleOAuthService{ | ||||||
|  | 		Config: oauth2.Config{ | ||||||
|  | 			ClientID:     config.ClientID, | ||||||
|  | 			ClientSecret: config.ClientSecret, | ||||||
|  | 			RedirectURL:  config.RedirectURL, | ||||||
|  | 			Scopes:       GoogleOAuthScopes, | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) Init() error { | ||||||
|  | 	httpClient := &http.Client{} | ||||||
|  | 	ctx := context.Background() | ||||||
|  | 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||||
|  | 	verifier := oauth2.GenerateVerifier() | ||||||
|  |  | ||||||
|  | 	google.Context = ctx | ||||||
|  | 	google.Verifier = verifier | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) Name() string { | ||||||
|  | 	return "google" | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (oauth *GoogleOAuthService) GenerateState() string { | ||||||
|  | 	b := make([]byte, 128) | ||||||
|  | 	rand.Read(b) | ||||||
|  | 	state := base64.URLEncoding.EncodeToString(b) | ||||||
|  | 	return state | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) GetAuthURL(state string) string { | ||||||
|  | 	return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) VerifyCode(code string) error { | ||||||
|  | 	token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) | ||||||
|  |  | ||||||
|  | 	if err != nil { | ||||||
|  | 		return nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	google.Token = token | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { | ||||||
|  | 	var user config.Claims | ||||||
|  |  | ||||||
|  | 	client := google.Config.Client(google.Context, google.Token) | ||||||
|  |  | ||||||
|  | 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||||
|  | 	if err != nil { | ||||||
|  | 		return config.Claims{}, err | ||||||
|  | 	} | ||||||
|  | 	defer res.Body.Close() | ||||||
|  |  | ||||||
|  | 	body, err := io.ReadAll(res.Body) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return config.Claims{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	var userInfo GoogleUserInfoResponse | ||||||
|  |  | ||||||
|  | 	err = json.Unmarshal(body, &userInfo) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return config.Claims{}, err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||||
|  | 	user.Name = userInfo.Name | ||||||
|  | 	user.Email = userInfo.Email | ||||||
|  |  | ||||||
|  | 	return user, nil | ||||||
|  | } | ||||||
| @@ -1,30 +1,40 @@ | |||||||
| package ldap | package service | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
| 	"crypto/tls" | 	"crypto/tls" | ||||||
| 	"fmt" | 	"fmt" | ||||||
| 	"time" | 	"time" | ||||||
| 	"tinyauth/internal/types" |  | ||||||
| 
 | 
 | ||||||
| 	"github.com/cenkalti/backoff/v5" | 	"github.com/cenkalti/backoff/v5" | ||||||
| 	ldapgo "github.com/go-ldap/ldap/v3" | 	ldapgo "github.com/go-ldap/ldap/v3" | ||||||
| 	"github.com/rs/zerolog/log" | 	"github.com/rs/zerolog/log" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| type LDAP struct { | type LdapServiceConfig struct { | ||||||
| 	Config types.LdapConfig | 	Address      string | ||||||
|  | 	BindDN       string | ||||||
|  | 	BindPassword string | ||||||
|  | 	BaseDN       string | ||||||
|  | 	Insecure     bool | ||||||
|  | 	SearchFilter string | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type LdapService struct { | ||||||
|  | 	Config LdapServiceConfig | ||||||
| 	Conn   *ldapgo.Conn | 	Conn   *ldapgo.Conn | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func NewLDAP(config types.LdapConfig) (*LDAP, error) { | func NewLdapService(config LdapServiceConfig) *LdapService { | ||||||
| 	ldap := &LDAP{ | 	return &LdapService{ | ||||||
| 		Config: config, | 		Config: config, | ||||||
| 	} | 	} | ||||||
|  | } | ||||||
| 
 | 
 | ||||||
|  | func (ldap *LdapService) Init() error { | ||||||
| 	_, err := ldap.connect() | 	_, err := ldap.connect() | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) | 		return fmt.Errorf("failed to connect to LDAP server: %w", err) | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	go func() { | 	go func() { | ||||||
| @@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { | |||||||
| 		} | 		} | ||||||
| 	}() | 	}() | ||||||
| 
 | 
 | ||||||
| 	return ldap, nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LDAP) connect() (*ldapgo.Conn, error) { | func (ldap *LdapService) connect() (*ldapgo.Conn, error) { | ||||||
| 	log.Debug().Msg("Connecting to LDAP server") | 	log.Debug().Msg("Connecting to LDAP server") | ||||||
| 	conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | 	conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | ||||||
| 		InsecureSkipVerify: l.Config.Insecure, | 		InsecureSkipVerify: ldap.Config.Insecure, | ||||||
| 		MinVersion:         tls.VersionTLS12, | 		MinVersion:         tls.VersionTLS12, | ||||||
| 	})) | 	})) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| @@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) { | |||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	log.Debug().Msg("Binding to LDAP server") | 	log.Debug().Msg("Binding to LDAP server") | ||||||
| 	err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) | 	err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return nil, err | 		return nil, err | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	// Set and return the connection | 	// Set and return the connection | ||||||
| 	l.Conn = conn | 	ldap.Conn = conn | ||||||
| 	return conn, nil | 	return conn, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LDAP) Search(username string) (string, error) { | func (ldap *LdapService) Search(username string) (string, error) { | ||||||
| 	// Escape the username to prevent LDAP injection | 	// Escape the username to prevent LDAP injection | ||||||
| 	escapedUsername := ldapgo.EscapeFilter(username) | 	escapedUsername := ldapgo.EscapeFilter(username) | ||||||
| 	filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) | 	filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername) | ||||||
| 
 | 
 | ||||||
| 	searchRequest := ldapgo.NewSearchRequest( | 	searchRequest := ldapgo.NewSearchRequest( | ||||||
| 		l.Config.BaseDN, | 		ldap.Config.BaseDN, | ||||||
| 		ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, | 		ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, | ||||||
| 		filter, | 		filter, | ||||||
| 		[]string{"dn"}, | 		[]string{"dn"}, | ||||||
| 		nil, | 		nil, | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	searchResult, err := l.Conn.Search(searchRequest) | 	searchResult, err := ldap.Conn.Search(searchRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return "", err | 		return "", err | ||||||
| 	} | 	} | ||||||
| @@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) { | |||||||
| 	return userDN, nil | 	return userDN, nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LDAP) Bind(userDN string, password string) error { | func (ldap *LdapService) Bind(userDN string, password string) error { | ||||||
| 	err := l.Conn.Bind(userDN, password) | 	err := ldap.Conn.Bind(userDN, password) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LDAP) heartbeat() error { | func (ldap *LdapService) heartbeat() error { | ||||||
| 	log.Debug().Msg("Performing LDAP connection heartbeat") | 	log.Debug().Msg("Performing LDAP connection heartbeat") | ||||||
| 
 | 
 | ||||||
| 	searchRequest := ldapgo.NewSearchRequest( | 	searchRequest := ldapgo.NewSearchRequest( | ||||||
| @@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error { | |||||||
| 		nil, | 		nil, | ||||||
| 	) | 	) | ||||||
| 
 | 
 | ||||||
| 	_, err := l.Conn.Search(searchRequest) | 	_, err := ldap.Conn.Search(searchRequest) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		return err | 		return err | ||||||
| 	} | 	} | ||||||
| @@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error { | |||||||
| 	return nil | 	return nil | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (l *LDAP) reconnect() error { | func (ldap *LdapService) reconnect() error { | ||||||
| 	log.Info().Msg("Reconnecting to LDAP server") | 	log.Info().Msg("Reconnecting to LDAP server") | ||||||
| 
 | 
 | ||||||
| 	exp := backoff.NewExponentialBackOff() | 	exp := backoff.NewExponentialBackOff() | ||||||
| @@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error { | |||||||
| 	exp.Reset() | 	exp.Reset() | ||||||
| 
 | 
 | ||||||
| 	operation := func() (*ldapgo.Conn, error) { | 	operation := func() (*ldapgo.Conn, error) { | ||||||
| 		l.Conn.Close() | 		ldap.Conn.Close() | ||||||
| 		conn, err := l.connect() | 		conn, err := ldap.connect() | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return nil, nil | 			return nil, nil | ||||||
| 		} | 		} | ||||||
| @@ -1,70 +0,0 @@ | |||||||
| package types |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"time" |  | ||||||
| 	"tinyauth/internal/oauth" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| // User is the struct for a user |  | ||||||
| type User struct { |  | ||||||
| 	Username   string |  | ||||||
| 	Password   string |  | ||||||
| 	TotpSecret string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // UserSearch is the response of the get user |  | ||||||
| type UserSearch struct { |  | ||||||
| 	Username string |  | ||||||
| 	Type     string // "local", "ldap" or empty |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // Users is a list of users |  | ||||||
| type Users []User |  | ||||||
|  |  | ||||||
| // OAuthProviders is the struct for the OAuth providers |  | ||||||
| type OAuthProviders struct { |  | ||||||
| 	Github    *oauth.OAuth |  | ||||||
| 	Google    *oauth.OAuth |  | ||||||
| 	Microsoft *oauth.OAuth |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // SessionCookie is the cookie for the session (exculding the expiry) |  | ||||||
| type SessionCookie struct { |  | ||||||
| 	Username    string |  | ||||||
| 	Name        string |  | ||||||
| 	Email       string |  | ||||||
| 	Provider    string |  | ||||||
| 	TotpPending bool |  | ||||||
| 	OAuthGroups string |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // UserContext is the context for the user |  | ||||||
| type UserContext struct { |  | ||||||
| 	Username    string |  | ||||||
| 	Name        string |  | ||||||
| 	Email       string |  | ||||||
| 	IsLoggedIn  bool |  | ||||||
| 	OAuth       bool |  | ||||||
| 	Provider    string |  | ||||||
| 	TotpPending bool |  | ||||||
| 	OAuthGroups string |  | ||||||
| 	TotpEnabled bool |  | ||||||
| } |  | ||||||
|  |  | ||||||
| // LoginAttempt tracks information about login attempts for rate limiting |  | ||||||
| type LoginAttempt struct { |  | ||||||
| 	FailedAttempts int |  | ||||||
| 	LastAttempt    time.Time |  | ||||||
| 	LockedUntil    time.Time |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type UnauthorizedQuery struct { |  | ||||||
| 	Username string `url:"username"` |  | ||||||
| 	Resource string `url:"resource"` |  | ||||||
| 	GroupErr bool   `url:"groupErr"` |  | ||||||
| 	IP       string `url:"ip"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type RedirectQuery struct { |  | ||||||
| 	RedirectURI string `url:"redirect_uri"` |  | ||||||
| } |  | ||||||
		Reference in New Issue
	
	Block a user
	 Stavros
					Stavros