feat: map info from OIDC claims to headers

This commit is contained in:
Stavros
2025-04-25 16:41:45 +03:00
parent 5e4e2ddbd9
commit dca09a3d9d
13 changed files with 117 additions and 63 deletions

View File

@@ -111,6 +111,11 @@ var rootCmd = &cobra.Command{
LoginMaxRetries: config.LoginMaxRetries, LoginMaxRetries: config.LoginMaxRetries,
} }
// Create hooks config
hooksConfig := types.HooksConfig{
Domain: domain,
}
// Create docker service // Create docker service
docker := docker.NewDocker() docker := docker.NewDocker()
@@ -128,7 +133,7 @@ var rootCmd = &cobra.Command{
providers.Init() providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers // Create handlers
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)

View File

@@ -10,7 +10,7 @@ import { useAppContext } from "../context/app-context";
import { Trans, useTranslation } from "react-i18next"; import { Trans, useTranslation } from "react-i18next";
export const LogoutPage = () => { export const LogoutPage = () => {
const { isLoggedIn, username, oauth, provider } = useUserContext(); const { isLoggedIn, oauth, provider, email } = useUserContext();
const { genericName } = useAppContext(); const { genericName } = useAppContext();
const { t } = useTranslation(); const { t } = useTranslation();
@@ -56,7 +56,7 @@ export const LogoutPage = () => {
values={{ values={{
provider: provider:
provider === "generic" ? genericName : capitalize(provider), provider === "generic" ? genericName : capitalize(provider),
username: username, username: email,
}} }}
/> />
) : ( ) : (
@@ -65,7 +65,7 @@ export const LogoutPage = () => {
t={t} t={t}
components={{ Code: <Code /> }} components={{ Code: <Code /> }}
values={{ values={{
username: username, username: email,
}} }}
/> />
)} )}

View File

@@ -3,6 +3,8 @@ import { z } from "zod";
export const userContextSchema = z.object({ export const userContextSchema = z.object({
isLoggedIn: z.boolean(), isLoggedIn: z.boolean(),
username: z.string(), username: z.string(),
name: z.string(),
email: z.string(),
oauth: z.boolean(), oauth: z.boolean(),
provider: z.string(), provider: z.string(),
totpPending: z.boolean(), totpPending: z.boolean(),

View File

@@ -45,6 +45,11 @@ var authConfig = types.AuthConfig{
LoginMaxRetries: 0, LoginMaxRetries: 0,
} }
// Simple hooks config for tests
var hooksConfig = types.HooksConfig{
Domain: "localhost",
}
// Cookie // Cookie
var cookie string var cookie string
@@ -83,7 +88,7 @@ func getAPI(t *testing.T) *api.API {
providers.Init() providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers service // Create handlers service
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)

View File

@@ -160,6 +160,8 @@ func (auth *Auth) CreateSessionCookie(c *gin.Context, data *types.SessionCookie)
// Set data // Set data
session.Values["username"] = data.Username session.Values["username"] = data.Username
session.Values["name"] = data.Name
session.Values["email"] = data.Email
session.Values["provider"] = data.Provider session.Values["provider"] = data.Provider
session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix() session.Values["expiry"] = time.Now().Add(time.Duration(sessionExpiry) * time.Second).Unix()
session.Values["totpPending"] = data.TotpPending session.Values["totpPending"] = data.TotpPending
@@ -211,14 +213,23 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
return types.SessionCookie{}, err return types.SessionCookie{}, err
} }
log.Debug().Interface("session", session).Msg("Got session")
// Get data from session // Get data from session
username, usernameOk := session.Values["username"].(string) username, usernameOk := session.Values["username"].(string)
email, emailOk := session.Values["email"].(string)
name, nameOk := session.Values["name"].(string)
provider, providerOK := session.Values["provider"].(string) provider, providerOK := session.Values["provider"].(string)
expiry, expiryOk := session.Values["expiry"].(int64) expiry, expiryOk := session.Values["expiry"].(int64)
totpPending, totpPendingOk := session.Values["totpPending"].(bool) totpPending, totpPendingOk := session.Values["totpPending"].(bool)
if !usernameOk || !providerOK || !expiryOk || !totpPendingOk { if !usernameOk || !providerOK || !expiryOk || !totpPendingOk || !emailOk || !nameOk {
log.Warn().Msg("Session cookie is missing data") log.Warn().Msg("Session cookie is invalid")
// If any data is missing, delete the session cookie
auth.DeleteSessionCookie(c)
// Return empty cookie
return types.SessionCookie{}, nil return types.SessionCookie{}, nil
} }
@@ -233,11 +244,13 @@ func (auth *Auth) GetSessionCookie(c *gin.Context) (types.SessionCookie, error)
return types.SessionCookie{}, nil return types.SessionCookie{}, nil
} }
log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Msg("Parsed cookie") log.Debug().Str("username", username).Str("provider", provider).Int64("expiry", expiry).Bool("totpPending", totpPending).Str("name", name).Str("email", email).Msg("Parsed cookie")
// Return the cookie // Return the cookie
return types.SessionCookie{ return types.SessionCookie{
Username: username, Username: username,
Name: name,
Email: email,
Provider: provider, Provider: provider,
TotpPending: totpPending, TotpPending: totpPending,
}, nil }, nil

View File

@@ -8,13 +8,9 @@ var TinyauthLabels = []string{
"tinyauth.headers", "tinyauth.headers",
} }
// Claims are the OIDC supported claims // Claims are the OIDC supported claims (including preferd username for some reason)
type Claims struct { type Claims struct {
Name string `json:"name"` Name string `json:"name"`
FamilyName string `json:"family_name"` Email string `json:"email"`
GivenName string `json:"given_name"` PreferredUsername string `json:"preferred_username"`
MiddleName string `json:"middle_name"`
Nickname string `json:"nickname"`
Picture string `json:"picture"`
Email string `json:"email"`
} }

View File

@@ -6,8 +6,7 @@ import (
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
apiTypes "github.com/docker/docker/api/types" container "github.com/docker/docker/api/types/container"
containerTypes "github.com/docker/docker/api/types/container"
"github.com/docker/docker/client" "github.com/docker/docker/client"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
@@ -38,9 +37,9 @@ func (docker *Docker) Init() error {
return nil return nil
} }
func (docker *Docker) GetContainers() ([]apiTypes.Container, error) { func (docker *Docker) GetContainers() ([]container.Summary, error) {
// Get the list of containers // Get the list of containers
containers, err := docker.Client.ContainerList(docker.Context, containerTypes.ListOptions{}) containers, err := docker.Client.ContainerList(docker.Context, container.ListOptions{})
// Check if there was an error // Check if there was an error
if err != nil { if err != nil {
@@ -51,13 +50,13 @@ func (docker *Docker) GetContainers() ([]apiTypes.Container, error) {
return containers, nil return containers, nil
} }
func (docker *Docker) InspectContainer(containerId string) (apiTypes.ContainerJSON, error) { func (docker *Docker) InspectContainer(containerId string) (container.InspectResponse, error) {
// Inspect the container // Inspect the container
inspect, err := docker.Client.ContainerInspect(docker.Context, containerId) inspect, err := docker.Client.ContainerInspect(docker.Context, containerId)
// Check if there was an error // Check if there was an error
if err != nil { if err != nil {
return apiTypes.ContainerJSON{}, err return container.InspectResponse{}, err
} }
// Return the inspect // Return the inspect

View File

@@ -10,6 +10,7 @@ import (
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
@@ -183,8 +184,9 @@ func (h *Handlers) AuthHandler(c *gin.Context) {
return return
} }
// Set the user header
c.Header("Remote-User", userContext.Username) c.Header("Remote-User", userContext.Username)
c.Header("Remote-Name", userContext.Name)
c.Header("Remote-Email", userContext.Email)
// Set the rest of the headers // Set the rest of the headers
for key, value := range labels.Headers { for key, value := range labels.Headers {
@@ -310,6 +312,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
// Set totp pending cookie // Set totp pending cookie
h.Auth.CreateSessionCookie(c, &types.SessionCookie{ h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: login.Username, Username: login.Username,
Name: utils.Capitalize(login.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
Provider: "username", Provider: "username",
TotpPending: true, TotpPending: true,
}) })
@@ -328,6 +332,8 @@ func (h *Handlers) LoginHandler(c *gin.Context) {
// Create session cookie with username as provider // Create session cookie with username as provider
h.Auth.CreateSessionCookie(c, &types.SessionCookie{ h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: login.Username, Username: login.Username,
Name: utils.Capitalize(login.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(login.Username), h.Config.Domain),
Provider: "username", Provider: "username",
}) })
@@ -402,6 +408,8 @@ func (h *Handlers) TotpHandler(c *gin.Context) {
// Create session cookie with username as provider // Create session cookie with username as provider
h.Auth.CreateSessionCookie(c, &types.SessionCookie{ h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: user.Username, Username: user.Username,
Name: utils.Capitalize(user.Username),
Email: fmt.Sprintf("%s@%s", strings.ToLower(user.Username), h.Config.Domain),
Provider: "username", Provider: "username",
}) })
@@ -465,6 +473,8 @@ func (h *Handlers) UserHandler(c *gin.Context) {
Status: 200, Status: 200,
IsLoggedIn: userContext.IsLoggedIn, IsLoggedIn: userContext.IsLoggedIn,
Username: userContext.Username, Username: userContext.Username,
Name: userContext.Name,
Email: userContext.Email,
Provider: userContext.Provider, Provider: userContext.Provider,
Oauth: userContext.OAuth, Oauth: userContext.OAuth,
TotpPending: userContext.TotpPending, TotpPending: userContext.TotpPending,
@@ -654,9 +664,29 @@ func (h *Handlers) OauthCallbackHandler(c *gin.Context) {
log.Debug().Msg("Email whitelisted") log.Debug().Msg("Email whitelisted")
// Get username
var username string
if user.PreferredUsername != "" {
username = user.PreferredUsername
} else {
username = fmt.Sprintf("%s_%s", strings.Split(user.Email, "@")[0], strings.Split(user.Email, "@")[1])
}
// Get name
var name string
if user.Name != "" {
name = user.Name
} else {
name = fmt.Sprintf("%s (%s)", utils.Capitalize(strings.Split(user.Email, "@")[0]), strings.Split(user.Email, "@")[1])
}
// Create session cookie (also cleans up redirect cookie) // Create session cookie (also cleans up redirect cookie)
h.Auth.CreateSessionCookie(c, &types.SessionCookie{ h.Auth.CreateSessionCookie(c, &types.SessionCookie{
Username: user.Email, Username: username,
Name: name,
Email: user.Email,
Provider: providerName.Provider, Provider: providerName.Provider,
}) })

View File

@@ -1,22 +1,27 @@
package hooks package hooks
import ( import (
"fmt"
"strings"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewHooks(auth *auth.Auth, providers *providers.Providers) *Hooks { func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{ return &Hooks{
Config: config,
Auth: auth, Auth: auth,
Providers: providers, Providers: providers,
} }
} }
type Hooks struct { type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth Auth *auth.Auth
Providers *providers.Providers Providers *providers.Providers
} }
@@ -36,11 +41,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) { if user != nil && hooks.Auth.CheckPassword(*user, basic.Password) {
// Return user context since we are logged in with basic auth // Return user context since we are logged in with basic auth
return types.UserContext{ return types.UserContext{
Username: basic.Username, Username: basic.Username,
IsLoggedIn: true, Name: utils.Capitalize(basic.Username),
OAuth: false, Email: fmt.Sprintf("%s@%s", strings.ToLower(basic.Username), hooks.Config.Domain),
Provider: "basic", IsLoggedIn: true,
TotpPending: false, Provider: "basic",
} }
} }
@@ -50,13 +55,7 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
if err != nil { if err != nil {
log.Error().Err(err).Msg("Failed to get session cookie") log.Error().Err(err).Msg("Failed to get session cookie")
// Return empty context // Return empty context
return types.UserContext{ return types.UserContext{}
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
} }
// Check if session cookie has totp pending // Check if session cookie has totp pending
@@ -65,8 +64,8 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Return empty context since we are pending totp // Return empty context since we are pending totp
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
IsLoggedIn: false, Name: cookie.Name,
OAuth: false, Email: cookie.Email,
Provider: cookie.Provider, Provider: cookie.Provider,
TotpPending: true, TotpPending: true,
} }
@@ -82,11 +81,11 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// It exists so we are logged in // It exists so we are logged in
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
IsLoggedIn: true, Name: cookie.Name,
OAuth: false, Email: cookie.Email,
Provider: "username", IsLoggedIn: true,
TotpPending: false, Provider: "username",
} }
} }
} }
@@ -108,33 +107,22 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
hooks.Auth.DeleteSessionCookie(c) hooks.Auth.DeleteSessionCookie(c)
// Return empty context // Return empty context
return types.UserContext{ return types.UserContext{}
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
} }
log.Debug().Msg("Email is whitelisted") log.Debug().Msg("Email is whitelisted")
// Return user context since we are logged in with oauth // Return user context since we are logged in with oauth
return types.UserContext{ return types.UserContext{
Username: cookie.Username, Username: cookie.Username,
IsLoggedIn: true, Name: cookie.Name,
OAuth: true, Email: cookie.Email,
Provider: cookie.Provider, IsLoggedIn: true,
TotpPending: false, OAuth: true,
Provider: cookie.Provider,
} }
} }
// Neither basic auth or oauth is set so we return an empty context // Neither basic auth or oauth is set so we return an empty context
return types.UserContext{ return types.UserContext{}
Username: "",
IsLoggedIn: false,
OAuth: false,
Provider: "",
TotpPending: false,
}
} }

View File

@@ -33,6 +33,8 @@ type UserContextResponse struct {
Message string `json:"message"` Message string `json:"message"`
IsLoggedIn bool `json:"isLoggedIn"` IsLoggedIn bool `json:"isLoggedIn"`
Username string `json:"username"` Username string `json:"username"`
Name string `json:"name"`
Email string `json:"email"`
Provider string `json:"provider"` Provider string `json:"provider"`
Oauth bool `json:"oauth"` Oauth bool `json:"oauth"`
TotpPending bool `json:"totpPending"` TotpPending bool `json:"totpPending"`

View File

@@ -78,3 +78,8 @@ type AuthConfig struct {
LoginTimeout int LoginTimeout int
LoginMaxRetries int LoginMaxRetries int
} }
// HooksConfig is the configuration for the hooks service
type HooksConfig struct {
Domain string
}

View File

@@ -25,6 +25,8 @@ type OAuthProviders struct {
// SessionCookie is the cookie for the session (exculding the expiry) // SessionCookie is the cookie for the session (exculding the expiry)
type SessionCookie struct { type SessionCookie struct {
Username string Username string
Name string
Email string
Provider string Provider string
TotpPending bool TotpPending bool
} }
@@ -40,6 +42,8 @@ type TinyauthLabels struct {
// UserContext is the context for the user // UserContext is the context for the user
type UserContext struct { type UserContext struct {
Username string Username string
Name string
Email string
IsLoggedIn bool IsLoggedIn bool
OAuth bool OAuth bool
Provider string Provider string

View File

@@ -323,3 +323,8 @@ func CheckWhitelist(whitelist string, str string) bool {
// Return false if no match was found // Return false if no match was found
return false return false
} }
// Capitalize just the first letter of a string
func Capitalize(str string) string {
return strings.ToUpper(string([]rune(str)[0])) + string([]rune(str)[1:])
}