Compare commits

..

2 Commits

Author SHA1 Message Date
Stavros
07933ad71c tests: fix tests 2025-05-22 22:30:54 +03:00
Stavros
7151832cc7 feat: generate a unique id for the cookie names based on the domain 2025-05-22 22:21:29 +03:00
13 changed files with 93 additions and 70 deletions

View File

@@ -30,4 +30,3 @@ APP_TITLE=Tinyauth SSO
FORGOT_PASSWORD_MESSAGE=Some message about resetting the password FORGOT_PASSWORD_MESSAGE=Some message about resetting the password
OAUTH_AUTO_REDIRECT=none OAUTH_AUTO_REDIRECT=none
BACKGROUND_IMAGE=some_image_url BACKGROUND_IMAGE=some_image_url
GENERIC_SKIP_SSL=false

View File

@@ -52,4 +52,7 @@ COPY --from=builder /tinyauth/tinyauth ./
EXPOSE 3000 EXPOSE 3000
HEALTHCHECK --interval=10s --timeout=5s \
CMD curl -f http://localhost:3000/api/healthcheck || exit 1
ENTRYPOINT ["./tinyauth"] ENTRYPOINT ["./tinyauth"]

View File

@@ -2,6 +2,7 @@ package cmd
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"strings" "strings"
"time" "time"
@@ -67,6 +68,12 @@ var rootCmd = &cobra.Command{
HandleError(err, "Failed to get upper domain") HandleError(err, "Failed to get upper domain")
log.Info().Str("domain", domain).Msg("Using domain for cookie store") log.Info().Str("domain", domain).Msg("Using domain for cookie store")
// Generate cookie name
cookieId := utils.GenerateIdentifier(strings.Split(domain, ".")[0])
sessionCookieName := fmt.Sprintf("%s-%s", constants.SessionCookieName, cookieId)
csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId)
redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId)
// Create OAuth config // Create OAuth config
oauthConfig := types.OAuthConfig{ oauthConfig := types.OAuthConfig{
GithubClientId: config.GithubClientId, GithubClientId: config.GithubClientId,
@@ -79,7 +86,6 @@ var rootCmd = &cobra.Command{
GenericAuthURL: config.GenericAuthURL, GenericAuthURL: config.GenericAuthURL,
GenericTokenURL: config.GenericTokenURL, GenericTokenURL: config.GenericTokenURL,
GenericUserURL: config.GenericUserURL, GenericUserURL: config.GenericUserURL,
GenericSkipSSL: config.GenericSkipSSL,
AppURL: config.AppURL, AppURL: config.AppURL,
} }
@@ -94,6 +100,8 @@ var rootCmd = &cobra.Command{
ForgotPasswordMessage: config.FogotPasswordMessage, ForgotPasswordMessage: config.FogotPasswordMessage,
BackgroundImage: config.BackgroundImage, BackgroundImage: config.BackgroundImage,
OAuthAutoRedirect: config.OAuthAutoRedirect, OAuthAutoRedirect: config.OAuthAutoRedirect,
CsrfCookieName: csrfCookieName,
RedirectCookieName: redirectCookieName,
} }
// Create api config // Create api config
@@ -104,14 +112,15 @@ var rootCmd = &cobra.Command{
// Create auth config // Create auth config
authConfig := types.AuthConfig{ authConfig := types.AuthConfig{
Users: users, Users: users,
OauthWhitelist: config.OAuthWhitelist, OauthWhitelist: config.OAuthWhitelist,
Secret: config.Secret, Secret: config.Secret,
CookieSecure: config.CookieSecure, CookieSecure: config.CookieSecure,
SessionExpiry: config.SessionExpiry, SessionExpiry: config.SessionExpiry,
Domain: domain, Domain: domain,
LoginTimeout: config.LoginTimeout, LoginTimeout: config.LoginTimeout,
LoginMaxRetries: config.LoginMaxRetries, LoginMaxRetries: config.LoginMaxRetries,
SessionCookieName: sessionCookieName,
} }
// Create hooks config // Create hooks config
@@ -198,7 +207,6 @@ func init() {
rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.")
rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.") rootCmd.Flags().String("generic-user-url", "", "Generic OAuth user info URL.")
rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.") rootCmd.Flags().String("generic-name", "Generic", "Generic OAuth provider name.")
rootCmd.Flags().Bool("generic-skip-ssl", false, "Skip SSL verification for the generic OAuth provider.")
rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.") rootCmd.Flags().Bool("disable-continue", false, "Disable continue screen and redirect to app directly.")
rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.") rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)") rootCmd.Flags().String("oauth-auto-redirect", "none", "Auto redirect to the specified OAuth provider if configured. (available providers: github, google, generic)")
@@ -233,7 +241,6 @@ func init() {
viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL")
viper.BindEnv("generic-user-url", "GENERIC_USER_URL") viper.BindEnv("generic-user-url", "GENERIC_USER_URL")
viper.BindEnv("generic-name", "GENERIC_NAME") viper.BindEnv("generic-name", "GENERIC_NAME")
viper.BindEnv("generic-skip-ssl", "GENERIC_SKIP_SSL")
viper.BindEnv("disable-continue", "DISABLE_CONTINUE") viper.BindEnv("disable-continue", "DISABLE_CONTINUE")
viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST") viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT") viper.BindEnv("oauth-auto-redirect", "OAUTH_AUTO_REDIRECT")

View File

@@ -29,7 +29,7 @@ export const LoginPage = () => {
return <Navigate to="/logout" />; return <Navigate to="/logout" />;
} }
const { configuredProviders, title, oauthAutoRedirect, genericName } = useAppContext(); const { configuredProviders, title, oauthAutoRedirect } = useAppContext();
const { search } = useLocation(); const { search } = useLocation();
const { t } = useTranslation(); const { t } = useTranslation();
const isMounted = useIsMounted(); const isMounted = useIsMounted();
@@ -138,7 +138,7 @@ export const LoginPage = () => {
)} )}
{configuredProviders.includes("generic") && ( {configuredProviders.includes("generic") && (
<OAuthButton <OAuthButton
title={genericName} title="Generic"
icon={<GenericIcon />} icon={<GenericIcon />}
className="w-full" className="w-full"
onClick={() => oauthMutation.mutate("generic")} onClick={() => oauthMutation.mutate("generic")}

1
go.mod
View File

@@ -6,6 +6,7 @@ require (
github.com/gin-gonic/gin v1.10.0 github.com/gin-gonic/gin v1.10.0
github.com/go-playground/validator/v10 v10.26.0 github.com/go-playground/validator/v10 v10.26.0
github.com/google/go-querystring v1.1.0 github.com/google/go-querystring v1.1.0
github.com/google/uuid v1.6.0
github.com/mdp/qrterminal/v3 v3.2.1 github.com/mdp/qrterminal/v3 v3.2.1
github.com/rs/zerolog v1.34.0 github.com/rs/zerolog v1.34.0
github.com/spf13/cobra v1.9.1 github.com/spf13/cobra v1.9.1

View File

@@ -28,21 +28,29 @@ var apiConfig = types.APIConfig{
// Simple handlers config for tests // Simple handlers config for tests
var handlersConfig = types.HandlersConfig{ var handlersConfig = types.HandlersConfig{
AppURL: "http://localhost:8080", AppURL: "http://localhost:8080",
Domain: "localhost",
DisableContinue: false, DisableContinue: false,
CookieSecure: false,
Title: "Tinyauth", Title: "Tinyauth",
GenericName: "Generic", GenericName: "Generic",
ForgotPasswordMessage: "Some message", ForgotPasswordMessage: "Some message",
CsrfCookieName: "tinyauth-csrf",
RedirectCookieName: "tinyauth-redirect",
BackgroundImage: "https://example.com/image.png",
OAuthAutoRedirect: "none",
} }
// Simple auth config for tests // Simple auth config for tests
var authConfig = types.AuthConfig{ var authConfig = types.AuthConfig{
Users: types.Users{}, Users: types.Users{},
OauthWhitelist: "", OauthWhitelist: "",
Secret: "super-secret-api-thing-for-tests", // It is 32 chars long Secret: "super-secret-api-thing-for-tests", // It is 32 chars long
CookieSecure: false, CookieSecure: false,
SessionExpiry: 3600, SessionExpiry: 3600,
LoginTimeout: 0, LoginTimeout: 0,
LoginMaxRetries: 0, LoginMaxRetries: 0,
SessionCookieName: "tinyauth-session",
Domain: "localhost",
} }
// Simple hooks config for tests // Simple hooks config for tests
@@ -206,6 +214,9 @@ func TestAppContext(t *testing.T) {
Title: "Tinyauth", Title: "Tinyauth",
GenericName: "Generic", GenericName: "Generic",
ForgotPasswordMessage: "Some message", ForgotPasswordMessage: "Some message",
BackgroundImage: "https://example.com/image.png",
OAuthAutoRedirect: "none",
Domain: "localhost",
} }
// We should get the username back // We should get the username back
@@ -234,7 +245,7 @@ func TestUserContext(t *testing.T) {
// Set the cookie // Set the cookie
req.AddCookie(&http.Cookie{ req.AddCookie(&http.Cookie{
Name: "tinyauth", Name: "tinyauth-session",
Value: cookie, Value: cookie,
}) })

View File

@@ -45,7 +45,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
} }
// Get session // Get session
session, err := store.Get(c.Request, "tinyauth") session, err := store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
return nil, err return nil, err

View File

@@ -21,3 +21,8 @@ type Claims struct {
var Version = "development" var Version = "development"
var CommitHash = "n/a" var CommitHash = "n/a"
var BuildTimestamp = "n/a" var BuildTimestamp = "n/a"
// Cookie names
var SessionCookieName = "tinyauth-session"
var CsrfCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"

View File

@@ -581,7 +581,7 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
log.Debug().Msg("Got auth URL") log.Debug().Msg("Got auth URL")
// Set CSRF cookie // Set CSRF cookie
c.SetCookie("tinyauth-csrf", state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
// Get redirect URI // Get redirect URI
redirectURI := c.Query("redirect_uri") redirectURI := c.Query("redirect_uri")
@@ -589,7 +589,7 @@ func (h *Handlers) OauthUrlHandler(c *gin.Context) {
// Set redirect cookie if redirect URI is provided // Set redirect cookie if redirect URI is provided
if redirectURI != "" { if redirectURI != "" {
log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie") log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
c.SetCookie("tinyauth-redirect", redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true) c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
} }
// Return auth URL // Return auth URL
@@ -620,7 +620,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
// Get CSRF cookie // Get CSRF cookie
csrfCookie, err := c.Cookie("tinyauth-csrf") csrfCookie, err := c.Cookie(h.Config.CsrfCookieName)
if err != nil { if err != nil {
log.Debug().Msg("No CSRF cookie") log.Debug().Msg("No CSRF cookie")
@@ -638,7 +638,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
} }
// Clean up CSRF cookie // Clean up CSRF cookie
c.SetCookie("tinyauth-csrf", "", -1, "/", "", h.Config.CookieSecure, true) c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true)
// Get code // Get code
code := c.Query("code") code := c.Query("code")
@@ -737,7 +737,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
}) })
// Check if we have a redirect URI // Check if we have a redirect URI
redirectCookie, err := c.Cookie("tinyauth-redirect") redirectCookie, err := c.Cookie(h.Config.RedirectCookieName)
if err != nil { if err != nil {
log.Debug().Msg("No redirect cookie") log.Debug().Msg("No redirect cookie")
@@ -762,7 +762,7 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
} }
// Clean up redirect cookie // Clean up redirect cookie
c.SetCookie("tinyauth-redirect", "", -1, "/", "", h.Config.CookieSecure, true) c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true)
// Redirect to continue with the redirect URI // Redirect to continue with the redirect URI
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode())) c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode()))

View File

@@ -3,48 +3,28 @@ package oauth
import ( import (
"context" "context"
"crypto/rand" "crypto/rand"
"crypto/tls"
"encoding/base64" "encoding/base64"
"net/http" "net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth { func NewOAuth(config oauth2.Config) *OAuth {
return &OAuth{ return &OAuth{
Config: config, Config: config,
InsecureSkipVerify: insecureSkipVerify,
} }
} }
type OAuth struct { type OAuth struct {
Config oauth2.Config Config oauth2.Config
Context context.Context Context context.Context
Token *oauth2.Token Token *oauth2.Token
Verifier string Verifier string
InsecureSkipVerify bool
} }
func (oauth *OAuth) Init() { func (oauth *OAuth) Init() {
// Create transport with TLS // Create a new context and verifier
transport := &http.Transport{
TLSClientConfig: &tls.Config{
InsecureSkipVerify: oauth.InsecureSkipVerify,
MinVersion: tls.VersionTLS12,
},
}
// Create a new context
oauth.Context = context.Background() oauth.Context = context.Background()
// Create the HTTP client with the transport
httpClient := &http.Client{
Transport: transport,
}
// Set the HTTP client in the context
oauth.Context = context.WithValue(oauth.Context, oauth2.HTTPClient, httpClient)
// Create the verifier
oauth.Verifier = oauth2.GenerateVerifier() oauth.Verifier = oauth2.GenerateVerifier()
} }

View File

@@ -36,7 +36,7 @@ func (providers *Providers) Init() {
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", providers.Config.AppURL),
Scopes: GithubScopes(), Scopes: GithubScopes(),
Endpoint: endpoints.GitHub, Endpoint: endpoints.GitHub,
}, false) })
// Initialize the oauth provider // Initialize the oauth provider
providers.Github.Init() providers.Github.Init()
@@ -53,7 +53,7 @@ func (providers *Providers) Init() {
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", providers.Config.AppURL),
Scopes: GoogleScopes(), Scopes: GoogleScopes(),
Endpoint: endpoints.Google, Endpoint: endpoints.Google,
}, false) })
// Initialize the oauth provider // Initialize the oauth provider
providers.Google.Init() providers.Google.Init()
@@ -73,7 +73,7 @@ func (providers *Providers) Init() {
AuthURL: providers.Config.GenericAuthURL, AuthURL: providers.Config.GenericAuthURL,
TokenURL: providers.Config.GenericTokenURL, TokenURL: providers.Config.GenericTokenURL,
}, },
}, providers.Config.GenericSkipSSL) })
// Initialize the oauth provider // Initialize the oauth provider
providers.Generic.Init() providers.Generic.Init()

View File

@@ -24,7 +24,6 @@ type Config struct {
GenericTokenURL string `mapstructure:"generic-token-url"` GenericTokenURL string `mapstructure:"generic-token-url"`
GenericUserURL string `mapstructure:"generic-user-url"` GenericUserURL string `mapstructure:"generic-user-url"`
GenericName string `mapstructure:"generic-name"` GenericName string `mapstructure:"generic-name"`
GenericSkipSSL bool `mapstructure:"generic-skip-ssl"`
DisableContinue bool `mapstructure:"disable-continue"` DisableContinue bool `mapstructure:"disable-continue"`
OAuthWhitelist string `mapstructure:"oauth-whitelist"` OAuthWhitelist string `mapstructure:"oauth-whitelist"`
OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"` OAuthAutoRedirect string `mapstructure:"oauth-auto-redirect" validate:"oneof=none github google generic"`
@@ -35,7 +34,7 @@ type Config struct {
LoginTimeout int `mapstructure:"login-timeout"` LoginTimeout int `mapstructure:"login-timeout"`
LoginMaxRetries int `mapstructure:"login-max-retries"` LoginMaxRetries int `mapstructure:"login-max-retries"`
FogotPasswordMessage string `mapstructure:"forgot-password-message" validate:"required"` FogotPasswordMessage string `mapstructure:"forgot-password-message" validate:"required"`
BackgroundImage string `mapstructure:"background-image" validate:"required"` BackgroundImage string `mapstructure:"background-image" validate:"required,url"`
} }
// Server configuration // Server configuration
@@ -49,6 +48,8 @@ type HandlersConfig struct {
ForgotPasswordMessage string ForgotPasswordMessage string
BackgroundImage string BackgroundImage string
OAuthAutoRedirect string OAuthAutoRedirect string
CsrfCookieName string
RedirectCookieName string
} }
// OAuthConfig is the configuration for the providers // OAuthConfig is the configuration for the providers
@@ -63,7 +64,6 @@ type OAuthConfig struct {
GenericAuthURL string GenericAuthURL string
GenericTokenURL string GenericTokenURL string
GenericUserURL string GenericUserURL string
GenericSkipSSL bool
AppURL string AppURL string
} }
@@ -75,14 +75,15 @@ type APIConfig struct {
// AuthConfig is the configuration for the auth service // AuthConfig is the configuration for the auth service
type AuthConfig struct { type AuthConfig struct {
Users Users Users Users
OauthWhitelist string OauthWhitelist string
SessionExpiry int SessionExpiry int
Secret string Secret string
CookieSecure bool CookieSecure bool
Domain string Domain string
LoginTimeout int LoginTimeout int
LoginMaxRetries int LoginMaxRetries int
SessionCookieName string
} }
// HooksConfig is the configuration for the hooks service // HooksConfig is the configuration for the hooks service

View File

@@ -10,6 +10,7 @@ import (
"tinyauth/internal/constants" "tinyauth/internal/constants"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -344,3 +345,18 @@ func SanitizeHeader(header string) string {
return -1 return -1
}, header) }, header)
} }
// Generate a static identifier from a string
func GenerateIdentifier(str string) string {
// Create a new UUID
uuid := uuid.NewSHA1(uuid.NameSpaceURL, []byte(str))
// Convert the UUID to a string
uuidString := uuid.String()
// Show the UUID
log.Debug().Str("uuid", uuidString).Msg("Generated UUID")
// Convert the UUID to a string
return strings.Split(uuidString, "-")[0]
}