diff --git a/internal/bootstrap/app_bootstrap.go b/internal/bootstrap/app_bootstrap.go index 55a5f082..0914282c 100644 --- a/internal/bootstrap/app_bootstrap.go +++ b/internal/bootstrap/app_bootstrap.go @@ -142,15 +142,6 @@ func (app *BootstrapApp) Setup() error { provider.ClientSecret = secret provider.ClientSecretFile = "" - if provider.RedirectURL == "" { - provider.RedirectURL = app.runtime.AppURL + "/api/oauth/callback/" + id - } - - app.runtime.OAuthProviders[id] = provider - } - - // set presets for built-in providers - for id, provider := range app.runtime.OAuthProviders { if provider.Name == "" { if name, ok := model.OverrideProviders[id]; ok { provider.Name = name @@ -158,6 +149,7 @@ func (app *BootstrapApp) Setup() error { provider.Name = utils.Capitalize(id) } } + app.runtime.OAuthProviders[id] = provider } diff --git a/internal/bootstrap/router_bootstrap.go b/internal/bootstrap/router_bootstrap.go index bb3d1df6..703d0442 100644 --- a/internal/bootstrap/router_bootstrap.go +++ b/internal/bootstrap/router_bootstrap.go @@ -147,7 +147,7 @@ func (app *BootstrapApp) getListenerFunc() (func(ctx context.Context) error, err func (app *BootstrapApp) serveHTTP(ctx context.Context) error { address := fmt.Sprintf("%s:%d", app.config.Server.Address, app.config.Server.Port) - app.log.App.Info().Msgf("Starting server on %s", address) + app.log.App.Info().Msgf("Starting server on http://%s", address) listener, err := net.Listen("tcp", address) diff --git a/internal/controller/context_controller.go b/internal/controller/context_controller.go index 84f52cc3..abfabaad 100644 --- a/internal/controller/context_controller.go +++ b/internal/controller/context_controller.go @@ -1,6 +1,8 @@ package controller import ( + "errors" + "github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils/logger" "go.uber.org/dig" @@ -109,7 +111,9 @@ func (controller *ContextController) userContextHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { - controller.log.App.Error().Err(err).Msg("Failed to create user context from request") + if !errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Error().Err(err).Msg("Failed to create user context from request") + } c.JSON(200, UserContextResponse{ Status: 401, Message: "Unauthorized", diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index 79f77dec..29663872 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -335,7 +335,17 @@ func (controller *OAuthController) isRedirectSafe(redirectURI string) bool { return false } - if u.Port() != au.Port() { + getEffectivePort := func(u *url.URL) string { + if u.Port() != "" { + return u.Port() + } + if u.Scheme == "https" { + return "443" + } + return "80" + } + + if getEffectivePort(u) != getEffectivePort(au) { controller.log.App.Warn().Msg("Redirect URI port does not match app URL port") return false } diff --git a/internal/controller/user_controller.go b/internal/controller/user_controller.go index f17b7d79..ae6c23bf 100644 --- a/internal/controller/user_controller.go +++ b/internal/controller/user_controller.go @@ -295,6 +295,14 @@ func (controller *UserController) totpHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { + if errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Warn().Msg("TOTP verification attempt without user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } controller.log.App.Error().Err(err).Msg("Failed to create user context from request for TOTP verification") c.JSON(500, gin.H{ "status": 500, @@ -405,6 +413,14 @@ func (controller *UserController) tailscaleHandler(c *gin.Context) { context, err := new(model.UserContext).NewFromGin(c) if err != nil { + if errors.Is(err, model.ErrUserContextNotFound) { + controller.log.App.Warn().Msg("Tailscale login attempt without user context") + c.JSON(401, gin.H{ + "status": 401, + "message": "Unauthorized", + }) + return + } controller.log.App.Error().Err(err).Msg("Failed to create user context from request") c.JSON(401, gin.H{ "status": 401, diff --git a/internal/service/auth_service.go b/internal/service/auth_service.go index 3ab4e0d9..627dd127 100644 --- a/internal/service/auth_service.go +++ b/internal/service/auth_service.go @@ -46,7 +46,7 @@ type OAuthPendingSession struct { State string Verifier string Token *oauth2.Token - Service *OAuthServiceImpl + Service IOAuthService ExpiresAt time.Time CallbackParams OAuthCallbackParams } @@ -527,7 +527,7 @@ func (auth *AuthService) NewOAuthSession(serviceName string, params OAuthCallbac session := OAuthPendingSession{ State: state, Verifier: verifier, - Service: &service, + Service: service, ExpiresAt: time.Now().Add(1 * time.Hour), CallbackParams: params, } @@ -544,7 +544,18 @@ func (auth *AuthService) GetOAuthURL(sessionId string) (string, error) { return "", err } - return (*session.Service).GetAuthURL(session.State, session.Verifier), nil + svc := session.Service + + cfg := svc.GetConfig() + + // If the redirect URL is not set in the service config, we set it ourselves + if cfg.RedirectURL == "" { + cfg.RedirectURL = auth.runtime.AppURL + "/api/oauth/callback/" + svc.ID() + } + + svc.UpdateConfig(cfg) + + return svc.GetAuthURL(session.State, session.Verifier), nil } func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.Token, error) { @@ -554,7 +565,7 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T return nil, fmt.Errorf("oauth session not found: %s", sessionId) } - token, err := (*session.Service).GetToken(code, session.Verifier) + token, err := session.Service.GetToken(code, session.Verifier) if err != nil { return nil, fmt.Errorf("failed to exchange code for token: %w", err) @@ -583,7 +594,7 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) } - userinfo, err := (*session.Service).GetUserinfo(session.Token) + userinfo, err := session.Service.GetUserinfo(session.Token) if err != nil { return nil, fmt.Errorf("failed to get userinfo: %w", err) @@ -592,14 +603,14 @@ func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, erro return userinfo, nil } -func (auth *AuthService) GetOAuthService(sessionId string) (OAuthServiceImpl, error) { +func (auth *AuthService) GetOAuthService(sessionId string) (IOAuthService, error) { session, err := auth.GetOAuthPendingSession(sessionId) if err != nil { return nil, err } - return *session.Service, nil + return session.Service, nil } func (auth *AuthService) EndOAuthSession(sessionId string) { diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 63503abc..4df0e825 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -12,19 +12,21 @@ import ( "golang.org/x/oauth2" ) -type OAuthServiceImpl interface { +type IOAuthService interface { Name() string ID() string NewRandom() string - GetAuthURL(state string, verifier string) string - GetToken(code string, verifier string) (*oauth2.Token, error) + GetAuthURL(state, verifier string) string + GetToken(code, verifier string) (*oauth2.Token, error) GetUserinfo(token *oauth2.Token) (*model.Claims, error) + GetConfig() model.OAuthServiceConfig + UpdateConfig(config model.OAuthServiceConfig) } type OAuthBrokerService struct { log *logger.Logger - services map[string]OAuthServiceImpl + services map[string]IOAuthService configs map[string]model.OAuthServiceConfig } @@ -44,7 +46,7 @@ type OAuthBrokerServiceInput struct { func NewOAuthBrokerService(i OAuthBrokerServiceInput) *OAuthBrokerService { service := &OAuthBrokerService{ log: i.Log, - services: make(map[string]OAuthServiceImpl), + services: make(map[string]IOAuthService), configs: i.Runtime.OAuthProviders, } @@ -70,7 +72,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string { return services } -func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { +func (broker *OAuthBrokerService) GetService(name string) (IOAuthService, bool) { service, exists := broker.services[name] return service, exists } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 07d0e1cc..888614ec 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -70,7 +70,7 @@ func (s *OAuthService) NewRandom() string { return random } -func (s *OAuthService) GetAuthURL(state string, verifier string) string { +func (s *OAuthService) GetAuthURL(state, verifier string) string { return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) } @@ -82,3 +82,17 @@ func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) } + +func (s *OAuthService) GetConfig() model.OAuthServiceConfig { + return s.serviceCfg +} + +func (s *OAuthService) UpdateConfig(config model.OAuthServiceConfig) { + s.serviceCfg = config + s.config.ClientID = config.ClientID + s.config.ClientSecret = config.ClientSecret + s.config.Scopes = config.Scopes + s.config.Endpoint.AuthURL = config.AuthURL + s.config.Endpoint.TokenURL = config.TokenURL + s.config.RedirectURL = config.RedirectURL +}