This commit is contained in:
Stavros
2025-01-23 19:16:35 +02:00
parent 143b13af2c
commit 80d25551e0
16 changed files with 491 additions and 115 deletions

View File

@@ -38,7 +38,7 @@ COPY --from=site-builder /site/dist ./internal/assets/dist
RUN go build RUN go build
# Runner # Runner
FROM busybox:1.37-musl AS runner FROM alpine:3.21 AS runner
WORKDIR /tinyauth WORKDIR /tinyauth

View File

@@ -7,6 +7,7 @@ import (
"tinyauth/internal/api" "tinyauth/internal/api"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
@@ -19,7 +20,7 @@ import (
var rootCmd = &cobra.Command{ var rootCmd = &cobra.Command{
Use: "tinyauth", Use: "tinyauth",
Short: "An extremely simple traefik forward auth proxy.", Short: "An extremely simple traefik forward auth proxy.",
Long: `Tinyauth is an extremely simple traefik forward-auth login screen that makes securing your apps easy.`, Long: `Tinyauth is an extremely simple traefik forward-auth login screen that makes securing your apps easy.`,
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
// Get config // Get config
log.Info().Msg("Parsing config") log.Info().Msg("Parsing config")
@@ -58,20 +59,36 @@ var rootCmd = &cobra.Command{
users, parseErr := utils.ParseUsers(usersString) users, parseErr := utils.ParseUsers(usersString)
HandleError(parseErr, "Failed to parse users") HandleError(parseErr, "Failed to parse users")
// Create OAuth config
oauthConfig := types.OAuthConfig{
GithubClientId: config.GithubClientId,
GithubClientSecret: config.GithubClientSecret,
GoogleClientId: config.GoogleClientId,
GoogleClientSecret: config.GoogleClientSecret,
MicrosoftClientId: config.MicrosoftClientId,
MicrosoftClientSecret: config.MicrosoftClientSecret,
}
// Create auth service // Create auth service
auth := auth.NewAuth(users) auth := auth.NewAuth(users)
// Create OAuth providers service
providers := providers.NewProviders(oauthConfig)
// Initialize providers
providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(auth) hooks := hooks.NewHooks(auth, providers)
// Create API // Create API
api := api.NewAPI(types.APIConfig{ api := api.NewAPI(types.APIConfig{
Port: config.Port, Port: config.Port,
Address: config.Address, Address: config.Address,
Secret: config.Secret, Secret: config.Secret,
AppURL: config.AppURL, AppURL: config.AppURL,
CookieSecure: config.CookieSecure, CookieSecure: config.CookieSecure,
}, hooks, auth) }, hooks, auth, providers)
// Setup routes // Setup routes
api.Init() api.Init()
@@ -107,6 +124,12 @@ func init() {
rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:bcrypt-hashed-password.") rootCmd.Flags().String("users", "", "Comma separated list of users in the format username:bcrypt-hashed-password.")
rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:bcrypt-hashed-password.") rootCmd.Flags().String("users-file", "", "Path to a file containing users in the format username:bcrypt-hashed-password.")
rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.") rootCmd.Flags().Bool("cookie-secure", false, "Send cookie over secure connection only.")
rootCmd.Flags().String("github-client-id", "", "Github OAuth client ID.")
rootCmd.Flags().String("github-client-secret", "", "Github OAuth client secret.")
rootCmd.Flags().String("google-client-id", "", "Google OAuth client ID.")
rootCmd.Flags().String("google-client-secret", "", "Google OAuth client secret.")
rootCmd.Flags().String("microsoft-client-id", "", "Microsoft OAuth client ID.")
rootCmd.Flags().String("microsoft-client-secret", "", "Microsoft OAuth client secret.")
viper.BindEnv("port", "PORT") viper.BindEnv("port", "PORT")
viper.BindEnv("address", "ADDRESS") viper.BindEnv("address", "ADDRESS")
viper.BindEnv("secret", "SECRET") viper.BindEnv("secret", "SECRET")
@@ -114,5 +137,11 @@ func init() {
viper.BindEnv("users", "USERS") viper.BindEnv("users", "USERS")
viper.BindEnv("users-file", "USERS_FILE") viper.BindEnv("users-file", "USERS_FILE")
viper.BindEnv("cookie-secure", "COOKIE_SECURE") viper.BindEnv("cookie-secure", "COOKIE_SECURE")
viper.BindEnv("github-client-id", "GITHUB_CLIENT_ID")
viper.BindEnv("github-client-secret", "GITHUB_CLIENT_SECRET")
viper.BindEnv("google-client-id", "GOOGLE_CLIENT_ID")
viper.BindEnv("google-client-secret", "GOOGLE_CLIENT_SECRET")
viper.BindEnv("microsoft-client-id", "MICROSOFT_CLIENT_ID")
viper.BindEnv("microsoft-client-secret", "MICROSOFT_CLIENT_SECRET")
viper.BindPFlags(rootCmd.Flags()) viper.BindPFlags(rootCmd.Flags())
} }

1
go.mod
View File

@@ -72,6 +72,7 @@ require (
golang.org/x/arch v0.13.0 // indirect golang.org/x/arch v0.13.0 // indirect
golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect
golang.org/x/net v0.34.0 // indirect golang.org/x/net v0.34.0 // indirect
golang.org/x/oauth2 v0.25.0 // indirect
golang.org/x/sync v0.10.0 // indirect golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect golang.org/x/text v0.21.0 // indirect

2
go.sum
View File

@@ -180,6 +180,8 @@ golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjs
golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k=
golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0=
golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k=
golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70=
golang.org/x/oauth2 v0.25.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI=
golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ=
golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk=
golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210809222454-d867a43fc93e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=

View File

@@ -10,6 +10,7 @@ import (
"tinyauth/internal/assets" "tinyauth/internal/assets"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
@@ -20,25 +21,26 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth) (*API) { func NewAPI(config types.APIConfig, hooks *hooks.Hooks, auth *auth.Auth, providers *providers.Providers) *API {
return &API{ return &API{
Config: config, Config: config,
Hooks: hooks, Hooks: hooks,
Auth: auth, Auth: auth,
Router: nil, Providers: providers,
} }
} }
type API struct { type API struct {
Config types.APIConfig Config types.APIConfig
Router *gin.Engine Router *gin.Engine
Hooks *hooks.Hooks Hooks *hooks.Hooks
Auth *auth.Auth Auth *auth.Auth
Providers *providers.Providers
} }
func (api *API) Init() { func (api *API) Init() {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
router := gin.New() router := gin.New()
router.Use(zerolog()) router.Use(zerolog())
dist, distErr := fs.Sub(assets.Assets, "dist") dist, distErr := fs.Sub(assets.Assets, "dist")
@@ -67,17 +69,17 @@ func (api *API) Init() {
} else { } else {
isSecure = false isSecure = false
} }
store.Options(sessions.Options{
Domain: fmt.Sprintf(".%s", domain),
Path: "/",
HttpOnly: true,
Secure: isSecure,
})
router.Use(sessions.Sessions("tinyauth", store))
router.Use(func(c *gin.Context) { store.Options(sessions.Options{
Domain: fmt.Sprintf(".%s", domain),
Path: "/",
HttpOnly: true,
Secure: isSecure,
})
router.Use(sessions.Sessions("tinyauth", store))
router.Use(func(c *gin.Context) {
if !strings.HasPrefix(c.Request.URL.Path, "/api") { if !strings.HasPrefix(c.Request.URL.Path, "/api") {
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/")) _, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
if os.IsNotExist(err) { if os.IsNotExist(err) {
@@ -92,12 +94,20 @@ func (api *API) Init() {
} }
func (api *API) SetupRoutes() { func (api *API) SetupRoutes() {
api.Router.GET("/api/auth", func (c *gin.Context) { api.Router.GET("/api/auth", func(c *gin.Context) {
userContext := api.Hooks.UseUserContext(c) userContext, userContextErr := api.Hooks.UseUserContext(c)
if userContextErr != nil {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if userContext.IsLoggedIn { if userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
}) })
return return
@@ -112,7 +122,7 @@ func (api *API) SetupRoutes() {
if queryErr != nil { if queryErr != nil {
c.JSON(501, gin.H{ c.JSON(501, gin.H{
"status": 501, "status": 501,
"message": "Internal Server Error", "message": "Internal Server Error",
}) })
return return
@@ -121,24 +131,24 @@ func (api *API) SetupRoutes() {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
}) })
api.Router.POST("/api/login", func (c *gin.Context) { api.Router.POST("/api/login", func(c *gin.Context) {
var login types.LoginRequest var login types.LoginRequest
err := c.BindJSON(&login) err := c.BindJSON(&login)
if err != nil { if err != nil {
c.JSON(400, gin.H{ c.JSON(400, gin.H{
"status": 400, "status": 400,
"message": "Bad Request", "message": "Bad Request",
}) })
return return
} }
user := api.Auth.GetUser(login.Username) user := api.Auth.GetUser(login.Email)
if user == nil { if user == nil {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
@@ -146,62 +156,149 @@ func (api *API) SetupRoutes() {
if !api.Auth.CheckPassword(*user, login.Password) { if !api.Auth.CheckPassword(*user, login.Password) {
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
}) })
return return
} }
session := sessions.Default(c) session := sessions.Default(c)
session.Set("tinyauth", user.Username) session.Set("tinyauth_sid", user.Email)
session.Set("tinyauth_oauth_provider", "")
session.Save() session.Save()
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged in", "message": "Logged in",
}) })
}) })
api.Router.POST("/api/logout", func (c *gin.Context) { api.Router.POST("/api/logout", func(c *gin.Context) {
session := sessions.Default(c) session := sessions.Default(c)
session.Delete("tinyauth") session.Delete("tinyauth_sid")
session.Delete("tinyauth_oauth_provider")
session.Save() session.Save()
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Logged out", "message": "Logged out",
}) })
}) })
api.Router.GET("/api/status", func (c *gin.Context) { api.Router.GET("/api/status", func(c *gin.Context) {
userContext := api.Hooks.UseUserContext(c) userContext, userContextErr := api.Hooks.UseUserContext(c)
if userContextErr != nil {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if !userContext.IsLoggedIn { if !userContext.IsLoggedIn {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Unauthenticated", "message": "Unauthenticated",
"username": "", "email": "",
"isLoggedIn": false, "isLoggedIn": false,
"oauth": false,
"provider": "",
}) })
return return
} }
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
"username": userContext.Username, "email": userContext.Email,
"isLoggedIn": true, "isLoggedIn": userContext.IsLoggedIn,
"oauth": userContext.OAuth,
"provider": userContext.Provider,
}) })
}) })
api.Router.GET("/api/healthcheck", func (c *gin.Context) { api.Router.GET("/api/healthcheck", func(c *gin.Context) {
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "OK", "message": "OK",
}) })
}) })
}
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
var provider types.OAuthBind
bindErr := c.BindUri(&provider)
if bindErr != nil {
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
authURL := api.Providers.GetAuthURL(provider.Provider)
if authURL == "" {
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
c.JSON(200, gin.H{
"status": 200,
"message": "Ok",
"url": authURL,
})
})
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
var provider types.OAuthBind
bindErr := c.BindUri(&provider)
if bindErr != nil {
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
code := c.Query("code")
if code == "" {
c.JSON(400, gin.H{
"status": 400,
"message": "Bad Request",
})
return
}
email, emailErr := api.Providers.Login(code, provider.Provider)
if emailErr != nil {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
session := sessions.Default(c)
session.Set("tinyauth_sid", email)
session.Set("tinyauth_oauth_provider", provider.Provider)
session.Save()
c.JSON(200, gin.H{
"status": 200,
"message": "Logged in",
})
})
}
func (api *API) Run() { func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server") log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
@@ -218,16 +315,16 @@ func zerolog() gin.HandlerFunc {
address := c.Request.RemoteAddr address := c.Request.RemoteAddr
method := c.Request.Method method := c.Request.Method
path := c.Request.URL.Path path := c.Request.URL.Path
latency := time.Since(tStart).String() latency := time.Since(tStart).String()
switch { switch {
case code >= 200 && code < 300: case code >= 200 && code < 300:
log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 300 && code < 400: case code >= 300 && code < 400:
log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") log.Warn().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
case code >= 400: case code >= 400:
log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request") log.Error().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
} }
} }
} }

View File

@@ -16,9 +16,9 @@ type Auth struct {
Users types.Users Users types.Users
} }
func (auth *Auth) GetUser(username string) *types.User { func (auth *Auth) GetUser(email string) *types.User {
for _, user := range auth.Users { for _, user := range auth.Users {
if user.Username == username { if user.Email == email {
return &user return &user
} }
} }

View File

@@ -2,53 +2,83 @@ package hooks
import ( import (
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/gin-contrib/sessions" "github.com/gin-contrib/sessions"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
func NewHooks(auth *auth.Auth) *Hooks { func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{ return &Hooks{
Auth: auth, Auth: auth,
Providers: providers,
} }
} }
type Hooks struct { type Hooks struct {
Auth *auth.Auth Auth *auth.Auth
Providers *providers.Providers
} }
func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext) { func (hooks *Hooks) UseUserContext(c *gin.Context) (types.UserContext, error) {
session := sessions.Default(c) session := sessions.Default(c)
cookie := session.Get("tinyauth") sessionCookie := session.Get("tinyauth_sid")
oauthProviderCookie := session.Get("tinyauth_oauth_provider")
if cookie == nil { if sessionCookie == nil {
return types.UserContext{ return types.UserContext{
Username: "", Email: "",
IsLoggedIn: false, IsLoggedIn: false,
} OAuth: false,
Provider: "",
}, nil
} }
username, ok := cookie.(string) email, emailOk := sessionCookie.(string)
provider, providerOk := oauthProviderCookie.(string)
if !ok { if provider == "" || !providerOk {
return types.UserContext{ if !emailOk {
Username: "", return types.UserContext{
IsLoggedIn: false, Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
} }
user := hooks.Auth.GetUser(email)
if user == nil {
return types.UserContext{
Email: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
}, nil
}
return types.UserContext{
Email: email,
IsLoggedIn: true,
OAuth: false,
Provider: "",
}, nil
} }
user := hooks.Auth.GetUser(username) oauthEmail, oauthEmailErr := hooks.Providers.GetUser(provider)
if user == nil { if oauthEmailErr != nil {
return types.UserContext{ return types.UserContext{
Username: "", Email: "",
IsLoggedIn: false, IsLoggedIn: false,
} OAuth: false,
Provider: "",
}, nil
} }
return types.UserContext{ return types.UserContext{
Username: username, Email: oauthEmail,
IsLoggedIn: true, IsLoggedIn: true,
} OAuth: true,
} Provider: provider,
}, nil
}

45
internal/oauth/oauth.go Normal file
View File

@@ -0,0 +1,45 @@
package oauth
import (
"context"
"net/http"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
)
func NewOAuth(config oauth2.Config) *OAuth {
return &OAuth{
Config: config,
}
}
type OAuth struct {
Config oauth2.Config
Context context.Context
Token *oauth2.Token
Verifier string
}
func (oauth *OAuth) Init() {
oauth.Context = context.Background()
oauth.Verifier = oauth2.GenerateVerifier()
}
func (oauth *OAuth) GetAuthURL() string {
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
}
func (oauth *OAuth) ExchangeToken(code string) error {
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
if err != nil {
log.Error().Err(err).Msg("Failed to exchange code")
return err
}
oauth.Token = token
return nil
}
func (oauth *OAuth) GetClient() *http.Client {
return oauth.Config.Client(oauth.Context, oauth.Token)
}

View File

@@ -0,0 +1,47 @@
package providers
import (
"encoding/json"
"errors"
"io"
"net/http"
)
type GithubEmailsResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
func GithubScopes() ([]string) {
return []string{"user:email"}
}
func GetGithubEmail(client *http.Client) (string, error) {
res, resErr := client.Get("https://api.github.com/user/emails")
if resErr != nil {
return "", resErr
}
body, bodyErr := io.ReadAll(res.Body)
if bodyErr != nil {
return "", bodyErr
}
var emails GithubEmailsResponse
jsonErr := json.Unmarshal(body, &emails)
if jsonErr != nil {
return "", jsonErr
}
for _, email := range emails {
if email.Primary {
return email.Email, nil
}
}
return "", errors.New("no primary email found")
}

View File

@@ -0,0 +1,86 @@
package providers
import (
"tinyauth/internal/oauth"
"tinyauth/internal/types"
"github.com/rs/zerolog/log"
"golang.org/x/oauth2"
"golang.org/x/oauth2/endpoints"
)
func NewProviders(config types.OAuthConfig) *Providers {
return &Providers{
Config: config,
}
}
type Providers struct {
Config types.OAuthConfig
Github *oauth.OAuth
Google *oauth.OAuth
Microsoft *oauth.OAuth
}
func (providers *Providers) Init() {
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth")
providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GithubClientId,
ClientSecret: providers.Config.GithubClientSecret,
Scopes: GithubScopes(),
Endpoint: endpoints.GitHub,
})
providers.Github.Init()
}
}
func (providers *Providers) Login(code string, provider string) (string, error) {
switch provider {
case "github":
if providers.Github == nil {
return "", nil
}
exchangeErr := providers.Github.ExchangeToken(code)
if exchangeErr != nil {
return "", exchangeErr
}
client := providers.Github.GetClient()
email, emailErr := GetGithubEmail(client)
if emailErr != nil {
return "", emailErr
}
return email, nil
default:
return "", nil
}
}
func (providers *Providers) GetUser(provider string) (string, error) {
switch provider {
case "github":
if providers.Github == nil {
return "", nil
}
client := providers.Github.GetClient()
email, emailErr := GetGithubEmail(client)
if emailErr != nil {
return "", emailErr
}
return email, nil
default:
return "", nil
}
}
func (providers *Providers) GetAuthURL(provider string) string {
switch provider {
case "github":
if providers.Github == nil {
return ""
}
return providers.Github.GetAuthURL()
default:
return ""
}
}

View File

@@ -1,40 +1,74 @@
package types package types
import "tinyauth/internal/oauth"
type LoginQuery struct { type LoginQuery struct {
RedirectURI string `url:"redirect_uri"` RedirectURI string `url:"redirect_uri"`
} }
type LoginRequest struct { type LoginRequest struct {
Username string `json:"username"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`
} }
type User struct { type User struct {
Username string Email string
Password string Password string
} }
type Users []User type Users []User
type Config struct { type Config struct {
Port int `validate:"number" mapstructure:"port"` Port int `validate:"number" mapstructure:"port"`
Address string `mapstructure:"address, ip4_addr"` Address string `mapstructure:"address, ip4_addr"`
Secret string `validate:"required,len=32" mapstructure:"secret"` Secret string `validate:"required,len=32" mapstructure:"secret"`
AppURL string `validate:"required,url" mapstructure:"app-url"` AppURL string `validate:"required,url" mapstructure:"app-url"`
Users string `mapstructure:"users"` Users string `mapstructure:"users"`
UsersFile string `mapstructure:"users-file"` UsersFile string `mapstructure:"users-file"`
CookieSecure bool `mapstructure:"cookie-secure"` CookieSecure bool `mapstructure:"cookie-secure"`
GithubClientId string `mapstructure:"github-client-id"`
GithubClientSecret string `mapstructure:"github-client-secret"`
GoogleClientId string `mapstructure:"google-client-id"`
GoogleClientSecret string `mapstructure:"google-client-secret"`
MicrosoftClientId string `mapstructure:"microsoft-client-id"`
MicrosoftClientSecret string `mapstructure:"microsoft-client-secret"`
} }
type UserContext struct { type UserContext struct {
Username string Email string
IsLoggedIn bool IsLoggedIn bool
OAuth bool
Provider string
} }
type APIConfig struct { type APIConfig struct {
Port int Port int
Address string Address string
Secret string Secret string
AppURL string AppURL string
CookieSecure bool CookieSecure bool
} }
type OAuthConfig struct {
GithubClientId string
GithubClientSecret string
GoogleClientId string
GoogleClientSecret string
MicrosoftClientId string
MicrosoftClientSecret string
}
type OAuthBind struct {
Provider string `uri:"provider" binding:"required"`
}
type OAuthProviders struct {
Github *oauth.OAuth
Google *oauth.OAuth
Microsoft *oauth.OAuth
}
type OAuthLogin struct {
Email string
Token string
}

View File

@@ -22,7 +22,7 @@ func ParseUsers(users string) (types.Users, error) {
return types.Users{}, errors.New("invalid user format") return types.Users{}, errors.New("invalid user format")
} }
usersParsed = append(usersParsed, types.User{ usersParsed = append(usersParsed, types.User{
Username: userSplit[0], Email: userSplit[0],
Password: userSplit[1], Password: userSplit[1],
}) })
} }

View File

@@ -20,7 +20,7 @@ export const LoginPage = () => {
} }
const schema = z.object({ const schema = z.object({
username: z.string(), email: z.string().email(),
password: z.string(), password: z.string(),
}); });
@@ -29,7 +29,7 @@ export const LoginPage = () => {
const form = useForm({ const form = useForm({
mode: "uncontrolled", mode: "uncontrolled",
initialValues: { initialValues: {
username: "", email: "",
password: "", password: "",
}, },
validate: zodResolver(schema), validate: zodResolver(schema),
@@ -42,7 +42,7 @@ export const LoginPage = () => {
onError: () => { onError: () => {
notifications.show({ notifications.show({
title: "Failed to login", title: "Failed to login",
message: "Check your username and password", message: "Check your email and password",
color: "red", color: "red",
}); });
}, },
@@ -68,12 +68,12 @@ export const LoginPage = () => {
<Paper shadow="md" p={30} mt={30} radius="md" withBorder> <Paper shadow="md" p={30} mt={30} radius="md" withBorder>
<form onSubmit={form.onSubmit(handleSubmit)}> <form onSubmit={form.onSubmit(handleSubmit)}>
<TextInput <TextInput
label="Username" label="Email"
placeholder="tinyauth" placeholder="user@example.com"
required required
disabled={loginMutation.isLoading} disabled={loginMutation.isLoading}
key={form.key("username")} key={form.key("email")}
{...form.getInputProps("username")} {...form.getInputProps("email")}
/> />
<PasswordInput <PasswordInput
label="Password" label="Password"
@@ -90,7 +90,7 @@ export const LoginPage = () => {
type="submit" type="submit"
loading={loginMutation.isLoading} loading={loginMutation.isLoading}
> >
Sign in Login
</Button> </Button>
</form> </form>
</Paper> </Paper>

View File

@@ -5,9 +5,10 @@ import axios from "axios";
import { useUserContext } from "../context/user-context"; import { useUserContext } from "../context/user-context";
import { Navigate } from "react-router"; import { Navigate } from "react-router";
import { Layout } from "../components/layouts/layout"; import { Layout } from "../components/layouts/layout";
import { capitalize } from "../utils/utils";
export const LogoutPage = () => { export const LogoutPage = () => {
const { isLoggedIn, username } = useUserContext(); const { isLoggedIn, email, oauth, provider } = useUserContext();
if (!isLoggedIn) { if (!isLoggedIn) {
return <Navigate to="/login" />; return <Navigate to="/login" />;
@@ -43,8 +44,9 @@ export const LogoutPage = () => {
Logout Logout
</Text> </Text>
<Text> <Text>
You are currently logged in as <Code>{username}</Code>, click the You are currently logged in as <Code>{email}</Code>{" "}
button below to log out. {oauth && `using ${capitalize(provider)}`}. Click the button below to
log out.
</Text> </Text>
<Button <Button
fullWidth fullWidth

View File

@@ -2,7 +2,9 @@ import { z } from "zod";
export const userContextSchema = z.object({ export const userContextSchema = z.object({
isLoggedIn: z.boolean(), isLoggedIn: z.boolean(),
username: z.string(), email: z.string(),
oauth: z.boolean(),
provider: z.string(),
}); });
export type UserContextSchemaType = z.infer<typeof userContextSchema>; export type UserContextSchemaType = z.infer<typeof userContextSchema>;

1
site/src/utils/utils.ts Normal file
View File

@@ -0,0 +1 @@
export const capitalize = (s: string) => s.charAt(0).toUpperCase() + s.slice(1);