mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-28 04:35:40 +00:00
chore: add comments to code
This commit is contained in:
14
cmd/root.go
14
cmd/root.go
@@ -125,20 +125,24 @@ var rootCmd = &cobra.Command{
|
|||||||
|
|
||||||
func Execute() {
|
func Execute() {
|
||||||
err := rootCmd.Execute()
|
err := rootCmd.Execute()
|
||||||
if err != nil {
|
HandleError(err, "Failed to execute root command")
|
||||||
log.Fatal().Err(err).Msg("Failed to execute command")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func HandleError(err error, msg string) {
|
func HandleError(err error, msg string) {
|
||||||
|
// If error log it and exit
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatal().Err(err).Msg(msg)
|
log.Fatal().Err(err).Msg(msg)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
// Add user command
|
||||||
rootCmd.AddCommand(cmd.UserCmd())
|
rootCmd.AddCommand(cmd.UserCmd())
|
||||||
|
|
||||||
|
// Read environment variables
|
||||||
viper.AutomaticEnv()
|
viper.AutomaticEnv()
|
||||||
|
|
||||||
|
// Flags
|
||||||
rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
|
rootCmd.Flags().Int("port", 3000, "Port to run the server on.")
|
||||||
rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
|
rootCmd.Flags().String("address", "0.0.0.0", "Address to bind the server to.")
|
||||||
rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
|
rootCmd.Flags().String("secret", "", "Secret to use for the cookie.")
|
||||||
@@ -167,6 +171,8 @@ func init() {
|
|||||||
rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
|
rootCmd.Flags().String("oauth-whitelist", "", "Comma separated list of email addresses to whitelist when using OAuth.")
|
||||||
rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
|
rootCmd.Flags().Int("session-expiry", 86400, "Session (cookie) expiration time in seconds.")
|
||||||
rootCmd.Flags().Int("log-level", 1, "Log level.")
|
rootCmd.Flags().Int("log-level", 1, "Log level.")
|
||||||
|
|
||||||
|
// Bind flags to environment
|
||||||
viper.BindEnv("port", "PORT")
|
viper.BindEnv("port", "PORT")
|
||||||
viper.BindEnv("address", "ADDRESS")
|
viper.BindEnv("address", "ADDRESS")
|
||||||
viper.BindEnv("secret", "SECRET")
|
viper.BindEnv("secret", "SECRET")
|
||||||
@@ -195,5 +201,7 @@ func init() {
|
|||||||
viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
|
viper.BindEnv("oauth-whitelist", "OAUTH_WHITELIST")
|
||||||
viper.BindEnv("session-expiry", "SESSION_EXPIRY")
|
viper.BindEnv("session-expiry", "SESSION_EXPIRY")
|
||||||
viper.BindEnv("log-level", "LOG_LEVEL")
|
viper.BindEnv("log-level", "LOG_LEVEL")
|
||||||
|
|
||||||
|
// Bind flags to viper
|
||||||
viper.BindPFlags(rootCmd.Flags())
|
viper.BindPFlags(rootCmd.Flags())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,9 +22,12 @@ var CreateCmd = &cobra.Command{
|
|||||||
Short: "Create a user",
|
Short: "Create a user",
|
||||||
Long: `Create a user either interactively or by passing flags.`,
|
Long: `Create a user either interactively or by passing flags.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
// Setup logger
|
||||||
log.Logger = log.Level(zerolog.InfoLevel)
|
log.Logger = log.Level(zerolog.InfoLevel)
|
||||||
|
|
||||||
|
// Check if interactive
|
||||||
if interactive {
|
if interactive {
|
||||||
|
// Create huh form
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
huh.NewGroup(
|
huh.NewGroup(
|
||||||
huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error {
|
huh.NewInput().Title("Username").Value(&username).Validate((func(s string) error {
|
||||||
@@ -43,6 +46,7 @@ var CreateCmd = &cobra.Command{
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Use simple theme
|
||||||
var baseTheme *huh.Theme = huh.ThemeBase()
|
var baseTheme *huh.Theme = huh.ThemeBase()
|
||||||
|
|
||||||
formErr := form.WithTheme(baseTheme).Run()
|
formErr := form.WithTheme(baseTheme).Run()
|
||||||
@@ -52,12 +56,14 @@ var CreateCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do we have username and password?
|
||||||
if username == "" || password == "" {
|
if username == "" || password == "" {
|
||||||
log.Error().Msg("Username and password cannot be empty")
|
log.Error().Msg("Username and password cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user")
|
log.Info().Str("username", username).Str("password", password).Bool("docker", docker).Msg("Creating user")
|
||||||
|
|
||||||
|
// Hash password
|
||||||
passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
passwordByte, passwordErr := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost)
|
||||||
|
|
||||||
if passwordErr != nil {
|
if passwordErr != nil {
|
||||||
@@ -66,15 +72,18 @@ var CreateCmd = &cobra.Command{
|
|||||||
|
|
||||||
passwordString := string(passwordByte)
|
passwordString := string(passwordByte)
|
||||||
|
|
||||||
|
// Escape $ for docker
|
||||||
if docker {
|
if docker {
|
||||||
passwordString = strings.ReplaceAll(passwordString, "$", "$$")
|
passwordString = strings.ReplaceAll(passwordString, "$", "$$")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Log user created
|
||||||
log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created")
|
log.Info().Str("user", fmt.Sprintf("%s:%s", username, passwordString)).Msg("User created")
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
// Flags
|
||||||
CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively")
|
CreateCmd.Flags().BoolVar(&interactive, "interactive", false, "Create a user interactively")
|
||||||
CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
|
CreateCmd.Flags().BoolVar(&docker, "docker", false, "Format output for docker")
|
||||||
CreateCmd.Flags().StringVar(&username, "username", "", "Username")
|
CreateCmd.Flags().StringVar(&username, "username", "", "Username")
|
||||||
|
|||||||
@@ -22,9 +22,12 @@ var VerifyCmd = &cobra.Command{
|
|||||||
Short: "Verify a user is set up correctly",
|
Short: "Verify a user is set up correctly",
|
||||||
Long: `Verify a user is set up correctly meaning that it has a correct username and password.`,
|
Long: `Verify a user is set up correctly meaning that it has a correct username and password.`,
|
||||||
Run: func(cmd *cobra.Command, args []string) {
|
Run: func(cmd *cobra.Command, args []string) {
|
||||||
|
// Setup logger
|
||||||
log.Logger = log.Level(zerolog.InfoLevel)
|
log.Logger = log.Level(zerolog.InfoLevel)
|
||||||
|
|
||||||
|
// Check if interactive
|
||||||
if interactive {
|
if interactive {
|
||||||
|
// Create huh form
|
||||||
form := huh.NewForm(
|
form := huh.NewForm(
|
||||||
huh.NewGroup(
|
huh.NewGroup(
|
||||||
huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error {
|
huh.NewInput().Title("User (username:hash)").Value(&user).Validate((func(s string) error {
|
||||||
@@ -49,6 +52,7 @@ var VerifyCmd = &cobra.Command{
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Use simple theme
|
||||||
var baseTheme *huh.Theme = huh.ThemeBase()
|
var baseTheme *huh.Theme = huh.ThemeBase()
|
||||||
|
|
||||||
formErr := form.WithTheme(baseTheme).Run()
|
formErr := form.WithTheme(baseTheme).Run()
|
||||||
@@ -58,22 +62,26 @@ var VerifyCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Do we have username, password and user?
|
||||||
if username == "" || password == "" || user == "" {
|
if username == "" || password == "" || user == "" {
|
||||||
log.Fatal().Msg("Username, password and user cannot be empty")
|
log.Fatal().Msg("Username, password and user cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user")
|
log.Info().Str("user", user).Str("username", username).Str("password", password).Bool("docker", docker).Msg("Verifying user")
|
||||||
|
|
||||||
|
// Split username and password
|
||||||
userSplit := strings.Split(user, ":")
|
userSplit := strings.Split(user, ":")
|
||||||
|
|
||||||
if userSplit[1] == "" {
|
if userSplit[1] == "" {
|
||||||
log.Fatal().Msg("User is not formatted correctly")
|
log.Fatal().Msg("User is not formatted correctly")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Replace $$ with $ if formatted for docker
|
||||||
if docker {
|
if docker {
|
||||||
userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$")
|
userSplit[1] = strings.ReplaceAll(userSplit[1], "$$", "$")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Compare username and password
|
||||||
verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password))
|
verifyErr := bcrypt.CompareHashAndPassword([]byte(userSplit[1]), []byte(password))
|
||||||
|
|
||||||
if verifyErr != nil || username != userSplit[0] {
|
if verifyErr != nil || username != userSplit[0] {
|
||||||
@@ -85,6 +93,7 @@ var VerifyCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
|
// Flags
|
||||||
VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
|
VerifyCmd.Flags().BoolVarP(&interactive, "interactive", "i", false, "Create a user interactively")
|
||||||
VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
|
VerifyCmd.Flags().BoolVar(&docker, "docker", false, "Is the user formatted for docker?")
|
||||||
VerifyCmd.Flags().StringVar(&username, "username", "", "Username")
|
VerifyCmd.Flags().StringVar(&username, "username", "", "Username")
|
||||||
|
|||||||
@@ -41,11 +41,15 @@ type API struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) Init() {
|
func (api *API) Init() {
|
||||||
|
// Disable gin logs
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
|
|
||||||
|
// Create router and use zerolog for logs
|
||||||
log.Debug().Msg("Setting up router")
|
log.Debug().Msg("Setting up router")
|
||||||
router := gin.New()
|
router := gin.New()
|
||||||
router.Use(zerolog())
|
router.Use(zerolog())
|
||||||
|
|
||||||
|
// Read UI assets
|
||||||
log.Debug().Msg("Setting up assets")
|
log.Debug().Msg("Setting up assets")
|
||||||
dist, distErr := fs.Sub(assets.Assets, "dist")
|
dist, distErr := fs.Sub(assets.Assets, "dist")
|
||||||
|
|
||||||
@@ -53,11 +57,15 @@ func (api *API) Init() {
|
|||||||
log.Fatal().Err(distErr).Msg("Failed to get UI assets")
|
log.Fatal().Err(distErr).Msg("Failed to get UI assets")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create file server
|
||||||
log.Debug().Msg("Setting up file server")
|
log.Debug().Msg("Setting up file server")
|
||||||
fileServer := http.FileServer(http.FS(dist))
|
fileServer := http.FileServer(http.FS(dist))
|
||||||
|
|
||||||
|
// Setup cookie store
|
||||||
log.Debug().Msg("Setting up cookie store")
|
log.Debug().Msg("Setting up cookie store")
|
||||||
store := cookie.NewStore([]byte(api.Config.Secret))
|
store := cookie.NewStore([]byte(api.Config.Secret))
|
||||||
|
|
||||||
|
// Get domain to use for session cookies
|
||||||
log.Debug().Msg("Getting domain")
|
log.Debug().Msg("Getting domain")
|
||||||
domain, domainErr := utils.GetRootURL(api.Config.AppURL)
|
domain, domainErr := utils.GetRootURL(api.Config.AppURL)
|
||||||
|
|
||||||
@@ -70,6 +78,7 @@ func (api *API) Init() {
|
|||||||
|
|
||||||
api.Domain = fmt.Sprintf(".%s", domain)
|
api.Domain = fmt.Sprintf(".%s", domain)
|
||||||
|
|
||||||
|
// Use session middleware
|
||||||
store.Options(sessions.Options{
|
store.Options(sessions.Options{
|
||||||
Domain: api.Domain,
|
Domain: api.Domain,
|
||||||
Path: "/",
|
Path: "/",
|
||||||
@@ -80,175 +89,169 @@ func (api *API) Init() {
|
|||||||
|
|
||||||
router.Use(sessions.Sessions("tinyauth", store))
|
router.Use(sessions.Sessions("tinyauth", store))
|
||||||
|
|
||||||
|
// UI middleware
|
||||||
router.Use(func(c *gin.Context) {
|
router.Use(func(c *gin.Context) {
|
||||||
|
// If not an API request, serve the UI
|
||||||
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
|
if !strings.HasPrefix(c.Request.URL.Path, "/api") {
|
||||||
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
|
_, err := fs.Stat(dist, strings.TrimPrefix(c.Request.URL.Path, "/"))
|
||||||
|
|
||||||
|
// If the file doesn't exist, serve the index.html
|
||||||
if os.IsNotExist(err) {
|
if os.IsNotExist(err) {
|
||||||
c.Request.URL.Path = "/"
|
c.Request.URL.Path = "/"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Serve the file
|
||||||
fileServer.ServeHTTP(c.Writer, c.Request)
|
fileServer.ServeHTTP(c.Writer, c.Request)
|
||||||
|
|
||||||
|
// Stop further processing
|
||||||
c.Abort()
|
c.Abort()
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Set router
|
||||||
api.Router = router
|
api.Router = router
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) SetupRoutes() {
|
func (api *API) SetupRoutes() {
|
||||||
api.Router.GET("/api/auth/:proxy", func(c *gin.Context) {
|
api.Router.GET("/api/auth/:proxy", func(c *gin.Context) {
|
||||||
|
// Create struct for proxy
|
||||||
var proxy types.Proxy
|
var proxy types.Proxy
|
||||||
|
|
||||||
|
// Bind URI
|
||||||
bindErr := c.BindUri(&proxy)
|
bindErr := c.BindUri(&proxy)
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to bind URI", bindErr) {
|
if api.handleError(c, "Failed to bind URI", bindErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy")
|
log.Debug().Interface("proxy", proxy.Proxy).Msg("Got proxy")
|
||||||
|
|
||||||
|
// Get user context
|
||||||
userContext := api.Hooks.UseUserContext(c)
|
userContext := api.Hooks.UseUserContext(c)
|
||||||
|
|
||||||
|
// Get headers
|
||||||
uri := c.Request.Header.Get("X-Forwarded-Uri")
|
uri := c.Request.Header.Get("X-Forwarded-Uri")
|
||||||
proto := c.Request.Header.Get("X-Forwarded-Proto")
|
proto := c.Request.Header.Get("X-Forwarded-Proto")
|
||||||
host := c.Request.Header.Get("X-Forwarded-Host")
|
host := c.Request.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
|
// Check if user is logged in
|
||||||
if userContext.IsLoggedIn {
|
if userContext.IsLoggedIn {
|
||||||
log.Debug().Msg("Authenticated")
|
log.Debug().Msg("Authenticated")
|
||||||
|
|
||||||
|
// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
|
||||||
appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host)
|
appAllowed, appAllowedErr := api.Auth.ResourceAllowed(userContext, host)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if appAllowedErr != nil {
|
if appAllowedErr != nil {
|
||||||
switch proxy.Proxy {
|
// Return 501 if nginx is the proxy or if the request is using an Authorization header
|
||||||
case "nginx":
|
if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
|
||||||
log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed")
|
log.Error().Err(appAllowedErr).Msg("Failed to check if app is allowed")
|
||||||
c.JSON(501, gin.H{
|
c.JSON(501, gin.H{
|
||||||
"status": 501,
|
"status": 501,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
default:
|
}
|
||||||
if c.GetHeader("Authorization") != "" {
|
|
||||||
log.Error().Err(appAllowedErr).Msg("Failed to check if resource is allowed")
|
// Return the internal server error page
|
||||||
c.JSON(501, gin.H{
|
if api.handleError(c, "Failed to check if app is allowed", appAllowedErr) {
|
||||||
"status": 501,
|
return
|
||||||
"message": "Internal Server Error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if api.handleError(c, "Failed to check if resource is allowed", appAllowedErr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
|
log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
|
||||||
|
|
||||||
|
// The user is not allowed to access the app
|
||||||
if !appAllowed {
|
if !appAllowed {
|
||||||
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
|
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User not allowed")
|
||||||
|
|
||||||
|
// Build query
|
||||||
queries, queryErr := query.Values(types.UnauthorizedQuery{
|
queries, queryErr := query.Values(types.UnauthorizedQuery{
|
||||||
Username: userContext.Username,
|
Username: userContext.Username,
|
||||||
Resource: strings.Split(host, ".")[0],
|
Resource: strings.Split(host, ".")[0],
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if queryErr != nil {
|
if queryErr != nil {
|
||||||
switch proxy.Proxy {
|
// Return 501 if nginx is the proxy or if the request is using an Authorization header
|
||||||
case "nginx":
|
if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
|
||||||
log.Error().Err(queryErr).Msg("Failed to build query")
|
log.Error().Err(queryErr).Msg("Failed to build query")
|
||||||
c.JSON(501, gin.H{
|
c.JSON(501, gin.H{
|
||||||
"status": 501,
|
"status": 501,
|
||||||
"message": "Internal Server Error",
|
"message": "Internal Server Error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
default:
|
}
|
||||||
if c.GetHeader("Authorization") != "" {
|
|
||||||
log.Error().Err(appAllowedErr).Msg("Failed to build query")
|
// Return the internal server error page
|
||||||
c.JSON(501, gin.H{
|
if api.handleError(c, "Failed to build query", queryErr) {
|
||||||
"status": 501,
|
return
|
||||||
"message": "Internal Server Error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if api.handleError(c, "Failed to build query", queryErr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
switch proxy.Proxy {
|
// Return 401 if nginx is the proxy or if the request is using an Authorization header
|
||||||
case "nginx":
|
if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
if c.GetHeader("Authorization") != "" {
|
|
||||||
c.JSON(401, gin.H{
|
|
||||||
"status": 401,
|
|
||||||
"message": "Unauthorized",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We are using caddy/traefik so redirect
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, queries.Encode()))
|
||||||
|
|
||||||
|
// Stop further processing
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The user is allowed to access the app
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Authenticated",
|
"message": "Authenticated",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Stop further processing
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch proxy.Proxy {
|
// The user is not logged in
|
||||||
case "nginx":
|
log.Debug().Msg("Unauthorized")
|
||||||
|
|
||||||
|
// Return 401 if nginx is the proxy or if the request is using an Authorization header
|
||||||
|
if proxy.Proxy == "nginx" || c.GetHeader("Authorization") != "" {
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"status": 401,
|
"status": 401,
|
||||||
"message": "Unauthorized",
|
"message": "Unauthorized",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
default:
|
|
||||||
if c.GetHeader("Authorization") != "" {
|
|
||||||
c.JSON(401, gin.H{
|
|
||||||
"status": 401,
|
|
||||||
"message": "Unauthorized",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
queries, queryErr := query.Values(types.LoginQuery{
|
|
||||||
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
|
|
||||||
})
|
|
||||||
|
|
||||||
log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
|
|
||||||
|
|
||||||
if queryErr != nil {
|
|
||||||
switch proxy.Proxy {
|
|
||||||
case "nginx":
|
|
||||||
log.Error().Err(queryErr).Msg("Failed to build query")
|
|
||||||
c.JSON(501, gin.H{
|
|
||||||
"status": 501,
|
|
||||||
"message": "Internal Server Error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if api.handleError(c, "Failed to build query", queryErr) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Build query
|
||||||
|
queries, queryErr := query.Values(types.LoginQuery{
|
||||||
|
RedirectURI: fmt.Sprintf("%s://%s%s", proto, host, uri),
|
||||||
|
})
|
||||||
|
|
||||||
|
log.Debug().Interface("redirect_uri", fmt.Sprintf("%s://%s%s", proto, host, uri)).Msg("Redirecting to login")
|
||||||
|
|
||||||
|
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
|
||||||
|
if api.handleError(c, "Failed to build query", queryErr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect to login
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/?%s", api.Config.AppURL, queries.Encode()))
|
||||||
})
|
})
|
||||||
|
|
||||||
api.Router.POST("/api/login", func(c *gin.Context) {
|
api.Router.POST("/api/login", func(c *gin.Context) {
|
||||||
|
// Create login struct
|
||||||
var login types.LoginRequest
|
var login types.LoginRequest
|
||||||
|
|
||||||
|
// Bind JSON
|
||||||
err := c.BindJSON(&login)
|
err := c.BindJSON(&login)
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to bind JSON")
|
log.Error().Err(err).Msg("Failed to bind JSON")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
@@ -260,8 +263,10 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Msg("Got login request")
|
log.Debug().Msg("Got login request")
|
||||||
|
|
||||||
|
// Get user based on username
|
||||||
user := api.Auth.GetUser(login.Username)
|
user := api.Auth.GetUser(login.Username)
|
||||||
|
|
||||||
|
// User does not exist
|
||||||
if user == nil {
|
if user == nil {
|
||||||
log.Debug().Str("username", login.Username).Msg("User not found")
|
log.Debug().Str("username", login.Username).Msg("User not found")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
@@ -271,6 +276,9 @@ func (api *API) SetupRoutes() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Got user")
|
||||||
|
|
||||||
|
// Check if password is correct
|
||||||
if !api.Auth.CheckPassword(*user, login.Password) {
|
if !api.Auth.CheckPassword(*user, login.Password) {
|
||||||
log.Debug().Str("username", login.Username).Msg("Password incorrect")
|
log.Debug().Str("username", login.Username).Msg("Password incorrect")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
@@ -282,11 +290,13 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Msg("Password correct, logging in")
|
log.Debug().Msg("Password correct, logging in")
|
||||||
|
|
||||||
|
// Create session cookie with username as provider
|
||||||
api.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
api.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
||||||
Username: login.Username,
|
Username: login.Username,
|
||||||
Provider: "username",
|
Provider: "username",
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Return logged in
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logged in",
|
"message": "Logged in",
|
||||||
@@ -294,12 +304,17 @@ func (api *API) SetupRoutes() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
api.Router.POST("/api/logout", func(c *gin.Context) {
|
api.Router.POST("/api/logout", func(c *gin.Context) {
|
||||||
|
log.Debug().Msg("Logging out")
|
||||||
|
|
||||||
|
// Delete session cookie
|
||||||
api.Auth.DeleteSessionCookie(c)
|
api.Auth.DeleteSessionCookie(c)
|
||||||
|
|
||||||
log.Debug().Msg("Cleaning up redirect cookie")
|
log.Debug().Msg("Cleaning up redirect cookie")
|
||||||
|
|
||||||
|
// Clean up redirect cookie if it exists
|
||||||
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
||||||
|
|
||||||
|
// Return logged out
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Logged out",
|
"message": "Logged out",
|
||||||
@@ -308,19 +323,24 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
api.Router.GET("/api/status", func(c *gin.Context) {
|
api.Router.GET("/api/status", func(c *gin.Context) {
|
||||||
log.Debug().Msg("Checking status")
|
log.Debug().Msg("Checking status")
|
||||||
|
|
||||||
|
// Get user context
|
||||||
userContext := api.Hooks.UseUserContext(c)
|
userContext := api.Hooks.UseUserContext(c)
|
||||||
|
|
||||||
|
// Get configured providers
|
||||||
configuredProviders := api.Providers.GetConfiguredProviders()
|
configuredProviders := api.Providers.GetConfiguredProviders()
|
||||||
|
|
||||||
|
// We have username/password configured so add it to our providers
|
||||||
if api.Auth.UserAuthConfigured() {
|
if api.Auth.UserAuthConfigured() {
|
||||||
configuredProviders = append(configuredProviders, "username")
|
configuredProviders = append(configuredProviders, "username")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// We are not logged in so return unauthorized
|
||||||
if !userContext.IsLoggedIn {
|
if !userContext.IsLoggedIn {
|
||||||
log.Debug().Msg("Unauthenticated")
|
log.Debug().Msg("Unauthorized")
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Unauthenticated",
|
"message": "Unauthorized",
|
||||||
"username": "",
|
"username": "",
|
||||||
"isLoggedIn": false,
|
"isLoggedIn": false,
|
||||||
"oauth": false,
|
"oauth": false,
|
||||||
@@ -333,6 +353,7 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated")
|
log.Debug().Interface("userContext", userContext).Strs("configuredProviders", configuredProviders).Bool("disableContinue", api.Config.DisableContinue).Msg("Authenticated")
|
||||||
|
|
||||||
|
// We are logged in so return our user context
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Authenticated",
|
"message": "Authenticated",
|
||||||
@@ -345,18 +366,14 @@ func (api *API) SetupRoutes() {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
api.Router.GET("/api/healthcheck", func(c *gin.Context) {
|
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"status": 200,
|
|
||||||
"message": "OK",
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
|
api.Router.GET("/api/oauth/url/:provider", func(c *gin.Context) {
|
||||||
|
// Create struct for OAuth request
|
||||||
var request types.OAuthRequest
|
var request types.OAuthRequest
|
||||||
|
|
||||||
|
// Bind URI
|
||||||
bindErr := c.BindUri(&request)
|
bindErr := c.BindUri(&request)
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if bindErr != nil {
|
if bindErr != nil {
|
||||||
log.Error().Err(bindErr).Msg("Failed to bind URI")
|
log.Error().Err(bindErr).Msg("Failed to bind URI")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
@@ -368,8 +385,10 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Msg("Got OAuth request")
|
log.Debug().Msg("Got OAuth request")
|
||||||
|
|
||||||
|
// Check if provider exists
|
||||||
provider := api.Providers.GetProvider(request.Provider)
|
provider := api.Providers.GetProvider(request.Provider)
|
||||||
|
|
||||||
|
// Provider does not exist
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
@@ -380,24 +399,38 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
||||||
|
|
||||||
|
// Get auth URL
|
||||||
authURL := provider.GetAuthURL()
|
authURL := provider.GetAuthURL()
|
||||||
|
|
||||||
log.Debug().Msg("Got auth URL")
|
log.Debug().Msg("Got auth URL")
|
||||||
|
|
||||||
|
// Get redirect URI
|
||||||
redirectURI := c.Query("redirect_uri")
|
redirectURI := c.Query("redirect_uri")
|
||||||
|
|
||||||
|
// Set redirect cookie if redirect URI is provided
|
||||||
if redirectURI != "" {
|
if redirectURI != "" {
|
||||||
log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
|
log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
|
||||||
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
|
c.SetCookie("tinyauth_redirect_uri", redirectURI, 3600, "/", api.Domain, api.Config.CookieSecure, true)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Tailscale does not have an auth url so we create a random code (does not need to be secure) to avoid caching and send it
|
||||||
if request.Provider == "tailscale" {
|
if request.Provider == "tailscale" {
|
||||||
|
// Build tailscale query
|
||||||
tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{
|
tailscaleQuery, tailscaleQueryErr := query.Values(types.TailscaleQuery{
|
||||||
Code: (1000 + rand.IntN(9000)), // doesn't need to be secure, just there to avoid caching
|
Code: (1000 + rand.IntN(9000)),
|
||||||
})
|
})
|
||||||
if api.handleError(c, "Failed to build query", tailscaleQueryErr) {
|
|
||||||
|
// Handle error
|
||||||
|
if tailscaleQueryErr != nil {
|
||||||
|
log.Error().Err(tailscaleQueryErr).Msg("Failed to build query")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "Internal Server Error",
|
||||||
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return tailscale URL (immidiately redirects to the callback)
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Ok",
|
"message": "Ok",
|
||||||
@@ -406,6 +439,7 @@ func (api *API) SetupRoutes() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return auth URL
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Ok",
|
"message": "Ok",
|
||||||
@@ -414,18 +448,23 @@ func (api *API) SetupRoutes() {
|
|||||||
})
|
})
|
||||||
|
|
||||||
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
|
api.Router.GET("/api/oauth/callback/:provider", func(c *gin.Context) {
|
||||||
|
// Create struct for OAuth request
|
||||||
var providerName types.OAuthRequest
|
var providerName types.OAuthRequest
|
||||||
|
|
||||||
|
// Bind URI
|
||||||
bindErr := c.BindUri(&providerName)
|
bindErr := c.BindUri(&providerName)
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to bind URI", bindErr) {
|
if api.handleError(c, "Failed to bind URI", bindErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
||||||
|
|
||||||
|
// Get code
|
||||||
code := c.Query("code")
|
code := c.Query("code")
|
||||||
|
|
||||||
|
// Code empty so redirect to error
|
||||||
if code == "" {
|
if code == "" {
|
||||||
log.Error().Msg("No code provided")
|
log.Error().Msg("No code provided")
|
||||||
c.Redirect(http.StatusPermanentRedirect, "/error")
|
c.Redirect(http.StatusPermanentRedirect, "/error")
|
||||||
@@ -434,51 +473,67 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Msg("Got code")
|
log.Debug().Msg("Got code")
|
||||||
|
|
||||||
|
// Get provider
|
||||||
provider := api.Providers.GetProvider(providerName.Provider)
|
provider := api.Providers.GetProvider(providerName.Provider)
|
||||||
|
|
||||||
log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
|
log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
|
||||||
|
|
||||||
|
// Provider does not exist
|
||||||
if provider == nil {
|
if provider == nil {
|
||||||
c.Redirect(http.StatusPermanentRedirect, "/not-found")
|
c.Redirect(http.StatusPermanentRedirect, "/not-found")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Exchange token (authenticates user)
|
||||||
_, tokenErr := provider.ExchangeToken(code)
|
_, tokenErr := provider.ExchangeToken(code)
|
||||||
|
|
||||||
log.Debug().Msg("Got token")
|
log.Debug().Msg("Got token")
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to exchange token", tokenErr) {
|
if api.handleError(c, "Failed to exchange token", tokenErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get email
|
||||||
email, emailErr := api.Providers.GetUser(providerName.Provider)
|
email, emailErr := api.Providers.GetUser(providerName.Provider)
|
||||||
|
|
||||||
log.Debug().Str("email", email).Msg("Got email")
|
log.Debug().Str("email", email).Msg("Got email")
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to get user", emailErr) {
|
if api.handleError(c, "Failed to get user", emailErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Email is not whitelisted
|
||||||
if !api.Auth.EmailWhitelisted(email) {
|
if !api.Auth.EmailWhitelisted(email) {
|
||||||
log.Warn().Str("email", email).Msg("Email not whitelisted")
|
log.Warn().Str("email", email).Msg("Email not whitelisted")
|
||||||
|
|
||||||
|
// Build query
|
||||||
unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
|
unauthorizedQuery, unauthorizedQueryErr := query.Values(types.UnauthorizedQuery{
|
||||||
Username: email,
|
Username: email,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to build query", unauthorizedQueryErr) {
|
if api.handleError(c, "Failed to build query", unauthorizedQueryErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect to unauthorized
|
||||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/unauthorized?%s", api.Config.AppURL, unauthorizedQuery.Encode()))
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Email whitelisted")
|
log.Debug().Msg("Email whitelisted")
|
||||||
|
|
||||||
|
// Create session cookie
|
||||||
api.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
api.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
||||||
Username: email,
|
Username: email,
|
||||||
Provider: providerName.Provider,
|
Provider: providerName.Provider,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Get redirect URI
|
||||||
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
|
redirectURI, redirectURIErr := c.Cookie("tinyauth_redirect_uri")
|
||||||
|
|
||||||
|
// If it is empty it means that no redirect_uri was provided to the login screen so we just log in
|
||||||
if redirectURIErr != nil {
|
if redirectURIErr != nil {
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
@@ -488,28 +543,44 @@ func (api *API) SetupRoutes() {
|
|||||||
|
|
||||||
log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI")
|
log.Debug().Str("redirectURI", redirectURI).Msg("Got redirect URI")
|
||||||
|
|
||||||
|
// Clean up redirect cookie since we already have the value
|
||||||
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
c.SetCookie("tinyauth_redirect_uri", "", -1, "/", api.Domain, api.Config.CookieSecure, true)
|
||||||
|
|
||||||
|
// Build query
|
||||||
redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
|
redirectQuery, redirectQueryErr := query.Values(types.LoginQuery{
|
||||||
RedirectURI: redirectURI,
|
RedirectURI: redirectURI,
|
||||||
})
|
})
|
||||||
|
|
||||||
log.Debug().Msg("Got redirect query")
|
log.Debug().Msg("Got redirect query")
|
||||||
|
|
||||||
|
// Handle error
|
||||||
if api.handleError(c, "Failed to build query", redirectQueryErr) {
|
if api.handleError(c, "Failed to build query", redirectQueryErr) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Redirect to continue with the redirect URI
|
||||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/continue?%s", api.Config.AppURL, redirectQuery.Encode()))
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Simple healthcheck
|
||||||
|
api.Router.GET("/api/healthcheck", func(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"message": "OK",
|
||||||
|
})
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) Run() {
|
func (api *API) Run() {
|
||||||
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
|
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
|
||||||
|
|
||||||
|
// Run server
|
||||||
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
|
api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// handleError logs the error and redirects to the error page (only meant for stuff the user may access does not apply for login paths)
|
||||||
func (api *API) handleError(c *gin.Context, msg string, err error) bool {
|
func (api *API) handleError(c *gin.Context, msg string, err error) bool {
|
||||||
|
// If error is not nil log it and redirect to error page also return true so we can stop further processing
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg(msg)
|
log.Error().Err(err).Msg(msg)
|
||||||
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL))
|
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", api.Config.AppURL))
|
||||||
@@ -518,19 +589,25 @@ func (api *API) handleError(c *gin.Context, msg string, err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// zerolog is a middleware for gin that logs requests using zerolog
|
||||||
func zerolog() gin.HandlerFunc {
|
func zerolog() gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
|
// Get initial time
|
||||||
tStart := time.Now()
|
tStart := time.Now()
|
||||||
|
|
||||||
|
// Process request
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
|
// Get status code, address, method and path
|
||||||
code := c.Writer.Status()
|
code := c.Writer.Status()
|
||||||
address := c.Request.RemoteAddr
|
address := c.Request.RemoteAddr
|
||||||
method := c.Request.Method
|
method := c.Request.Method
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
|
|
||||||
|
// Get latency
|
||||||
latency := time.Since(tStart).String()
|
latency := time.Since(tStart).String()
|
||||||
|
|
||||||
|
// Log request
|
||||||
switch {
|
switch {
|
||||||
case code >= 200 && code < 300:
|
case code >= 200 && code < 300:
|
||||||
log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
|
log.Info().Str("method", method).Str("path", path).Str("address", address).Int("status", code).Str("latency", latency).Msg("Request")
|
||||||
|
|||||||
@@ -4,8 +4,12 @@ import (
|
|||||||
"embed"
|
"embed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// UI assets
|
||||||
|
//
|
||||||
//go:embed dist
|
//go:embed dist
|
||||||
var Assets embed.FS
|
var Assets embed.FS
|
||||||
|
|
||||||
|
// Version file
|
||||||
|
//
|
||||||
//go:embed version
|
//go:embed version
|
||||||
var Version string
|
var Version string
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ type Auth struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) GetUser(username string) *types.User {
|
func (auth *Auth) GetUser(username string) *types.User {
|
||||||
|
// Loop through users and return the user if the username matches
|
||||||
for _, user := range auth.Users {
|
for _, user := range auth.Users {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
@@ -40,64 +41,93 @@ func (auth *Auth) GetUser(username string) *types.User {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
func (auth *Auth) CheckPassword(user types.User, password string) bool {
|
||||||
hashedPasswordErr := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
// Compare the hashed password with the password provided
|
||||||
return hashedPasswordErr == nil
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
|
func (auth *Auth) EmailWhitelisted(emailSrc string) bool {
|
||||||
|
// If the whitelist is empty, allow all emails
|
||||||
if len(auth.OAuthWhitelist) == 0 {
|
if len(auth.OAuthWhitelist) == 0 {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Loop through the whitelist and return true if the email matches
|
||||||
for _, email := range auth.OAuthWhitelist {
|
for _, email := range auth.OAuthWhitelist {
|
||||||
if email == emailSrc {
|
if email == emailSrc {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If no emails match, return false
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
|
func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie) {
|
||||||
log.Debug().Msg("Creating session cookie")
|
log.Debug().Msg("Creating session cookie")
|
||||||
|
|
||||||
|
// Get session
|
||||||
sessions := sessions.Default(c)
|
sessions := sessions.Default(c)
|
||||||
|
|
||||||
log.Debug().Msg("Setting session cookie")
|
log.Debug().Msg("Setting session cookie")
|
||||||
|
|
||||||
|
// Set data
|
||||||
sessions.Set("username", data.Username)
|
sessions.Set("username", data.Username)
|
||||||
sessions.Set("provider", data.Provider)
|
sessions.Set("provider", data.Provider)
|
||||||
sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix())
|
sessions.Set("expiry", time.Now().Add(time.Duration(auth.SessionExpiry)*time.Second).Unix())
|
||||||
|
|
||||||
|
// Save session
|
||||||
sessions.Save()
|
sessions.Save()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
|
func (auth *Auth) DeleteSessionCookie(c *gin.Context) {
|
||||||
log.Debug().Msg("Deleting session cookie")
|
log.Debug().Msg("Deleting session cookie")
|
||||||
|
|
||||||
|
// Get session
|
||||||
sessions := sessions.Default(c)
|
sessions := sessions.Default(c)
|
||||||
|
|
||||||
|
// Clear session
|
||||||
sessions.Clear()
|
sessions.Clear()
|
||||||
|
|
||||||
|
// Save session
|
||||||
sessions.Save()
|
sessions.Save()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
|
func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
|
||||||
log.Debug().Msg("Getting session cookie")
|
log.Debug().Msg("Getting session cookie")
|
||||||
|
|
||||||
|
// Get session
|
||||||
sessions := sessions.Default(c)
|
sessions := sessions.Default(c)
|
||||||
|
|
||||||
|
// Get data
|
||||||
cookieUsername := sessions.Get("username")
|
cookieUsername := sessions.Get("username")
|
||||||
cookieProvider := sessions.Get("provider")
|
cookieProvider := sessions.Get("provider")
|
||||||
cookieExpiry := sessions.Get("expiry")
|
cookieExpiry := sessions.Get("expiry")
|
||||||
|
|
||||||
|
// Convert interfaces to correct types
|
||||||
username, usernameOk := cookieUsername.(string)
|
username, usernameOk := cookieUsername.(string)
|
||||||
provider, providerOk := cookieProvider.(string)
|
provider, providerOk := cookieProvider.(string)
|
||||||
expiry, expiryOk := cookieExpiry.(int64)
|
expiry, expiryOk := cookieExpiry.(int64)
|
||||||
|
|
||||||
|
// Check if the cookie is invalid
|
||||||
if !usernameOk || !providerOk || !expiryOk {
|
if !usernameOk || !providerOk || !expiryOk {
|
||||||
log.Warn().Msg("Session cookie invalid")
|
log.Warn().Msg("Session cookie invalid")
|
||||||
return types.SessionCookie{}
|
return types.SessionCookie{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the cookie has expired
|
||||||
if time.Now().Unix() > expiry {
|
if time.Now().Unix() > expiry {
|
||||||
log.Warn().Msg("Session cookie expired")
|
log.Warn().Msg("Session cookie expired")
|
||||||
|
|
||||||
|
// If it has, delete it
|
||||||
auth.DeleteSessionCookie(c)
|
auth.DeleteSessionCookie(c)
|
||||||
|
|
||||||
|
// Return empty cookie
|
||||||
return types.SessionCookie{}
|
return types.SessionCookie{}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie")
|
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Msg("Parsed cookie")
|
||||||
|
|
||||||
|
// Return the cookie
|
||||||
return types.SessionCookie{
|
return types.SessionCookie{
|
||||||
Username: username,
|
Username: username,
|
||||||
Provider: provider,
|
Provider: provider,
|
||||||
@@ -105,42 +135,56 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) types.SessionCookie {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) UserAuthConfigured() bool {
|
func (auth *Auth) UserAuthConfigured() bool {
|
||||||
|
// If there are users, return true
|
||||||
return len(auth.Users) > 0
|
return len(auth.Users) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) {
|
func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool, error) {
|
||||||
|
// Check if we have access to the Docker API
|
||||||
isConnected := auth.Docker.DockerConnected()
|
isConnected := auth.Docker.DockerConnected()
|
||||||
|
|
||||||
|
// If we don't have access, it is assumed that the user has access
|
||||||
if !isConnected {
|
if !isConnected {
|
||||||
log.Debug().Msg("Docker not connected, allowing access")
|
log.Debug().Msg("Docker not connected, allowing access")
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the app ID from the host
|
||||||
appId := strings.Split(host, ".")[0]
|
appId := strings.Split(host, ".")[0]
|
||||||
|
|
||||||
|
// Get the containers
|
||||||
containers, containersErr := auth.Docker.GetContainers()
|
containers, containersErr := auth.Docker.GetContainers()
|
||||||
|
|
||||||
|
// If there is an error, return false
|
||||||
if containersErr != nil {
|
if containersErr != nil {
|
||||||
return false, containersErr
|
return false, containersErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got containers")
|
log.Debug().Msg("Got containers")
|
||||||
|
|
||||||
|
// Loop through the containers
|
||||||
for _, container := range containers {
|
for _, container := range containers {
|
||||||
|
// Inspect the container
|
||||||
inspect, inspectErr := auth.Docker.InspectContainer(container.ID)
|
inspect, inspectErr := auth.Docker.InspectContainer(container.ID)
|
||||||
|
|
||||||
|
// If there is an error, return false
|
||||||
if inspectErr != nil {
|
if inspectErr != nil {
|
||||||
return false, inspectErr
|
return false, inspectErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the container name (for some reason it is /name)
|
||||||
containerName := strings.Split(inspect.Name, "/")[1]
|
containerName := strings.Split(inspect.Name, "/")[1]
|
||||||
|
|
||||||
|
// There is a container with the same name as the app ID
|
||||||
if containerName == appId {
|
if containerName == appId {
|
||||||
log.Debug().Str("container", containerName).Msg("Found container")
|
log.Debug().Str("container", containerName).Msg("Found container")
|
||||||
|
|
||||||
|
// Get only the tinyauth labels in a struct
|
||||||
labels := utils.GetTinyauthLabels(inspect.Config.Labels)
|
labels := utils.GetTinyauthLabels(inspect.Config.Labels)
|
||||||
|
|
||||||
log.Debug().Msg("Got labels")
|
log.Debug().Msg("Got labels")
|
||||||
|
|
||||||
|
// If the container has an oauth whitelist, check if the user is in it
|
||||||
if context.OAuth && len(labels.OAuthWhitelist) != 0 {
|
if context.OAuth && len(labels.OAuthWhitelist) != 0 {
|
||||||
log.Debug().Msg("Checking OAuth whitelist")
|
log.Debug().Msg("Checking OAuth whitelist")
|
||||||
if slices.Contains(labels.OAuthWhitelist, context.Username) {
|
if slices.Contains(labels.OAuthWhitelist, context.Username) {
|
||||||
@@ -149,6 +193,7 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
|
|||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the container has users, check if the user is in it
|
||||||
if len(labels.Users) != 0 {
|
if len(labels.Users) != 0 {
|
||||||
log.Debug().Msg("Checking users")
|
log.Debug().Msg("Checking users")
|
||||||
if slices.Contains(labels.Users, context.Username) {
|
if slices.Contains(labels.Users, context.Username) {
|
||||||
@@ -162,32 +207,40 @@ func (auth *Auth) ResourceAllowed(context types.UserContext, host string) (bool,
|
|||||||
|
|
||||||
log.Debug().Msg("No matching container found, allowing access")
|
log.Debug().Msg("No matching container found, allowing access")
|
||||||
|
|
||||||
|
// If no matching container is found, allow access
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) GetBasicAuth(c *gin.Context) types.User {
|
func (auth *Auth) GetBasicAuth(c *gin.Context) types.User {
|
||||||
|
// Get the Authorization header
|
||||||
header := c.GetHeader("Authorization")
|
header := c.GetHeader("Authorization")
|
||||||
|
|
||||||
|
// If the header is empty, return an empty user
|
||||||
if header == "" {
|
if header == "" {
|
||||||
return types.User{}
|
return types.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split the header
|
||||||
headerSplit := strings.Split(header, " ")
|
headerSplit := strings.Split(header, " ")
|
||||||
|
|
||||||
if len(headerSplit) != 2 {
|
if len(headerSplit) != 2 {
|
||||||
return types.User{}
|
return types.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if the header is Basic
|
||||||
if headerSplit[0] != "Basic" {
|
if headerSplit[0] != "Basic" {
|
||||||
return types.User{}
|
return types.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split the credentials
|
||||||
credentials := strings.Split(headerSplit[1], ":")
|
credentials := strings.Split(headerSplit[1], ":")
|
||||||
|
|
||||||
|
// If the credentials are not in the correct format, return an empty user
|
||||||
if len(credentials) != 2 {
|
if len(credentials) != 2 {
|
||||||
return types.User{}
|
return types.User{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the user
|
||||||
return types.User{
|
return types.User{
|
||||||
Username: credentials[0],
|
Username: credentials[0],
|
||||||
Password: credentials[1],
|
Password: credentials[1],
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
package constants
|
package constants
|
||||||
|
|
||||||
|
// TinyauthLabels is a list of labels that can be used in a tinyauth protected container
|
||||||
var TinyauthLabels = []string{
|
var TinyauthLabels = []string{
|
||||||
"tinyauth.oauth.whitelist",
|
"tinyauth.oauth.whitelist",
|
||||||
"tinyauth.users",
|
"tinyauth.users",
|
||||||
|
|||||||
@@ -18,39 +18,50 @@ type Docker struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (docker *Docker) Init() error {
|
func (docker *Docker) Init() error {
|
||||||
|
// Create a new docker client
|
||||||
apiClient, err := client.NewClientWithOpts(client.FromEnv)
|
apiClient, err := client.NewClientWithOpts(client.FromEnv)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the context and api client
|
||||||
docker.Context = context.Background()
|
docker.Context = context.Background()
|
||||||
docker.Client = apiClient
|
docker.Client = apiClient
|
||||||
|
|
||||||
|
// Done
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *Docker) GetContainers() ([]types.Container, error) {
|
func (docker *Docker) GetContainers() ([]types.Container, error) {
|
||||||
|
// Get the list of containers
|
||||||
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
|
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the containers
|
||||||
return containers, nil
|
return containers, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) {
|
func (docker *Docker) InspectContainer(containerId string) (types.ContainerJSON, error) {
|
||||||
|
// Inspect the container
|
||||||
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
|
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return types.ContainerJSON{}, err
|
return types.ContainerJSON{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the inspect
|
||||||
return inspect, nil
|
return inspect, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (docker *Docker) DockerConnected() bool {
|
func (docker *Docker) DockerConnected() bool {
|
||||||
|
// Ping the docker client if there is an error it is not connected
|
||||||
_, err := docker.Client.Ping(docker.Context)
|
_, err := docker.Client.Ping(docker.Context)
|
||||||
return err == nil
|
return err == nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,13 +22,19 @@ type Hooks struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
||||||
|
// Get session cookie and basic auth
|
||||||
cookie := hooks.Auth.GetSessionCookie(c)
|
cookie := hooks.Auth.GetSessionCookie(c)
|
||||||
basic := hooks.Auth.GetBasicAuth(c)
|
basic := hooks.Auth.GetBasicAuth(c)
|
||||||
|
|
||||||
|
// Check if basic auth is set
|
||||||
if basic.Username != "" {
|
if basic.Username != "" {
|
||||||
log.Debug().Msg("Got basic auth")
|
log.Debug().Msg("Got basic auth")
|
||||||
|
|
||||||
|
// Check if user exists and password is correct
|
||||||
user := hooks.Auth.GetUser(basic.Username)
|
user := hooks.Auth.GetUser(basic.Username)
|
||||||
|
|
||||||
if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
|
if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
|
||||||
|
// Return user context since we are logged in with basic auth
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Username: basic.Username,
|
Username: basic.Username,
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
@@ -39,10 +45,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if session cookie is username/password auth
|
||||||
if cookie.Provider == "username" {
|
if cookie.Provider == "username" {
|
||||||
log.Debug().Msg("Provider is username")
|
log.Debug().Msg("Provider is username")
|
||||||
|
|
||||||
|
// Check if user exists
|
||||||
if hooks.Auth.GetUser(cookie.Username) != nil {
|
if hooks.Auth.GetUser(cookie.Username) != nil {
|
||||||
log.Debug().Msg("User exists")
|
log.Debug().Msg("User exists")
|
||||||
|
|
||||||
|
// It exists so we are logged in
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Username: cookie.Username,
|
Username: cookie.Username,
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
@@ -53,13 +64,22 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Provider is not username")
|
log.Debug().Msg("Provider is not username")
|
||||||
|
|
||||||
|
// The provider is not username so we need to check if it is an oauth provider
|
||||||
provider := hooks.Providers.GetProvider(cookie.Provider)
|
provider := hooks.Providers.GetProvider(cookie.Provider)
|
||||||
|
|
||||||
|
// If we have a provider with this name
|
||||||
if provider != nil {
|
if provider != nil {
|
||||||
log.Debug().Msg("Provider exists")
|
log.Debug().Msg("Provider exists")
|
||||||
|
|
||||||
|
// Check if the oauth email is whitelisted
|
||||||
if !hooks.Auth.EmailWhitelisted(cookie.Username) {
|
if !hooks.Auth.EmailWhitelisted(cookie.Username) {
|
||||||
log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted")
|
log.Error().Str("email", cookie.Username).Msg("Email is not whitelisted")
|
||||||
|
|
||||||
|
// It isn't so we delete the cookie and return an empty context
|
||||||
hooks.Auth.DeleteSessionCookie(c)
|
hooks.Auth.DeleteSessionCookie(c)
|
||||||
|
|
||||||
|
// Return empty context
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Username: "",
|
Username: "",
|
||||||
IsLoggedIn: false,
|
IsLoggedIn: false,
|
||||||
@@ -67,7 +87,10 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
Provider: "",
|
Provider: "",
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Email is whitelisted")
|
log.Debug().Msg("Email is whitelisted")
|
||||||
|
|
||||||
|
// Return user context since we are logged in with oauth
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Username: cookie.Username,
|
Username: cookie.Username,
|
||||||
IsLoggedIn: true,
|
IsLoggedIn: true,
|
||||||
@@ -76,6 +99,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Neither basic auth or oauth is set so we return an empty context
|
||||||
return types.UserContext{
|
return types.UserContext{
|
||||||
Username: "",
|
Username: "",
|
||||||
IsLoggedIn: false,
|
IsLoggedIn: false,
|
||||||
|
|||||||
@@ -21,23 +21,33 @@ type OAuth struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) Init() {
|
func (oauth *OAuth) Init() {
|
||||||
|
// Create a new context and verifier
|
||||||
oauth.Context = context.Background()
|
oauth.Context = context.Background()
|
||||||
oauth.Verifier = oauth2.GenerateVerifier()
|
oauth.Verifier = oauth2.GenerateVerifier()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) GetAuthURL() string {
|
func (oauth *OAuth) GetAuthURL() string {
|
||||||
|
// Return the auth url
|
||||||
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
return oauth.Config.AuthCodeURL("state", oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(oauth.Verifier))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
func (oauth *OAuth) ExchangeToken(code string) (string, error) {
|
||||||
|
// Exchange the code for a token
|
||||||
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
|
token, err := oauth.Config.Exchange(oauth.Context, code, oauth2.VerifierOption(oauth.Verifier))
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set the token
|
||||||
oauth.Token = token
|
oauth.Token = token
|
||||||
|
|
||||||
|
// Return the access token
|
||||||
return oauth.Token.AccessToken, nil
|
return oauth.Token.AccessToken, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (oauth *OAuth) GetClient() *http.Client {
|
func (oauth *OAuth) GetClient() *http.Client {
|
||||||
|
// Return the http client with the token set
|
||||||
return oauth.Config.Client(oauth.Context, oauth.Token)
|
return oauth.Config.Client(oauth.Context, oauth.Token)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,36 +8,45 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// We are assuming that the generic provider will return a JSON object with an email field
|
||||||
type GenericUserInfoResponse struct {
|
type GenericUserInfoResponse struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGenericEmail(client *http.Client, url string) (string, error) {
|
func GetGenericEmail(client *http.Client, url string) (string, error) {
|
||||||
|
// Using the oauth client get the user info url
|
||||||
res, resErr := client.Get(url)
|
res, resErr := client.Get(url)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return "", resErr
|
return "", resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got response from generic provider")
|
log.Debug().Msg("Got response from generic provider")
|
||||||
|
|
||||||
|
// Read the body of the response
|
||||||
body, bodyErr := io.ReadAll(res.Body)
|
body, bodyErr := io.ReadAll(res.Body)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if bodyErr != nil {
|
if bodyErr != nil {
|
||||||
return "", bodyErr
|
return "", bodyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Read body from generic provider")
|
log.Debug().Msg("Read body from generic provider")
|
||||||
|
|
||||||
|
// Parse the body into a user struct
|
||||||
var user GenericUserInfoResponse
|
var user GenericUserInfoResponse
|
||||||
|
|
||||||
|
// Unmarshal the body into the user struct
|
||||||
jsonErr := json.Unmarshal(body, &user)
|
jsonErr := json.Unmarshal(body, &user)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return "", jsonErr
|
return "", jsonErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Parsed user from generic provider")
|
log.Debug().Msg("Parsed user from generic provider")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return user.Email, nil
|
return user.Email, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,47 +9,58 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Github has a different response than the generic provider
|
||||||
type GithubUserInfoResponse []struct {
|
type GithubUserInfoResponse []struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
Primary bool `json:"primary"`
|
Primary bool `json:"primary"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The scopes required for the github provider
|
||||||
func GithubScopes() []string {
|
func GithubScopes() []string {
|
||||||
return []string{"user:email"}
|
return []string{"user:email"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGithubEmail(client *http.Client) (string, error) {
|
func GetGithubEmail(client *http.Client) (string, error) {
|
||||||
|
// Get the user emails from github using the oauth http client
|
||||||
res, resErr := client.Get("https://api.github.com/user/emails")
|
res, resErr := client.Get("https://api.github.com/user/emails")
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return "", resErr
|
return "", resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got response from github")
|
log.Debug().Msg("Got response from github")
|
||||||
|
|
||||||
|
// Read the body of the response
|
||||||
body, bodyErr := io.ReadAll(res.Body)
|
body, bodyErr := io.ReadAll(res.Body)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if bodyErr != nil {
|
if bodyErr != nil {
|
||||||
return "", bodyErr
|
return "", bodyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Read body from github")
|
log.Debug().Msg("Read body from github")
|
||||||
|
|
||||||
|
// Parse the body into a user struct
|
||||||
var emails GithubUserInfoResponse
|
var emails GithubUserInfoResponse
|
||||||
|
|
||||||
|
// Unmarshal the body into the user struct
|
||||||
jsonErr := json.Unmarshal(body, &emails)
|
jsonErr := json.Unmarshal(body, &emails)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return "", jsonErr
|
return "", jsonErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Parsed emails from github")
|
log.Debug().Msg("Parsed emails from github")
|
||||||
|
|
||||||
|
// Find and return the primary email
|
||||||
for _, email := range emails {
|
for _, email := range emails {
|
||||||
if email.Primary {
|
if email.Primary {
|
||||||
return email.Email, nil
|
return email.Email, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User does not have a primary email?
|
||||||
return "", errors.New("no primary email found")
|
return "", errors.New("no primary email found")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,40 +8,50 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Google works the same as the generic provider
|
||||||
type GoogleUserInfoResponse struct {
|
type GoogleUserInfoResponse struct {
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The scopes required for the google provider
|
||||||
func GoogleScopes() []string {
|
func GoogleScopes() []string {
|
||||||
return []string{"https://www.googleapis.com/auth/userinfo.email"}
|
return []string{"https://www.googleapis.com/auth/userinfo.email"}
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetGoogleEmail(client *http.Client) (string, error) {
|
func GetGoogleEmail(client *http.Client) (string, error) {
|
||||||
|
// Get the user info from google using the oauth http client
|
||||||
res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me")
|
res, resErr := client.Get("https://www.googleapis.com/userinfo/v2/me")
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return "", resErr
|
return "", resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got response from google")
|
log.Debug().Msg("Got response from google")
|
||||||
|
|
||||||
|
// Read the body of the response
|
||||||
body, bodyErr := io.ReadAll(res.Body)
|
body, bodyErr := io.ReadAll(res.Body)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if bodyErr != nil {
|
if bodyErr != nil {
|
||||||
return "", bodyErr
|
return "", bodyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Read body from google")
|
log.Debug().Msg("Read body from google")
|
||||||
|
|
||||||
|
// Parse the body into a user struct
|
||||||
var user GoogleUserInfoResponse
|
var user GoogleUserInfoResponse
|
||||||
|
|
||||||
|
// Unmarshal the body into the user struct
|
||||||
jsonErr := json.Unmarshal(body, &user)
|
jsonErr := json.Unmarshal(body, &user)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return "", jsonErr
|
return "", jsonErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Parsed user from google")
|
log.Debug().Msg("Parsed user from google")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return user.Email, nil
|
return user.Email, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -25,8 +25,11 @@ type Providers struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (providers *Providers) Init() {
|
func (providers *Providers) Init() {
|
||||||
|
// If we have a client id and secret for github, initialize the oauth provider
|
||||||
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
|
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" {
|
||||||
log.Info().Msg("Initializing Github OAuth")
|
log.Info().Msg("Initializing Github OAuth")
|
||||||
|
|
||||||
|
// Create a new oauth provider with the github config
|
||||||
providers.Github = oauth.NewOAuth(oauth2.Config{
|
providers.Github = oauth.NewOAuth(oauth2.Config{
|
||||||
ClientID: providers.Config.GithubClientId,
|
ClientID: providers.Config.GithubClientId,
|
||||||
ClientSecret: providers.Config.GithubClientSecret,
|
ClientSecret: providers.Config.GithubClientSecret,
|
||||||
@@ -34,10 +37,16 @@ func (providers *Providers) Init() {
|
|||||||
Scopes: GithubScopes(),
|
Scopes: GithubScopes(),
|
||||||
Endpoint: endpoints.GitHub,
|
Endpoint: endpoints.GitHub,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize the oauth provider
|
||||||
providers.Github.Init()
|
providers.Github.Init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have a client id and secret for google, initialize the oauth provider
|
||||||
if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" {
|
if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" {
|
||||||
log.Info().Msg("Initializing Google OAuth")
|
log.Info().Msg("Initializing Google OAuth")
|
||||||
|
|
||||||
|
// Create a new oauth provider with the google config
|
||||||
providers.Google = oauth.NewOAuth(oauth2.Config{
|
providers.Google = oauth.NewOAuth(oauth2.Config{
|
||||||
ClientID: providers.Config.GoogleClientId,
|
ClientID: providers.Config.GoogleClientId,
|
||||||
ClientSecret: providers.Config.GoogleClientSecret,
|
ClientSecret: providers.Config.GoogleClientSecret,
|
||||||
@@ -45,10 +54,15 @@ func (providers *Providers) Init() {
|
|||||||
Scopes: GoogleScopes(),
|
Scopes: GoogleScopes(),
|
||||||
Endpoint: endpoints.Google,
|
Endpoint: endpoints.Google,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize the oauth provider
|
||||||
providers.Google.Init()
|
providers.Google.Init()
|
||||||
}
|
}
|
||||||
|
|
||||||
if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" {
|
if providers.Config.TailscaleClientId != "" && providers.Config.TailscaleClientSecret != "" {
|
||||||
log.Info().Msg("Initializing Tailscale OAuth")
|
log.Info().Msg("Initializing Tailscale OAuth")
|
||||||
|
|
||||||
|
// Create a new oauth provider with the tailscale config
|
||||||
providers.Tailscale = oauth.NewOAuth(oauth2.Config{
|
providers.Tailscale = oauth.NewOAuth(oauth2.Config{
|
||||||
ClientID: providers.Config.TailscaleClientId,
|
ClientID: providers.Config.TailscaleClientId,
|
||||||
ClientSecret: providers.Config.TailscaleClientSecret,
|
ClientSecret: providers.Config.TailscaleClientSecret,
|
||||||
@@ -56,10 +70,16 @@ func (providers *Providers) Init() {
|
|||||||
Scopes: TailscaleScopes(),
|
Scopes: TailscaleScopes(),
|
||||||
Endpoint: TailscaleEndpoint,
|
Endpoint: TailscaleEndpoint,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize the oauth provider
|
||||||
providers.Tailscale.Init()
|
providers.Tailscale.Init()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If we have a client id and secret for generic oauth, initialize the oauth provider
|
||||||
if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" {
|
if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" {
|
||||||
log.Info().Msg("Initializing Generic OAuth")
|
log.Info().Msg("Initializing Generic OAuth")
|
||||||
|
|
||||||
|
// Create a new oauth provider with the generic config
|
||||||
providers.Generic = oauth.NewOAuth(oauth2.Config{
|
providers.Generic = oauth.NewOAuth(oauth2.Config{
|
||||||
ClientID: providers.Config.GenericClientId,
|
ClientID: providers.Config.GenericClientId,
|
||||||
ClientSecret: providers.Config.GenericClientSecret,
|
ClientSecret: providers.Config.GenericClientSecret,
|
||||||
@@ -70,11 +90,14 @@ func (providers *Providers) Init() {
|
|||||||
TokenURL: providers.Config.GenericTokenURL,
|
TokenURL: providers.Config.GenericTokenURL,
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Initialize the oauth provider
|
||||||
providers.Generic.Init()
|
providers.Generic.Init()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
|
func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
|
||||||
|
// Return the provider based on the provider string
|
||||||
switch provider {
|
switch provider {
|
||||||
case "github":
|
case "github":
|
||||||
return providers.Github
|
return providers.Github
|
||||||
@@ -90,58 +113,103 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (providers *Providers) GetUser(provider string) (string, error) {
|
func (providers *Providers) GetUser(provider string) (string, error) {
|
||||||
|
// Get the email from the provider
|
||||||
switch provider {
|
switch provider {
|
||||||
case "github":
|
case "github":
|
||||||
|
// If the github provider is not configured, return an error
|
||||||
if providers.Github == nil {
|
if providers.Github == nil {
|
||||||
log.Debug().Msg("Github provider not configured")
|
log.Debug().Msg("Github provider not configured")
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the client from the github provider
|
||||||
client := providers.Github.GetClient()
|
client := providers.Github.GetClient()
|
||||||
|
|
||||||
log.Debug().Msg("Got client from github")
|
log.Debug().Msg("Got client from github")
|
||||||
|
|
||||||
|
// Get the email from the github provider
|
||||||
email, emailErr := GetGithubEmail(client)
|
email, emailErr := GetGithubEmail(client)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if emailErr != nil {
|
if emailErr != nil {
|
||||||
return "", emailErr
|
return "", emailErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got email from github")
|
log.Debug().Msg("Got email from github")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return email, nil
|
return email, nil
|
||||||
case "google":
|
case "google":
|
||||||
|
// If the google provider is not configured, return an error
|
||||||
if providers.Google == nil {
|
if providers.Google == nil {
|
||||||
log.Debug().Msg("Google provider not configured")
|
log.Debug().Msg("Google provider not configured")
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the client from the google provider
|
||||||
client := providers.Google.GetClient()
|
client := providers.Google.GetClient()
|
||||||
|
|
||||||
log.Debug().Msg("Got client from google")
|
log.Debug().Msg("Got client from google")
|
||||||
|
|
||||||
|
// Get the email from the google provider
|
||||||
email, emailErr := GetGoogleEmail(client)
|
email, emailErr := GetGoogleEmail(client)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if emailErr != nil {
|
if emailErr != nil {
|
||||||
return "", emailErr
|
return "", emailErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got email from google")
|
log.Debug().Msg("Got email from google")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return email, nil
|
return email, nil
|
||||||
case "tailscale":
|
case "tailscale":
|
||||||
|
// If the tailscale provider is not configured, return an error
|
||||||
if providers.Tailscale == nil {
|
if providers.Tailscale == nil {
|
||||||
log.Debug().Msg("Tailscale provider not configured")
|
log.Debug().Msg("Tailscale provider not configured")
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the client from the tailscale provider
|
||||||
client := providers.Tailscale.GetClient()
|
client := providers.Tailscale.GetClient()
|
||||||
|
|
||||||
log.Debug().Msg("Got client from tailscale")
|
log.Debug().Msg("Got client from tailscale")
|
||||||
|
|
||||||
|
// Get the email from the tailscale provider
|
||||||
email, emailErr := GetTailscaleEmail(client)
|
email, emailErr := GetTailscaleEmail(client)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if emailErr != nil {
|
if emailErr != nil {
|
||||||
return "", emailErr
|
return "", emailErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got email from tailscale")
|
log.Debug().Msg("Got email from tailscale")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return email, nil
|
return email, nil
|
||||||
case "generic":
|
case "generic":
|
||||||
|
// If the generic provider is not configured, return an error
|
||||||
if providers.Generic == nil {
|
if providers.Generic == nil {
|
||||||
log.Debug().Msg("Generic provider not configured")
|
log.Debug().Msg("Generic provider not configured")
|
||||||
return "", nil
|
return "", nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the client from the generic provider
|
||||||
client := providers.Generic.GetClient()
|
client := providers.Generic.GetClient()
|
||||||
|
|
||||||
log.Debug().Msg("Got client from generic")
|
log.Debug().Msg("Got client from generic")
|
||||||
|
|
||||||
|
// Get the email from the generic provider
|
||||||
email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL)
|
email, emailErr := GetGenericEmail(client, providers.Config.GenericUserURL)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if emailErr != nil {
|
if emailErr != nil {
|
||||||
return "", emailErr
|
return "", emailErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got email from generic")
|
log.Debug().Msg("Got email from generic")
|
||||||
|
|
||||||
|
// Return the email
|
||||||
return email, nil
|
return email, nil
|
||||||
default:
|
default:
|
||||||
return "", nil
|
return "", nil
|
||||||
@@ -149,6 +217,7 @@ func (providers *Providers) GetUser(provider string) (string, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (provider *Providers) GetConfiguredProviders() []string {
|
func (provider *Providers) GetConfiguredProviders() []string {
|
||||||
|
// Create a list of the configured providers
|
||||||
providers := []string{}
|
providers := []string{}
|
||||||
if provider.Github != nil {
|
if provider.Github != nil {
|
||||||
providers = append(providers, "github")
|
providers = append(providers, "github")
|
||||||
|
|||||||
@@ -9,48 +9,60 @@ import (
|
|||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// The tailscale email is the loginName
|
||||||
type TailscaleUser struct {
|
type TailscaleUser struct {
|
||||||
LoginName string `json:"loginName"`
|
LoginName string `json:"loginName"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The response from the tailscale user info endpoint
|
||||||
type TailscaleUserInfoResponse struct {
|
type TailscaleUserInfoResponse struct {
|
||||||
Users []TailscaleUser `json:"users"`
|
Users []TailscaleUser `json:"users"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The scopes required for the tailscale provider
|
||||||
func TailscaleScopes() []string {
|
func TailscaleScopes() []string {
|
||||||
return []string{"users:read"}
|
return []string{"users:read"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// The tailscale endpoint
|
||||||
var TailscaleEndpoint = oauth2.Endpoint{
|
var TailscaleEndpoint = oauth2.Endpoint{
|
||||||
TokenURL: "https://api.tailscale.com/api/v2/oauth/token",
|
TokenURL: "https://api.tailscale.com/api/v2/oauth/token",
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetTailscaleEmail(client *http.Client) (string, error) {
|
func GetTailscaleEmail(client *http.Client) (string, error) {
|
||||||
|
// Get the user info from tailscale using the oauth http client
|
||||||
res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users")
|
res, resErr := client.Get("https://api.tailscale.com/api/v2/tailnet/-/users")
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if resErr != nil {
|
if resErr != nil {
|
||||||
return "", resErr
|
return "", resErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Got response from tailscale")
|
log.Debug().Msg("Got response from tailscale")
|
||||||
|
|
||||||
|
// Read the body of the response
|
||||||
body, bodyErr := io.ReadAll(res.Body)
|
body, bodyErr := io.ReadAll(res.Body)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if bodyErr != nil {
|
if bodyErr != nil {
|
||||||
return "", bodyErr
|
return "", bodyErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Read body from tailscale")
|
log.Debug().Msg("Read body from tailscale")
|
||||||
|
|
||||||
|
// Parse the body into a user struct
|
||||||
var users TailscaleUserInfoResponse
|
var users TailscaleUserInfoResponse
|
||||||
|
|
||||||
|
// Unmarshal the body into the user struct
|
||||||
jsonErr := json.Unmarshal(body, &users)
|
jsonErr := json.Unmarshal(body, &users)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if jsonErr != nil {
|
if jsonErr != nil {
|
||||||
return "", jsonErr
|
return "", jsonErr
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debug().Msg("Parsed users from tailscale")
|
log.Debug().Msg("Parsed users from tailscale")
|
||||||
|
|
||||||
|
// Return the email of the first user
|
||||||
return users.Users[0].LoginName, nil
|
return users.Users[0].LoginName, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,22 +2,27 @@ package types
|
|||||||
|
|
||||||
import "tinyauth/internal/oauth"
|
import "tinyauth/internal/oauth"
|
||||||
|
|
||||||
|
// LoginQuery is the query parameters for the login endpoint
|
||||||
type LoginQuery struct {
|
type LoginQuery struct {
|
||||||
RedirectURI string `url:"redirect_uri"`
|
RedirectURI string `url:"redirect_uri"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// LoginRequest is the request body for the login endpoint
|
||||||
type LoginRequest struct {
|
type LoginRequest struct {
|
||||||
Username string `json:"username"`
|
Username string `json:"username"`
|
||||||
Password string `json:"password"`
|
Password string `json:"password"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// User is the struct for a user
|
||||||
type User struct {
|
type User struct {
|
||||||
Username string
|
Username string
|
||||||
Password string
|
Password string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Users is a list of users
|
||||||
type Users []User
|
type Users []User
|
||||||
|
|
||||||
|
// Config is the configuration for the tinyauth server
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Port int `mapstructure:"port" validate:"required"`
|
Port int `mapstructure:"port" validate:"required"`
|
||||||
Address string `validate:"required,ip4_addr" mapstructure:"address"`
|
Address string `validate:"required,ip4_addr" mapstructure:"address"`
|
||||||
@@ -49,6 +54,7 @@ type Config struct {
|
|||||||
LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"`
|
LogLevel int8 `mapstructure:"log-level" validate:"min=-1,max=5"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UserContext is the context for the user
|
||||||
type UserContext struct {
|
type UserContext struct {
|
||||||
Username string
|
Username string
|
||||||
IsLoggedIn bool
|
IsLoggedIn bool
|
||||||
@@ -56,6 +62,7 @@ type UserContext struct {
|
|||||||
Provider string
|
Provider string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// APIConfig is the configuration for the API
|
||||||
type APIConfig struct {
|
type APIConfig struct {
|
||||||
Port int
|
Port int
|
||||||
Address string
|
Address string
|
||||||
@@ -66,6 +73,7 @@ type APIConfig struct {
|
|||||||
DisableContinue bool
|
DisableContinue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuthConfig is the configuration for the providers
|
||||||
type OAuthConfig struct {
|
type OAuthConfig struct {
|
||||||
GithubClientId string
|
GithubClientId string
|
||||||
GithubClientSecret string
|
GithubClientSecret string
|
||||||
@@ -82,35 +90,42 @@ type OAuthConfig struct {
|
|||||||
AppURL string
|
AppURL string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuthRequest is the request for the OAuth endpoint
|
||||||
type OAuthRequest struct {
|
type OAuthRequest struct {
|
||||||
Provider string `uri:"provider" binding:"required"`
|
Provider string `uri:"provider" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// OAuthProviders is the struct for the OAuth providers
|
||||||
type OAuthProviders struct {
|
type OAuthProviders struct {
|
||||||
Github *oauth.OAuth
|
Github *oauth.OAuth
|
||||||
Google *oauth.OAuth
|
Google *oauth.OAuth
|
||||||
Microsoft *oauth.OAuth
|
Microsoft *oauth.OAuth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnauthorizedQuery is the query parameters for the unauthorized endpoint
|
||||||
type UnauthorizedQuery struct {
|
type UnauthorizedQuery struct {
|
||||||
Username string `url:"username"`
|
Username string `url:"username"`
|
||||||
Resource string `url:"resource"`
|
Resource string `url:"resource"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SessionCookie is the cookie for the session (exculding the expiry)
|
||||||
type SessionCookie struct {
|
type SessionCookie struct {
|
||||||
Username string
|
Username string
|
||||||
Provider string
|
Provider string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TinyauthLabels is the labels for the tinyauth container
|
||||||
type TinyauthLabels struct {
|
type TinyauthLabels struct {
|
||||||
OAuthWhitelist []string
|
OAuthWhitelist []string
|
||||||
Users []string
|
Users []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TailscaleQuery is the query parameters for the tailscale endpoint
|
||||||
type TailscaleQuery struct {
|
type TailscaleQuery struct {
|
||||||
Code int `url:"code"`
|
Code int `url:"code"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Proxy is the uri parameters for the proxy endpoint
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
Proxy string `uri:"proxy" binding:"required"`
|
Proxy string `uri:"proxy" binding:"required"`
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,20 +12,32 @@ import (
|
|||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Parses a list of comma separated users in a struct
|
||||||
func ParseUsers(users string) (types.Users, error) {
|
func ParseUsers(users string) (types.Users, error) {
|
||||||
log.Debug().Msg("Parsing users")
|
log.Debug().Msg("Parsing users")
|
||||||
|
|
||||||
|
// Create a new users struct
|
||||||
var usersParsed types.Users
|
var usersParsed types.Users
|
||||||
|
|
||||||
|
// Split the users by comma
|
||||||
userList := strings.Split(users, ",")
|
userList := strings.Split(users, ",")
|
||||||
|
|
||||||
|
// Check if there are any users
|
||||||
if len(userList) == 0 {
|
if len(userList) == 0 {
|
||||||
return types.Users{}, errors.New("invalid user format")
|
return types.Users{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Loop through the users and split them by colon
|
||||||
for _, user := range userList {
|
for _, user := range userList {
|
||||||
|
// Split the user by colon
|
||||||
userSplit := strings.Split(user, ":")
|
userSplit := strings.Split(user, ":")
|
||||||
|
|
||||||
|
// Check if the user is in the correct format
|
||||||
if len(userSplit) != 2 {
|
if len(userSplit) != 2 {
|
||||||
return types.Users{}, errors.New("invalid user format")
|
return types.Users{}, errors.New("invalid user format")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Append the user to the users struct
|
||||||
usersParsed = append(usersParsed, types.User{
|
usersParsed = append(usersParsed, types.User{
|
||||||
Username: userSplit[0],
|
Username: userSplit[0],
|
||||||
Password: userSplit[1],
|
Password: userSplit[1],
|
||||||
@@ -34,43 +46,61 @@ func ParseUsers(users string) (types.Users, error) {
|
|||||||
|
|
||||||
log.Debug().Msg("Parsed users")
|
log.Debug().Msg("Parsed users")
|
||||||
|
|
||||||
|
// Return the users struct
|
||||||
return usersParsed, nil
|
return usersParsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Root url parses parses a hostname and returns the root domain (e.g. sub1.sub2.domain.com -> domain.com)
|
||||||
func GetRootURL(urlSrc string) (string, error) {
|
func GetRootURL(urlSrc string) (string, error) {
|
||||||
|
// Make sure the url is valid
|
||||||
urlParsed, parseErr := url.Parse(urlSrc)
|
urlParsed, parseErr := url.Parse(urlSrc)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if parseErr != nil {
|
if parseErr != nil {
|
||||||
return "", parseErr
|
return "", parseErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Split the hostname by period
|
||||||
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
|
urlSplitted := strings.Split(urlParsed.Hostname(), ".")
|
||||||
|
|
||||||
|
// Get the last part of the url
|
||||||
urlFinal := strings.Join(urlSplitted[1:], ".")
|
urlFinal := strings.Join(urlSplitted[1:], ".")
|
||||||
|
|
||||||
|
// Return the root domain
|
||||||
return urlFinal, nil
|
return urlFinal, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Reads a file and returns the contents
|
||||||
func ReadFile(file string) (string, error) {
|
func ReadFile(file string) (string, error) {
|
||||||
|
// Check if the file exists
|
||||||
_, statErr := os.Stat(file)
|
_, statErr := os.Stat(file)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if statErr != nil {
|
if statErr != nil {
|
||||||
return "", statErr
|
return "", statErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Read the file
|
||||||
data, readErr := os.ReadFile(file)
|
data, readErr := os.ReadFile(file)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if readErr != nil {
|
if readErr != nil {
|
||||||
return "", readErr
|
return "", readErr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the file contents
|
||||||
return string(data), nil
|
return string(data), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parses a file into a comma separated list of users
|
||||||
func ParseFileToLine(content string) string {
|
func ParseFileToLine(content string) string {
|
||||||
|
// Split the content by newline
|
||||||
lines := strings.Split(content, "\n")
|
lines := strings.Split(content, "\n")
|
||||||
|
|
||||||
|
// Create a list of users
|
||||||
users := make([]string, 0)
|
users := make([]string, 0)
|
||||||
|
|
||||||
|
// Loop through the lines, trimming the whitespace and appending to the users list
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
if strings.TrimSpace(line) == "" {
|
if strings.TrimSpace(line) == "" {
|
||||||
continue
|
continue
|
||||||
@@ -79,63 +109,92 @@ func ParseFileToLine(content string) string {
|
|||||||
users = append(users, strings.TrimSpace(line))
|
users = append(users, strings.TrimSpace(line))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the users as a comma separated string
|
||||||
return strings.Join(users, ",")
|
return strings.Join(users, ",")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the secret from the config or file
|
||||||
func GetSecret(conf string, file string) string {
|
func GetSecret(conf string, file string) string {
|
||||||
|
// If neither the config or file is set, return an empty string
|
||||||
if conf == "" && file == "" {
|
if conf == "" && file == "" {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the config is set, return the config (environment variable)
|
||||||
if conf != "" {
|
if conf != "" {
|
||||||
return conf
|
return conf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the file is set, read the file
|
||||||
contents, err := ReadFile(file)
|
contents, err := ReadFile(file)
|
||||||
|
|
||||||
|
// Check if there was an error
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the contents of the file
|
||||||
return contents
|
return contents
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get the users from the config or file
|
||||||
func GetUsers(conf string, file string) (types.Users, error) {
|
func GetUsers(conf string, file string) (types.Users, error) {
|
||||||
|
// Create a string to store the users
|
||||||
var users string
|
var users string
|
||||||
|
|
||||||
|
// If neither the config or file is set, return an empty users struct
|
||||||
if conf == "" && file == "" {
|
if conf == "" && file == "" {
|
||||||
return types.Users{}, nil
|
return types.Users{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the config (environment) is set, append the users to the users string
|
||||||
if conf != "" {
|
if conf != "" {
|
||||||
log.Debug().Msg("Using users from config")
|
log.Debug().Msg("Using users from config")
|
||||||
users += conf
|
users += conf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If the file is set, read the file and append the users to the users string
|
||||||
if file != "" {
|
if file != "" {
|
||||||
|
// Read the file
|
||||||
fileContents, fileErr := ReadFile(file)
|
fileContents, fileErr := ReadFile(file)
|
||||||
|
|
||||||
|
// If there isn't an error we can append the users to the users string
|
||||||
if fileErr == nil {
|
if fileErr == nil {
|
||||||
log.Debug().Msg("Using users from file")
|
log.Debug().Msg("Using users from file")
|
||||||
|
|
||||||
|
// Append the users to the users string
|
||||||
if users != "" {
|
if users != "" {
|
||||||
users += ","
|
users += ","
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the file contents into a comma separated list of users
|
||||||
users += ParseFileToLine(fileContents)
|
users += ParseFileToLine(fileContents)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the parsed users
|
||||||
return ParseUsers(users)
|
return ParseUsers(users)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check if any of the OAuth providers are configured based on the client id and secret
|
||||||
func OAuthConfigured(config types.Config) bool {
|
func OAuthConfigured(config types.Config) bool {
|
||||||
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "")
|
return (config.GithubClientId != "" && config.GithubClientSecret != "") || (config.GoogleClientId != "" && config.GoogleClientSecret != "") || (config.GenericClientId != "" && config.GenericClientSecret != "") || (config.TailscaleClientId != "" && config.TailscaleClientSecret != "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Parse the docker labels to the tinyauth labels struct
|
||||||
func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
|
func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
|
||||||
|
// Create a new tinyauth labels struct
|
||||||
var tinyauthLabels types.TinyauthLabels
|
var tinyauthLabels types.TinyauthLabels
|
||||||
|
|
||||||
|
// Loop through the labels
|
||||||
for label, value := range labels {
|
for label, value := range labels {
|
||||||
|
|
||||||
|
// Check if the label is in the tinyauth labels
|
||||||
if slices.Contains(constants.TinyauthLabels, label) {
|
if slices.Contains(constants.TinyauthLabels, label) {
|
||||||
|
|
||||||
log.Debug().Str("label", label).Msg("Found label")
|
log.Debug().Str("label", label).Msg("Found label")
|
||||||
|
|
||||||
|
// Add the label value to the tinyauth labels struct
|
||||||
switch label {
|
switch label {
|
||||||
case "tinyauth.oauth.whitelist":
|
case "tinyauth.oauth.whitelist":
|
||||||
tinyauthLabels.OAuthWhitelist = strings.Split(value, ",")
|
tinyauthLabels.OAuthWhitelist = strings.Split(value, ",")
|
||||||
@@ -144,5 +203,7 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Return the tinyauth labels
|
||||||
return tinyauthLabels
|
return tinyauthLabels
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user