feat: map info from OIDC claims to headers (#122)

* refactor: return all values from body in the providers

* refactor: only accept claims following the OIDC spec

* feat: map info from OIDC claims to headers

* feat: add support for required oauth groups

* fix: bot suggestions

* feat: get claims from github and google

* fix: close body correctly
This commit is contained in:
Stavros
2025-04-30 19:57:49 +03:00
committed by GitHub
parent f824b84787
commit a9e8bf89a9
24 changed files with 528 additions and 210 deletions

View File

@@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{
LoginMaxRetries: 0,
}
// Simple hooks config for tests
var hooksConfig = types.HooksConfig{
Domain: "localhost",
}
// Cookie
var cookie string
@@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API {
providers.Init()
// Create hooks service
hooks := hooks.NewHooks(auth, providers)
hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers service
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)

View File

@@ -160,9 +160,12 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
// Set data
session.Values["username"] = data.Username
session.Values["name"] = data.Name
session.Values["email"] = data.Email
session.Values["provider"] = data.Provider
session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix()
session.Values["totpPending"] = data.TotpPending
session.Values["oauthGroups"] = data.OAuthGroups
// Save session
err = session.Save(c.Request, c.Writer)
@@ -211,14 +214,24 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
return types.SessionCookie{}, err
}
log.Debug().Msg("Got session")
// Get data from session
username, usernameOk := session.Values["username"].(string)
email, emailOk := session.Values["email"].(string)
name, nameOk := session.Values["name"].(string)
provider, providerOK := session.Values["provider"].(string)
expiry, expiryOk := session.Values["expiry"].(int64)
totpPending, totpPendingOk := session.Values["totpPending"].(bool)
oauthGroups, oauthGroupsOk := session.Values["oauthGroups"].(string)
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk {
log.Warn().Msg("Session cookie is missing data")
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk || !oauthGroupsOk {
log.Warn().Msg("Session cookie is invalid")
// If any data is missing, delete the session cookie
auth.DeleteSessionCookie(c)
// Return empty cookie
return types.SessionCookie{}, nil
}
@@ -233,13 +246,16 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
return types.SessionCookie{}, nil
}
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie")
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Str("oauthGroups", oauthGroups).Msg("Parsed cookie")
// Return the cookie
return types.SessionCookie{
Username: username,
Name: name,
Email: email,
Provider: provider,
TotpPending: totpPending,
OAuthGroups: oauthGroups,
}, nil
}
@@ -248,48 +264,52 @@ func (auth *Auth) UserAuthConfigured() bool {
return len(auth.Config.Users) > 0
}
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext) (bool, error) {
// Get headers
host := c.Request.Header.Get("X-Forwarded-Host")
// Get app id
appId := strings.Split(host, ".")[0]
// Get the container labels
labels, err := auth.Docker.GetLabels(appId)
// If there is an error, return false
if err != nil {
return false, err
}
func (auth *Auth) ResourceAllowed(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool {
// Check if oauth is allowed
if context.OAuth {
log.Debug().Msg("Checking OAuth whitelist")
return utils.CheckWhitelist(labels.OAuthWhitelist, context.Username), nil
return utils.CheckWhitelist(labels.OAuthWhitelist, context.Email)
}
// Check users
log.Debug().Msg("Checking users")
return utils.CheckWhitelist(labels.Users, context.Username), nil
return utils.CheckWhitelist(labels.Users, context.Username)
}
func (auth *Auth) AuthEnabled(c *gin.Context) (bool, error) {
func (auth *Auth) OAuthGroup(c *gin.Context, context types.UserContext, labels types.TinyauthLabels) bool {
// Check if groups are required
if labels.OAuthGroups == "" {
return true
}
// Check if we are using the generic oauth provider
if context.Provider != "generic" {
log.Debug().Msg("Not using generic provider, skipping group check")
return true
}
// Split the groups by comma (no need to parse since they are from the API response)
oauthGroups := strings.Split(context.OAuthGroups, ",")
// For every group check if it is in the required groups
for _, group := range oauthGroups {
if utils.CheckWhitelist(labels.OAuthGroups, group) {
log.Debug().Str("group", group).Msg("Group is in required groups")
return true
}
}
// No groups matched
log.Debug().Msg("No groups matched")
// Return false
return false
}
func (auth *Auth) AuthEnabled(c *gin.Context, labels types.TinyauthLabels) (bool, error) {
// Get headers
uri := c.Request.Header.Get("X-Forwarded-Uri")
host := c.Request.Header.Get("X-Forwarded-Host")
// Get app id
appId := strings.Split(host, ".")[0]
// Get the container labels
labels, err := auth.Docker.GetLabels(appId)
// If there is an error, auth enabled
if err != nil {
return true, err
}
// Check if the allowed label is empty
if labels.Allowed == "" {

View File

@@ -6,4 +6,13 @@ var TinyauthLabels = []string{
"tinyauth.users",
"tinyauth.allowed",
"tinyauth.headers",
"tinyauth.oauth.groups",
}
// Claims are the OIDC supported claims (including preferd username for some reason)
type Claims struct {
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups []string `json:"groups"`
}

View File

@@ -6,8 +6,7 @@ import (
"tinyauth/internal/types"
"tinyauth/internal/utils"
apiTypes "github.com/docker/docker/api/types"
containerTypes "github.com/docker/docker/api/types/container"
container "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/rs/zerolog/log"
)
@@ -38,9 +37,9 @@ func (docker *Docker) Init() error {
return nil
}
func (docker *Docker) GetContainers() ([]apiTypes.Container, error) {
func (docker *Docker) GetContainers() ([]container.Summary, error) {
// Get the list of containers
containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{})
containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
// Check if there was an error
if err != nil {
@@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) {
return containers, nil
}
func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) {
func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
// Inspect the container
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
// Check if there was an error
if err != nil {
return apiTypes.ContainerJSON{}, err
return container.InspectResponse{}, err
}
// Return the inspect

View File

@@ -10,6 +10,7 @@ import (
"tinyauth/internal/hooks"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"github.com/google/go-querystring/query"
@@ -68,12 +69,15 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
proto := c.Request.Header.Get("X-Forwarded-Proto")
host := c.Request.Header.Get("X-Forwarded-Host")
// Check if auth is enabled
authEnabled, err := h.Auth.AuthEnabled(c)
// Get the app id
appId := strings.Split(host, ".")[0]
// Get the container labels
labels, err := h.Docker.GetLabels(appId)
// Check if there was an error
if err != nil {
log.Error().Err(err).Msg("Failed to check if app is allowed")
log.Error().Err(err).Msg("Failed to get container labels")
if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{
@@ -87,11 +91,8 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return
}
// Get the app id
appId := strings.Split(host, ".")[0]
// Get the container labels
labels, err := h.Docker.GetLabels(appId)
// Check if auth is enabled
authEnabled, err := h.Auth.AuthEnabled(c, labels)
// Check if there was an error
if err != nil {
@@ -113,7 +114,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
if !authEnabled {
for key, value := range labels.Headers {
log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
c.Header(key, value)
c.Header(key, utils.SanitizeHeader(value))
}
c.JSON(200, gin.H{
"status": 200,
@@ -130,23 +131,7 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
log.Debug().Msg("Authenticated")
// Check if user is allowed to access subdomain, if request is nginx.example.com the subdomain (resource) is nginx
appAllowed, err := h.Auth.ResourceAllowed(c, userContext)
// Check if there was an error
if err != nil {
log.Error().Err(err).Msg("Failed to check if app is allowed")
if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
appAllowed := h.Auth.ResourceAllowed(c, userContext, labels)
log.Debug().Bool("appAllowed", appAllowed).Msg("Checking if app is allowed")
@@ -165,11 +150,20 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return
}
// Build query
queries, err := query.Values(types.UnauthorizedQuery{
Username: userContext.Username,
// Values
values := types.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
})
}
// Use either username or email
if userContext.OAuth {
values.Username = userContext.Email
} else {
values.Username = userContext.Username
}
// Build query
queries, err := query.Values(values)
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
if err != nil {
@@ -183,13 +177,65 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return
}
// Set the user header
c.Header("Remote-User", userContext.Username)
log.Debug().Interface("labels", labels).Msg("Got labels")
// Check if user is in required groups
groupOk := h.Auth.OAuthGroup(c, userContext, labels)
log.Debug().Bool("groupOk", groupOk).Msg("Checking if user is in required groups")
// The user is not allowed to access the app
if !groupOk {
log.Warn().Str("username", userContext.Username).Str("host", host).Msg("User is not in required groups")
// Set WWW-Authenticate header
c.Header("WWW-Authenticate", "Basic realm=\"tinyauth\"")
if proxy.Proxy == "nginx" || !isBrowser {
c.JSON(401, gin.H{
"status": 401,
"message": "Unauthorized",
})
return
}
// Values
values := types.UnauthorizedQuery{
Resource: strings.Split(host, ".")[0],
GroupErr: true,
}
// Use either username or email
if userContext.OAuth {
values.Username = userContext.Email
} else {
values.Username = userContext.Username
}
// Build query
queries, err := query.Values(values)
// Handle error (no need to check for nginx/headers since we are sure we are using caddy/traefik)
if err != nil {
log.Error().Err(err).Msg("Failed to build queries")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
// We are using caddy/traefik so redirect
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", h.Config.AppURL, queries.Encode()))
return
}
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
// Set the rest of the headers
for key, value := range labels.Headers {
log.Debug().Str("key", key).Str("value", value).Msg("Setting header")
c.Header(key, value)
c.Header(key, utils.SanitizeHeader(value))
}
// The user is allowed to access the app
@@ -310,6 +356,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
// Set totp pending cookie
h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: login.Username,
Name: utils.Capitalize(login.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
Provider: "username",
TotpPending: true,
})
@@ -328,6 +376,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
// Create session cookie with username as provider
h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: login.Username,
Name: utils.Capitalize(login.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
Provider: "username",
})
@@ -402,6 +452,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) {
// Create session cookie with username as provider
h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain),
Provider: "username",
})
@@ -465,6 +517,8 @@ func (h *Handlers) UserHandler(c *gin.Context) {
Status: 200,
IsLoggedIn: userContext.IsLoggedIn,
Username: userContext.Username,
Name: userContext.Name,
Email: userContext.Email,
Provider: userContext.Provider,
Oauth: userContext.OAuth,
TotpPending: userContext.TotpPending,
@@ -613,25 +667,32 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
return
}
// Get email
email, err := h.Providers.GetUser(providerName.Provider)
log.Debug().Str("email", email).Msg("Got email")
// Get user
user, err := h.Providers.GetUser(providerName.Provider)
// Handle error
if err != nil {
log.Error().Err(err).Msg("Failed to get email")
log.Error().Msg("Failed to get user")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
log.Debug().Msg("Got user")
// Check that email is not empty
if user.Email == "" {
log.Error().Msg("Email is empty")
c.Redirect(http.StatusPermanentRedirect, fmt.Sprintf("%s/error", h.Config.AppURL))
return
}
// Email is not whitelisted
if !h.Auth.EmailWhitelisted(email) {
log.Warn().Str("email", email).Msg("Email not whitelisted")
if !h.Auth.EmailWhitelisted(user.Email) {
log.Warn().Str("email", user.Email).Msg("Email not whitelisted")
// Build query
queries, err := query.Values(types.UnauthorizedQuery{
Username: email,
Username: user.Email,
})
// Handle error
@@ -647,10 +708,31 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
log.Debug().Msg("Email whitelisted")
// Get username
var username string
if user.PreferredUsername != "" {
username = user.PreferredUsername
} else {
username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1])
}
// Get name
var name string
if user.Name != "" {
name = user.Name
} else {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
// Create session cookie (also cleans up redirect cookie)
h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: email,
Provider: providerName.Provider,
Username: username,
Name: name,
Email: user.Email,
Provider: providerName.Provider,
OAuthGroups: strings.Join(user.Groups, ","),
})
// Check if we have a redirect URI

View File

@@ -1,22 +1,27 @@
package hooks
import (
"fmt"
"strings"
"tinyauth/internal/auth"
"tinyauth/internal/providers"
"tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"github.com/rs/zerolog/log"
)
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks {
func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{
Config: config,
Auth: auth,
Providers: providers,
}
}
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
@@ -36,11 +41,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
// Return user context since we are logged in with basic auth
return types.UserContext{
Username: basic.Username,
IsLoggedIn: true,
OAuth: false,
Provider: "basic",
TotpPending: false,
Username: basic.Username,
Name: utils.Capitalize(basic.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain),
IsLoggedIn: true,
Provider: "basic",
}
}
@@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if err != nil {
log.Error().Err(err).Msg("Failed to get session cookie")
// Return empty context
return types.UserContext{
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
return types.UserContext{}
}
// Check if session cookie has totp pending
@@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Return empty context since we are pending totp
return types.UserContext{
Username: cookie.Username,
IsLoggedIn: false,
OAuth: false,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
TotpPending: true,
}
@@ -82,11 +81,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// It exists so we are logged in
return types.UserContext{
Username: cookie.Username,
IsLoggedIn: true,
OAuth: false,
Provider: "username",
TotpPending: false,
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
IsLoggedIn: true,
Provider: "username",
}
}
}
@@ -108,13 +107,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
hooks.Auth.DeleteSessionCookie(c)
// Return empty context
return types.UserContext{
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
return types.UserContext{}
}
log.Debug().Msg("Email is whitelisted")
@@ -122,19 +115,15 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Return user context since we are logged in with oauth
return types.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
IsLoggedIn: true,
OAuth: true,
Provider: cookie.Provider,
TotpPending: false,
OAuthGroups: cookie.OAuthGroups,
}
}
// Neither basic auth or oauth is set so we return an empty context
return types.UserContext{
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
return types.UserContext{}
}

View File

@@ -4,24 +4,25 @@ import (
"encoding/json"
"io"
"net/http"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
// We are assuming that the generic provider will return a JSON object with an email field
type GenericUserInfoResponse struct {
Email string `json:"email"`
}
func GetGenericUser(client *http.Client, url string) (constants.Claims, error) {
// Create user struct
var user constants.Claims
func GetGenericEmail(client *http.Client, url string) (string, error) {
// Using the oauth client get the user info url
res, err := client.Get(url)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got response from generic provider")
// Read the body of the response
@@ -29,24 +30,21 @@ func GetGenericEmail(client *http.Client, url string) (string, error) {
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Read body from generic provider")
// Parse the body into a user struct
var user GenericUserInfoResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &user)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Parsed user from generic provider")
// Return the email
return user.Email, nil
// Return the user
return user, nil
}

View File

@@ -5,51 +5,96 @@ import (
"errors"
"io"
"net/http"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
// Github has a different response than the generic provider
type GithubUserInfoResponse []struct {
// Response for the github email endpoint
type GithubEmailResponse []struct {
Email string `json:"email"`
Primary bool `json:"primary"`
}
// The scopes required for the github provider
func GithubScopes() []string {
return []string{"user:email"}
// Response for the github user endpoint
type GithubUserInfoResponse struct {
Login string `json:"login"`
Name string `json:"name"`
}
func GetGithubEmail(client *http.Client) (string, error) {
// Get the user emails from github using the oauth http client
res, err := client.Get("https://api.github.com/user/emails")
// The scopes required for the github provider
func GithubScopes() []string {
return []string{"user:email", "read:user"}
}
func GetGithubUser(client *http.Client) (constants.Claims, error) {
// Create user struct
var user constants.Claims
// Get the user info from github using the oauth http client
res, err := client.Get("https://api.github.com/user")
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Got response from github")
defer res.Body.Close()
log.Debug().Msg("Got user response from github")
// Read the body of the response
body, err := io.ReadAll(res.Body)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Read body from github")
log.Debug().Msg("Read user body from github")
// Parse the body into a user struct
var emails GithubUserInfoResponse
var userInfo GithubUserInfoResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &userInfo)
// Check if there was an error
if err != nil {
return user, err
}
// Get the user emails from github using the oauth http client
res, err = client.Get("https://api.github.com/user/emails")
// Check if there was an error
if err != nil {
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got email response from github")
// Read the body of the response
body, err = io.ReadAll(res.Body)
// Check if there was an error
if err != nil {
return user, err
}
log.Debug().Msg("Read email body from github")
// Parse the body into a user struct
var emails GithubEmailResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &emails)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Parsed emails from github")
@@ -57,10 +102,26 @@ func GetGithubEmail(client *http.Client) (string, error) {
// Find and return the primary email
for _, email := range emails {
if email.Primary {
return email.Email, nil
// Set the email then exit
user.Email = email.Email
break
}
}
// User does not have a primary email?
return "", errors.New("no primary email found")
// If no primary email was found, use the first available email
if len(emails) == 0 {
return user, errors.New("no emails found")
}
// Set the email if it is not set picking the first one
if user.Email == "" {
user.Email = emails[0].Email
}
// Set the username and name
user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name
// Return
return user, nil
}

View File

@@ -4,29 +4,37 @@ import (
"encoding/json"
"io"
"net/http"
"strings"
"tinyauth/internal/constants"
"github.com/rs/zerolog/log"
)
// Google works the same as the generic provider
// Response for the google user endpoint
type GoogleUserInfoResponse struct {
Email string `json:"email"`
Name string `json:"name"`
}
// The scopes required for the google provider
func GoogleScopes() []string {
return []string{"https://www.googleapis.com/auth/userinfo.email"}
return []string{"https://www.googleapis.com/auth/userinfo.email", "https://www.googleapis.com/auth/userinfo.profile"}
}
func GetGoogleEmail(client *http.Client) (string, error) {
func GetGoogleUser(client *http.Client) (constants.Claims, error) {
// Create user struct
var user constants.Claims
// Get the user info from google using the oauth http client
res, err := client.Get("https://www.googleapis.com/userinfo/v2/me")
// Check if there was an error
if err != nil {
return "", err
return user, err
}
defer res.Body.Close()
log.Debug().Msg("Got response from google")
// Read the body of the response
@@ -34,24 +42,29 @@ func GetGoogleEmail(client *http.Client) (string, error) {
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Read body from google")
// Parse the body into a user struct
var user GoogleUserInfoResponse
// Create a new user info struct
var userInfo GoogleUserInfoResponse
// Unmarshal the body into the user struct
err = json.Unmarshal(body, &user)
err = json.Unmarshal(body, &userInfo)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Parsed user from google")
// Return the email
return user.Email, nil
// Map the user info to the user struct
user.PreferredUsername = strings.Split(userInfo.Email, "@")[0]
user.Name = userInfo.Name
user.Email = userInfo.Email
// Return the user
return user, nil
}

View File

@@ -2,6 +2,7 @@ package providers
import (
"fmt"
"tinyauth/internal/constants"
"tinyauth/internal/oauth"
"tinyauth/internal/types"
@@ -93,14 +94,17 @@ func (providers *Providers) GetProvider(provider string) *oauth.OAuth {
}
}
func (providers *Providers) GetUser(provider string) (string, error) {
// Get the email from the provider
func (providers *Providers) GetUser(provider string) (constants.Claims, error) {
// Create user struct
var user constants.Claims
// Get the user from the provider
switch provider {
case "github":
// If the github provider is not configured, return an error
if providers.Github == nil {
log.Debug().Msg("Github provider not configured")
return "", nil
return user, nil
}
// Get the client from the github provider
@@ -108,23 +112,23 @@ func (providers *Providers) GetUser(provider string) (string, error) {
log.Debug().Msg("Got client from github")
// Get the email from the github provider
email, err := GetGithubEmail(client)
// Get the user from the github provider
user, err := GetGithubUser(client)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Got email from github")
log.Debug().Msg("Got user from github")
// Return the email
return email, nil
// Return the user
return user, nil
case "google":
// If the google provider is not configured, return an error
if providers.Google == nil {
log.Debug().Msg("Google provider not configured")
return "", nil
return user, nil
}
// Get the client from the google provider
@@ -132,23 +136,23 @@ func (providers *Providers) GetUser(provider string) (string, error) {
log.Debug().Msg("Got client from google")
// Get the email from the google provider
email, err := GetGoogleEmail(client)
// Get the user from the google provider
user, err := GetGoogleUser(client)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Got email from google")
log.Debug().Msg("Got user from google")
// Return the email
return email, nil
// Return the user
return user, nil
case "generic":
// If the generic provider is not configured, return an error
if providers.Generic == nil {
log.Debug().Msg("Generic provider not configured")
return "", nil
return user, nil
}
// Get the client from the generic provider
@@ -156,20 +160,20 @@ func (providers *Providers) GetUser(provider string) (string, error) {
log.Debug().Msg("Got client from generic")
// Get the email from the generic provider
email, err := GetGenericEmail(client, providers.Config.GenericUserURL)
// Get the user from the generic provider
user, err := GetGenericUser(client, providers.Config.GenericUserURL)
// Check if there was an error
if err != nil {
return "", err
return user, err
}
log.Debug().Msg("Got email from generic")
log.Debug().Msg("Got user from generic")
// Return the email
return email, nil
return user, nil
default:
return "", nil
return user, nil
}
}

View File

@@ -20,6 +20,7 @@ type OAuthRequest struct {
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
}
// Proxy is the uri parameters for the proxy endpoint
@@ -33,6 +34,8 @@ type UserContextResponse struct {
Message string `json:"message"`
IsLoggedIn bool `json:"isLoggedIn"`
Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
Provider string `json:"provider"`
Oauth bool `json:"oauth"`
TotpPending bool `json:"totpPending"`

View File

@@ -78,3 +78,8 @@ type AuthConfig struct {
LoginTimeout int
LoginMaxRetries int
}
// HooksConfig is the configuration for the hooks service
type HooksConfig struct {
Domain string
}

View File

@@ -25,8 +25,11 @@ type OAuthProviders struct {
// SessionCookie is the cookie for the session (exculding the expiry)
type SessionCookie struct {
Username string
Name string
Email string
Provider string
TotpPending bool
OAuthGroups string
}
// TinyauthLabels is the labels for the tinyauth container
@@ -35,15 +38,19 @@ type TinyauthLabels struct {
Users string
Allowed string
Headers map[string]string
OAuthGroups string
}
// UserContext is the context for the user
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
}
// LoginAttempt tracks information about login attempts for rate limiting

View File

@@ -204,6 +204,8 @@ func GetTinyauthLabels(labels map[string]string) types.TinyauthLabels {
}
tinyauthLabels.Headers[headerSplit[0]] = headerSplit[1]
}
case "tinyauth.oauth.groups":
tinyauthLabels.OAuthGroups = value
}
}
}
@@ -323,3 +325,22 @@ func CheckWhitelist(whitelist string, str string) bool {
// Return false if no match was found
return false
}
// Capitalize just the first letter of a string
func Capitalize(str string) string {
if len(str) == 0 {
return ""
}
return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:])
}
// Sanitize header removes all control characters from a string
func SanitizeHeader(header string) string {
return strings.Map(func(r rune) rune {
// Allow only printable ASCII characters (32-126) and safe whitespace (space, tab)
if r == ' ' || r == '\t' || (r >= 32 && r <= 126) {
return r
}
return -1
}, header)
}

View File

@@ -467,3 +467,65 @@ func TestCheckWhitelist(t *testing.T) {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
// Test capitalize
func TestCapitalize(t *testing.T) {
t.Log("Testing capitalize with a valid string")
// Create variables
str := "test"
expected := "Test"
// Test the capitalize function
result := utils.Capitalize(str)
// Check if the result is equal to the expected
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing capitalize with an empty string")
// Create variables
str = ""
expected = ""
// Test the capitalize function
result = utils.Capitalize(str)
// Check if the result is equal to the expected
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}
// Test the header sanitizer
func TestSanitizeHeader(t *testing.T) {
t.Log("Testing sanitize header with a valid string")
// Create variables
str := "X-Header=value"
expected := "X-Header=value"
// Test the sanitize header function
result := utils.SanitizeHeader(str)
// Check if the result is equal to the expected
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
t.Log("Testing sanitize header with an invalid string")
// Create variables
str = "X-Header=val\nue"
expected = "X-Header=value"
// Test the sanitize header function
result = utils.SanitizeHeader(str)
// Check if the result is equal to the expected
if result != expected {
t.Fatalf("Expected %v, got %v", expected, result)
}
}