diff --git a/internal/service/generic_oauth_service.go b/internal/service/generic_oauth_service.go deleted file mode 100644 index ef17b0e..0000000 --- a/internal/service/generic_oauth_service.go +++ /dev/null @@ -1,132 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "crypto/tls" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - "github.com/steveiliop56/tinyauth/internal/utils/tlog" - - "golang.org/x/oauth2" -) - -type GenericOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - insecureSkipVerify bool - userinfoUrl string - name string -} - -func NewGenericOAuthService(config config.OAuthServiceConfig) *GenericOAuthService { - return &GenericOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: config.Scopes, - Endpoint: oauth2.Endpoint{ - AuthURL: config.AuthURL, - TokenURL: config.TokenURL, - }, - }, - insecureSkipVerify: config.Insecure, - userinfoUrl: config.UserinfoURL, - name: config.Name, - } -} - -func (generic *GenericOAuthService) Init() error { - transport := &http.Transport{ - TLSClientConfig: &tls.Config{ - InsecureSkipVerify: generic.insecureSkipVerify, - MinVersion: tls.VersionTLS12, - }, - } - - httpClient := &http.Client{ - Transport: transport, - Timeout: 30 * time.Second, - } - - ctx := context.Background() - - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - - generic.context = ctx - return nil -} - -func (generic *GenericOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (generic *GenericOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - generic.verifier = verifier - return verifier -} - -func (generic *GenericOAuthService) GetAuthURL(state string) string { - return generic.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(generic.verifier)) -} - -func (generic *GenericOAuthService) VerifyCode(code string) error { - token, err := generic.config.Exchange(generic.context, code, oauth2.VerifierOption(generic.verifier)) - - if err != nil { - return err - } - - generic.token = token - return nil -} - -func (generic *GenericOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := generic.config.Client(generic.context, generic.token) - - res, err := client.Get(generic.userinfoUrl) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body") - - err = json.Unmarshal(body, &user) - if err != nil { - return user, err - } - - return user, nil -} - -func (generic *GenericOAuthService) GetName() string { - return generic.name -} diff --git a/internal/service/github_oauth_service.go b/internal/service/github_oauth_service.go deleted file mode 100644 index 35b552a..0000000 --- a/internal/service/github_oauth_service.go +++ /dev/null @@ -1,184 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "errors" - "fmt" - "io" - "net/http" - "strconv" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -var GithubOAuthScopes = []string{"user:email", "read:user"} - -type GithubEmailResponse []struct { - Email string `json:"email"` - Primary bool `json:"primary"` -} - -type GithubUserInfoResponse struct { - Login string `json:"login"` - Name string `json:"name"` - ID int `json:"id"` -} - -type GithubOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - name string -} - -func NewGithubOAuthService(config config.OAuthServiceConfig) *GithubOAuthService { - return &GithubOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: GithubOAuthScopes, - Endpoint: endpoints.GitHub, - }, - name: config.Name, - } -} - -func (github *GithubOAuthService) Init() error { - httpClient := &http.Client{ - Timeout: 30 * time.Second, - } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - github.context = ctx - return nil -} - -func (github *GithubOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (github *GithubOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - github.verifier = verifier - return verifier -} - -func (github *GithubOAuthService) GetAuthURL(state string) string { - return github.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(github.verifier)) -} - -func (github *GithubOAuthService) VerifyCode(code string) error { - token, err := github.config.Exchange(github.context, code, oauth2.VerifierOption(github.verifier)) - - if err != nil { - return err - } - - github.token = token - return nil -} - -func (github *GithubOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := github.config.Client(github.context, github.token) - - req, err := http.NewRequest("GET", "https://api.github.com/user", nil) - if err != nil { - return user, err - } - - req.Header.Set("Accept", "application/vnd.github+json") - - res, err := client.Do(req) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return user, err - } - - var userInfo GithubUserInfoResponse - - err = json.Unmarshal(body, &userInfo) - if err != nil { - return user, err - } - - req, err = http.NewRequest("GET", "https://api.github.com/user/emails", nil) - if err != nil { - return user, err - } - - req.Header.Set("Accept", "application/vnd.github+json") - - res, err = client.Do(req) - if err != nil { - return user, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err = io.ReadAll(res.Body) - if err != nil { - return user, err - } - - var emails GithubEmailResponse - - err = json.Unmarshal(body, &emails) - if err != nil { - return user, err - } - - for _, email := range emails { - if email.Primary { - user.Email = email.Email - break - } - } - - if len(emails) == 0 { - return user, errors.New("no emails found") - } - - // Use first available email if no primary email was found - if user.Email == "" { - user.Email = emails[0].Email - } - - user.PreferredUsername = userInfo.Login - user.Name = userInfo.Name - user.Sub = strconv.Itoa(userInfo.ID) - - return user, nil -} - -func (github *GithubOAuthService) GetName() string { - return github.name -} diff --git a/internal/service/google_oauth_service.go b/internal/service/google_oauth_service.go deleted file mode 100644 index 6dfbeaf..0000000 --- a/internal/service/google_oauth_service.go +++ /dev/null @@ -1,116 +0,0 @@ -package service - -import ( - "context" - "crypto/rand" - "encoding/base64" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "time" - - "github.com/steveiliop56/tinyauth/internal/config" - - "golang.org/x/oauth2" - "golang.org/x/oauth2/endpoints" -) - -var GoogleOAuthScopes = []string{"openid", "email", "profile"} - -type GoogleOAuthService struct { - config oauth2.Config - context context.Context - token *oauth2.Token - verifier string - name string -} - -func NewGoogleOAuthService(config config.OAuthServiceConfig) *GoogleOAuthService { - return &GoogleOAuthService{ - config: oauth2.Config{ - ClientID: config.ClientID, - ClientSecret: config.ClientSecret, - RedirectURL: config.RedirectURL, - Scopes: GoogleOAuthScopes, - Endpoint: endpoints.Google, - }, - name: config.Name, - } -} - -func (google *GoogleOAuthService) Init() error { - httpClient := &http.Client{ - Timeout: 30 * time.Second, - } - ctx := context.Background() - ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) - google.context = ctx - return nil -} - -func (oauth *GoogleOAuthService) GenerateState() string { - b := make([]byte, 128) - _, err := rand.Read(b) - if err != nil { - return base64.RawURLEncoding.EncodeToString(fmt.Appendf(nil, "state-%d", time.Now().UnixNano())) - } - state := base64.RawURLEncoding.EncodeToString(b) - return state -} - -func (google *GoogleOAuthService) GenerateVerifier() string { - verifier := oauth2.GenerateVerifier() - google.verifier = verifier - return verifier -} - -func (google *GoogleOAuthService) GetAuthURL(state string) string { - return google.config.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(google.verifier)) -} - -func (google *GoogleOAuthService) VerifyCode(code string) error { - token, err := google.config.Exchange(google.context, code, oauth2.VerifierOption(google.verifier)) - - if err != nil { - return err - } - - google.token = token - return nil -} - -func (google *GoogleOAuthService) Userinfo() (config.Claims, error) { - var user config.Claims - - client := google.config.Client(google.context, google.token) - - res, err := client.Get("https://openidconnect.googleapis.com/v1/userinfo") - if err != nil { - return config.Claims{}, err - } - defer res.Body.Close() - - if res.StatusCode < 200 || res.StatusCode >= 300 { - return user, fmt.Errorf("request failed with status: %s", res.Status) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return config.Claims{}, err - } - - err = json.Unmarshal(body, &user) - if err != nil { - return config.Claims{}, err - } - - user.PreferredUsername = strings.SplitN(user.Email, "@", 2)[0] - - return user, nil -} - -func (google *GoogleOAuthService) GetName() string { - return google.name -} diff --git a/internal/service/oauth_broker_service.go b/internal/service/oauth_broker_service.go index 76c23e9..40b6734 100644 --- a/internal/service/oauth_broker_service.go +++ b/internal/service/oauth_broker_service.go @@ -1,60 +1,48 @@ package service import ( - "errors" - "github.com/steveiliop56/tinyauth/internal/config" "github.com/steveiliop56/tinyauth/internal/utils/tlog" "golang.org/x/exp/slices" + "golang.org/x/oauth2" ) -type OAuthService interface { - Init() error - GenerateState() string - GenerateVerifier() string - GetAuthURL(state string) string - VerifyCode(code string) error - Userinfo() (config.Claims, error) - GetName() string +type OAuthServiceImpl interface { + Name() string + NewRandom() string + GetAuthURL(state string, verifier string) string + GetToken(code string, verifier string) (*oauth2.Token, error) + GetUserinfo(token *oauth2.Token) (config.Claims, error) } type OAuthBrokerService struct { - services map[string]OAuthService + services map[string]OAuthServiceImpl configs map[string]config.OAuthServiceConfig } +var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{ + "github": newGitHubOAuthService, + "google": newGoogleOAuthService, +} + func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService { return &OAuthBrokerService{ - services: make(map[string]OAuthService), + services: make(map[string]OAuthServiceImpl), configs: configs, } } func (broker *OAuthBrokerService) Init() error { for name, cfg := range broker.configs { - switch name { - case "github": - service := NewGithubOAuthService(cfg) - broker.services[name] = service - case "google": - service := NewGoogleOAuthService(cfg) - broker.services[name] = service - default: - service := NewGenericOAuthService(cfg) - broker.services[name] = service + if presetFunc, exists := presets[name]; exists { + broker.services[name] = presetFunc(cfg) + tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from preset") + } else { + broker.services[name] = NewOAuthService(cfg) + tlog.App.Debug().Str("service", name).Msg("Loaded OAuth service from config") } } - - for name, service := range broker.services { - err := service.Init() - if err != nil { - tlog.App.Error().Err(err).Msgf("Failed to initialize OAuth service: %s", name) - return err - } - tlog.App.Info().Str("service", name).Msg("Initialized OAuth service") - } - return nil } @@ -67,15 +55,7 @@ func (broker *OAuthBrokerService) GetConfiguredServices() []string { return services } -func (broker *OAuthBrokerService) GetService(name string) (OAuthService, bool) { +func (broker *OAuthBrokerService) GetService(name string) (OAuthServiceImpl, bool) { service, exists := broker.services[name] return service, exists } - -func (broker *OAuthBrokerService) GetUser(service string) (config.Claims, error) { - oauthService, exists := broker.services[service] - if !exists { - return config.Claims{}, errors.New("oauth service not found") - } - return oauthService.Userinfo() -} diff --git a/internal/service/oauth_extractors.go b/internal/service/oauth_extractors.go new file mode 100644 index 0000000..2bf4ab8 --- /dev/null +++ b/internal/service/oauth_extractors.go @@ -0,0 +1,121 @@ +package service + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strconv" + + "github.com/steveiliop56/tinyauth/internal/config" + "github.com/steveiliop56/tinyauth/internal/utils/tlog" +) + +type GithubEmailResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +type GithubUserInfoResponse struct { + Login string `json:"login"` + Name string `json:"name"` + ID int `json:"id"` +} + +func defaultExtractor(client *http.Client, url string) (config.Claims, error) { + var claims config.Claims + + res, err := client.Get(url) + if err != nil { + return config.Claims{}, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return config.Claims{}, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return config.Claims{}, err + } + + tlog.App.Trace().Str("body", string(body)).Msg("Userinfo response body") + + err = json.Unmarshal(body, &claims) + if err != nil { + return config.Claims{}, err + } + + return claims, nil +} + +func githubExtractor(client *http.Client, url string) (config.Claims, error) { + var user config.Claims + + userInfo, err := githubRequest[GithubUserInfoResponse](client, "https://api.github.com/user") + if err != nil { + return config.Claims{}, err + } + + userEmails, err := githubRequest[GithubEmailResponse](client, "https://api.github.com/user/emails") + if err != nil { + return config.Claims{}, err + } + + if len(userEmails) == 0 { + return user, errors.New("no emails found") + } + + for _, email := range userEmails { + if email.Primary { + user.Email = email.Email + break + } + } + + // Use first available email if no primary email was found + if user.Email == "" { + user.Email = userEmails[0].Email + } + + user.PreferredUsername = userInfo.Login + user.Name = userInfo.Name + user.Sub = strconv.Itoa(userInfo.ID) + + return user, nil +} + +func githubRequest[T any](client *http.Client, url string) (T, error) { + var githubRes T + + req, err := http.NewRequest("GET", "https://api.github.com/user", nil) + if err != nil { + return githubRes, err + } + + req.Header.Set("Accept", "application/vnd.github+json") + + res, err := client.Do(req) + if err != nil { + return githubRes, err + } + defer res.Body.Close() + + if res.StatusCode < 200 || res.StatusCode >= 300 { + return githubRes, fmt.Errorf("request failed with status: %s", res.Status) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return githubRes, err + } + + err = json.Unmarshal(body, &githubRes) + if err != nil { + return githubRes, err + } + + return githubRes, nil +} diff --git a/internal/service/oauth_presets.go b/internal/service/oauth_presets.go new file mode 100644 index 0000000..6c658dc --- /dev/null +++ b/internal/service/oauth_presets.go @@ -0,0 +1,23 @@ +package service + +import ( + "github.com/steveiliop56/tinyauth/internal/config" + "golang.org/x/oauth2/endpoints" +) + +func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService { + scopes := []string{"openid", "email", "profile"} + config.Scopes = scopes + config.AuthURL = endpoints.Google.AuthURL + config.TokenURL = endpoints.Google.TokenURL + config.UserinfoURL = "https://openidconnect.googleapis.com/v1/userinfo" + return NewOAuthService(config) +} + +func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService { + scopes := []string{"read:user", "user:email"} + config.Scopes = scopes + config.AuthURL = endpoints.GitHub.AuthURL + config.TokenURL = endpoints.GitHub.TokenURL + return NewOAuthService(config).WithUserinfoExtractor(githubExtractor) +} diff --git a/internal/service/oauth_service.go b/internal/service/oauth_service.go new file mode 100644 index 0000000..76f5a92 --- /dev/null +++ b/internal/service/oauth_service.go @@ -0,0 +1,78 @@ +package service + +import ( + "context" + "crypto/tls" + "net/http" + "time" + + "github.com/steveiliop56/tinyauth/internal/config" + "golang.org/x/oauth2" +) + +type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error) + +type OAuthService struct { + serviceCfg config.OAuthServiceConfig + config *oauth2.Config + ctx context.Context + userinfoExtractor UserinfoExtractor +} + +func NewOAuthService(config config.OAuthServiceConfig) *OAuthService { + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + InsecureSkipVerify: config.Insecure, + }, + }, + } + ctx := context.Background() + ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient) + + return &OAuthService{ + serviceCfg: config, + config: &oauth2.Config{ + ClientID: config.ClientID, + ClientSecret: config.ClientSecret, + RedirectURL: config.RedirectURL, + Scopes: config.Scopes, + Endpoint: oauth2.Endpoint{ + AuthURL: config.AuthURL, + TokenURL: config.TokenURL, + }, + }, + ctx: ctx, + userinfoExtractor: defaultExtractor, + } +} + +func (s *OAuthService) WithUserinfoExtractor(extractor UserinfoExtractor) *OAuthService { + s.userinfoExtractor = extractor + return s +} + +func (s *OAuthService) Name() string { + return s.serviceCfg.Name +} + +func (s *OAuthService) NewRandom() string { + // The generate verifier function just creates a random string, + // so we can use it to generate a random state as well + random := oauth2.GenerateVerifier() + return random +} + +func (s *OAuthService) GetAuthURL(state string, verifier string) string { + return s.config.AuthCodeURL(state, oauth2.AccessTypeOnline, oauth2.S256ChallengeOption(verifier)) +} + +func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, error) { + return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) +} + +func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) { + client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) + return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) +}