mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-11-03 07:35:44 +00:00
Compare commits
7 Commits
v3.6.1
...
v3.6.2-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5854d973ea | ||
|
|
f25ab72747 | ||
|
|
2233557990 | ||
|
|
d3bec635f8 | ||
|
|
6519644fc1 | ||
|
|
736f65b7b2 | ||
|
|
63d39b5500 |
@@ -12,6 +12,7 @@ import {
|
|||||||
} from "../ui/form";
|
} from "../ui/form";
|
||||||
import { Button } from "../ui/button";
|
import { Button } from "../ui/button";
|
||||||
import { loginSchema, LoginSchema } from "@/schemas/login-schema";
|
import { loginSchema, LoginSchema } from "@/schemas/login-schema";
|
||||||
|
import z from "zod";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
onSubmit: (data: LoginSchema) => void;
|
onSubmit: (data: LoginSchema) => void;
|
||||||
@@ -22,6 +23,11 @@ export const LoginForm = (props: Props) => {
|
|||||||
const { onSubmit, loading } = props;
|
const { onSubmit, loading } = props;
|
||||||
const { t } = useTranslation();
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
z.config({
|
||||||
|
customError: (iss) =>
|
||||||
|
iss.input === undefined ? t("fieldRequired") : t("invalidInput"),
|
||||||
|
});
|
||||||
|
|
||||||
const form = useForm<LoginSchema>({
|
const form = useForm<LoginSchema>({
|
||||||
resolver: zodResolver(loginSchema),
|
resolver: zodResolver(loginSchema),
|
||||||
});
|
});
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ import {
|
|||||||
import { zodResolver } from "@hookform/resolvers/zod";
|
import { zodResolver } from "@hookform/resolvers/zod";
|
||||||
import { useForm } from "react-hook-form";
|
import { useForm } from "react-hook-form";
|
||||||
import { totpSchema, TotpSchema } from "@/schemas/totp-schema";
|
import { totpSchema, TotpSchema } from "@/schemas/totp-schema";
|
||||||
|
import { useTranslation } from "react-i18next";
|
||||||
|
import z from "zod";
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
formId: string;
|
formId: string;
|
||||||
@@ -17,6 +19,12 @@ interface Props {
|
|||||||
|
|
||||||
export const TotpForm = (props: Props) => {
|
export const TotpForm = (props: Props) => {
|
||||||
const { formId, onSubmit, loading } = props;
|
const { formId, onSubmit, loading } = props;
|
||||||
|
const { t } = useTranslation();
|
||||||
|
|
||||||
|
z.config({
|
||||||
|
customError: (iss) =>
|
||||||
|
iss.input === undefined ? t("fieldRequired") : t("invalidInput"),
|
||||||
|
});
|
||||||
|
|
||||||
const form = useForm<TotpSchema>({
|
const form = useForm<TotpSchema>({
|
||||||
resolver: zodResolver(totpSchema),
|
resolver: zodResolver(totpSchema),
|
||||||
|
|||||||
@@ -51,5 +51,7 @@
|
|||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable."
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
|
"fieldRequired": "This field is required",
|
||||||
|
"invalidInput": "Invalid input"
|
||||||
}
|
}
|
||||||
@@ -51,5 +51,7 @@
|
|||||||
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
"failedToFetchProvidersTitle": "Failed to load authentication providers. Please check your configuration.",
|
||||||
"errorTitle": "An error occurred",
|
"errorTitle": "An error occurred",
|
||||||
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
"errorSubtitle": "An error occurred while trying to perform this action. Please check the console for more information.",
|
||||||
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable."
|
"forgotPasswordMessage": "You can reset your password by changing the `USERS` environment variable.",
|
||||||
|
"fieldRequired": "This field is required",
|
||||||
|
"invalidInput": "Invalid input"
|
||||||
}
|
}
|
||||||
1
go.mod
1
go.mod
@@ -17,6 +17,7 @@ require (
|
|||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 // indirect
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.2 // indirect
|
||||||
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
github.com/charmbracelet/colorprofile v0.2.3-0.20250311203215-f60798e515dc // indirect
|
||||||
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
|
github.com/charmbracelet/x/cellbuf v0.0.13 // indirect
|
||||||
github.com/containerd/errdefs v1.0.0 // indirect
|
github.com/containerd/errdefs v1.0.0 // indirect
|
||||||
|
|||||||
2
go.sum
2
go.sum
@@ -26,6 +26,8 @@ github.com/catppuccin/go v0.3.0 h1:d+0/YicIq+hSTo5oPuRi5kOpqkVA5tAsU6dNhvRu+aY=
|
|||||||
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
github.com/catppuccin/go v0.3.0/go.mod h1:8IHJuMGaUUjQM82qBrGNBv7LFq6JI3NnQCF6MOlZjpc=
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8=
|
||||||
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE=
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.2 h1:rIfFVxEf1QsI7E1ZHfp/B4DF/6QBAUhmgkxc0H7Zss8=
|
||||||
|
github.com/cenkalti/backoff/v5 v5.0.2/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw=
|
||||||
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
github.com/charmbracelet/bubbles v0.21.0 h1:9TdC97SdRVg/1aaXNVWfFH3nnLAwOXr8Fn6u6mfQdFs=
|
||||||
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
github.com/charmbracelet/bubbles v0.21.0/go.mod h1:HF+v6QUR4HkEpz62dx7ym2xc71/KBHg+zKwJtMw+qtg=
|
||||||
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
github.com/charmbracelet/bubbletea v1.3.4 h1:kCg7B+jSCFPLYRA52SDZjr51kG/fMUEoPoZrkaDHyoI=
|
||||||
|
|||||||
@@ -50,7 +50,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
|
|||||||
|
|
||||||
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
// If there was an error getting the session, it might be invalid so let's clear it and retry
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying")
|
log.Error().Err(err).Msg("Invalid session, clearing cookie and retrying")
|
||||||
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
|
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
|
||||||
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -79,7 +79,7 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
|
|||||||
log.Debug().Str("username", username).Msg("Checking LDAP for user")
|
log.Debug().Str("username", username).Msg("Checking LDAP for user")
|
||||||
userDN, err := auth.LDAP.Search(username)
|
userDN, err := auth.LDAP.Search(username)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
log.Error().Err(err).Str("username", username).Msg("Failed to find user in LDAP")
|
||||||
return types.UserSearch{}
|
return types.UserSearch{}
|
||||||
}
|
}
|
||||||
return types.UserSearch{
|
return types.UserSearch{
|
||||||
@@ -88,7 +88,9 @@ func (auth *Auth) SearchUser(username string) types.UserSearch {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return types.UserSearch{}
|
return types.UserSearch{
|
||||||
|
Type: "unknown",
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
||||||
@@ -105,7 +107,7 @@ func (auth *Auth) VerifyUser(search types.UserSearch, password string) bool {
|
|||||||
|
|
||||||
err := auth.LDAP.Bind(search.Username, password)
|
err := auth.LDAP.Bind(search.Username, password)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
log.Error().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,7 +372,7 @@ func (auth *Auth) AuthEnabled(uri string, labels types.Labels) (bool, error) {
|
|||||||
|
|
||||||
// If there is an error, invalid regex, auth enabled
|
// If there is an error, invalid regex, auth enabled
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Msg("Invalid regex")
|
log.Error().Err(err).Msg("Invalid regex")
|
||||||
return true, err
|
return true, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -399,7 +401,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
|||||||
for _, blocked := range labels.IP.Block {
|
for _, blocked := range labels.IP.Block {
|
||||||
res, err := utils.FilterIP(blocked, ip)
|
res, err := utils.FilterIP(blocked, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
log.Error().Err(err).Str("item", blocked).Msg("Invalid IP/CIDR in block list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
@@ -412,7 +414,7 @@ func (auth *Auth) CheckIP(labels types.Labels, ip string) bool {
|
|||||||
for _, allowed := range labels.IP.Allow {
|
for _, allowed := range labels.IP.Allow {
|
||||||
res, err := utils.FilterIP(allowed, ip)
|
res, err := utils.FilterIP(allowed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
log.Error().Err(err).Str("item", allowed).Msg("Invalid IP/CIDR in allow list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
@@ -436,7 +438,7 @@ func (auth *Auth) BypassedIP(labels types.Labels, ip string) bool {
|
|||||||
for _, bypassed := range labels.IP.Bypass {
|
for _, bypassed := range labels.IP.Bypass {
|
||||||
res, err := utils.FilterIP(bypassed, ip)
|
res, err := utils.FilterIP(bypassed, ip)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
log.Error().Err(err).Str("item", bypassed).Msg("Invalid IP/CIDR in bypass list")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if res {
|
if res {
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ type Claims struct {
|
|||||||
Name string `json:"name"`
|
Name string `json:"name"`
|
||||||
Email string `json:"email"`
|
Email string `json:"email"`
|
||||||
PreferredUsername string `json:"preferred_username"`
|
PreferredUsername string `json:"preferred_username"`
|
||||||
Groups []string `json:"groups"`
|
Groups any `json:"groups"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// Version information
|
// Version information
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package server_test
|
package handlers_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -189,7 +189,7 @@ func (h *Handlers) OAuthCallbackHandler(c *gin.Context) {
|
|||||||
Name: name,
|
Name: name,
|
||||||
Email: user.Email,
|
Email: user.Email,
|
||||||
Provider: providerName.Provider,
|
Provider: providerName.Provider,
|
||||||
OAuthGroups: strings.Join(user.Groups, ","),
|
OAuthGroups: utils.CoalesceToString(user.Groups),
|
||||||
})
|
})
|
||||||
|
|
||||||
// Check if we have a redirect URI
|
// Check if we have a redirect URI
|
||||||
|
|||||||
@@ -40,10 +40,7 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
proto := c.Request.Header.Get("X-Forwarded-Proto")
|
proto := c.Request.Header.Get("X-Forwarded-Proto")
|
||||||
host := c.Request.Header.Get("X-Forwarded-Host")
|
host := c.Request.Header.Get("X-Forwarded-Host")
|
||||||
|
|
||||||
// Remove the port from the host if it exists
|
|
||||||
hostPortless := strings.Split(host, ":")[0] // *lol*
|
hostPortless := strings.Split(host, ":")[0] // *lol*
|
||||||
|
|
||||||
// Get the id
|
|
||||||
id := strings.Split(hostPortless, ".")[0]
|
id := strings.Split(hostPortless, ".")[0]
|
||||||
|
|
||||||
labels, err := h.Docker.GetLabels(id, hostPortless)
|
labels, err := h.Docker.GetLabels(id, hostPortless)
|
||||||
@@ -66,10 +63,10 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
|
|
||||||
ip := c.ClientIP()
|
ip := c.ClientIP()
|
||||||
|
|
||||||
// Check if the IP is in bypass list
|
|
||||||
if h.Auth.BypassedIP(labels, ip) {
|
if h.Auth.BypassedIP(labels, ip) {
|
||||||
headersParsed := utils.ParseHeaders(labels.Headers)
|
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
|
||||||
|
headersParsed := utils.ParseHeaders(labels.Headers)
|
||||||
for key, value := range headersParsed {
|
for key, value := range headersParsed {
|
||||||
log.Debug().Str("key", key).Msg("Setting header")
|
log.Debug().Str("key", key).Msg("Setting header")
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
@@ -87,7 +84,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if the IP is allowed/blocked
|
|
||||||
if !h.Auth.CheckIP(labels, ip) {
|
if !h.Auth.CheckIP(labels, ip) {
|
||||||
if proxy.Proxy == "nginx" || !isBrowser {
|
if proxy.Proxy == "nginx" || !isBrowser {
|
||||||
c.JSON(403, gin.H{
|
c.JSON(403, gin.H{
|
||||||
@@ -113,7 +109,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if auth is enabled
|
|
||||||
authEnabled, err := h.Auth.AuthEnabled(uri, labels)
|
authEnabled, err := h.Auth.AuthEnabled(uri, labels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("Failed to check if app is allowed")
|
log.Error().Err(err).Msg("Failed to check if app is allowed")
|
||||||
@@ -129,8 +124,9 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If auth is not enabled, return 200
|
|
||||||
if !authEnabled {
|
if !authEnabled {
|
||||||
|
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
|
|
||||||
headersParsed := utils.ParseHeaders(labels.Headers)
|
headersParsed := utils.ParseHeaders(labels.Headers)
|
||||||
for key, value := range headersParsed {
|
for key, value := range headersParsed {
|
||||||
log.Debug().Str("key", key).Msg("Setting header")
|
log.Debug().Str("key", key).Msg("Setting header")
|
||||||
@@ -150,7 +146,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get user context
|
|
||||||
userContext := h.Hooks.UseUserContext(c)
|
userContext := h.Hooks.UseUserContext(c)
|
||||||
|
|
||||||
// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth
|
// If we are using basic auth, we need to check if the user has totp and if it does then disable basic auth
|
||||||
@@ -159,7 +154,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
userContext.IsLoggedIn = false
|
userContext.IsLoggedIn = false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check if user is logged in
|
|
||||||
if userContext.IsLoggedIn {
|
if userContext.IsLoggedIn {
|
||||||
log.Debug().Msg("Authenticated")
|
log.Debug().Msg("Authenticated")
|
||||||
|
|
||||||
@@ -200,7 +194,6 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check groups if using OAuth
|
|
||||||
if userContext.OAuth {
|
if userContext.OAuth {
|
||||||
groupOk := h.Auth.OAuthGroup(c, userContext, labels)
|
groupOk := h.Auth.OAuthGroup(c, userContext, labels)
|
||||||
|
|
||||||
@@ -239,19 +232,18 @@ func (h *Handlers) ProxyHandler(c *gin.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.Header("Authorization", c.Request.Header.Get("Authorization"))
|
||||||
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
|
c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
|
||||||
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
|
c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
|
||||||
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
|
c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
|
||||||
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
|
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
|
||||||
|
|
||||||
// Set the rest of the headers
|
|
||||||
parsedHeaders := utils.ParseHeaders(labels.Headers)
|
parsedHeaders := utils.ParseHeaders(labels.Headers)
|
||||||
for key, value := range parsedHeaders {
|
for key, value := range parsedHeaders {
|
||||||
log.Debug().Str("key", key).Msg("Setting header")
|
log.Debug().Str("key", key).Msg("Setting header")
|
||||||
c.Header(key, value)
|
c.Header(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set basic auth headers if configured
|
|
||||||
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
|
if labels.Basic.Username != "" && utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File) != "" {
|
||||||
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
|
log.Debug().Str("username", labels.Basic.Username).Msg("Setting basic auth headers")
|
||||||
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
|
c.Header("Authorization", fmt.Sprintf("Basic %s", utils.GetBasicAuth(labels.Basic.Username, utils.GetSecret(labels.Basic.Password.Plain, labels.Basic.Password.File))))
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
"tinyauth/internal/auth"
|
"tinyauth/internal/auth"
|
||||||
|
"tinyauth/internal/oauth"
|
||||||
"tinyauth/internal/providers"
|
"tinyauth/internal/providers"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/types"
|
||||||
"tinyauth/internal/utils"
|
"tinyauth/internal/utils"
|
||||||
@@ -27,28 +28,92 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
||||||
// Get session cookie and basic auth
|
|
||||||
cookie, err := hooks.Auth.GetSessionCookie(c)
|
cookie, err := hooks.Auth.GetSessionCookie(c)
|
||||||
|
var provider *oauth.OAuth
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("Failed to get session cookie")
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie.TotpPending {
|
||||||
|
log.Debug().Msg("Totp pending")
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
Provider: cookie.Provider,
|
||||||
|
TotpPending: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie.Provider == "username" {
|
||||||
|
log.Debug().Msg("Provider is username")
|
||||||
|
|
||||||
|
userSearch := hooks.Auth.SearchUser(cookie.Username)
|
||||||
|
|
||||||
|
if userSearch.Type == "unknown" {
|
||||||
|
log.Warn().Str("username", cookie.Username).Msg("User does not exist")
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Str("type", userSearch.Type).Msg("User exists")
|
||||||
|
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
IsLoggedIn: true,
|
||||||
|
Provider: "username",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Provider is not username")
|
||||||
|
|
||||||
|
provider = hooks.Providers.GetProvider(cookie.Provider)
|
||||||
|
|
||||||
|
if provider != nil {
|
||||||
|
log.Debug().Msg("Provider exists")
|
||||||
|
|
||||||
|
if !hooks.Auth.EmailWhitelisted(cookie.Email) {
|
||||||
|
log.Warn().Str("email", cookie.Email).Msg("Email is not whitelisted")
|
||||||
|
hooks.Auth.DeleteSessionCookie(c)
|
||||||
|
goto basic
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Email is whitelisted")
|
||||||
|
|
||||||
|
return types.UserContext{
|
||||||
|
Username: cookie.Username,
|
||||||
|
Name: cookie.Name,
|
||||||
|
Email: cookie.Email,
|
||||||
|
IsLoggedIn: true,
|
||||||
|
OAuth: true,
|
||||||
|
Provider: cookie.Provider,
|
||||||
|
OAuthGroups: cookie.OAuthGroups,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
basic:
|
||||||
|
log.Debug().Msg("Trying basic auth")
|
||||||
|
|
||||||
basic := hooks.Auth.GetBasicAuth(c)
|
basic := hooks.Auth.GetBasicAuth(c)
|
||||||
|
|
||||||
// Check if basic auth is set
|
|
||||||
if basic != nil {
|
if basic != nil {
|
||||||
log.Debug().Msg("Got basic auth")
|
log.Debug().Msg("Got basic auth")
|
||||||
|
|
||||||
userSearch := hooks.Auth.SearchUser(basic.Username)
|
userSearch := hooks.Auth.SearchUser(basic.Username)
|
||||||
|
|
||||||
if userSearch.Type == "" {
|
if userSearch.Type == "unkown" {
|
||||||
log.Error().Str("username", basic.Username).Msg("User does not exist")
|
log.Error().Str("username", basic.Username).Msg("Basic auth user does not exist")
|
||||||
return types.UserContext{}
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify the user
|
|
||||||
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
|
if !hooks.Auth.VerifyUser(userSearch, basic.Password) {
|
||||||
log.Error().Str("username", basic.Username).Msg("Password incorrect")
|
log.Error().Str("username", basic.Username).Msg("Basic auth user password incorrect")
|
||||||
return types.UserContext{}
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the user type
|
|
||||||
if userSearch.Type == "ldap" {
|
if userSearch.Type == "ldap" {
|
||||||
log.Debug().Msg("User is LDAP")
|
log.Debug().Msg("User is LDAP")
|
||||||
|
|
||||||
@@ -75,73 +140,5 @@ func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
|
|||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check cookie error after basic auth
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("Failed to get session cookie")
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie.TotpPending {
|
|
||||||
log.Debug().Msg("Totp pending")
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
TotpPending: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if session cookie is username/password auth
|
|
||||||
if cookie.Provider == "username" {
|
|
||||||
log.Debug().Msg("Provider is username")
|
|
||||||
|
|
||||||
userSearch := hooks.Auth.SearchUser(cookie.Username)
|
|
||||||
|
|
||||||
if userSearch.Type == "" {
|
|
||||||
log.Error().Str("username", cookie.Username).Msg("User does not exist")
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Str("type", userSearch.Type).Msg("User exists")
|
|
||||||
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
Provider: "username",
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msg("Provider is not username")
|
|
||||||
|
|
||||||
// The provider is not username so we need to check if it is an oauth provider
|
|
||||||
provider := hooks.Providers.GetProvider(cookie.Provider)
|
|
||||||
|
|
||||||
// If we have a provider with this name
|
|
||||||
if provider != nil {
|
|
||||||
log.Debug().Msg("Provider exists")
|
|
||||||
|
|
||||||
// If the email is not whitelisted we delete the cookie and return an empty context
|
|
||||||
if !hooks.Auth.EmailWhitelisted(cookie.Email) {
|
|
||||||
log.Error().Str("email", cookie.Email).Msg("Email is not whitelisted")
|
|
||||||
hooks.Auth.DeleteSessionCookie(c)
|
|
||||||
return types.UserContext{}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug().Msg("Email is whitelisted")
|
|
||||||
|
|
||||||
return types.UserContext{
|
|
||||||
Username: cookie.Username,
|
|
||||||
Name: cookie.Name,
|
|
||||||
Email: cookie.Email,
|
|
||||||
IsLoggedIn: true,
|
|
||||||
OAuth: true,
|
|
||||||
Provider: cookie.Provider,
|
|
||||||
OAuthGroups: cookie.OAuthGroups,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return types.UserContext{}
|
return types.UserContext{}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
package ldap
|
package ldap
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"fmt"
|
"fmt"
|
||||||
"time"
|
"time"
|
||||||
"tinyauth/internal/types"
|
"tinyauth/internal/types"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v5"
|
||||||
ldapgo "github.com/go-ldap/ldap/v3"
|
ldapgo "github.com/go-ldap/ldap/v3"
|
||||||
"github.com/rs/zerolog/log"
|
"github.com/rs/zerolog/log"
|
||||||
)
|
)
|
||||||
@@ -30,6 +32,11 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
|
|||||||
err := ldap.heartbeat()
|
err := ldap.heartbeat()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("LDAP connection heartbeat failed")
|
log.Error().Err(err).Msg("LDAP connection heartbeat failed")
|
||||||
|
if reconnectErr := ldap.reconnect(); reconnectErr != nil {
|
||||||
|
log.Error().Err(reconnectErr).Msg("Failed to reconnect to LDAP server")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
log.Info().Msg("Successfully reconnected to LDAP server")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
@@ -38,6 +45,7 @@ func NewLDAP(config types.LdapConfig) (*LDAP, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (l *LDAP) connect() (*ldapgo.Conn, error) {
|
func (l *LDAP) connect() (*ldapgo.Conn, error) {
|
||||||
|
log.Debug().Msg("Connecting to LDAP server")
|
||||||
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
conn, err := ldapgo.DialURL(l.Config.Address, ldapgo.DialWithTLSConfig(&tls.Config{
|
||||||
InsecureSkipVerify: l.Config.Insecure,
|
InsecureSkipVerify: l.Config.Insecure,
|
||||||
MinVersion: tls.VersionTLS12,
|
MinVersion: tls.VersionTLS12,
|
||||||
@@ -46,6 +54,7 @@ func (l *LDAP) connect() (*ldapgo.Conn, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debug().Msg("Binding to LDAP server")
|
||||||
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
|
err = conn.Bind(l.Config.BindDN, l.Config.BindPassword)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -109,3 +118,30 @@ func (l *LDAP) heartbeat() error {
|
|||||||
// No error means the connection is alive
|
// No error means the connection is alive
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (l *LDAP) reconnect() error {
|
||||||
|
log.Info().Msg("Reconnecting to LDAP server")
|
||||||
|
|
||||||
|
exp := backoff.NewExponentialBackOff()
|
||||||
|
exp.InitialInterval = 500 * time.Millisecond
|
||||||
|
exp.RandomizationFactor = 0.1
|
||||||
|
exp.Multiplier = 1.5
|
||||||
|
exp.Reset()
|
||||||
|
|
||||||
|
operation := func() (*ldapgo.Conn, error) {
|
||||||
|
l.Conn.Close()
|
||||||
|
_, err := l.connect()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := backoff.Retry(context.TODO(), operation, backoff.WithBackOff(exp), backoff.WithMaxTries(3))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -327,3 +327,15 @@ func DeriveKey(secret string, info string) (string, error) {
|
|||||||
encodedKey := base64.StdEncoding.EncodeToString(key)
|
encodedKey := base64.StdEncoding.EncodeToString(key)
|
||||||
return encodedKey, nil
|
return encodedKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CoalesceToString(value any) string {
|
||||||
|
switch v := value.(type) {
|
||||||
|
case []string:
|
||||||
|
return strings.Join(v, ",")
|
||||||
|
case string:
|
||||||
|
return v
|
||||||
|
default:
|
||||||
|
log.Warn().Interface("value", value).Msg("Unsupported type, returning empty string")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -511,3 +511,38 @@ func TestDeriveKey(t *testing.T) {
|
|||||||
t.Fatalf("Expected %v, got %v", expected, result)
|
t.Fatalf("Expected %v, got %v", expected, result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCoalesceToString(t *testing.T) {
|
||||||
|
t.Log("Testing coalesce to string with a string")
|
||||||
|
|
||||||
|
value := "test"
|
||||||
|
expected := "test"
|
||||||
|
|
||||||
|
result := utils.CoalesceToString(value)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Fatalf("Expected %v, got %v", expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Testing coalesce to string with a slice of strings")
|
||||||
|
|
||||||
|
valueSlice := []string{"test1", "test2"}
|
||||||
|
expected = "test1,test2"
|
||||||
|
|
||||||
|
result = utils.CoalesceToString(valueSlice)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Fatalf("Expected %v, got %v", expected, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Log("Testing coalesce to string with an unsupported type")
|
||||||
|
|
||||||
|
valueUnsupported := 12345
|
||||||
|
expected = ""
|
||||||
|
|
||||||
|
result = utils.CoalesceToString(valueUnsupported)
|
||||||
|
|
||||||
|
if result != expected {
|
||||||
|
t.Fatalf("Expected %v, got %v", expected, result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user