diff --git a/internal/controller/oauth_controller.go b/internal/controller/oauth_controller.go index d5dfc39..6be0bc5 100644 --- a/internal/controller/oauth_controller.go +++ b/internal/controller/oauth_controller.go @@ -206,11 +206,17 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) { return } + if service.ID() != req.Provider { + tlog.App.Error().Msgf("OAuth service ID mismatch: expected %s, got %s", service.ID(), req.Provider) + c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) + return + } + sessionCookie := repository.Session{ Username: username, Name: name, Email: user.Email, - Provider: req.Provider, + Provider: service.ID(), OAuthGroups: utils.CoalesceToString(user.Groups), OAuthName: service.Name(), OAuthSub: user.Sub, diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 40b6734..2f94713 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -10,6 +10,7 @@ import ( type OAuthServiceImpl interface { Name() string + ID() string NewRandom() string GetAuthURL(state string, verifier string) string GetToken(code string, verifier string) (*oauth2.Token, error) @@ -39,7 +40,7 @@ func (broker *OAuthBrokerService) Init() error { broker.services[name] = presetFunc(cfg) tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") } else { - broker.services[name] = NewOAuthService(cfg) + broker.services[name] = NewOAuthService(cfg, name) tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") } } diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go index 6c658dc..477ea0e 100644 --- a/internal/service/oauth_presets.go +++ b/internal/service/oauth_presets.go @@ -11,7 +11,7 @@ func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { config.AuthURL = endpoints.Google.AuthURL config.TokenURL = endpoints.Google.TokenURL config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" - return NewOAuthService(config) + return NewOAuthService(config, "google") } func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { @@ -19,5 +19,5 @@ func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { config.Scopes = scopes config.AuthURL = endpoints.GitHub.AuthURL config.TokenURL = endpoints.GitHub.TokenURL - return NewOAuthService(config).WithUserinfoExtractor(githubExtractor) + return NewOAuthService(config, "github").WithUserinfoExtractor(githubExtractor) } diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go index 76f5a92..1e6cd51 100644 --- a/internal/service/oauth_service.go +++ b/internal/service/oauth_service.go @@ -17,9 +17,10 @@ type OAuthService struct { config *oauth2.Config ctx context.Context userinfoExtractor UserinfoExtractor + id string } -func NewOAuthService(config config.OAuthServiceConfig) *OAuthService { +func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService { httpClient := &http.Client{ Timeout: 30 * time.Second, Transport: &http.Transport{ @@ -45,6 +46,7 @@ func NewOAuthService(config config.OAuthServiceConfig) *OAuthService { }, ctx: ctx, userinfoExtractor: defaultExtractor, + id: id, } } @@ -57,6 +59,10 @@ func (s *OAuthService) Name() string { return s.serviceCfg.Name } +func (s *OAuthService) ID() string { + return s.id +} + func (s *OAuthService) NewRandom() string { // The generate verifier function just creates a random string, // so we can use it to generate a random state as well