mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-29 05:05:42 +00:00
refactor: move oauth providers into services (non-working)
This commit is contained in:
@@ -1,146 +0,0 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
"tinyauth/internal/auth"
|
||||
"tinyauth/internal/types"
|
||||
)
|
||||
|
||||
var config = types.AuthConfig{
|
||||
Users: types.Users{},
|
||||
OauthWhitelist: "",
|
||||
SessionExpiry: 3600,
|
||||
}
|
||||
|
||||
func TestLoginRateLimiting(t *testing.T) {
|
||||
// Initialize a new auth service with 3 max retries and 5 seconds timeout
|
||||
config.LoginMaxRetries = 3
|
||||
config.LoginTimeout = 5
|
||||
authService := auth.NewAuth(config, nil, nil)
|
||||
|
||||
// Test identifier
|
||||
identifier := "test_user"
|
||||
|
||||
// Test successful login - should not lock account
|
||||
t.Log("Testing successful login")
|
||||
|
||||
authService.RecordLoginAttempt(identifier, true)
|
||||
locked, _ := authService.IsAccountLocked(identifier)
|
||||
|
||||
if locked {
|
||||
t.Fatalf("Account should not be locked after successful login")
|
||||
}
|
||||
|
||||
// Test 2 failed attempts - should not lock account yet
|
||||
t.Log("Testing 2 failed login attempts")
|
||||
|
||||
authService.RecordLoginAttempt(identifier, false)
|
||||
authService.RecordLoginAttempt(identifier, false)
|
||||
locked, _ = authService.IsAccountLocked(identifier)
|
||||
|
||||
if locked {
|
||||
t.Fatalf("Account should not be locked after only 2 failed attempts")
|
||||
}
|
||||
|
||||
// Add one more failed attempt (total 3) - should lock account with maxRetries=3
|
||||
t.Log("Testing 3 failed login attempts")
|
||||
authService.RecordLoginAttempt(identifier, false)
|
||||
locked, remainingTime := authService.IsAccountLocked(identifier)
|
||||
|
||||
if !locked {
|
||||
t.Fatalf("Account should be locked after reaching max retries")
|
||||
}
|
||||
if remainingTime <= 0 || remainingTime > 5 {
|
||||
t.Fatalf("Expected remaining time between 1-5 seconds, got %d", remainingTime)
|
||||
}
|
||||
|
||||
// Test reset after waiting for timeout - use 1 second timeout for fast testing
|
||||
t.Log("Testing unlocking after timeout")
|
||||
|
||||
// Reinitialize auth service with a shorter timeout for testing
|
||||
config.LoginTimeout = 1
|
||||
config.LoginMaxRetries = 3
|
||||
authService = auth.NewAuth(config, nil, nil)
|
||||
|
||||
// Add enough failed attempts to lock the account
|
||||
for i := 0; i < 3; i++ {
|
||||
authService.RecordLoginAttempt(identifier, false)
|
||||
}
|
||||
|
||||
// Verify it's locked
|
||||
locked, _ = authService.IsAccountLocked(identifier)
|
||||
if !locked {
|
||||
t.Fatalf("Account should be locked initially")
|
||||
}
|
||||
|
||||
// Wait a bit and verify it gets unlocked after timeout
|
||||
time.Sleep(1500 * time.Millisecond) // Wait longer than the timeout
|
||||
locked, _ = authService.IsAccountLocked(identifier)
|
||||
|
||||
if locked {
|
||||
t.Fatalf("Account should be unlocked after timeout period")
|
||||
}
|
||||
|
||||
// Test disabled rate limiting
|
||||
t.Log("Testing disabled rate limiting")
|
||||
config.LoginMaxRetries = 0
|
||||
config.LoginTimeout = 0
|
||||
authService = auth.NewAuth(config, nil, nil)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
authService.RecordLoginAttempt(identifier, false)
|
||||
}
|
||||
|
||||
locked, _ = authService.IsAccountLocked(identifier)
|
||||
if locked {
|
||||
t.Fatalf("Account should not be locked when rate limiting is disabled")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConcurrentLoginAttempts(t *testing.T) {
|
||||
// Initialize a new auth service with 2 max retries and 5 seconds timeout
|
||||
config.LoginMaxRetries = 2
|
||||
config.LoginTimeout = 5
|
||||
authService := auth.NewAuth(config, nil, nil)
|
||||
|
||||
// Test multiple identifiers
|
||||
identifiers := []string{"user1", "user2", "user3"}
|
||||
|
||||
// Test that locking one identifier doesn't affect others
|
||||
t.Log("Testing multiple identifiers")
|
||||
|
||||
// Add enough failed attempts to lock first user (2 attempts with maxRetries=2)
|
||||
authService.RecordLoginAttempt(identifiers[0], false)
|
||||
authService.RecordLoginAttempt(identifiers[0], false)
|
||||
|
||||
// Check if first user is locked
|
||||
locked, _ := authService.IsAccountLocked(identifiers[0])
|
||||
if !locked {
|
||||
t.Fatalf("User1 should be locked after reaching max retries")
|
||||
}
|
||||
|
||||
// Check that other users are not affected
|
||||
for i := 1; i < len(identifiers); i++ {
|
||||
locked, _ := authService.IsAccountLocked(identifiers[i])
|
||||
if locked {
|
||||
t.Fatalf("User%d should not be locked", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
// Test successful login after failed attempts (but before lock)
|
||||
t.Log("Testing successful login after failed attempts but before lock")
|
||||
|
||||
// One failed attempt for user2
|
||||
authService.RecordLoginAttempt(identifiers[1], false)
|
||||
|
||||
// Successful login should reset the counter
|
||||
authService.RecordLoginAttempt(identifiers[1], true)
|
||||
|
||||
// Now try a failed login again - should not be locked as counter was reset
|
||||
authService.RecordLoginAttempt(identifiers[1], false)
|
||||
locked, _ = authService.IsAccountLocked(identifiers[1])
|
||||
if locked {
|
||||
t.Fatalf("User2 should not be locked after successful login reset")
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,22 @@
|
||||
package types
|
||||
package config
|
||||
|
||||
import "time"
|
||||
|
||||
type Claims struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Groups any `json:"groups"`
|
||||
}
|
||||
|
||||
var Version = "development"
|
||||
var CommitHash = "n/a"
|
||||
var BuildTimestamp = "n/a"
|
||||
|
||||
var SessionCookieName = "tinyauth-session"
|
||||
var CsrfCookieName = "tinyauth-csrf"
|
||||
var RedirectCookieName = "tinyauth-redirect"
|
||||
|
||||
// Config is the configuration for the tinyauth server
|
||||
type Config struct {
|
||||
Port int `mapstructure:"port" validate:"required"`
|
||||
Address string `validate:"required,ip4_addr" mapstructure:"address"`
|
||||
@@ -44,62 +60,27 @@ type Config struct {
|
||||
LdapSearchFilter string `mapstructure:"ldap-search-filter"`
|
||||
}
|
||||
|
||||
// OAuthConfig is the configuration for the providers
|
||||
type OAuthConfig struct {
|
||||
GithubClientId string
|
||||
GithubClientSecret string
|
||||
GoogleClientId string
|
||||
GoogleClientSecret string
|
||||
GenericClientId string
|
||||
GenericClientSecret string
|
||||
GenericScopes []string
|
||||
GenericAuthURL string
|
||||
GenericTokenURL string
|
||||
GenericUserURL string
|
||||
GenericSkipSSL bool
|
||||
AppURL string
|
||||
}
|
||||
|
||||
// AuthConfig is the configuration for the auth service
|
||||
type AuthConfig struct {
|
||||
Users Users
|
||||
OauthWhitelist string
|
||||
SessionExpiry int
|
||||
CookieSecure bool
|
||||
Domain string
|
||||
LoginTimeout int
|
||||
LoginMaxRetries int
|
||||
SessionCookieName string
|
||||
HMACSecret string
|
||||
EncryptionSecret string
|
||||
}
|
||||
|
||||
// OAuthLabels is a list of labels that can be used in a tinyauth protected container
|
||||
type OAuthLabels struct {
|
||||
Whitelist string
|
||||
Groups string
|
||||
}
|
||||
|
||||
// Basic auth labels for a tinyauth protected container
|
||||
type BasicLabels struct {
|
||||
Username string
|
||||
Password PassowrdLabels
|
||||
}
|
||||
|
||||
// PassowrdLabels is a struct that contains the password labels for a tinyauth protected container
|
||||
type PassowrdLabels struct {
|
||||
Plain string
|
||||
File string
|
||||
}
|
||||
|
||||
// IP labels for a tinyauth protected container
|
||||
type IPLabels struct {
|
||||
Allow []string
|
||||
Block []string
|
||||
Bypass []string
|
||||
}
|
||||
|
||||
// Labels is a struct that contains the labels for a tinyauth protected container
|
||||
type Labels struct {
|
||||
Users string
|
||||
Allowed string
|
||||
@@ -110,12 +91,65 @@ type Labels struct {
|
||||
IP IPLabels
|
||||
}
|
||||
|
||||
// Ldap config is a struct that contains the configuration for the LDAP service
|
||||
type LdapConfig struct {
|
||||
Address string
|
||||
BindDN string
|
||||
BindPassword string
|
||||
BaseDN string
|
||||
Insecure bool
|
||||
SearchFilter string
|
||||
type OAuthServiceConfig struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
Scopes []string
|
||||
RedirectURL string
|
||||
AuthURL string
|
||||
TokenURL string
|
||||
UserinfoURL string
|
||||
InsecureSkipVerify bool
|
||||
Name string
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
TotpSecret string
|
||||
}
|
||||
|
||||
type UserSearch struct {
|
||||
Username string
|
||||
Type string // local, ldap or unknown
|
||||
}
|
||||
|
||||
type Users []User
|
||||
|
||||
type SessionCookie struct {
|
||||
Username string
|
||||
Name string
|
||||
Email string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
}
|
||||
|
||||
type UserContext struct {
|
||||
Username string
|
||||
Name string
|
||||
Email string
|
||||
IsLoggedIn bool
|
||||
OAuth bool
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
TotpEnabled bool
|
||||
}
|
||||
|
||||
type LoginAttempt struct {
|
||||
FailedAttempts int
|
||||
LastAttempt time.Time
|
||||
LockedUntil time.Time
|
||||
}
|
||||
|
||||
type UnauthorizedQuery struct {
|
||||
Username string `url:"username"`
|
||||
Resource string `url:"resource"`
|
||||
GroupErr bool `url:"groupErr"`
|
||||
IP string `url:"ip"`
|
||||
}
|
||||
|
||||
type RedirectQuery struct {
|
||||
RedirectURI string `url:"redirect_uri"`
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
package constants
|
||||
|
||||
// Claims are the OIDC supported claims (prefered username is included for convinience)
|
||||
type Claims struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Groups any `json:"groups"`
|
||||
}
|
||||
|
||||
// Version information
|
||||
var Version = "development"
|
||||
var CommitHash = "n/a"
|
||||
var BuildTimestamp = "n/a"
|
||||
|
||||
// Base cookie names
|
||||
var SessionCookieName = "tinyauth-session"
|
||||
var CsrfCookieName = "tinyauth-csrf"
|
||||
var RedirectCookieName = "tinyauth-redirect"
|
||||
@@ -1,71 +0,0 @@
|
||||
package oauth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type OAuth struct {
|
||||
Config oauth2.Config
|
||||
Context context.Context
|
||||
Token *oauth2.Token
|
||||
Verifier string
|
||||
}
|
||||
|
||||
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: insecureSkipVerify,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
// Set the HTTP client in the context
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
|
||||
return &OAuth{
|
||||
Config: config,
|
||||
Context: ctx,
|
||||
Verifier: verifier,
|
||||
}
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GetAuthURL(state string) string {
|
||||
return oauth.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||
}
|
||||
|
||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
||||
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Set and return the token
|
||||
oauth.Token = token
|
||||
return oauth.Token.AccessToken, nil
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GetClient() *http.Client {
|
||||
return oauth.Config.Client(oauth.Context, oauth.Token)
|
||||
}
|
||||
|
||||
func (oauth *OAuth) GenerateState() string {
|
||||
b := make([]byte, 128)
|
||||
rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
return state
|
||||
}
|
||||
@@ -1,37 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/constants"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func GetGenericUser(client *http.Client, url string) (constants.Claims, error) {
|
||||
var user constants.Claims
|
||||
|
||||
res, err := client.Get(url)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
log.Debug().Msg("Got response from generic provider")
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Read body from generic provider")
|
||||
|
||||
err = json.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Parsed user from generic provider")
|
||||
return user, nil
|
||||
}
|
||||
@@ -1,102 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/constants"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Response for the github email endpoint
|
||||
type GithubEmailResponse []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
|
||||
// Response for the github user endpoint
|
||||
type GithubUserInfoResponse struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// The scopes required for the github provider
|
||||
func GithubScopes() []string {
|
||||
return []string{"user:email", "read:user"}
|
||||
}
|
||||
|
||||
func GetGithubUser(client *http.Client) (constants.Claims, error) {
|
||||
var user constants.Claims
|
||||
|
||||
res, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
log.Debug().Msg("Got user response from github")
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Read user body from github")
|
||||
|
||||
var userInfo GithubUserInfoResponse
|
||||
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
res, err = client.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
log.Debug().Msg("Got email response from github")
|
||||
|
||||
body, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Read email body from github")
|
||||
|
||||
var emails GithubEmailResponse
|
||||
|
||||
err = json.Unmarshal(body, &emails)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Parsed emails from github")
|
||||
|
||||
// Find and return the primary email
|
||||
for _, email := range emails {
|
||||
if email.Primary {
|
||||
log.Debug().Str("email", email.Email).Msg("Found primary email")
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(emails) == 0 {
|
||||
return user, errors.New("no emails found")
|
||||
}
|
||||
|
||||
// Use first available email if no primary email was found
|
||||
if user.Email == "" {
|
||||
log.Warn().Str("email", emails[0].Email).Msg("No primary email found, using first email")
|
||||
user.Email = emails[0].Email
|
||||
}
|
||||
|
||||
user.PreferredUsername = userInfo.Login
|
||||
user.Name = userInfo.Name
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"tinyauth/internal/constants"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// Response for the google user endpoint
|
||||
type GoogleUserInfoResponse struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// The scopes required for the google provider
|
||||
func GoogleScopes() []string {
|
||||
return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
|
||||
}
|
||||
|
||||
func GetGoogleUser(client *http.Client) (constants.Claims, error) {
|
||||
var user constants.Claims
|
||||
|
||||
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
log.Debug().Msg("Got response from google")
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Read body from google")
|
||||
|
||||
var userInfo GoogleUserInfoResponse
|
||||
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Parsed user from google")
|
||||
|
||||
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
|
||||
user.Name = userInfo.Name
|
||||
user.Email = userInfo.Email
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -1,154 +0,0 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"tinyauth/internal/constants"
|
||||
"tinyauth/internal/oauth"
|
||||
"tinyauth/internal/types"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/endpoints"
|
||||
)
|
||||
|
||||
type Providers struct {
|
||||
Config types.OAuthConfig
|
||||
Github *oauth.OAuth
|
||||
Google *oauth.OAuth
|
||||
Generic *oauth.OAuth
|
||||
}
|
||||
|
||||
func NewProviders(config types.OAuthConfig) *Providers {
|
||||
providers := &Providers{
|
||||
Config: config,
|
||||
}
|
||||
|
||||
if config.GithubClientId != "" && config.GithubClientSecret != "" {
|
||||
log.Info().Msg("Initializing Github OAuth")
|
||||
providers.Github = oauth.NewOAuth(oauth2.Config{
|
||||
ClientID: config.GithubClientId,
|
||||
ClientSecret: config.GithubClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL),
|
||||
Scopes: GithubScopes(),
|
||||
Endpoint: endpoints.GitHub,
|
||||
}, false)
|
||||
}
|
||||
|
||||
if config.GoogleClientId != "" && config.GoogleClientSecret != "" {
|
||||
log.Info().Msg("Initializing Google OAuth")
|
||||
providers.Google = oauth.NewOAuth(oauth2.Config{
|
||||
ClientID: config.GoogleClientId,
|
||||
ClientSecret: config.GoogleClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL),
|
||||
Scopes: GoogleScopes(),
|
||||
Endpoint: endpoints.Google,
|
||||
}, false)
|
||||
}
|
||||
|
||||
if config.GenericClientId != "" && config.GenericClientSecret != "" {
|
||||
log.Info().Msg("Initializing Generic OAuth")
|
||||
providers.Generic = oauth.NewOAuth(oauth2.Config{
|
||||
ClientID: config.GenericClientId,
|
||||
ClientSecret: config.GenericClientSecret,
|
||||
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL),
|
||||
Scopes: config.GenericScopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.GenericAuthURL,
|
||||
TokenURL: config.GenericTokenURL,
|
||||
},
|
||||
}, config.GenericSkipSSL)
|
||||
}
|
||||
|
||||
return providers
|
||||
}
|
||||
|
||||
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
|
||||
switch provider {
|
||||
case "github":
|
||||
return providers.Github
|
||||
case "google":
|
||||
return providers.Google
|
||||
case "generic":
|
||||
return providers.Generic
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
|
||||
var user constants.Claims
|
||||
|
||||
// Get the user from the provider
|
||||
switch provider {
|
||||
case "github":
|
||||
if providers.Github == nil {
|
||||
log.Debug().Msg("Github provider not configured")
|
||||
return user, nil
|
||||
}
|
||||
|
||||
client := providers.Github.GetClient()
|
||||
|
||||
log.Debug().Msg("Got client from github")
|
||||
|
||||
user, err := GetGithubUser(client)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got user from github")
|
||||
|
||||
return user, nil
|
||||
case "google":
|
||||
if providers.Google == nil {
|
||||
log.Debug().Msg("Google provider not configured")
|
||||
return user, nil
|
||||
}
|
||||
|
||||
client := providers.Google.GetClient()
|
||||
|
||||
log.Debug().Msg("Got client from google")
|
||||
|
||||
user, err := GetGoogleUser(client)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got user from google")
|
||||
|
||||
return user, nil
|
||||
case "generic":
|
||||
if providers.Generic == nil {
|
||||
log.Debug().Msg("Generic provider not configured")
|
||||
return user, nil
|
||||
}
|
||||
|
||||
client := providers.Generic.GetClient()
|
||||
|
||||
log.Debug().Msg("Got client from generic")
|
||||
|
||||
user, err := GetGenericUser(client, providers.Config.GenericUserURL)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got user from generic")
|
||||
|
||||
return user, nil
|
||||
default:
|
||||
return user, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (provider *Providers) GetConfiguredProviders() []string {
|
||||
providers := []string{}
|
||||
if provider.Github != nil {
|
||||
providers = append(providers, "github")
|
||||
}
|
||||
if provider.Google != nil {
|
||||
providers = append(providers, "google")
|
||||
}
|
||||
if provider.Generic != nil {
|
||||
providers = append(providers, "generic")
|
||||
}
|
||||
return providers
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package auth
|
||||
package service
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
@@ -6,8 +6,6 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
"tinyauth/internal/docker"
|
||||
"tinyauth/internal/ldap"
|
||||
"tinyauth/internal/types"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
@@ -17,35 +15,50 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type Auth struct {
|
||||
Config types.AuthConfig
|
||||
Docker *docker.Docker
|
||||
type AuthServiceConfig struct {
|
||||
Users types.Users
|
||||
OauthWhitelist string
|
||||
SessionExpiry int
|
||||
CookieSecure bool
|
||||
Domain string
|
||||
LoginTimeout int
|
||||
LoginMaxRetries int
|
||||
SessionCookieName string
|
||||
HMACSecret string
|
||||
EncryptionSecret string
|
||||
}
|
||||
|
||||
type AuthService struct {
|
||||
Config AuthServiceConfig
|
||||
Docker *DockerService
|
||||
LoginAttempts map[string]*types.LoginAttempt
|
||||
LoginMutex sync.RWMutex
|
||||
Store *sessions.CookieStore
|
||||
LDAP *ldap.LDAP
|
||||
LDAP *LdapService
|
||||
}
|
||||
|
||||
func NewAuth(config types.AuthConfig, docker *docker.Docker, ldap *ldap.LDAP) *Auth {
|
||||
// Setup cookie store and create the auth service
|
||||
store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret))
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: config.SessionExpiry,
|
||||
Secure: config.CookieSecure,
|
||||
HttpOnly: true,
|
||||
Domain: fmt.Sprintf(".%s", config.Domain),
|
||||
}
|
||||
return &Auth{
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService) *AuthService {
|
||||
return &AuthService{
|
||||
Config: config,
|
||||
Docker: docker,
|
||||
LoginAttempts: make(map[string]*types.LoginAttempt),
|
||||
Store: store,
|
||||
LDAP: ldap,
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
||||
func (auth *AuthService) Init() error {
|
||||
store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret))
|
||||
store.Options = &sessions.Options{
|
||||
Path: "/",
|
||||
MaxAge: auth.Config.SessionExpiry,
|
||||
Secure: auth.Config.CookieSecure,
|
||||
HttpOnly: true,
|
||||
Domain: fmt.Sprintf(".%s", auth.Config.Domain),
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *AuthService) GetSession(c *gin.Context) (*sessions.Session, error) {
|
||||
session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
||||
|
||||
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
||||
@@ -62,7 +75,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (auth *Auth) SearchUser(username string) types.UserSearch {
|
||||
func (auth *AuthService) SearchUser(username string) types.UserSearch {
|
||||
log.Debug().Str("username", username).Msg("Searching for user")
|
||||
|
||||
// Check local users first
|
||||
@@ -93,7 +106,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
||||
func (auth *AuthService) VerifyUser(search types.UserSearch, password string) bool {
|
||||
// Authenticate the user based on the type
|
||||
switch search.Type {
|
||||
case "local":
|
||||
@@ -131,7 +144,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *Auth) GetLocalUser(username string) types.User {
|
||||
func (auth *AuthService) GetLocalUser(username string) types.User {
|
||||
// Loop through users and return the user if the username matches
|
||||
log.Debug().Str("username", username).Msg("Searching for local user")
|
||||
|
||||
@@ -146,11 +159,11 @@ func (auth *Auth) GetLocalUser(username string) types.User {
|
||||
return types.User{}
|
||||
}
|
||||
|
||||
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
||||
func (auth *AuthService) CheckPassword(user types.User, password string) bool {
|
||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||
}
|
||||
|
||||
func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
|
||||
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
|
||||
auth.LoginMutex.RLock()
|
||||
defer auth.LoginMutex.RUnlock()
|
||||
|
||||
@@ -176,7 +189,7 @@ func (auth *Auth) IsAccountLocked(identifier string) (bool, int) {
|
||||
return false, 0
|
||||
}
|
||||
|
||||
func (auth *Auth) RecordLoginAttempt(identifier string, success bool) {
|
||||
func (auth *AuthService) RecordLoginAttempt(identifier string, success bool) {
|
||||
// Skip if rate limiting is not configured
|
||||
if auth.Config.LoginMaxRetries <= 0 || auth.Config.LoginTimeout <= 0 {
|
||||
return
|
||||
@@ -212,11 +225,11 @@ func (auth *Auth) RecordLoginAttempt(identifier string, success bool) {
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *Auth) EmailWhitelisted(email string) bool {
|
||||
func (auth *AuthService) EmailWhitelisted(email string) bool {
|
||||
return utils.CheckFilter(auth.Config.OauthWhitelist, email)
|
||||
}
|
||||
|
||||
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
|
||||
func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) error {
|
||||
log.Debug().Msg("Creating session cookie")
|
||||
|
||||
session, err := auth.GetSession(c)
|
||||
@@ -252,7 +265,7 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
|
||||
func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
||||
log.Debug().Msg("Deleting session cookie")
|
||||
|
||||
session, err := auth.GetSession(c)
|
||||
@@ -275,7 +288,7 @@ func (auth *Auth) DeleteSessionCookie(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
|
||||
func (auth *AuthService) GetSessionCookie(c *gin.Context) (types.SessionCookie, error) {
|
||||
log.Debug().Msg("Getting session cookie")
|
||||
|
||||
session, err := auth.GetSession(c)
|
||||
@@ -319,12 +332,12 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (auth *Auth) UserAuthConfigured() bool {
|
||||
func (auth *AuthService) UserAuthConfigured() bool {
|
||||
// If there are users or LDAP is configured, return true
|
||||
return len(auth.Config.Users) > 0 || auth.LDAP != nil
|
||||
}
|
||||
|
||||
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
||||
func (auth *AuthService) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
||||
if context.OAuth {
|
||||
log.Debug().Msg("Checking OAuth whitelist")
|
||||
return utils.CheckFilter(labels.OAuth.Whitelist, context.Email)
|
||||
@@ -334,7 +347,7 @@ func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, lab
|
||||
return utils.CheckFilter(labels.Users, context.Username)
|
||||
}
|
||||
|
||||
func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
||||
func (auth *AuthService) OAuthGroup(c *gin.Context, context types.UserContext, labels types.Labels) bool {
|
||||
if labels.OAuth.Groups == "" {
|
||||
return true
|
||||
}
|
||||
@@ -361,7 +374,7 @@ func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels t
|
||||
return false
|
||||
}
|
||||
|
||||
func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
||||
func (auth *AuthService) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
||||
// If the label is empty, auth is enabled
|
||||
if labels.Allowed == "" {
|
||||
return true, nil
|
||||
@@ -385,7 +398,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User {
|
||||
func (auth *AuthService) GetBasicAuth(c *gin.Context) *types.User {
|
||||
username, password, ok := c.Request.BasicAuth()
|
||||
if !ok {
|
||||
return nil
|
||||
@@ -396,7 +409,7 @@ func (auth *Auth) GetBasicAuth(c *gin.Context) *types.User {
|
||||
}
|
||||
}
|
||||
|
||||
func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
||||
func (auth *AuthService) CheckIP(labels types.Labels, ip string) bool {
|
||||
// Check if the IP is in block list
|
||||
for _, blocked := range labels.IP.Block {
|
||||
res, err := utils.FilterIP(blocked, ip)
|
||||
@@ -433,7 +446,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool {
|
||||
func (auth *AuthService) BypassedIP(labels types.Labels, ip string) bool {
|
||||
// For every IP in the bypass list, check if the IP matches
|
||||
for _, bypassed := range labels.IP.Bypass {
|
||||
res, err := utils.FilterIP(bypassed, ip)
|
||||
@@ -1,9 +1,9 @@
|
||||
package docker
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"tinyauth/internal/types"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
container "github.com/docker/docker/api/types/container"
|
||||
@@ -11,27 +11,27 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type Docker struct {
|
||||
type DockerService struct {
|
||||
Client *client.Client
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
func NewDocker() (*Docker, error) {
|
||||
func NewDockerService() *DockerService {
|
||||
return &DockerService{}
|
||||
}
|
||||
|
||||
func (docker *DockerService) Init() error {
|
||||
client, err := client.NewClientWithOpts(client.FromEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
client.NegotiateAPIVersion(ctx)
|
||||
|
||||
return &Docker{
|
||||
Client: client,
|
||||
Context: ctx,
|
||||
}, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (docker *Docker) GetContainers() ([]container.Summary, error) {
|
||||
func (docker *DockerService) GetContainers() ([]container.Summary, error) {
|
||||
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -39,7 +39,7 @@ func (docker *Docker) GetContainers() ([]container.Summary, error) {
|
||||
return containers, nil
|
||||
}
|
||||
|
||||
func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
|
||||
func (docker *DockerService) InspectContainer(containerId string) (container.InspectResponse, error) {
|
||||
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
|
||||
if err != nil {
|
||||
return container.InspectResponse{}, err
|
||||
@@ -47,17 +47,17 @@ func (docker *Docker) InspectContainer(containerId string) (container.InspectRes
|
||||
return inspect, nil
|
||||
}
|
||||
|
||||
func (docker *Docker) DockerConnected() bool {
|
||||
func (docker *DockerService) DockerConnected() bool {
|
||||
_, err := docker.Client.Ping(docker.Context)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error) {
|
||||
func (docker *DockerService) GetLabels(app string, domain string) (config.Labels, error) {
|
||||
isConnected := docker.DockerConnected()
|
||||
|
||||
if !isConnected {
|
||||
log.Debug().Msg("Docker not connected, returning empty labels")
|
||||
return types.Labels{}, nil
|
||||
return config.Labels{}, nil
|
||||
}
|
||||
|
||||
log.Debug().Msg("Getting containers")
|
||||
@@ -65,7 +65,7 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error)
|
||||
containers, err := docker.GetContainers()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error getting containers")
|
||||
return types.Labels{}, err
|
||||
return config.Labels{}, err
|
||||
}
|
||||
|
||||
for _, container := range containers {
|
||||
@@ -98,5 +98,5 @@ func (docker *Docker) GetLabels(app string, domain string) (types.Labels, error)
|
||||
}
|
||||
|
||||
log.Debug().Msg("No matching container found, returning empty labels")
|
||||
return types.Labels{}, nil
|
||||
return config.Labels{}, nil
|
||||
}
|
||||
114
internal/service/generic_oauth_service.go
Normal file
114
internal/service/generic_oauth_service.go
Normal file
@@ -0,0 +1,114 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
type GenericOAuthService struct {
|
||||
Config oauth2.Config
|
||||
Context context.Context
|
||||
Token *oauth2.Token
|
||||
Verifier string
|
||||
InsecureSkipVerify bool
|
||||
ServiceName string
|
||||
UserinfoURL string
|
||||
}
|
||||
|
||||
func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService {
|
||||
return &GenericOAuthService{
|
||||
Config: oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: config.Scopes,
|
||||
Endpoint: oauth2.Endpoint{
|
||||
AuthURL: config.AuthURL,
|
||||
TokenURL: config.TokenURL,
|
||||
},
|
||||
},
|
||||
InsecureSkipVerify: config.InsecureSkipVerify,
|
||||
ServiceName: config.Name,
|
||||
UserinfoURL: config.UserinfoURL,
|
||||
}
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) Init() error {
|
||||
transport := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: generic.InsecureSkipVerify,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
}
|
||||
|
||||
httpClient := &http.Client{
|
||||
Transport: transport,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
|
||||
generic.Context = ctx
|
||||
generic.Verifier = verifier
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) Name() string {
|
||||
return generic.ServiceName
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) GenerateState() string {
|
||||
b := make([]byte, 128)
|
||||
rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
return state
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) GetAuthURL(state string) string {
|
||||
return generic.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.Verifier))
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) VerifyCode(code string) error {
|
||||
token, err := generic.Config.Exchange(generic.Context, code, oauth2.VerifierOption(generic.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
generic.Token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (generic *GenericOAuthService) Userinfo() (config.Claims, error) {
|
||||
var user config.Claims
|
||||
|
||||
client := generic.Config.Client(generic.Context, generic.Token)
|
||||
|
||||
res, err := client.Get(generic.UserinfoURL)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
err = json.Unmarshal(body, &user)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
144
internal/service/github_oauth_service.go
Normal file
144
internal/service/github_oauth_service.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var GithubOAuthScopes = []string{"user:email", "read:user"}
|
||||
|
||||
type GithubEmailResponse []struct {
|
||||
Email string `json:"email"`
|
||||
Primary bool `json:"primary"`
|
||||
}
|
||||
|
||||
type GithubUserInfoResponse struct {
|
||||
Login string `json:"login"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type GithubOAuthService struct {
|
||||
Config oauth2.Config
|
||||
Context context.Context
|
||||
Token *oauth2.Token
|
||||
Verifier string
|
||||
}
|
||||
|
||||
func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService {
|
||||
return &GithubOAuthService{
|
||||
Config: oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: GithubOAuthScopes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) Init() error {
|
||||
httpClient := &http.Client{}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
|
||||
github.Context = ctx
|
||||
github.Verifier = verifier
|
||||
return nil
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) Name() string {
|
||||
return "github"
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) GenerateState() string {
|
||||
b := make([]byte, 128)
|
||||
rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
return state
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) GetAuthURL(state string) string {
|
||||
return github.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.Verifier))
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) VerifyCode(code string) error {
|
||||
token, err := github.Config.Exchange(github.Context, code, oauth2.VerifierOption(github.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
github.Token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (github *GithubOAuthService) Userinfo() (config.Claims, error) {
|
||||
var user config.Claims
|
||||
|
||||
client := github.Config.Client(github.Context, github.Token)
|
||||
|
||||
res, err := client.Get("https://api.github.com/user")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
var userInfo GithubUserInfoResponse
|
||||
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
res, err = client.Get("https://api.github.com/user/emails")
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err = io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
var emails GithubEmailResponse
|
||||
|
||||
err = json.Unmarshal(body, &emails)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
|
||||
for _, email := range emails {
|
||||
if email.Primary {
|
||||
user.Email = email.Email
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if len(emails) == 0 {
|
||||
return user, errors.New("no emails found")
|
||||
}
|
||||
|
||||
// Use first available email if no primary email was found
|
||||
if user.Email == "" {
|
||||
user.Email = emails[0].Email
|
||||
}
|
||||
|
||||
user.PreferredUsername = userInfo.Login
|
||||
user.Name = userInfo.Name
|
||||
|
||||
return user, nil
|
||||
}
|
||||
106
internal/service/google_oauth_service.go
Normal file
106
internal/service/google_oauth_service.go
Normal file
@@ -0,0 +1,106 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
var GoogleOAuthScopes = []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
|
||||
|
||||
type GoogleUserInfoResponse struct {
|
||||
Email string `json:"email"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
type GoogleOAuthService struct {
|
||||
Config oauth2.Config
|
||||
Context context.Context
|
||||
Token *oauth2.Token
|
||||
Verifier string
|
||||
}
|
||||
|
||||
func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService {
|
||||
return &GoogleOAuthService{
|
||||
Config: oauth2.Config{
|
||||
ClientID: config.ClientID,
|
||||
ClientSecret: config.ClientSecret,
|
||||
RedirectURL: config.RedirectURL,
|
||||
Scopes: GoogleOAuthScopes,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) Init() error {
|
||||
httpClient := &http.Client{}
|
||||
ctx := context.Background()
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
|
||||
google.Context = ctx
|
||||
google.Verifier = verifier
|
||||
return nil
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) Name() string {
|
||||
return "google"
|
||||
}
|
||||
|
||||
func (oauth *GoogleOAuthService) GenerateState() string {
|
||||
b := make([]byte, 128)
|
||||
rand.Read(b)
|
||||
state := base64.URLEncoding.EncodeToString(b)
|
||||
return state
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) GetAuthURL(state string) string {
|
||||
return google.Config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.Verifier))
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) VerifyCode(code string) error {
|
||||
token, err := google.Config.Exchange(google.Context, code, oauth2.VerifierOption(google.Verifier))
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
google.Token = token
|
||||
return nil
|
||||
}
|
||||
|
||||
func (google *GoogleOAuthService) Userinfo() (config.Claims, error) {
|
||||
var user config.Claims
|
||||
|
||||
client := google.Config.Client(google.Context, google.Token)
|
||||
|
||||
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
var userInfo GoogleUserInfoResponse
|
||||
|
||||
err = json.Unmarshal(body, &userInfo)
|
||||
if err != nil {
|
||||
return config.Claims{}, err
|
||||
}
|
||||
|
||||
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
|
||||
user.Name = userInfo.Name
|
||||
user.Email = userInfo.Email
|
||||
|
||||
return user, nil
|
||||
}
|
||||
@@ -1,30 +1,40 @@
|
||||
package ldap
|
||||
package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"time"
|
||||
"tinyauth/internal/types"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
ldapgo "github.com/go-ldap/ldap/v3"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type LDAP struct {
|
||||
Config types.LdapConfig
|
||||
type LdapServiceConfig struct {
|
||||
Address string
|
||||
BindDN string
|
||||
BindPassword string
|
||||
BaseDN string
|
||||
Insecure bool
|
||||
SearchFilter string
|
||||
}
|
||||
|
||||
type LdapService struct {
|
||||
Config LdapServiceConfig
|
||||
Conn *ldapgo.Conn
|
||||
}
|
||||
|
||||
func NewLDAP(config types.LdapConfig) (*LDAP, error) {
|
||||
ldap := &LDAP{
|
||||
func NewLdapService(config LdapServiceConfig) *LdapService {
|
||||
return &LdapService{
|
||||
Config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ldap *LdapService) Init() error {
|
||||
_, err := ldap.connect()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||
return fmt.Errorf("failed to connect to LDAP server: %w", err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
@@ -41,13 +51,13 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
|
||||
}
|
||||
}()
|
||||
|
||||
return ldap, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LDAP) connect() (*ldapgo.Conn, error) {
|
||||
func (ldap *LdapService) connect() (*ldapgo.Conn, error) {
|
||||
log.Debug().Msg("Connecting to LDAP server")
|
||||
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: l.Config.Insecure,
|
||||
conn, err := ldapgo.DialURL(ldap.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||
InsecureSkipVerify: ldap.Config.Insecure,
|
||||
MinVersion: tls.VersionTLS12,
|
||||
}))
|
||||
if err != nil {
|
||||
@@ -55,30 +65,30 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) {
|
||||
}
|
||||
|
||||
log.Debug().Msg("Binding to LDAP server")
|
||||
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
|
||||
err = conn.Bind(ldap.Config.BindDN, ldap.Config.BindPassword)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Set and return the connection
|
||||
l.Conn = conn
|
||||
ldap.Conn = conn
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (l *LDAP) Search(username string) (string, error) {
|
||||
func (ldap *LdapService) Search(username string) (string, error) {
|
||||
// Escape the username to prevent LDAP injection
|
||||
escapedUsername := ldapgo.EscapeFilter(username)
|
||||
filter := fmt.Sprintf(l.Config.SearchFilter, escapedUsername)
|
||||
filter := fmt.Sprintf(ldap.Config.SearchFilter, escapedUsername)
|
||||
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
l.Config.BaseDN,
|
||||
ldap.Config.BaseDN,
|
||||
ldapgo.ScopeWholeSubtree, ldapgo.NeverDerefAliases, 0, 0, false,
|
||||
filter,
|
||||
[]string{"dn"},
|
||||
nil,
|
||||
)
|
||||
|
||||
searchResult, err := l.Conn.Search(searchRequest)
|
||||
searchResult, err := ldap.Conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -91,15 +101,15 @@ func (l *LDAP) Search(username string) (string, error) {
|
||||
return userDN, nil
|
||||
}
|
||||
|
||||
func (l *LDAP) Bind(userDN string, password string) error {
|
||||
err := l.Conn.Bind(userDN, password)
|
||||
func (ldap *LdapService) Bind(userDN string, password string) error {
|
||||
err := ldap.Conn.Bind(userDN, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LDAP) heartbeat() error {
|
||||
func (ldap *LdapService) heartbeat() error {
|
||||
log.Debug().Msg("Performing LDAP connection heartbeat")
|
||||
|
||||
searchRequest := ldapgo.NewSearchRequest(
|
||||
@@ -110,7 +120,7 @@ func (l *LDAP) heartbeat() error {
|
||||
nil,
|
||||
)
|
||||
|
||||
_, err := l.Conn.Search(searchRequest)
|
||||
_, err := ldap.Conn.Search(searchRequest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -119,7 +129,7 @@ func (l *LDAP) heartbeat() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *LDAP) reconnect() error {
|
||||
func (ldap *LdapService) reconnect() error {
|
||||
log.Info().Msg("Reconnecting to LDAP server")
|
||||
|
||||
exp := backoff.NewExponentialBackOff()
|
||||
@@ -129,8 +139,8 @@ func (l *LDAP) reconnect() error {
|
||||
exp.Reset()
|
||||
|
||||
operation := func() (*ldapgo.Conn, error) {
|
||||
l.Conn.Close()
|
||||
conn, err := l.connect()
|
||||
ldap.Conn.Close()
|
||||
conn, err := ldap.connect()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
@@ -1,70 +0,0 @@
|
||||
package types
|
||||
|
||||
import (
|
||||
"time"
|
||||
"tinyauth/internal/oauth"
|
||||
)
|
||||
|
||||
// User is the struct for a user
|
||||
type User struct {
|
||||
Username string
|
||||
Password string
|
||||
TotpSecret string
|
||||
}
|
||||
|
||||
// UserSearch is the response of the get user
|
||||
type UserSearch struct {
|
||||
Username string
|
||||
Type string // "local", "ldap" or empty
|
||||
}
|
||||
|
||||
// Users is a list of users
|
||||
type Users []User
|
||||
|
||||
// OAuthProviders is the struct for the OAuth providers
|
||||
type OAuthProviders struct {
|
||||
Github *oauth.OAuth
|
||||
Google *oauth.OAuth
|
||||
Microsoft *oauth.OAuth
|
||||
}
|
||||
|
||||
// SessionCookie is the cookie for the session (exculding the expiry)
|
||||
type SessionCookie struct {
|
||||
Username string
|
||||
Name string
|
||||
Email string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
}
|
||||
|
||||
// UserContext is the context for the user
|
||||
type UserContext struct {
|
||||
Username string
|
||||
Name string
|
||||
Email string
|
||||
IsLoggedIn bool
|
||||
OAuth bool
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
TotpEnabled bool
|
||||
}
|
||||
|
||||
// LoginAttempt tracks information about login attempts for rate limiting
|
||||
type LoginAttempt struct {
|
||||
FailedAttempts int
|
||||
LastAttempt time.Time
|
||||
LockedUntil time.Time
|
||||
}
|
||||
|
||||
type UnauthorizedQuery struct {
|
||||
Username string `url:"username"`
|
||||
Resource string `url:"resource"`
|
||||
GroupErr bool `url:"groupErr"`
|
||||
IP string `url:"ip"`
|
||||
}
|
||||
|
||||
type RedirectQuery struct {
|
||||
RedirectURI string `url:"redirect_uri"`
|
||||
}
|
||||
Reference in New Issue
Block a user