mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-10-27 20:25:41 +00:00
224 lines
5.6 KiB
Go
224 lines
5.6 KiB
Go
package handlers
|
|
|
|
import (
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
"tinyauth/internal/types"
|
|
"tinyauth/internal/utils"
|
|
|
|
"github.com/gin-gonic/gin"
|
|
"github.com/google/go-querystring/query"
|
|
"github.com/rs/zerolog/log"
|
|
)
|
|
|
|
func (h *Handlers) OAuthURLHandler(c *gin.Context) {
|
|
var request types.OAuthRequest
|
|
|
|
err := c.BindUri(&request)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to bind URI")
|
|
c.JSON(400, gin.H{
|
|
"status": 400,
|
|
"message": "Bad Request",
|
|
})
|
|
return
|
|
}
|
|
|
|
log.Debug().Msg("Got OAuth request")
|
|
|
|
// Check if provider exists
|
|
provider := h.Providers.GetProvider(request.Provider)
|
|
|
|
if provider == nil {
|
|
c.JSON(404, gin.H{
|
|
"status": 404,
|
|
"message": "Not Found",
|
|
})
|
|
return
|
|
}
|
|
|
|
log.Debug().Str("provider", request.Provider).Msg("Got provider")
|
|
|
|
// Create state
|
|
state := provider.GenerateState()
|
|
|
|
// Get auth URL
|
|
authURL := provider.GetAuthURL(state)
|
|
|
|
log.Debug().Msg("Got auth URL")
|
|
|
|
// Set CSRF cookie
|
|
c.SetCookie(h.Config.CsrfCookieName, state, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
|
|
|
|
// Get redirect URI
|
|
redirectURI := c.Query("redirect_uri")
|
|
|
|
// Set redirect cookie if redirect URI is provided
|
|
if redirectURI != "" {
|
|
log.Debug().Str("redirectURI", redirectURI).Msg("Setting redirect cookie")
|
|
c.SetCookie(h.Config.RedirectCookieName, redirectURI, int(time.Hour.Seconds()), "/", "", h.Config.CookieSecure, true)
|
|
}
|
|
|
|
// Return auth URL
|
|
c.JSON(200, gin.H{
|
|
"status": 200,
|
|
"message": "OK",
|
|
"url": authURL,
|
|
})
|
|
}
|
|
|
|
func (h *Handlers) OAuthCallbackHandler(c *gin.Context) {
|
|
var providerName types.OAuthRequest
|
|
|
|
err := c.BindUri(&providerName)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to bind URI")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
log.Debug().Interface("provider", providerName.Provider).Msg("Got provider name")
|
|
|
|
// Get state
|
|
state := c.Query("state")
|
|
|
|
// Get CSRF cookie
|
|
csrfCookie, err := c.Cookie(h.Config.CsrfCookieName)
|
|
|
|
if err != nil {
|
|
log.Debug().Msg("No CSRF cookie")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
log.Debug().Str("csrfCookie", csrfCookie).Msg("Got CSRF cookie")
|
|
|
|
// Check if CSRF cookie is valid
|
|
if csrfCookie != state {
|
|
log.Warn().Msg("Invalid CSRF cookie or CSRF cookie does not match with the state")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
// Clean up CSRF cookie
|
|
c.SetCookie(h.Config.CsrfCookieName, "", -1, "/", "", h.Config.CookieSecure, true)
|
|
|
|
// Get code
|
|
code := c.Query("code")
|
|
|
|
log.Debug().Msg("Got code")
|
|
|
|
// Get provider
|
|
provider := h.Providers.GetProvider(providerName.Provider)
|
|
|
|
if provider == nil {
|
|
c.Redirect(http.StatusTemporaryRedirect, "/not-found")
|
|
return
|
|
}
|
|
|
|
log.Debug().Str("provider", providerName.Provider).Msg("Got provider")
|
|
|
|
// Exchange token (authenticates user)
|
|
_, err = provider.ExchangeToken(code)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to exchange token")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
log.Debug().Msg("Got token")
|
|
|
|
// Get user
|
|
user, err := h.Providers.GetUser(providerName.Provider)
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to get user")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
log.Debug().Interface("user", user).Msg("Got user")
|
|
|
|
// Check that email is not empty
|
|
if user.Email == "" {
|
|
log.Error().Msg("Email is empty")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
// Email is not whitelisted
|
|
if !h.Auth.EmailWhitelisted(user.Email) {
|
|
log.Warn().Str("email", user.Email).Msg("Email not whitelisted")
|
|
queries, err := query.Values(types.UnauthorizedQuery{
|
|
Username: user.Email,
|
|
})
|
|
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to build queries")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
|
|
}
|
|
|
|
log.Debug().Msg("Email whitelisted")
|
|
|
|
// Get username
|
|
var username string
|
|
|
|
if user.PreferredUsername != "" {
|
|
username = user.PreferredUsername
|
|
} else {
|
|
username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1])
|
|
}
|
|
|
|
// Get name
|
|
var name string
|
|
|
|
if user.Name != "" {
|
|
name = user.Name
|
|
} else {
|
|
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
|
|
}
|
|
|
|
// Create session cookie
|
|
h.Auth.CreateSessionCookie(c, &types.SessionCookie{
|
|
Username: username,
|
|
Name: name,
|
|
Email: user.Email,
|
|
Provider: providerName.Provider,
|
|
OAuthGroups: utils.CoalesceToString(user.Groups),
|
|
})
|
|
|
|
// Check if we have a redirect URI
|
|
redirectCookie, err := c.Cookie(h.Config.RedirectCookieName)
|
|
|
|
if err != nil {
|
|
log.Debug().Msg("No redirect cookie")
|
|
c.Redirect(http.StatusTemporaryRedirect, h.Config.AppURL)
|
|
return
|
|
}
|
|
|
|
log.Debug().Str("redirectURI", redirectCookie).Msg("Got redirect URI")
|
|
|
|
queries, err := query.Values(types.LoginQuery{
|
|
RedirectURI: redirectCookie,
|
|
})
|
|
|
|
if err != nil {
|
|
log.Error().Err(err).Msg("Failed to build queries")
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
|
|
return
|
|
}
|
|
|
|
log.Debug().Msg("Got redirect query")
|
|
|
|
// Clean up redirect cookie
|
|
c.SetCookie(h.Config.RedirectCookieName, "", -1, "/", "", h.Config.CookieSecure, true)
|
|
|
|
// Redirect to continue with the redirect URI
|
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", h.Config.AppURL, queries.Encode()))
|
|
}
|