mirror of
				https://github.com/steveiliop56/tinyauth.git
				synced 2025-10-30 21:55:43 +00:00 
			
		
		
		
	refactor: move oauth providers into services (non-working)
This commit is contained in:
		| @@ -1,146 +0,0 @@ | ||||
| package auth_test | ||||
|  | ||||
| import ( | ||||
| 	"testing" | ||||
| 	"time" | ||||
| 	"tinyauth/internal/auth" | ||||
| 	"tinyauth/internal/types" | ||||
| ) | ||||
|  | ||||
| var config = types.AuthConfig{ | ||||
| 	Users:          types.Users{}, | ||||
| 	OauthWhitelist: "", | ||||
| 	SessionExpiry:  3600, | ||||
| } | ||||
|  | ||||
| func TestLoginRateLimiting(t *testing.T) { | ||||
| 	// Initialize a new auth service with 3 max retries and 5 seconds timeout | ||||
| 	config.LoginMaxRetries = 3 | ||||
| 	config.LoginTimeout = 5 | ||||
| 	authService := auth.NewAuth(config, nil, nil) | ||||
|  | ||||
| 	// Test identifier | ||||
| 	identifier := "test_user" | ||||
|  | ||||
| 	// Test successful login - should not lock account | ||||
| 	t.Log("Testing successful login") | ||||
|  | ||||
| 	authService.RecordLoginAttempt(identifier, true) | ||||
| 	locked, _ := authService.IsAccountLocked(identifier) | ||||
|  | ||||
| 	if locked { | ||||
| 		t.Fatalf("Account should not be locked after successful login") | ||||
| 	} | ||||
|  | ||||
| 	// Test 2 failed attempts - should not lock account yet | ||||
| 	t.Log("Testing 2 failed login attempts") | ||||
|  | ||||
| 	authService.RecordLoginAttempt(identifier, false) | ||||
| 	authService.RecordLoginAttempt(identifier, false) | ||||
| 	locked, _ = authService.IsAccountLocked(identifier) | ||||
|  | ||||
| 	if locked { | ||||
| 		t.Fatalf("Account should not be locked after only 2 failed attempts") | ||||
| 	} | ||||
|  | ||||
| 	// Add one more failed attempt (total 3) - should lock account with maxRetries=3 | ||||
| 	t.Log("Testing 3 failed login attempts") | ||||
| 	authService.RecordLoginAttempt(identifier, false) | ||||
| 	locked, remainingTime := authService.IsAccountLocked(identifier) | ||||
|  | ||||
| 	if !locked { | ||||
| 		t.Fatalf("Account should be locked after reaching max retries") | ||||
| 	} | ||||
| 	if remainingTime <= 0 || remainingTime > 5 { | ||||
| 		t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime) | ||||
| 	} | ||||
|  | ||||
| 	// Test reset after waiting for timeout - use 1 second timeout for fast testing | ||||
| 	t.Log("Testing unlocking after timeout") | ||||
|  | ||||
| 	// Reinitialize auth service with a shorter timeout for testing | ||||
| 	config.LoginTimeout = 1 | ||||
| 	config.LoginMaxRetries = 3 | ||||
| 	authService = auth.NewAuth(config, nil, nil) | ||||
|  | ||||
| 	// Add enough failed attempts to lock the account | ||||
| 	for i := 0; i < 3; i++ { | ||||
| 		authService.RecordLoginAttempt(identifier, false) | ||||
| 	} | ||||
|  | ||||
| 	// Verify it's locked | ||||
| 	locked, _ = authService.IsAccountLocked(identifier) | ||||
| 	if !locked { | ||||
| 		t.Fatalf("Account should be locked initially") | ||||
| 	} | ||||
|  | ||||
| 	// Wait a bit and verify it gets unlocked after timeout | ||||
| 	time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout | ||||
| 	locked, _ = authService.IsAccountLocked(identifier) | ||||
|  | ||||
| 	if locked { | ||||
| 		t.Fatalf("Account should be unlocked after timeout period") | ||||
| 	} | ||||
|  | ||||
| 	// Test disabled rate limiting | ||||
| 	t.Log("Testing disabled rate limiting") | ||||
| 	config.LoginMaxRetries = 0 | ||||
| 	config.LoginTimeout = 0 | ||||
| 	authService = auth.NewAuth(config, nil, nil) | ||||
|  | ||||
| 	for i := 0; i < 10; i++ { | ||||
| 		authService.RecordLoginAttempt(identifier, false) | ||||
| 	} | ||||
|  | ||||
| 	locked, _ = authService.IsAccountLocked(identifier) | ||||
| 	if locked { | ||||
| 		t.Fatalf("Account should not be locked when rate limiting is disabled") | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func TestConcurrentLoginAttempts(t *testing.T) { | ||||
| 	// Initialize a new auth service with 2 max retries and 5 seconds timeout | ||||
| 	config.LoginMaxRetries = 2 | ||||
| 	config.LoginTimeout = 5 | ||||
| 	authService := auth.NewAuth(config, nil, nil) | ||||
|  | ||||
| 	// Test multiple identifiers | ||||
| 	identifiers := []string{"user1", "user2", "user3"} | ||||
|  | ||||
| 	// Test that locking one identifier doesn't affect others | ||||
| 	t.Log("Testing multiple identifiers") | ||||
|  | ||||
| 	// Add enough failed attempts to lock first user (2 attempts with maxRetries=2) | ||||
| 	authService.RecordLoginAttempt(identifiers[0], false) | ||||
| 	authService.RecordLoginAttempt(identifiers[0], false) | ||||
|  | ||||
| 	// Check if first user is locked | ||||
| 	locked, _ := authService.IsAccountLocked(identifiers[0]) | ||||
| 	if !locked { | ||||
| 		t.Fatalf("User1 should be locked after reaching max retries") | ||||
| 	} | ||||
|  | ||||
| 	// Check that other users are not affected | ||||
| 	for i := 1; i < len(identifiers); i++ { | ||||
| 		locked, _ := authService.IsAccountLocked(identifiers[i]) | ||||
| 		if locked { | ||||
| 			t.Fatalf("User%d should not be locked", i+1) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// Test successful login after failed attempts (but before lock) | ||||
| 	t.Log("Testing successful login after failed attempts but before lock") | ||||
|  | ||||
| 	// One failed attempt for user2 | ||||
| 	authService.RecordLoginAttempt(identifiers[1], false) | ||||
|  | ||||
| 	// Successful login should reset the counter | ||||
| 	authService.RecordLoginAttempt(identifiers[1], true) | ||||
|  | ||||
| 	// Now try a failed login again - should not be locked as counter was reset | ||||
| 	authService.RecordLoginAttempt(identifiers[1], false) | ||||
| 	locked, _ = authService.IsAccountLocked(identifiers[1]) | ||||
| 	if locked { | ||||
| 		t.Fatalf("User2 should not be locked after successful login reset") | ||||
| 	} | ||||
| } | ||||
| @@ -1,6 +1,22 @@ | ||||
| package types | ||||
| package config | ||||
| 
 | ||||
| import "time" | ||||
| 
 | ||||
| type Claims struct { | ||||
| 	Name              string `json:"name"` | ||||
| 	Email             string `json:"email"` | ||||
| 	PreferredUsername string `json:"preferred_username"` | ||||
| 	Groups            any    `json:"groups"` | ||||
| } | ||||
| 
 | ||||
| var Version = "development" | ||||
| var CommitHash = "n/a" | ||||
| var BuildTimestamp = "n/a" | ||||
| 
 | ||||
| var SessionCookieName = "tinyauth-session" | ||||
| var CsrfCookieName = "tinyauth-csrf" | ||||
| var RedirectCookieName = "tinyauth-redirect" | ||||
| 
 | ||||
| // Config is the configuration for the tinyauth server | ||||
| type Config struct { | ||||
| 	Port                    int    `mapstructure:"port" validate:"required"` | ||||
| 	Address                 string `validate:"required,ip4_addr" mapstructure:"address"` | ||||
| @@ -44,62 +60,27 @@ type Config struct { | ||||
| 	LdapSearchFilter        string `mapstructure:"ldap-search-filter"` | ||||
| } | ||||
| 
 | ||||
| // OAuthConfig is the configuration for the providers | ||||
| type OAuthConfig struct { | ||||
| 	GithubClientId      string | ||||
| 	GithubClientSecret  string | ||||
| 	GoogleClientId      string | ||||
| 	GoogleClientSecret  string | ||||
| 	GenericClientId     string | ||||
| 	GenericClientSecret string | ||||
| 	GenericScopes       []string | ||||
| 	GenericAuthURL      string | ||||
| 	GenericTokenURL     string | ||||
| 	GenericUserURL      string | ||||
| 	GenericSkipSSL      bool | ||||
| 	AppURL              string | ||||
| } | ||||
| 
 | ||||
| // AuthConfig is the configuration for the auth service | ||||
| type AuthConfig struct { | ||||
| 	Users             Users | ||||
| 	OauthWhitelist    string | ||||
| 	SessionExpiry     int | ||||
| 	CookieSecure      bool | ||||
| 	Domain            string | ||||
| 	LoginTimeout      int | ||||
| 	LoginMaxRetries   int | ||||
| 	SessionCookieName string | ||||
| 	HMACSecret        string | ||||
| 	EncryptionSecret  string | ||||
| } | ||||
| 
 | ||||
| // OAuthLabels is a list of labels that can be used in a tinyauth protected container | ||||
| type OAuthLabels struct { | ||||
| 	Whitelist string | ||||
| 	Groups    string | ||||
| } | ||||
| 
 | ||||
| // Basic auth labels for a tinyauth protected container | ||||
| type BasicLabels struct { | ||||
| 	Username string | ||||
| 	Password PassowrdLabels | ||||
| } | ||||
| 
 | ||||
| // PassowrdLabels is a struct that contains the password labels for a tinyauth protected container | ||||
| type PassowrdLabels struct { | ||||
| 	Plain string | ||||
| 	File  string | ||||
| } | ||||
| 
 | ||||
| // IP labels for a tinyauth protected container | ||||
| type IPLabels struct { | ||||
| 	Allow  []string | ||||
| 	Block  []string | ||||
| 	Bypass []string | ||||
| } | ||||
| 
 | ||||
| // Labels is a struct that contains the labels for a tinyauth protected container | ||||
| type Labels struct { | ||||
| 	Users   string | ||||
| 	Allowed string | ||||
| @@ -110,12 +91,65 @@ type Labels struct { | ||||
| 	IP      IPLabels | ||||
| } | ||||
| 
 | ||||
| // Ldap config is a struct that contains the configuration for the LDAP service | ||||
| type LdapConfig struct { | ||||
| 	Address      string | ||||
| 	BindDN       string | ||||
| 	BindPassword string | ||||
| 	BaseDN       string | ||||
| 	Insecure     bool | ||||
| 	SearchFilter string | ||||
| type OAuthServiceConfig struct { | ||||
| 	ClientID           string | ||||
| 	ClientSecret       string | ||||
| 	Scopes             []string | ||||
| 	RedirectURL        string | ||||
| 	AuthURL            string | ||||
| 	TokenURL           string | ||||
| 	UserinfoURL        string | ||||
| 	InsecureSkipVerify bool | ||||
| 	Name               string | ||||
| } | ||||
| 
 | ||||
| type User struct { | ||||
| 	Username   string | ||||
| 	Password   string | ||||
| 	TotpSecret string | ||||
| } | ||||
| 
 | ||||
| type UserSearch struct { | ||||
| 	Username string | ||||
| 	Type     string // local, ldap or unknown | ||||
| } | ||||
| 
 | ||||
| type Users []User | ||||
| 
 | ||||
| type SessionCookie struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| } | ||||
| 
 | ||||
| type UserContext struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	IsLoggedIn  bool | ||||
| 	OAuth       bool | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| 	TotpEnabled bool | ||||
| } | ||||
| 
 | ||||
| type LoginAttempt struct { | ||||
| 	FailedAttempts int | ||||
| 	LastAttempt    time.Time | ||||
| 	LockedUntil    time.Time | ||||
| } | ||||
| 
 | ||||
| type UnauthorizedQuery struct { | ||||
| 	Username string `url:"username"` | ||||
| 	Resource string `url:"resource"` | ||||
| 	GroupErr bool   `url:"groupErr"` | ||||
| 	IP       string `url:"ip"` | ||||
| } | ||||
| 
 | ||||
| type RedirectQuery struct { | ||||
| 	RedirectURI string `url:"redirect_uri"` | ||||
| } | ||||
| @@ -1,19 +0,0 @@ | ||||
| package constants | ||||
|  | ||||
| // Claims are the OIDC supported claims (prefered username is included for convinience) | ||||
| type Claims struct { | ||||
| 	Name              string `json:"name"` | ||||
| 	Email             string `json:"email"` | ||||
| 	PreferredUsername string `json:"preferred_username"` | ||||
| 	Groups            any    `json:"groups"` | ||||
| } | ||||
|  | ||||
| // Version information | ||||
| var Version = "development" | ||||
| var CommitHash = "n/a" | ||||
| var BuildTimestamp = "n/a" | ||||
|  | ||||
| // Base cookie names | ||||
| var SessionCookieName = "tinyauth-session" | ||||
| var CsrfCookieName = "tinyauth-csrf" | ||||
| var RedirectCookieName = "tinyauth-redirect" | ||||
| @@ -1,71 +0,0 @@ | ||||
| package oauth | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"net/http" | ||||
|  | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
|  | ||||
| type OAuth struct { | ||||
| 	Config   oauth2.Config | ||||
| 	Context  context.Context | ||||
| 	Token    *oauth2.Token | ||||
| 	Verifier string | ||||
| } | ||||
|  | ||||
| func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { | ||||
| 	transport := &http.Transport{ | ||||
| 		TLSClientConfig: &tls.Config{ | ||||
| 			InsecureSkipVerify: insecureSkipVerify, | ||||
| 			MinVersion:         tls.VersionTLS12, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	httpClient := &http.Client{ | ||||
| 		Transport: transport, | ||||
| 	} | ||||
|  | ||||
| 	ctx := context.Background() | ||||
|  | ||||
| 	// Set the HTTP client in the context | ||||
| 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||
|  | ||||
| 	verifier := oauth2.GenerateVerifier() | ||||
|  | ||||
| 	return &OAuth{ | ||||
| 		Config:   config, | ||||
| 		Context:  ctx, | ||||
| 		Verifier: verifier, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (oauth *OAuth) GetAuthURL(state string) string { | ||||
| 	return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) | ||||
| } | ||||
|  | ||||
| func (oauth *OAuth) ExchangeToken(code string) (string, error) { | ||||
| 	token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
|  | ||||
| 	// Set and return the token | ||||
| 	oauth.Token = token | ||||
| 	return oauth.Token.AccessToken, nil | ||||
| } | ||||
|  | ||||
| func (oauth *OAuth) GetClient() *http.Client { | ||||
| 	return oauth.Config.Client(oauth.Context, oauth.Token) | ||||
| } | ||||
|  | ||||
| func (oauth *OAuth) GenerateState() string { | ||||
| 	b := make([]byte, 128) | ||||
| 	rand.Read(b) | ||||
| 	state := base64.URLEncoding.EncodeToString(b) | ||||
| 	return state | ||||
| } | ||||
| @@ -1,37 +0,0 @@ | ||||
| package providers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| func GetGenericUser(client *http.Client, url string) (constants.Claims, error) { | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	res, err := client.Get(url) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got response from generic provider") | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read body from generic provider") | ||||
|  | ||||
| 	err = json.Unmarshal(body, &user) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed user from generic provider") | ||||
| 	return user, nil | ||||
| } | ||||
| @@ -1,102 +0,0 @@ | ||||
| package providers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| // Response for the github email endpoint | ||||
| type GithubEmailResponse []struct { | ||||
| 	Email   string `json:"email"` | ||||
| 	Primary bool   `json:"primary"` | ||||
| } | ||||
|  | ||||
| // Response for the github user endpoint | ||||
| type GithubUserInfoResponse struct { | ||||
| 	Login string `json:"login"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| // The scopes required for the github provider | ||||
| func GithubScopes() []string { | ||||
| 	return []string{"user:email", "read:user"} | ||||
| } | ||||
|  | ||||
| func GetGithubUser(client *http.Client) (constants.Claims, error) { | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	res, err := client.Get("https://api.github.com/user") | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got user response from github") | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read user body from github") | ||||
|  | ||||
| 	var userInfo GithubUserInfoResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	res, err = client.Get("https://api.github.com/user/emails") | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got email response from github") | ||||
|  | ||||
| 	body, err = io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read email body from github") | ||||
|  | ||||
| 	var emails GithubEmailResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &emails) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed emails from github") | ||||
|  | ||||
| 	// Find and return the primary email | ||||
| 	for _, email := range emails { | ||||
| 		if email.Primary { | ||||
| 			log.Debug().Str("email", email.Email).Msg("Found primary email") | ||||
| 			user.Email = email.Email | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(emails) == 0 { | ||||
| 		return user, errors.New("no emails found") | ||||
| 	} | ||||
|  | ||||
| 	// Use first available email if no primary email was found | ||||
| 	if user.Email == "" { | ||||
| 		log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email") | ||||
| 		user.Email = emails[0].Email | ||||
| 	} | ||||
|  | ||||
| 	user.PreferredUsername = userInfo.Login | ||||
| 	user.Name = userInfo.Name | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
| @@ -1,56 +0,0 @@ | ||||
| package providers | ||||
|  | ||||
| import ( | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/constants" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
|  | ||||
| // Response for the google user endpoint | ||||
| type GoogleUserInfoResponse struct { | ||||
| 	Email string `json:"email"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| // The scopes required for the google provider | ||||
| func GoogleScopes() []string { | ||||
| 	return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||
| } | ||||
|  | ||||
| func GetGoogleUser(client *http.Client) (constants.Claims, error) { | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	log.Debug().Msg("Got response from google") | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Read body from google") | ||||
|  | ||||
| 	var userInfo GoogleUserInfoResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	log.Debug().Msg("Parsed user from google") | ||||
|  | ||||
| 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||
| 	user.Name = userInfo.Name | ||||
| 	user.Email = userInfo.Email | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
| @@ -1,154 +0,0 @@ | ||||
| package providers | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"tinyauth/internal/constants" | ||||
| 	"tinyauth/internal/oauth" | ||||
| 	"tinyauth/internal/types" | ||||
|  | ||||
| 	"github.com/rs/zerolog/log" | ||||
| 	"golang.org/x/oauth2" | ||||
| 	"golang.org/x/oauth2/endpoints" | ||||
| ) | ||||
|  | ||||
| type Providers struct { | ||||
| 	Config  types.OAuthConfig | ||||
| 	Github  *oauth.OAuth | ||||
| 	Google  *oauth.OAuth | ||||
| 	Generic *oauth.OAuth | ||||
| } | ||||
|  | ||||
| func NewProviders(config types.OAuthConfig) *Providers { | ||||
| 	providers := &Providers{ | ||||
| 		Config: config, | ||||
| 	} | ||||
|  | ||||
| 	if config.GithubClientId != "" && config.GithubClientSecret != "" { | ||||
| 		log.Info().Msg("Initializing Github OAuth") | ||||
| 		providers.Github = oauth.NewOAuth(oauth2.Config{ | ||||
| 			ClientID:     config.GithubClientId, | ||||
| 			ClientSecret: config.GithubClientSecret, | ||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL), | ||||
| 			Scopes:       GithubScopes(), | ||||
| 			Endpoint:     endpoints.GitHub, | ||||
| 		}, false) | ||||
| 	} | ||||
|  | ||||
| 	if config.GoogleClientId != "" && config.GoogleClientSecret != "" { | ||||
| 		log.Info().Msg("Initializing Google OAuth") | ||||
| 		providers.Google = oauth.NewOAuth(oauth2.Config{ | ||||
| 			ClientID:     config.GoogleClientId, | ||||
| 			ClientSecret: config.GoogleClientSecret, | ||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL), | ||||
| 			Scopes:       GoogleScopes(), | ||||
| 			Endpoint:     endpoints.Google, | ||||
| 		}, false) | ||||
| 	} | ||||
|  | ||||
| 	if config.GenericClientId != "" && config.GenericClientSecret != "" { | ||||
| 		log.Info().Msg("Initializing Generic OAuth") | ||||
| 		providers.Generic = oauth.NewOAuth(oauth2.Config{ | ||||
| 			ClientID:     config.GenericClientId, | ||||
| 			ClientSecret: config.GenericClientSecret, | ||||
| 			RedirectURL:  fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL), | ||||
| 			Scopes:       config.GenericScopes, | ||||
| 			Endpoint: oauth2.Endpoint{ | ||||
| 				AuthURL:  config.GenericAuthURL, | ||||
| 				TokenURL: config.GenericTokenURL, | ||||
| 			}, | ||||
| 		}, config.GenericSkipSSL) | ||||
| 	} | ||||
|  | ||||
| 	return providers | ||||
| } | ||||
|  | ||||
| func (providers *Providers) GetProvider(provider string) *oauth.OAuth { | ||||
| 	switch provider { | ||||
| 	case "github": | ||||
| 		return providers.Github | ||||
| 	case "google": | ||||
| 		return providers.Google | ||||
| 	case "generic": | ||||
| 		return providers.Generic | ||||
| 	default: | ||||
| 		return nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (providers *Providers) GetUser(provider string) (constants.Claims, error) { | ||||
| 	var user constants.Claims | ||||
|  | ||||
| 	// Get the user from the provider | ||||
| 	switch provider { | ||||
| 	case "github": | ||||
| 		if providers.Github == nil { | ||||
| 			log.Debug().Msg("Github provider not configured") | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		client := providers.Github.GetClient() | ||||
|  | ||||
| 		log.Debug().Msg("Got client from github") | ||||
|  | ||||
| 		user, err := GetGithubUser(client) | ||||
| 		if err != nil { | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got user from github") | ||||
|  | ||||
| 		return user, nil | ||||
| 	case "google": | ||||
| 		if providers.Google == nil { | ||||
| 			log.Debug().Msg("Google provider not configured") | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		client := providers.Google.GetClient() | ||||
|  | ||||
| 		log.Debug().Msg("Got client from google") | ||||
|  | ||||
| 		user, err := GetGoogleUser(client) | ||||
| 		if err != nil { | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got user from google") | ||||
|  | ||||
| 		return user, nil | ||||
| 	case "generic": | ||||
| 		if providers.Generic == nil { | ||||
| 			log.Debug().Msg("Generic provider not configured") | ||||
| 			return user, nil | ||||
| 		} | ||||
|  | ||||
| 		client := providers.Generic.GetClient() | ||||
|  | ||||
| 		log.Debug().Msg("Got client from generic") | ||||
|  | ||||
| 		user, err := GetGenericUser(client, providers.Config.GenericUserURL) | ||||
| 		if err != nil { | ||||
| 			return user, err | ||||
| 		} | ||||
|  | ||||
| 		log.Debug().Msg("Got user from generic") | ||||
|  | ||||
| 		return user, nil | ||||
| 	default: | ||||
| 		return user, nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (provider *Providers) GetConfiguredProviders() []string { | ||||
| 	providers := []string{} | ||||
| 	if provider.Github != nil { | ||||
| 		providers = append(providers, "github") | ||||
| 	} | ||||
| 	if provider.Google != nil { | ||||
| 		providers = append(providers, "google") | ||||
| 	} | ||||
| 	if provider.Generic != nil { | ||||
| 		providers = append(providers, "generic") | ||||
| 	} | ||||
| 	return providers | ||||
| } | ||||
| @@ -1,4 +1,4 @@ | ||||
| package auth | ||||
| package service | ||||
| 
 | ||||
| import ( | ||||
| 	"fmt" | ||||
| @@ -6,8 +6,6 @@ import ( | ||||
| 	"strings" | ||||
| 	"sync" | ||||
| 	"time" | ||||
| 	"tinyauth/internal/docker" | ||||
| 	"tinyauth/internal/ldap" | ||||
| 	"tinyauth/internal/types" | ||||
| 	"tinyauth/internal/utils" | ||||
| 
 | ||||
| @@ -17,35 +15,50 @@ import ( | ||||
| 	"golang.org/x/crypto/bcrypt" | ||||
| ) | ||||
| 
 | ||||
| type Auth struct { | ||||
| 	Config        types.AuthConfig | ||||
| 	Docker        *docker.Docker | ||||
| type AuthServiceConfig struct { | ||||
| 	Users             types.Users | ||||
| 	OauthWhitelist    string | ||||
| 	SessionExpiry     int | ||||
| 	CookieSecure      bool | ||||
| 	Domain            string | ||||
| 	LoginTimeout      int | ||||
| 	LoginMaxRetries   int | ||||
| 	SessionCookieName string | ||||
| 	HMACSecret        string | ||||
| 	EncryptionSecret  string | ||||
| } | ||||
| 
 | ||||
| type AuthService struct { | ||||
| 	Config        AuthServiceConfig | ||||
| 	Docker        *DockerService | ||||
| 	LoginAttempts map[string]*types.LoginAttempt | ||||
| 	LoginMutex    sync.RWMutex | ||||
| 	Store         *sessions.CookieStore | ||||
| 	LDAP          *ldap.LDAP | ||||
| 	LDAP          *LdapService | ||||
| } | ||||
| 
 | ||||
| func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth { | ||||
| 	// Setup cookie store and create the auth service | ||||
| 	store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret)) | ||||
| 	store.Options = &sessions.Options{ | ||||
| 		Path:     "/", | ||||
| 		MaxAge:   config.SessionExpiry, | ||||
| 		Secure:   config.CookieSecure, | ||||
| 		HttpOnly: true, | ||||
| 		Domain:   fmt.Sprintf(".%s", config.Domain), | ||||
| 	} | ||||
| 	return &Auth{ | ||||
| func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService { | ||||
| 	return &AuthService{ | ||||
| 		Config:        config, | ||||
| 		Docker:        docker, | ||||
| 		LoginAttempts: make(map[string]*types.LoginAttempt), | ||||
| 		Store:         store, | ||||
| 		LDAP:          ldap, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { | ||||
| func (auth *AuthService) Init() error { | ||||
| 	store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) | ||||
| 	store.Options = &sessions.Options{ | ||||
| 		Path:     "/", | ||||
| 		MaxAge:   auth.Config.SessionExpiry, | ||||
| 		Secure:   auth.Config.CookieSecure, | ||||
| 		HttpOnly: true, | ||||
| 		Domain:   fmt.Sprintf(".%s", auth.Config.Domain), | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) { | ||||
| 	session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName) | ||||
| 
 | ||||
| 	// If there was an error getting the session, it might be invalid so let's clear it and retry | ||||
| @@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { | ||||
| 	return session, nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) SearchUser(username string) types.UserSearch { | ||||
| func (auth *AuthService) SearchUser(username string) types.UserSearch { | ||||
| 	log.Debug().Str("username", username).Msg("Searching for user") | ||||
| 
 | ||||
| 	// Check local users first | ||||
| @@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | ||||
| func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool { | ||||
| 	// Authenticate the user based on the type | ||||
| 	switch search.Type { | ||||
| 	case "local": | ||||
| @@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool { | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) GetLocalUser(username string) types.User { | ||||
| func (auth *AuthService) GetLocalUser(username string) types.User { | ||||
| 	// Loop through users and return the user if the username matches | ||||
| 	log.Debug().Str("username", username).Msg("Searching for local user") | ||||
| 
 | ||||
| @@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User { | ||||
| 	return types.User{} | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) CheckPassword(user types.User, password string) bool { | ||||
| func (auth *AuthService) CheckPassword(user types.User, password string) bool { | ||||
| 	return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { | ||||
| func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { | ||||
| 	auth.LoginMutex.RLock() | ||||
| 	defer auth.LoginMutex.RUnlock() | ||||
| 
 | ||||
| @@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) { | ||||
| 	return false, 0 | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { | ||||
| func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) { | ||||
| 	// Skip if rate limiting is not configured | ||||
| 	if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 { | ||||
| 		return | ||||
| @@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) EmailWhitelisted(email string) bool { | ||||
| func (auth *AuthService) EmailWhitelisted(email string) bool { | ||||
| 	return utils.CheckFilter(auth.Config.OauthWhitelist, email) | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { | ||||
| func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error { | ||||
| 	log.Debug().Msg("Creating session cookie") | ||||
| 
 | ||||
| 	session, err := auth.GetSession(c) | ||||
| @@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { | ||||
| func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error { | ||||
| 	log.Debug().Msg("Deleting session cookie") | ||||
| 
 | ||||
| 	session, err := auth.GetSession(c) | ||||
| @@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { | ||||
| func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) { | ||||
| 	log.Debug().Msg("Getting session cookie") | ||||
| 
 | ||||
| 	session, err := auth.GetSession(c) | ||||
| @@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) | ||||
| 	}, nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) UserAuthConfigured() bool { | ||||
| func (auth *AuthService) UserAuthConfigured() bool { | ||||
| 	// If there are users or LDAP is configured, return true | ||||
| 	return len(auth.Config.Users) > 0 || auth.LDAP != nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||
| func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||
| 	if context.OAuth { | ||||
| 		log.Debug().Msg("Checking OAuth whitelist") | ||||
| 		return utils.CheckFilter(labels.OAuth.Whitelist, context.Email) | ||||
| @@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab | ||||
| 	return utils.CheckFilter(labels.Users, context.Username) | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||
| func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool { | ||||
| 	if labels.OAuth.Groups == "" { | ||||
| 		return true | ||||
| 	} | ||||
| @@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t | ||||
| 	return false | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { | ||||
| func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) { | ||||
| 	// If the label is empty, auth is enabled | ||||
| 	if labels.Allowed == "" { | ||||
| 		return true, nil | ||||
| @@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) { | ||||
| 	return true, nil | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { | ||||
| func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User { | ||||
| 	username, password, ok := c.Request.BasicAuth() | ||||
| 	if !ok { | ||||
| 		return nil | ||||
| @@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User { | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | ||||
| func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool { | ||||
| 	// Check if the IP is in block list | ||||
| 	for _, blocked := range labels.IP.Block { | ||||
| 		res, err := utils.FilterIP(blocked, ip) | ||||
| @@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool { | ||||
| 	return true | ||||
| } | ||||
| 
 | ||||
| func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool { | ||||
| func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool { | ||||
| 	// For every IP in the bypass list, check if the IP matches | ||||
| 	for _, bypassed := range labels.IP.Bypass { | ||||
| 		res, err := utils.FilterIP(bypassed, ip) | ||||
| @@ -1,9 +1,9 @@ | ||||
| package docker | ||||
| package service | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/types" | ||||
| 	"tinyauth/internal/config" | ||||
| 	"tinyauth/internal/utils" | ||||
| 
 | ||||
| 	container "github.com/docker/docker/api/types/container" | ||||
| @@ -11,27 +11,27 @@ import ( | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
| 
 | ||||
| type Docker struct { | ||||
| type DockerService struct { | ||||
| 	Client  *client.Client | ||||
| 	Context context.Context | ||||
| } | ||||
| 
 | ||||
| func NewDocker() (*Docker, error) { | ||||
| func NewDockerService() *DockerService { | ||||
| 	return &DockerService{} | ||||
| } | ||||
| 
 | ||||
| func (docker *DockerService) Init() error { | ||||
| 	client, err := client.NewClientWithOpts(client.FromEnv) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 		return err | ||||
| 	} | ||||
| 
 | ||||
| 	ctx := context.Background() | ||||
| 	client.NegotiateAPIVersion(ctx) | ||||
| 
 | ||||
| 	return &Docker{ | ||||
| 		Client:  client, | ||||
| 		Context: ctx, | ||||
| 	}, nil | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (docker *Docker) GetContainers() ([]container.Summary, error) { | ||||
| func (docker *DockerService) GetContainers() ([]container.Summary, error) { | ||||
| 	containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{}) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| @@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) { | ||||
| 	return containers, nil | ||||
| } | ||||
| 
 | ||||
| func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||
| func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) { | ||||
| 	inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) | ||||
| 	if err != nil { | ||||
| 		return container.InspectResponse{}, err | ||||
| @@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes | ||||
| 	return inspect, nil | ||||
| } | ||||
| 
 | ||||
| func (docker *Docker) DockerConnected() bool { | ||||
| func (docker *DockerService) DockerConnected() bool { | ||||
| 	_, err := docker.Client.Ping(docker.Context) | ||||
| 	return err == nil | ||||
| } | ||||
| 
 | ||||
| func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) { | ||||
| func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) { | ||||
| 	isConnected := docker.DockerConnected() | ||||
| 
 | ||||
| 	if !isConnected { | ||||
| 		log.Debug().Msg("Docker not connected, returning empty labels") | ||||
| 		return types.Labels{}, nil | ||||
| 		return config.Labels{}, nil | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug().Msg("Getting containers") | ||||
| @@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) | ||||
| 	containers, err := docker.GetContainers() | ||||
| 	if err != nil { | ||||
| 		log.Error().Err(err).Msg("Error getting containers") | ||||
| 		return types.Labels{}, err | ||||
| 		return config.Labels{}, err | ||||
| 	} | ||||
| 
 | ||||
| 	for _, container := range containers { | ||||
| @@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug().Msg("No matching container found, returning empty labels") | ||||
| 	return types.Labels{}, nil | ||||
| 	return config.Labels{}, nil | ||||
| } | ||||
							
								
								
									
										114
									
								
								internal/service/generic_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										114
									
								
								internal/service/generic_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,114 @@ | ||||
| package service | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"crypto/tls" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/config" | ||||
|  | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
|  | ||||
| type GenericOAuthService struct { | ||||
| 	Config             oauth2.Config | ||||
| 	Context            context.Context | ||||
| 	Token              *oauth2.Token | ||||
| 	Verifier           string | ||||
| 	InsecureSkipVerify bool | ||||
| 	ServiceName        string | ||||
| 	UserinfoURL        string | ||||
| } | ||||
|  | ||||
| func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { | ||||
| 	return &GenericOAuthService{ | ||||
| 		Config: oauth2.Config{ | ||||
| 			ClientID:     config.ClientID, | ||||
| 			ClientSecret: config.ClientSecret, | ||||
| 			RedirectURL:  config.RedirectURL, | ||||
| 			Scopes:       config.Scopes, | ||||
| 			Endpoint: oauth2.Endpoint{ | ||||
| 				AuthURL:  config.AuthURL, | ||||
| 				TokenURL: config.TokenURL, | ||||
| 			}, | ||||
| 		}, | ||||
| 		InsecureSkipVerify: config.InsecureSkipVerify, | ||||
| 		ServiceName:        config.Name, | ||||
| 		UserinfoURL:        config.UserinfoURL, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) Init() error { | ||||
| 	transport := &http.Transport{ | ||||
| 		TLSClientConfig: &tls.Config{ | ||||
| 			InsecureSkipVerify: generic.InsecureSkipVerify, | ||||
| 			MinVersion:         tls.VersionTLS12, | ||||
| 		}, | ||||
| 	} | ||||
|  | ||||
| 	httpClient := &http.Client{ | ||||
| 		Transport: transport, | ||||
| 	} | ||||
|  | ||||
| 	ctx := context.Background() | ||||
|  | ||||
| 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||
| 	verifier := oauth2.GenerateVerifier() | ||||
|  | ||||
| 	generic.Context = ctx | ||||
| 	generic.Verifier = verifier | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) Name() string { | ||||
| 	return generic.ServiceName | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) GenerateState() string { | ||||
| 	b := make([]byte, 128) | ||||
| 	rand.Read(b) | ||||
| 	state := base64.URLEncoding.EncodeToString(b) | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) GetAuthURL(state string) string { | ||||
| 	return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier)) | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) VerifyCode(code string) error { | ||||
| 	token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	generic.Token = token | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { | ||||
| 	var user config.Claims | ||||
|  | ||||
| 	client := generic.Config.Client(generic.Context, generic.Token) | ||||
|  | ||||
| 	res, err := client.Get(generic.UserinfoURL) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	err = json.Unmarshal(body, &user) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
							
								
								
									
										144
									
								
								internal/service/github_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										144
									
								
								internal/service/github_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,144 @@ | ||||
| package service | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"errors" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"tinyauth/internal/config" | ||||
|  | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
|  | ||||
| var GithubOAuthScopes = []string{"user:email", "read:user"} | ||||
|  | ||||
| type GithubEmailResponse []struct { | ||||
| 	Email   string `json:"email"` | ||||
| 	Primary bool   `json:"primary"` | ||||
| } | ||||
|  | ||||
| type GithubUserInfoResponse struct { | ||||
| 	Login string `json:"login"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| type GithubOAuthService struct { | ||||
| 	Config   oauth2.Config | ||||
| 	Context  context.Context | ||||
| 	Token    *oauth2.Token | ||||
| 	Verifier string | ||||
| } | ||||
|  | ||||
| func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { | ||||
| 	return &GithubOAuthService{ | ||||
| 		Config: oauth2.Config{ | ||||
| 			ClientID:     config.ClientID, | ||||
| 			ClientSecret: config.ClientSecret, | ||||
| 			RedirectURL:  config.RedirectURL, | ||||
| 			Scopes:       GithubOAuthScopes, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) Init() error { | ||||
| 	httpClient := &http.Client{} | ||||
| 	ctx := context.Background() | ||||
| 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||
| 	verifier := oauth2.GenerateVerifier() | ||||
|  | ||||
| 	github.Context = ctx | ||||
| 	github.Verifier = verifier | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) Name() string { | ||||
| 	return "github" | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) GenerateState() string { | ||||
| 	b := make([]byte, 128) | ||||
| 	rand.Read(b) | ||||
| 	state := base64.URLEncoding.EncodeToString(b) | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) GetAuthURL(state string) string { | ||||
| 	return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier)) | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) VerifyCode(code string) error { | ||||
| 	token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	github.Token = token | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (github *GithubOAuthService) Userinfo() (config.Claims, error) { | ||||
| 	var user config.Claims | ||||
|  | ||||
| 	client := github.Config.Client(github.Context, github.Token) | ||||
|  | ||||
| 	res, err := client.Get("https://api.github.com/user") | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	var userInfo GithubUserInfoResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	res, err = client.Get("https://api.github.com/user/emails") | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	body, err = io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	var emails GithubEmailResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &emails) | ||||
| 	if err != nil { | ||||
| 		return user, err | ||||
| 	} | ||||
|  | ||||
| 	for _, email := range emails { | ||||
| 		if email.Primary { | ||||
| 			user.Email = email.Email | ||||
| 			break | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	if len(emails) == 0 { | ||||
| 		return user, errors.New("no emails found") | ||||
| 	} | ||||
|  | ||||
| 	// Use first available email if no primary email was found | ||||
| 	if user.Email == "" { | ||||
| 		user.Email = emails[0].Email | ||||
| 	} | ||||
|  | ||||
| 	user.PreferredUsername = userInfo.Login | ||||
| 	user.Name = userInfo.Name | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
							
								
								
									
										106
									
								
								internal/service/google_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								internal/service/google_oauth_service.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,106 @@ | ||||
| package service | ||||
|  | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/rand" | ||||
| 	"encoding/base64" | ||||
| 	"encoding/json" | ||||
| 	"io" | ||||
| 	"net/http" | ||||
| 	"strings" | ||||
| 	"tinyauth/internal/config" | ||||
|  | ||||
| 	"golang.org/x/oauth2" | ||||
| ) | ||||
|  | ||||
| var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"} | ||||
|  | ||||
| type GoogleUserInfoResponse struct { | ||||
| 	Email string `json:"email"` | ||||
| 	Name  string `json:"name"` | ||||
| } | ||||
|  | ||||
| type GoogleOAuthService struct { | ||||
| 	Config   oauth2.Config | ||||
| 	Context  context.Context | ||||
| 	Token    *oauth2.Token | ||||
| 	Verifier string | ||||
| } | ||||
|  | ||||
| func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { | ||||
| 	return &GoogleOAuthService{ | ||||
| 		Config: oauth2.Config{ | ||||
| 			ClientID:     config.ClientID, | ||||
| 			ClientSecret: config.ClientSecret, | ||||
| 			RedirectURL:  config.RedirectURL, | ||||
| 			Scopes:       GoogleOAuthScopes, | ||||
| 		}, | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) Init() error { | ||||
| 	httpClient := &http.Client{} | ||||
| 	ctx := context.Background() | ||||
| 	ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) | ||||
| 	verifier := oauth2.GenerateVerifier() | ||||
|  | ||||
| 	google.Context = ctx | ||||
| 	google.Verifier = verifier | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) Name() string { | ||||
| 	return "google" | ||||
| } | ||||
|  | ||||
| func (oauth *GoogleOAuthService) GenerateState() string { | ||||
| 	b := make([]byte, 128) | ||||
| 	rand.Read(b) | ||||
| 	state := base64.URLEncoding.EncodeToString(b) | ||||
| 	return state | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) GetAuthURL(state string) string { | ||||
| 	return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier)) | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) VerifyCode(code string) error { | ||||
| 	token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier)) | ||||
|  | ||||
| 	if err != nil { | ||||
| 		return nil | ||||
| 	} | ||||
|  | ||||
| 	google.Token = token | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { | ||||
| 	var user config.Claims | ||||
|  | ||||
| 	client := google.Config.Client(google.Context, google.Token) | ||||
|  | ||||
| 	res, err := client.Get("https://www.googleapis.com/userinfo/v2/me") | ||||
| 	if err != nil { | ||||
| 		return config.Claims{}, err | ||||
| 	} | ||||
| 	defer res.Body.Close() | ||||
|  | ||||
| 	body, err := io.ReadAll(res.Body) | ||||
| 	if err != nil { | ||||
| 		return config.Claims{}, err | ||||
| 	} | ||||
|  | ||||
| 	var userInfo GoogleUserInfoResponse | ||||
|  | ||||
| 	err = json.Unmarshal(body, &userInfo) | ||||
| 	if err != nil { | ||||
| 		return config.Claims{}, err | ||||
| 	} | ||||
|  | ||||
| 	user.PreferredUsername = strings.Split(userInfo.Email, "@")[0] | ||||
| 	user.Name = userInfo.Name | ||||
| 	user.Email = userInfo.Email | ||||
|  | ||||
| 	return user, nil | ||||
| } | ||||
| @@ -1,30 +1,40 @@ | ||||
| package ldap | ||||
| package service | ||||
| 
 | ||||
| import ( | ||||
| 	"context" | ||||
| 	"crypto/tls" | ||||
| 	"fmt" | ||||
| 	"time" | ||||
| 	"tinyauth/internal/types" | ||||
| 
 | ||||
| 	"github.com/cenkalti/backoff/v5" | ||||
| 	ldapgo "github.com/go-ldap/ldap/v3" | ||||
| 	"github.com/rs/zerolog/log" | ||||
| ) | ||||
| 
 | ||||
| type LDAP struct { | ||||
| 	Config types.LdapConfig | ||||
| type LdapServiceConfig struct { | ||||
| 	Address      string | ||||
| 	BindDN       string | ||||
| 	BindPassword string | ||||
| 	BaseDN       string | ||||
| 	Insecure     bool | ||||
| 	SearchFilter string | ||||
| } | ||||
| 
 | ||||
| type LdapService struct { | ||||
| 	Config LdapServiceConfig | ||||
| 	Conn   *ldapgo.Conn | ||||
| } | ||||
| 
 | ||||
| func NewLDAP(config types.LdapConfig) (*LDAP, error) { | ||||
| 	ldap := &LDAP{ | ||||
| func NewLdapService(config LdapServiceConfig) *LdapService { | ||||
| 	return &LdapService{ | ||||
| 		Config: config, | ||||
| 	} | ||||
| } | ||||
| 
 | ||||
| func (ldap *LdapService) Init() error { | ||||
| 	_, err := ldap.connect() | ||||
| 	if err != nil { | ||||
| 		return nil, fmt.Errorf("failed to connect to LDAP server: %w", err) | ||||
| 		return fmt.Errorf("failed to connect to LDAP server: %w", err) | ||||
| 	} | ||||
| 
 | ||||
| 	go func() { | ||||
| @@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) { | ||||
| 		} | ||||
| 	}() | ||||
| 
 | ||||
| 	return ldap, nil | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (l *LDAP) connect() (*ldapgo.Conn, error) { | ||||
| func (ldap *LdapService) connect() (*ldapgo.Conn, error) { | ||||
| 	log.Debug().Msg("Connecting to LDAP server") | ||||
| 	conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | ||||
| 		InsecureSkipVerify: l.Config.Insecure, | ||||
| 	conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{ | ||||
| 		InsecureSkipVerify: ldap.Config.Insecure, | ||||
| 		MinVersion:         tls.VersionTLS12, | ||||
| 	})) | ||||
| 	if err != nil { | ||||
| @@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) { | ||||
| 	} | ||||
| 
 | ||||
| 	log.Debug().Msg("Binding to LDAP server") | ||||
| 	err = conn.Bind(l.Config.BindDN, l.Config.BindPassword) | ||||
| 	err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword) | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 
 | ||||
| 	// Set and return the connection | ||||
| 	l.Conn = conn | ||||
| 	ldap.Conn = conn | ||||
| 	return conn, nil | ||||
| } | ||||
| 
 | ||||
| func (l *LDAP) Search(username string) (string, error) { | ||||
| func (ldap *LdapService) Search(username string) (string, error) { | ||||
| 	// Escape the username to prevent LDAP injection | ||||
| 	escapedUsername := ldapgo.EscapeFilter(username) | ||||
| 	filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername) | ||||
| 	filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername) | ||||
| 
 | ||||
| 	searchRequest := ldapgo.NewSearchRequest( | ||||
| 		l.Config.BaseDN, | ||||
| 		ldap.Config.BaseDN, | ||||
| 		ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false, | ||||
| 		filter, | ||||
| 		[]string{"dn"}, | ||||
| 		nil, | ||||
| 	) | ||||
| 
 | ||||
| 	searchResult, err := l.Conn.Search(searchRequest) | ||||
| 	searchResult, err := ldap.Conn.Search(searchRequest) | ||||
| 	if err != nil { | ||||
| 		return "", err | ||||
| 	} | ||||
| @@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) { | ||||
| 	return userDN, nil | ||||
| } | ||||
| 
 | ||||
| func (l *LDAP) Bind(userDN string, password string) error { | ||||
| 	err := l.Conn.Bind(userDN, password) | ||||
| func (ldap *LdapService) Bind(userDN string, password string) error { | ||||
| 	err := ldap.Conn.Bind(userDN, password) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (l *LDAP) heartbeat() error { | ||||
| func (ldap *LdapService) heartbeat() error { | ||||
| 	log.Debug().Msg("Performing LDAP connection heartbeat") | ||||
| 
 | ||||
| 	searchRequest := ldapgo.NewSearchRequest( | ||||
| @@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error { | ||||
| 		nil, | ||||
| 	) | ||||
| 
 | ||||
| 	_, err := l.Conn.Search(searchRequest) | ||||
| 	_, err := ldap.Conn.Search(searchRequest) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| @@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error { | ||||
| 	return nil | ||||
| } | ||||
| 
 | ||||
| func (l *LDAP) reconnect() error { | ||||
| func (ldap *LdapService) reconnect() error { | ||||
| 	log.Info().Msg("Reconnecting to LDAP server") | ||||
| 
 | ||||
| 	exp := backoff.NewExponentialBackOff() | ||||
| @@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error { | ||||
| 	exp.Reset() | ||||
| 
 | ||||
| 	operation := func() (*ldapgo.Conn, error) { | ||||
| 		l.Conn.Close() | ||||
| 		conn, err := l.connect() | ||||
| 		ldap.Conn.Close() | ||||
| 		conn, err := ldap.connect() | ||||
| 		if err != nil { | ||||
| 			return nil, nil | ||||
| 		} | ||||
| @@ -1,70 +0,0 @@ | ||||
| package types | ||||
|  | ||||
| import ( | ||||
| 	"time" | ||||
| 	"tinyauth/internal/oauth" | ||||
| ) | ||||
|  | ||||
| // User is the struct for a user | ||||
| type User struct { | ||||
| 	Username   string | ||||
| 	Password   string | ||||
| 	TotpSecret string | ||||
| } | ||||
|  | ||||
| // UserSearch is the response of the get user | ||||
| type UserSearch struct { | ||||
| 	Username string | ||||
| 	Type     string // "local", "ldap" or empty | ||||
| } | ||||
|  | ||||
| // Users is a list of users | ||||
| type Users []User | ||||
|  | ||||
| // OAuthProviders is the struct for the OAuth providers | ||||
| type OAuthProviders struct { | ||||
| 	Github    *oauth.OAuth | ||||
| 	Google    *oauth.OAuth | ||||
| 	Microsoft *oauth.OAuth | ||||
| } | ||||
|  | ||||
| // SessionCookie is the cookie for the session (exculding the expiry) | ||||
| type SessionCookie struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| } | ||||
|  | ||||
| // UserContext is the context for the user | ||||
| type UserContext struct { | ||||
| 	Username    string | ||||
| 	Name        string | ||||
| 	Email       string | ||||
| 	IsLoggedIn  bool | ||||
| 	OAuth       bool | ||||
| 	Provider    string | ||||
| 	TotpPending bool | ||||
| 	OAuthGroups string | ||||
| 	TotpEnabled bool | ||||
| } | ||||
|  | ||||
| // LoginAttempt tracks information about login attempts for rate limiting | ||||
| type LoginAttempt struct { | ||||
| 	FailedAttempts int | ||||
| 	LastAttempt    time.Time | ||||
| 	LockedUntil    time.Time | ||||
| } | ||||
|  | ||||
| type UnauthorizedQuery struct { | ||||
| 	Username string `url:"username"` | ||||
| 	Resource string `url:"resource"` | ||||
| 	GroupErr bool   `url:"groupErr"` | ||||
| 	IP       string `url:"ip"` | ||||
| } | ||||
|  | ||||
| type RedirectQuery struct { | ||||
| 	RedirectURI string `url:"redirect_uri"` | ||||
| } | ||||
		Reference in New Issue
	
	Block a user
	 Stavros
					Stavros