diff --git a/Dockerfile b/Dockerfile index eb0b3a0..1f4c1c0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,7 +38,7 @@ COPY --from=site-builder /site/dist ./internal/assets/dist RUN go build # Runner -FROM busybox:1.37-musl AS runner +FROM alpine:3.21 AS runner WORKDIR /tinyauth diff --git a/cmd/root.go b/cmd/root.go index 576fc82..3b3d61a 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -7,6 +7,7 @@ import ( "tinyauth/internal/api" "tinyauth/internal/auth" "tinyauth/internal/hooks" + "tinyauth/internal/providers" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -19,7 +20,7 @@ import ( var rootCmd = &cobra.Command{ Use: "tinyauth", Short: "An extremely simple traefik forward auth proxy.", - Long: `Tinyauth is an extremely simple traefik forward-auth login screen that makes securing your apps easy.`, + Long: `Tinyauth is an extremely simple traefik forward-auth login screen that makes securing your apps easy.`, Run: func(cmd *cobra.Command, args []string) { // Get config log.Info().Msg("Parsing config") @@ -58,20 +59,41 @@ var rootCmd = &cobra.Command{ users, parseErr := utils.ParseUsers(usersString) HandleError(parseErr, "Failed to parse users") + // Create OAuth config + oauthConfig := types.OAuthConfig{ + GithubClientId: config.GithubClientId, + GithubClientSecret: config.GithubClientSecret, + GoogleClientId: config.GoogleClientId, + GoogleClientSecret: config.GoogleClientSecret, + GenericClientId: config.GenericClientId, + GenericClientSecret: config.GenericClientSecret, + GenericScopes: config.GenericScopes, + GenericAuthURL: config.GenericAuthURL, + GenericTokenURL: config.GenericTokenURL, + GenericUserInfoURL: config.GenericUserInfoURL, + AppURL: config.AppURL, + } + // Create auth service auth := auth.NewAuth(users) - + + // Create OAuth providers service + providers := providers.NewProviders(oauthConfig) + + // Initialize providers + providers.Init() + // Create hooks service - hooks := hooks.NewHooks(auth) - + hooks := hooks.NewHooks(auth, providers) + // Create API api := api.NewAPI(types.APIConfig{ - Port: config.Port, - Address: config.Address, - Secret: config.Secret, - AppURL: config.AppURL, + Port: config.Port, + Address: config.Address, + Secret: config.Secret, + AppURL: config.AppURL, CookieSecure: config.CookieSecure, - }, hooks, auth) + }, hooks, auth, providers) // Setup routes api.Init() @@ -107,6 +129,16 @@ func init() { rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:bcrypt-hashed-password.") rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:bcrypt-hashed-password.") rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.") + rootCmd.Flags().String("github-client-id", "", "Github OAuth client ID.") + rootCmd.Flags().String("github-client-secret", "", "Github OAuth client secret.") + rootCmd.Flags().String("google-client-id", "", "Google OAuth client ID.") + rootCmd.Flags().String("google-client-secret", "", "Google OAuth client secret.") + rootCmd.Flags().String("generic-client-id", "", "Generic OAuth client ID.") + rootCmd.Flags().String("generic-client-secret", "", "Generic OAuth client secret.") + rootCmd.Flags().String("generic-scopes", "", "Generic OAuth scopes.") + rootCmd.Flags().String("generic-auth-url", "", "Generic OAuth auth URL.") + rootCmd.Flags().String("generic-token-url", "", "Generic OAuth token URL.") + rootCmd.Flags().String("generic-user-info-url", "", "Generic OAuth user info URL.") viper.BindEnv("port", "PORT") viper.BindEnv("address", "ADDRESS") viper.BindEnv("secret", "SECRET") @@ -114,5 +146,15 @@ func init() { viper.BindEnv("users", "USERS") viper.BindEnv("users-file", "USERS_FILE") viper.BindEnv("cookie-secure", "COOKIE_SECURE") + viper.BindEnv("github-client-id", "GITHUB_CLIENT_ID") + viper.BindEnv("github-client-secret", "GITHUB_CLIENT_SECRET") + viper.BindEnv("google-client-id", "GOOGLE_CLIENT_ID") + viper.BindEnv("google-client-secret", "GOOGLE_CLIENT_SECRET") + viper.BindEnv("generic-client-id", "GENERIC_CLIENT_ID") + viper.BindEnv("generic-client-secret", "GENERIC_CLIENT_SECRET") + viper.BindEnv("generic-scopes", "GENERIC_SCOPES") + viper.BindEnv("generic-auth-url", "GENERIC_AUTH_URL") + viper.BindEnv("generic-token-url", "GENERIC_TOKEN_URL") + viper.BindEnv("generic-user-info-url", "GENERIC_USER_INFO_URL") viper.BindPFlags(rootCmd.Flags()) } diff --git a/go.mod b/go.mod index e8fe7a7..7d68c51 100644 --- a/go.mod +++ b/go.mod @@ -72,6 +72,7 @@ require ( golang.org/x/arch v0.13.0 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/net v0.34.0 // indirect + golang.org/x/oauth2 v0.25.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.29.0 // indirect golang.org/x/text v0.21.0 // indirect diff --git a/go.sum b/go.sum index 0887ca6..71ce8ac 100644 --- a/go.sum +++ b/go.sum @@ -180,6 +180,8 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjs golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= +golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/api/api.go b/internal/api/api.go index 360aa3c..86f77c1 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -10,6 +10,7 @@ import ( "tinyauth/internal/assets" "tinyauth/internal/auth" "tinyauth/internal/hooks" + "tinyauth/internal/providers" "tinyauth/internal/types" "tinyauth/internal/utils" @@ -20,25 +21,27 @@ import ( "github.com/rs/zerolog/log" ) -func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) { +func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth, providers *providers.Providers) *API { return &API{ - Config: config, - Hooks: hooks, - Auth: auth, - Router: nil, + Config: config, + Hooks: hooks, + Auth: auth, + Providers: providers, } } type API struct { - Config types.APIConfig - Router *gin.Engine - Hooks *hooks.Hooks - Auth *auth.Auth + Config types.APIConfig + Router *gin.Engine + Hooks *hooks.Hooks + Auth *auth.Auth + Providers *providers.Providers + Domain string } func (api *API) Init() { gin.SetMode(gin.ReleaseMode) - + router := gin.New() router.Use(zerolog()) dist, distErr := fs.Sub(assets.Assets, "dist") @@ -67,17 +70,19 @@ func (api *API) Init() { } else { isSecure = false } - - store.Options(sessions.Options{ - Domain: fmt.Sprintf(".%s", domain), - Path: "/", - HttpOnly: true, - Secure: isSecure, - }) - - router.Use(sessions.Sessions("tinyauth", store)) - router.Use(func(c *gin.Context) { + api.Domain = fmt.Sprintf(".%s", domain) + + store.Options(sessions.Options{ + Domain: api.Domain, + Path: "/", + HttpOnly: true, + Secure: isSecure, + }) + + router.Use(sessions.Sessions("tinyauth", store)) + + router.Use(func(c *gin.Context) { if !strings.HasPrefix(c.Request.URL.Path, "/api") { _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) if os.IsNotExist(err) { @@ -92,12 +97,20 @@ func (api *API) Init() { } func (api *API) SetupRoutes() { - api.Router.GET("/api/auth", func (c *gin.Context) { - userContext := api.Hooks.UseUserContext(c) + api.Router.GET("/api/auth", func(c *gin.Context) { + userContext, userContextErr := api.Hooks.UseUserContext(c) + + if userContextErr != nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } if userContext.IsLoggedIn { c.JSON(200, gin.H{ - "status": 200, + "status": 200, "message": "Authenticated", }) return @@ -112,7 +125,7 @@ func (api *API) SetupRoutes() { if queryErr != nil { c.JSON(501, gin.H{ - "status": 501, + "status": 501, "message": "Internal Server Error", }) return @@ -121,24 +134,24 @@ func (api *API) SetupRoutes() { c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) }) - api.Router.POST("/api/login", func (c *gin.Context) { + api.Router.POST("/api/login", func(c *gin.Context) { var login types.LoginRequest err := c.BindJSON(&login) if err != nil { c.JSON(400, gin.H{ - "status": 400, + "status": 400, "message": "Bad Request", }) return } - user := api.Auth.GetUser(login.Username) + user := api.Auth.GetUser(login.Email) if user == nil { c.JSON(401, gin.H{ - "status": 401, + "status": 401, "message": "Unauthorized", }) return @@ -146,62 +159,188 @@ func (api *API) SetupRoutes() { if !api.Auth.CheckPassword(*user, login.Password) { c.JSON(401, gin.H{ - "status": 401, + "status": 401, "message": "Unauthorized", }) return } session := sessions.Default(c) - session.Set("tinyauth", user.Username) + session.Set("tinyauth_sid", fmt.Sprintf("email:%s", login.Email)) session.Save() c.JSON(200, gin.H{ - "status": 200, + "status": 200, "message": "Logged in", }) }) - api.Router.POST("/api/logout", func (c *gin.Context) { + api.Router.POST("/api/logout", func(c *gin.Context) { session := sessions.Default(c) - session.Delete("tinyauth") + session.Delete("tinyauth_sid") session.Save() + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + c.JSON(200, gin.H{ - "status": 200, + "status": 200, "message": "Logged out", }) }) - api.Router.GET("/api/status", func (c *gin.Context) { - userContext := api.Hooks.UseUserContext(c) + api.Router.GET("/api/status", func(c *gin.Context) { + userContext, userContextErr := api.Hooks.UseUserContext(c) + + if userContextErr != nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } if !userContext.IsLoggedIn { c.JSON(200, gin.H{ - "status": 200, - "message": "Unauthenticated", - "username": "", - "isLoggedIn": false, + "status": 200, + "message": "Unauthenticated", + "email": "", + "isLoggedIn": false, + "oauth": false, + "provider": "", + "configuredProviders": api.Providers.GetConfiguredProviders(), }) return - } + } c.JSON(200, gin.H{ - "status": 200, - "message": "Authenticated", - "username": userContext.Username, - "isLoggedIn": true, + "status": 200, + "message": "Authenticated", + "email": userContext.Email, + "isLoggedIn": userContext.IsLoggedIn, + "oauth": userContext.OAuth, + "provider": userContext.Provider, + "configuredProviders": api.Providers.GetConfiguredProviders(), }) }) - api.Router.GET("/api/healthcheck", func (c *gin.Context) { + api.Router.GET("/api/healthcheck", func(c *gin.Context) { c.JSON(200, gin.H{ - "status": 200, + "status": 200, "message": "OK", }) }) -} + api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) { + var request types.OAuthRequest + + bindErr := c.BindUri(&request) + + if bindErr != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + provider := api.Providers.GetProvider(request.Provider) + + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + authURL := provider.GetAuthURL() + + redirectURI := c.Query("redirect_uri") + + if redirectURI != "" { + c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true) + } + + c.JSON(200, gin.H{ + "status": 200, + "message": "Ok", + "url": authURL, + }) + }) + + api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) { + var providerName types.OAuthRequest + + bindErr := c.BindUri(&providerName) + + if bindErr != nil { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + code := c.Query("code") + + if code == "" { + c.JSON(400, gin.H{ + "status": 400, + "message": "Bad Request", + }) + return + } + + provider := api.Providers.GetProvider(providerName.Provider) + + if provider == nil { + c.JSON(404, gin.H{ + "status": 404, + "message": "Not Found", + }) + return + } + + token, tokenErr := provider.ExchangeToken(code) + + if tokenErr != nil { + c.JSON(500, gin.H{ + "status": 500, + "message": "Internal Server Error", + }) + return + } + + session := sessions.Default(c) + session.Set("tinyauth_sid", fmt.Sprintf("%s:%s", providerName.Provider, token)) + session.Save() + + redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri") + + if redirectURIErr != nil { + c.JSON(200, gin.H{ + "status": 200, + "message": "Logged in", + }) + } + + c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true) + + queries, queryErr := query.Values(types.LoginQuery{ + RedirectURI: redirectURI, + }) + + if queryErr != nil { + c.JSON(501, gin.H{ + "status": 501, + "message": "Internal Server Error", + }) + return + } + + c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, queries.Encode())) + }) +} func (api *API) Run() { log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") @@ -218,16 +357,16 @@ func zerolog() gin.HandlerFunc { address := c.Request.RemoteAddr method := c.Request.Method path := c.Request.URL.Path - + latency := time.Since(tStart).String() switch { - case code >= 200 && code < 300: - log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 300 && code < 400: - log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") - case code >= 400: - log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 200 && code < 300: + log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 300 && code < 400: + log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") + case code >= 400: + log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") } } -} \ No newline at end of file +} diff --git a/internal/auth/auth.go b/internal/auth/auth.go index a5f0cef..a82f737 100644 --- a/internal/auth/auth.go +++ b/internal/auth/auth.go @@ -16,9 +16,9 @@ type Auth struct { Users types.Users } -func (auth *Auth) GetUser(username string) *types.User { +func (auth *Auth) GetUser(email string) *types.User { for _, user := range auth.Users { - if user.Username == username { + if user.Email == email { return &user } } diff --git a/internal/hooks/hooks.go b/internal/hooks/hooks.go index d45d84f..bfb1ac3 100644 --- a/internal/hooks/hooks.go +++ b/internal/hooks/hooks.go @@ -1,54 +1,114 @@ package hooks import ( + "strings" "tinyauth/internal/auth" + "tinyauth/internal/providers" "tinyauth/internal/types" "github.com/gin-contrib/sessions" "github.com/gin-gonic/gin" + "golang.org/x/oauth2" ) -func NewHooks(auth *auth.Auth) *Hooks { +func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { return &Hooks{ - Auth: auth, + Auth: auth, + Providers: providers, } } type Hooks struct { - Auth *auth.Auth + Auth *auth.Auth + Providers *providers.Providers } -func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) { +func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) { session := sessions.Default(c) - cookie := session.Get("tinyauth") + sessionCookie := session.Get("tinyauth_sid") - if cookie == nil { + if sessionCookie == nil { return types.UserContext{ - Username: "", + Email: "", IsLoggedIn: false, - } + OAuth: false, + Provider: "", + }, nil } - username, ok := cookie.(string) + data, dataOk := sessionCookie.(string) - if !ok { + if !dataOk { return types.UserContext{ - Username: "", + Email: "", IsLoggedIn: false, - } + OAuth: false, + Provider: "", + }, nil } - user := hooks.Auth.GetUser(username) + split := strings.Split(data, ":") - if user == nil { + if len(split) != 2 { return types.UserContext{ - Username: "", + Email: "", IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + + sessionType := split[0] + sessionValue := split[1] + + if sessionType == "email" { + user := hooks.Auth.GetUser(sessionValue) + if user == nil { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil } + return types.UserContext{ + Email: sessionValue, + IsLoggedIn: true, + OAuth: false, + Provider: "", + }, nil + } + + provider := hooks.Providers.GetProvider(sessionType) + + if provider == nil { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil + } + + provider.Token = &oauth2.Token{ + AccessToken: sessionValue, + } + + email, emailErr := hooks.Providers.GetUser(sessionType) + + if emailErr != nil { + return types.UserContext{ + Email: "", + IsLoggedIn: false, + OAuth: false, + Provider: "", + }, nil } return types.UserContext{ - Username: username, + Email: email, IsLoggedIn: true, - } -} \ No newline at end of file + OAuth: true, + Provider: sessionType, + }, nil +} diff --git a/internal/oauth/oauth.go b/internal/oauth/oauth.go new file mode 100644 index 0000000..7d13d81 --- /dev/null +++ b/internal/oauth/oauth.go @@ -0,0 +1,45 @@ +package oauth + +import ( + "context" + "net/http" + + "github.com/rs/zerolog/log" + "golang.org/x/oauth2" +) + +func NewOAuth(config oauth2.Config) *OAuth { + return &OAuth{ + Config: config, + } +} + +type OAuth struct { + Config oauth2.Config + Context context.Context + Token *oauth2.Token + Verifier string +} + +func (oauth *OAuth) Init() { + oauth.Context = context.Background() + oauth.Verifier = oauth2.GenerateVerifier() +} + +func (oauth *OAuth) GetAuthURL() string { + return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier)) +} + +func (oauth *OAuth) ExchangeToken(code string) (string, error) { + token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier)) + if err != nil { + log.Error().Err(err).Msg("Failed to exchange code") + return "", err + } + oauth.Token = token + return oauth.Token.AccessToken, nil +} + +func (oauth *OAuth) GetClient() *http.Client { + return oauth.Config.Client(oauth.Context, oauth.Token) +} diff --git a/internal/providers/generic.go b/internal/providers/generic.go new file mode 100644 index 0000000..22b538c --- /dev/null +++ b/internal/providers/generic.go @@ -0,0 +1,35 @@ +package providers + +import ( + "encoding/json" + "io" + "net/http" +) + +type GenericUserInfoResponse struct { + Email string `json:"email"` +} + +func GetGenericEmail(client *http.Client, url string) (string, error) { + res, resErr := client.Get(url) + + if resErr != nil { + return "", resErr + } + + body, bodyErr := io.ReadAll(res.Body) + + if bodyErr != nil { + return "", bodyErr + } + + var user GenericUserInfoResponse + + jsonErr := json.Unmarshal(body, &user) + + if jsonErr != nil { + return "", jsonErr + } + + return user.Email, nil +} diff --git a/internal/providers/github.go b/internal/providers/github.go new file mode 100644 index 0000000..687be6b --- /dev/null +++ b/internal/providers/github.go @@ -0,0 +1,47 @@ +package providers + +import ( + "encoding/json" + "errors" + "io" + "net/http" +) + +type GithubUserInfoResponse []struct { + Email string `json:"email"` + Primary bool `json:"primary"` +} + +func GithubScopes() []string { + return []string{"user:email"} +} + +func GetGithubEmail(client *http.Client) (string, error) { + res, resErr := client.Get("https://api.github.com/user/emails") + + if resErr != nil { + return "", resErr + } + + body, bodyErr := io.ReadAll(res.Body) + + if bodyErr != nil { + return "", bodyErr + } + + var emails GithubUserInfoResponse + + jsonErr := json.Unmarshal(body, &emails) + + if jsonErr != nil { + return "", jsonErr + } + + for _, email := range emails { + if email.Primary { + return email.Email, nil + } + } + + return "", errors.New("no primary email found") +} diff --git a/internal/providers/google.go b/internal/providers/google.go new file mode 100644 index 0000000..49eaf6d --- /dev/null +++ b/internal/providers/google.go @@ -0,0 +1,39 @@ +package providers + +import ( + "encoding/json" + "io" + "net/http" +) + +type GoogleUserInfoResponse struct { + Email string `json:"email"` +} + +func GoogleScopes() []string { + return []string{"https://www.googleapis.com/auth/userinfo.email"} +} + +func GetGoogleEmail(client *http.Client) (string, error) { + res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me") + + if resErr != nil { + return "", resErr + } + + body, bodyErr := io.ReadAll(res.Body) + + if bodyErr != nil { + return "", bodyErr + } + + var user GoogleUserInfoResponse + + jsonErr := json.Unmarshal(body, &user) + + if jsonErr != nil { + return "", jsonErr + } + + return user.Email, nil +} diff --git a/internal/providers/providers.go b/internal/providers/providers.go new file mode 100644 index 0000000..dc6f354 --- /dev/null +++ b/internal/providers/providers.go @@ -0,0 +1,127 @@ +package providers + +import ( + "fmt" + "tinyauth/internal/oauth" + "tinyauth/internal/types" + + "github.com/rs/zerolog/log" + "golang.org/x/oauth2" + "golang.org/x/oauth2/endpoints" +) + +func NewProviders(config types.OAuthConfig) *Providers { + return &Providers{ + Config: config, + } +} + +type Providers struct { + Config types.OAuthConfig + Github *oauth.OAuth + Google *oauth.OAuth + Generic *oauth.OAuth +} + +func (providers *Providers) Init() { + if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { + log.Info().Msg("Initializing Github OAuth") + providers.Github = oauth.NewOAuth(oauth2.Config{ + ClientID: providers.Config.GithubClientId, + ClientSecret: providers.Config.GithubClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", providers.Config.AppURL), + Scopes: GithubScopes(), + Endpoint: endpoints.GitHub, + }) + providers.Github.Init() + } + if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { + log.Info().Msg("Initializing Google OAuth") + providers.Google = oauth.NewOAuth(oauth2.Config{ + ClientID: providers.Config.GoogleClientId, + ClientSecret: providers.Config.GoogleClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", providers.Config.AppURL), + Scopes: GoogleScopes(), + Endpoint: endpoints.Google, + }) + providers.Google.Init() + } + if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { + log.Info().Msg("Initializing Generic OAuth") + providers.Generic = oauth.NewOAuth(oauth2.Config{ + ClientID: providers.Config.GenericClientId, + ClientSecret: providers.Config.GenericClientSecret, + RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", providers.Config.AppURL), + Scopes: []string{providers.Config.GenericScopes}, + Endpoint: oauth2.Endpoint{ + AuthURL: providers.Config.GenericAuthURL, + TokenURL: providers.Config.GenericTokenURL, + }, + }) + providers.Generic.Init() + } +} + +func (providers *Providers) GetProvider(provider string) *oauth.OAuth { + switch provider { + case "github": + return providers.Github + case "google": + return providers.Google + case "generic": + return providers.Generic + default: + return nil + } +} + +func (providers *Providers) GetUser(provider string) (string, error) { + switch provider { + case "github": + if providers.Github == nil { + return "", nil + } + client := providers.Github.GetClient() + email, emailErr := GetGithubEmail(client) + if emailErr != nil { + return "", emailErr + } + return email, nil + case "google": + if providers.Google == nil { + return "", nil + } + client := providers.Google.GetClient() + email, emailErr := GetGoogleEmail(client) + if emailErr != nil { + return "", emailErr + } + return email, nil + case "generic": + if providers.Generic == nil { + return "", nil + } + client := providers.Generic.GetClient() + email, emailErr := GetGenericEmail(client, providers.Config.GenericUserInfoURL) + if emailErr != nil { + return "", emailErr + } + return email, nil + default: + return "", nil + } +} + +func (provider *Providers) GetConfiguredProviders() []string { + providers := []string{} + if provider.Github != nil { + providers = append(providers, "github") + } + if provider.Google != nil { + providers = append(providers, "google") + } + if provider.Generic != nil { + providers = append(providers, "generic") + } + return providers +} diff --git a/internal/types/types.go b/internal/types/types.go index e1e3034..f6a3adb 100644 --- a/internal/types/types.go +++ b/internal/types/types.go @@ -1,40 +1,78 @@ package types +import "tinyauth/internal/oauth" + type LoginQuery struct { RedirectURI string `url:"redirect_uri"` } type LoginRequest struct { - Username string `json:"username"` + Email string `json:"email"` Password string `json:"password"` } type User struct { - Username string + Email string Password string } type Users []User type Config struct { - Port int `validate:"number" mapstructure:"port"` - Address string `mapstructure:"address, ip4_addr"` - Secret string `validate:"required,len=32" mapstructure:"secret"` - AppURL string `validate:"required,url" mapstructure:"app-url"` - Users string `mapstructure:"users"` - UsersFile string `mapstructure:"users-file"` - CookieSecure bool `mapstructure:"cookie-secure"` + Port int `validate:"number" mapstructure:"port"` + Address string `mapstructure:"address, ip4_addr"` + Secret string `validate:"required,len=32" mapstructure:"secret"` + AppURL string `validate:"required,url" mapstructure:"app-url"` + Users string `mapstructure:"users"` + UsersFile string `mapstructure:"users-file"` + CookieSecure bool `mapstructure:"cookie-secure"` + GithubClientId string `mapstructure:"github-client-id"` + GithubClientSecret string `mapstructure:"github-client-secret"` + GoogleClientId string `mapstructure:"google-client-id"` + GoogleClientSecret string `mapstructure:"google-client-secret"` + GenericClientId string `mapstructure:"generic-client-id"` + GenericClientSecret string `mapstructure:"generic-client-secret"` + GenericScopes string `mapstructure:"generic-scopes"` + GenericAuthURL string `mapstructure:"generic-auth-url"` + GenericTokenURL string `mapstructure:"generic-token-url"` + GenericUserInfoURL string `mapstructure:"generic-user-info-url"` } type UserContext struct { - Username string + Email string IsLoggedIn bool + OAuth bool + Provider string } type APIConfig struct { - Port int - Address string - Secret string - AppURL string + Port int + Address string + Secret string + AppURL string CookieSecure bool -} \ No newline at end of file +} + +type OAuthConfig struct { + GithubClientId string + GithubClientSecret string + GoogleClientId string + GoogleClientSecret string + GenericClientId string + GenericClientSecret string + GenericScopes string + GenericAuthURL string + GenericTokenURL string + GenericUserInfoURL string + AppURL string +} + +type OAuthRequest struct { + Provider string `uri:"provider" binding:"required"` +} + +type OAuthProviders struct { + Github *oauth.OAuth + Google *oauth.OAuth + Microsoft *oauth.OAuth +} diff --git a/internal/utils/utils.go b/internal/utils/utils.go index e237301..239f3dc 100644 --- a/internal/utils/utils.go +++ b/internal/utils/utils.go @@ -22,7 +22,7 @@ func ParseUsers(users string) (types.Users, error) { return types.Users{}, errors.New("invalid user format") } usersParsed = append(usersParsed, types.User{ - Username: userSplit[0], + Email: userSplit[0], Password: userSplit[1], }) } diff --git a/site/src/icons/github.tsx b/site/src/icons/github.tsx new file mode 100644 index 0000000..1485f35 --- /dev/null +++ b/site/src/icons/github.tsx @@ -0,0 +1,18 @@ +import type { SVGProps } from "react"; + +export function GithubIcon(props: SVGProps) { + return ( + + + + ); +} diff --git a/site/src/icons/google.tsx b/site/src/icons/google.tsx new file mode 100644 index 0000000..1148569 --- /dev/null +++ b/site/src/icons/google.tsx @@ -0,0 +1,30 @@ +import type { SVGProps } from "react"; + +export function GoogleIcon(props: SVGProps) { + return ( + + + + + + + ); +} diff --git a/site/src/icons/oauth.tsx b/site/src/icons/oauth.tsx new file mode 100644 index 0000000..3ca531d --- /dev/null +++ b/site/src/icons/oauth.tsx @@ -0,0 +1,24 @@ +import type { SVGProps } from "react"; + +export function OAuthIcon(props: SVGProps) { + return ( + + + + + + + ); +} diff --git a/site/src/pages/login-page.tsx b/site/src/pages/login-page.tsx index d332f32..6000d4e 100644 --- a/site/src/pages/login-page.tsx +++ b/site/src/pages/login-page.tsx @@ -1,4 +1,13 @@ -import { Button, Paper, PasswordInput, TextInput, Title } from "@mantine/core"; +import { + Button, + Paper, + PasswordInput, + TextInput, + Title, + Text, + Divider, + Grid, +} from "@mantine/core"; import { useForm, zodResolver } from "@mantine/form"; import { notifications } from "@mantine/notifications"; import { useMutation } from "@tanstack/react-query"; @@ -7,20 +16,23 @@ import { z } from "zod"; import { useUserContext } from "../context/user-context"; import { Navigate } from "react-router"; import { Layout } from "../components/layouts/layout"; +import { GoogleIcon } from "../icons/google"; +import { GithubIcon } from "../icons/github"; +import { OAuthIcon } from "../icons/oauth"; export const LoginPage = () => { const queryString = window.location.search; const params = new URLSearchParams(queryString); const redirectUri = params.get("redirect_uri"); - const { isLoggedIn } = useUserContext(); + const { isLoggedIn, configuredProviders } = useUserContext(); if (isLoggedIn) { return ; } const schema = z.object({ - username: z.string(), + email: z.string().email(), password: z.string(), }); @@ -29,7 +41,7 @@ export const LoginPage = () => { const form = useForm({ mode: "uncontrolled", initialValues: { - username: "", + email: "", password: "", }, validate: zodResolver(schema), @@ -42,7 +54,7 @@ export const LoginPage = () => { onError: () => { notifications.show({ title: "Failed to login", - message: "Check your username and password", + message: "Check your email and password", color: "red", }); }, @@ -58,22 +70,89 @@ export const LoginPage = () => { }, }); + const loginOAuthMutation = useMutation({ + mutationFn: (provider: string) => { + return axios.get( + `/api/oauth/url/${provider}?redirect_uri=${redirectUri}`, + ); + }, + onError: () => { + notifications.show({ + title: "Internal error", + message: "Failed to get OAuth URL", + color: "red", + }); + }, + onSuccess: (data) => { + window.location.replace(data.data.url); + }, + }); + const handleSubmit = (values: FormValues) => { loginMutation.mutate(values); }; return ( - Welcome back! - + Tinyauth + + + Welcome back, login with + + + {configuredProviders.includes("google") && ( + + + + )} + {configuredProviders.includes("github") && ( + + + + )} + {configuredProviders.includes("generic") && ( + + + + )} + +
{ type="submit" loading={loginMutation.isLoading} > - Sign in + Login
diff --git a/site/src/pages/logout-page.tsx b/site/src/pages/logout-page.tsx index 52a5e99..80ecd4b 100644 --- a/site/src/pages/logout-page.tsx +++ b/site/src/pages/logout-page.tsx @@ -5,9 +5,10 @@ import axios from "axios"; import { useUserContext } from "../context/user-context"; import { Navigate } from "react-router"; import { Layout } from "../components/layouts/layout"; +import { capitalize } from "../utils/utils"; export const LogoutPage = () => { - const { isLoggedIn, username } = useUserContext(); + const { isLoggedIn, email, oauth, provider } = useUserContext(); if (!isLoggedIn) { return ; @@ -43,8 +44,9 @@ export const LogoutPage = () => { Logout - You are currently logged in as {username}, click the - button below to log out. + You are currently logged in as {email} + {oauth && ` using ${capitalize(provider)}`}. Click the button below to + log out.