Compare commits

..

1 Commits

Author SHA1 Message Date
Stavros 44c763c302 fix: narrow down action permissions to per-job ones 2026-04-29 16:41:24 +03:00
37 changed files with 800 additions and 976 deletions
+17 -2
View File
@@ -5,12 +5,13 @@ on:
- cron: "0 0 * * *" - cron: "0 0 * * *"
permissions: permissions:
contents: write contents: read
packages: write
jobs: jobs:
create-release: create-release:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -145,6 +146,8 @@ jobs:
needs: needs:
- create-release - create-release
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -203,6 +206,8 @@ jobs:
- create-release - create-release
- generate-metadata - generate-metadata
- image-build - image-build
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -261,6 +266,8 @@ jobs:
needs: needs:
- create-release - create-release
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -319,6 +326,8 @@ jobs:
- create-release - create-release
- generate-metadata - generate-metadata
- image-build-arm - image-build-arm
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -377,6 +386,8 @@ jobs:
needs: needs:
- image-build - image-build
- image-build-arm - image-build-arm
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -416,6 +427,8 @@ jobs:
needs: needs:
- image-build-distroless - image-build-distroless
- image-build-arm-distroless - image-build-arm-distroless
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -455,6 +468,8 @@ jobs:
needs: needs:
- binary-build - binary-build
- binary-build-arm - binary-build-arm
permissions:
contents: write
steps: steps:
- uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with: with:
+15 -2
View File
@@ -6,8 +6,7 @@ on:
- "v*" - "v*"
permissions: permissions:
contents: write contents: read
packages: write
jobs: jobs:
generate-metadata: generate-metadata:
@@ -117,6 +116,8 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: needs:
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -172,6 +173,8 @@ jobs:
needs: needs:
- generate-metadata - generate-metadata
- image-build - image-build
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -227,6 +230,8 @@ jobs:
runs-on: ubuntu-24.04-arm runs-on: ubuntu-24.04-arm
needs: needs:
- generate-metadata - generate-metadata
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -282,6 +287,8 @@ jobs:
needs: needs:
- generate-metadata - generate-metadata
- image-build-arm - image-build-arm
permissions:
packages: read
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
@@ -338,6 +345,8 @@ jobs:
needs: needs:
- image-build - image-build
- image-build-arm - image-build-arm
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -379,6 +388,8 @@ jobs:
needs: needs:
- image-build-distroless - image-build-distroless
- image-build-arm-distroless - image-build-arm-distroless
permissions:
packages: write
steps: steps:
- name: Download digests - name: Download digests
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
@@ -422,6 +433,8 @@ jobs:
needs: needs:
- binary-build - binary-build
- binary-build-arm - binary-build-arm
permissions:
contents: write
steps: steps:
- uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 - uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8
with: with:
+4 -2
View File
@@ -3,12 +3,14 @@ on:
workflow_dispatch: workflow_dispatch:
permissions: permissions:
contents: write contents: read
pull-requests: write
jobs: jobs:
generate-sponsors: generate-sponsors:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
+4 -2
View File
@@ -4,12 +4,14 @@ on:
- cron: 0 10 * * * - cron: 0 10 * * *
permissions: permissions:
issues: write contents: read
pull-requests: write
jobs: jobs:
stale: stale:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
issues: write
pull-requests: write
steps: steps:
- uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10 - uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10
with: with:
+3 -3
View File
@@ -73,7 +73,7 @@ func generateTotpCmd() *cli.Command {
docker = true docker = true
} }
if user.TOTPSecret != "" { if user.TotpSecret != "" {
return fmt.Errorf("user already has a TOTP secret") return fmt.Errorf("user already has a TOTP secret")
} }
@@ -102,14 +102,14 @@ func generateTotpCmd() *cli.Command {
qrterminal.GenerateWithConfig(key.URL(), config) qrterminal.GenerateWithConfig(key.URL(), config)
user.TOTPSecret = secret user.TotpSecret = secret
// If using docker escape re-escape it // If using docker escape re-escape it
if docker { if docker {
user.Password = strings.ReplaceAll(user.Password, "$", "$$") user.Password = strings.ReplaceAll(user.Password, "$", "$$")
} }
tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TOTPSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.") tlog.App.Info().Str("user", fmt.Sprintf("%s:%s:%s", user.Username, user.Password, user.TotpSecret)).Msg("Add the totp secret to your authenticator app then use the verify command to ensure everything is working correctly.")
return nil return nil
}, },
+4 -4
View File
@@ -5,7 +5,7 @@ import (
"charm.land/huh/v2" "charm.land/huh/v2"
"github.com/tinyauthapp/tinyauth/internal/bootstrap" "github.com/tinyauthapp/tinyauth/internal/bootstrap"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/loaders" "github.com/tinyauthapp/tinyauth/internal/utils/loaders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -14,7 +14,7 @@ import (
) )
func main() { func main() {
tConfig := model.NewDefaultConfiguration() tConfig := config.NewDefaultConfiguration()
loaders := []cli.ResourceLoader{ loaders := []cli.ResourceLoader{
&loaders.FileLoader{}, &loaders.FileLoader{},
@@ -108,11 +108,11 @@ func main() {
} }
} }
func runCmd(cfg model.Config) error { func runCmd(cfg config.Config) error {
logger := tlog.NewLogger(cfg.Log) logger := tlog.NewLogger(cfg.Log)
logger.Init() logger.Init()
tlog.App.Info().Str("version", model.Version).Msg("Starting tinyauth") tlog.App.Info().Str("version", config.Version).Msg("Starting tinyauth")
app := bootstrap.NewBootstrapApp(cfg) app := bootstrap.NewBootstrapApp(cfg)
+2 -2
View File
@@ -95,7 +95,7 @@ func verifyUserCmd() *cli.Command {
return fmt.Errorf("password is incorrect: %w", err) return fmt.Errorf("password is incorrect: %w", err)
} }
if user.TOTPSecret == "" { if user.TotpSecret == "" {
if tCfg.Totp != "" { if tCfg.Totp != "" {
tlog.App.Warn().Msg("User does not have TOTP secret") tlog.App.Warn().Msg("User does not have TOTP secret")
} }
@@ -103,7 +103,7 @@ func verifyUserCmd() *cli.Command {
return nil return nil
} }
ok := totp.Validate(tCfg.Totp, user.TOTPSecret) ok := totp.Validate(tCfg.Totp, user.TotpSecret)
if !ok { if !ok {
return fmt.Errorf("TOTP code incorrect") return fmt.Errorf("TOTP code incorrect")
+5 -4
View File
@@ -3,8 +3,9 @@ package main
import ( import (
"fmt" "fmt"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
func versionCmd() *cli.Command { func versionCmd() *cli.Command {
@@ -14,9 +15,9 @@ func versionCmd() *cli.Command {
Configuration: nil, Configuration: nil,
Resources: nil, Resources: nil,
Run: func(_ []string) error { Run: func(_ []string) error {
fmt.Printf("Version: %s\n", model.Version) fmt.Printf("Version: %s\n", config.Version)
fmt.Printf("Commit Hash: %s\n", model.CommitHash) fmt.Printf("Commit Hash: %s\n", config.CommitHash)
fmt.Printf("Build Timestamp: %s\n", model.BuildTimestamp) fmt.Printf("Build Timestamp: %s\n", config.BuildTimestamp)
return nil return nil
}, },
} }
+16 -16
View File
@@ -12,15 +12,15 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type BootstrapApp struct { type BootstrapApp struct {
config model.Config config config.Config
context struct { context struct {
appUrl string appUrl string
uuid string uuid string
@@ -29,15 +29,15 @@ type BootstrapApp struct {
csrfCookieName string csrfCookieName string
redirectCookieName string redirectCookieName string
oauthSessionCookieName string oauthSessionCookieName string
localUsers []model.LocalUser users []config.User
oauthProviders map[string]model.OAuthServiceConfig oauthProviders map[string]config.OAuthServiceConfig
configuredProviders []controller.Provider configuredProviders []controller.Provider
oidcClients []model.OIDCClientConfig oidcClients []config.OIDCClientConfig
} }
services Services services Services
} }
func NewBootstrapApp(config model.Config) *BootstrapApp { func NewBootstrapApp(config config.Config) *BootstrapApp {
return &BootstrapApp{ return &BootstrapApp{
config: config, config: config,
} }
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
return err return err
} }
app.context.localUsers = *users app.context.users = users
// Setup OAuth providers // Setup OAuth providers
app.context.oauthProviders = app.config.OAuth.Providers app.context.oauthProviders = app.config.OAuth.Providers
@@ -88,7 +88,7 @@ func (app *BootstrapApp) Setup() error {
for id, provider := range app.context.oauthProviders { for id, provider := range app.context.oauthProviders {
if provider.Name == "" { if provider.Name == "" {
if name, ok := model.OverrideProviders[id]; ok { if name, ok := config.OverrideProviders[id]; ok {
provider.Name = name provider.Name = name
} else { } else {
provider.Name = utils.Capitalize(id) provider.Name = utils.Capitalize(id)
@@ -115,14 +115,14 @@ func (app *BootstrapApp) Setup() error {
// Cookie names // Cookie names
app.context.uuid = utils.GenerateUUID(appUrl.Hostname()) app.context.uuid = utils.GenerateUUID(appUrl.Hostname())
cookieId := strings.Split(app.context.uuid, "-")[0] cookieId := strings.Split(app.context.uuid, "-")[0]
app.context.sessionCookieName = fmt.Sprintf("%s-%s", model.SessionCookieName, cookieId) app.context.sessionCookieName = fmt.Sprintf("%s-%s", config.SessionCookieName, cookieId)
app.context.csrfCookieName = fmt.Sprintf("%s-%s", model.CSRFCookieName, cookieId) app.context.csrfCookieName = fmt.Sprintf("%s-%s", config.CSRFCookieName, cookieId)
app.context.redirectCookieName = fmt.Sprintf("%s-%s", model.RedirectCookieName, cookieId) app.context.redirectCookieName = fmt.Sprintf("%s-%s", config.RedirectCookieName, cookieId)
app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", model.OAuthSessionCookieName, cookieId) app.context.oauthSessionCookieName = fmt.Sprintf("%s-%s", config.OAuthSessionCookieName, cookieId)
// Dumps // Dumps
tlog.App.Trace().Interface("config", app.config).Msg("Config dump") tlog.App.Trace().Interface("config", app.config).Msg("Config dump")
tlog.App.Trace().Interface("users", app.context.localUsers).Msg("Users dump") tlog.App.Trace().Interface("users", app.context.users).Msg("Users dump")
tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump") tlog.App.Trace().Interface("oauthProviders", app.context.oauthProviders).Msg("OAuth providers dump")
tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain") tlog.App.Trace().Str("cookieDomain", app.context.cookieDomain).Msg("Cookie domain")
tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name") tlog.App.Trace().Str("sessionCookieName", app.context.sessionCookieName).Msg("Session cookie name")
@@ -171,7 +171,7 @@ func (app *BootstrapApp) Setup() error {
}) })
} }
if services.authService.LDAPAuthConfigured() { if services.authService.LdapAuthConfigured() {
configuredProviders = append(configuredProviders, controller.Provider{ configuredProviders = append(configuredProviders, controller.Provider{
Name: "LDAP", Name: "LDAP",
ID: "ldap", ID: "ldap",
@@ -244,7 +244,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
var body heartbeat var body heartbeat
body.UUID = app.context.uuid body.UUID = app.context.uuid
body.Version = model.Version body.Version = config.Version
bodyJson, err := json.Marshal(body) bodyJson, err := json.Marshal(body)
@@ -257,7 +257,7 @@ func (app *BootstrapApp) heartbeatRoutine() {
Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond Timeout: 30 * time.Second, // The server should never take more than 30 seconds to respond
} }
heartbeatURL := model.APIServer + "/v1/instances/heartbeat" heartbeatURL := config.ApiServer + "/v1/instances/heartbeat"
for range ticker.C { for range ticker.C {
tlog.App.Debug().Msg("Sending heartbeat") tlog.App.Debug().Msg("Sending heartbeat")
+2 -4
View File
@@ -4,9 +4,9 @@ import (
"fmt" "fmt"
"slices" "slices"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/middleware" "github.com/tinyauthapp/tinyauth/internal/middleware"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
) )
@@ -14,7 +14,7 @@ import (
var DEV_MODES = []string{"main", "test", "development"} var DEV_MODES = []string{"main", "test", "development"}
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) { func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
if !slices.Contains(DEV_MODES, model.Version) { if !slices.Contains(DEV_MODES, config.Version) {
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
} }
@@ -31,7 +31,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{ contextMiddleware := middleware.NewContextMiddleware(middleware.ContextMiddlewareConfig{
CookieDomain: app.context.cookieDomain, CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, app.services.authService, app.services.oauthBrokerService) }, app.services.authService, app.services.oauthBrokerService)
err := contextMiddleware.Init() err := contextMiddleware.Init()
@@ -100,7 +99,6 @@ func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
userController := controller.NewUserController(controller.UserControllerConfig{ userController := controller.NewUserController(controller.UserControllerConfig{
CookieDomain: app.context.cookieDomain, CookieDomain: app.context.cookieDomain,
SessionCookieName: app.context.sessionCookieName,
}, apiRouter, app.services.authService) }, apiRouter, app.services.authService)
userController.SetupRoutes() userController.SetupRoutes()
+10 -10
View File
@@ -22,14 +22,14 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services := Services{} services := Services{}
ldapService := service.NewLdapService(service.LdapServiceConfig{ ldapService := service.NewLdapService(service.LdapServiceConfig{
Address: app.config.LDAP.Address, Address: app.config.Ldap.Address,
BindDN: app.config.LDAP.BindDN, BindDN: app.config.Ldap.BindDN,
BindPassword: app.config.LDAP.BindPassword, BindPassword: app.config.Ldap.BindPassword,
BaseDN: app.config.LDAP.BaseDN, BaseDN: app.config.Ldap.BaseDN,
Insecure: app.config.LDAP.Insecure, Insecure: app.config.Ldap.Insecure,
SearchFilter: app.config.LDAP.SearchFilter, SearchFilter: app.config.Ldap.SearchFilter,
AuthCert: app.config.LDAP.AuthCert, AuthCert: app.config.Ldap.AuthCert,
AuthKey: app.config.LDAP.AuthKey, AuthKey: app.config.Ldap.AuthKey,
}) })
err := ldapService.Init() err := ldapService.Init()
@@ -89,7 +89,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
services.oauthBrokerService = oauthBrokerService services.oauthBrokerService = oauthBrokerService
authService := service.NewAuthService(service.AuthServiceConfig{ authService := service.NewAuthService(service.AuthServiceConfig{
LocalUsers: app.context.localUsers, Users: app.context.users,
OauthWhitelist: app.config.OAuth.Whitelist, OauthWhitelist: app.config.OAuth.Whitelist,
SessionExpiry: app.config.Auth.SessionExpiry, SessionExpiry: app.config.Auth.SessionExpiry,
SessionMaxLifetime: app.config.Auth.SessionMaxLifetime, SessionMaxLifetime: app.config.Auth.SessionMaxLifetime,
@@ -99,7 +99,7 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
LoginMaxRetries: app.config.Auth.LoginMaxRetries, LoginMaxRetries: app.config.Auth.LoginMaxRetries,
SessionCookieName: app.context.sessionCookieName, SessionCookieName: app.context.sessionCookieName,
IP: app.config.Auth.IP, IP: app.config.Auth.IP,
LDAPGroupsCacheTTL: app.config.LDAP.GroupCacheTTL, LDAPGroupsCacheTTL: app.config.Ldap.GroupCacheTTL,
}, services.ldapService, queries, services.oauthBrokerService) }, services.ldapService, queries, services.oauthBrokerService)
err = authService.Init() err = authService.Init()
@@ -1,4 +1,4 @@
package model package config
// Default configuration // Default configuration
func NewDefaultConfiguration() *Config { func NewDefaultConfiguration() *Config {
@@ -29,7 +29,7 @@ func NewDefaultConfiguration() *Config {
BackgroundImage: "/background.jpg", BackgroundImage: "/background.jpg",
WarningsEnabled: true, WarningsEnabled: true,
}, },
LDAP: LDAPConfig{ Ldap: LdapConfig{
Insecure: false, Insecure: false,
SearchFilter: "(uid=%s)", SearchFilter: "(uid=%s)",
GroupCacheTTL: 900, // 15 minutes GroupCacheTTL: 900, // 15 minutes
@@ -63,6 +63,20 @@ func NewDefaultConfiguration() *Config {
} }
} }
// Version information, set at build time
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
// Cookie name templates
var SessionCookieName = "tinyauth-session"
var CSRFCookieName = "tinyauth-csrf"
var RedirectCookieName = "tinyauth-redirect"
var OAuthSessionCookieName = "tinyauth-oauth"
// Main app config
type Config struct { type Config struct {
AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"` AppURL string `description:"The base URL where the app is hosted." yaml:"appUrl"`
Database DatabaseConfig `description:"Database configuration." yaml:"database"` Database DatabaseConfig `description:"Database configuration." yaml:"database"`
@@ -74,7 +88,7 @@ type Config struct {
OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"` OAuth OAuthConfig `description:"OAuth configuration." yaml:"oauth"`
OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"` OIDC OIDCConfig `description:"OIDC configuration." yaml:"oidc"`
UI UIConfig `description:"UI customization." yaml:"ui"` UI UIConfig `description:"UI customization." yaml:"ui"`
LDAP LDAPConfig `description:"LDAP configuration." yaml:"ldap"` Ldap LdapConfig `description:"LDAP configuration." yaml:"ldap"`
Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"` Experimental ExperimentalConfig `description:"Experimental features, use with caution." yaml:"experimental"`
LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"` LabelProvider string `description:"Label provider to use for ACLs (auto, docker, or kubernetes). auto detects the environment." yaml:"labelProvider"`
Log LogConfig `description:"Logging configuration." yaml:"log"` Log LogConfig `description:"Logging configuration." yaml:"log"`
@@ -163,7 +177,7 @@ type UIConfig struct {
WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"` WarningsEnabled bool `description:"Enable UI warnings." yaml:"warningsEnabled"`
} }
type LDAPConfig struct { type LdapConfig struct {
Address string `description:"LDAP server address." yaml:"address"` Address string `description:"LDAP server address." yaml:"address"`
BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"` BindDN string `description:"Bind DN for LDAP authentication." yaml:"bindDn"`
BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"` BindPassword string `description:"Bind password for LDAP authentication." yaml:"bindPassword"`
@@ -196,6 +210,20 @@ type ExperimentalConfig struct {
ConfigFile string `description:"Path to config file." yaml:"-"` ConfigFile string `description:"Path to config file." yaml:"-"`
} }
// Config loader options
const DefaultNamePrefix = "TINYAUTH_"
// OAuth/OIDC config
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
type OAuthServiceConfig struct { type OAuthServiceConfig struct {
ClientID string `description:"OAuth client ID." yaml:"clientId"` ClientID string `description:"OAuth client ID." yaml:"clientId"`
ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"` ClientSecret string `description:"OAuth client secret." yaml:"clientSecret"`
@@ -218,6 +246,60 @@ type OIDCClientConfig struct {
Name string `description:"Client name in UI." yaml:"name"` Name string `description:"Client name in UI." yaml:"name"`
} }
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
// User/session related stuff
type User struct {
Username string
Password string
TotpSecret string
Attributes UserAttributes
}
type LdapUser struct {
DN string
Groups []string
}
type UserSearch struct {
Username string
Type string // local, ldap or unknown
}
type UserContext struct {
Username string
Name string
Email string
IsLoggedIn bool
IsBasicAuth bool
OAuth bool
Provider string
TotpPending bool
OAuthGroups string
TotpEnabled bool
OAuthName string
OAuthSub string
LdapGroups string
Attributes UserAttributes
}
// API responses and queries
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
// ACLs // ACLs
type Apps struct { type Apps struct {
@@ -273,3 +355,7 @@ type AppPath struct {
Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"` Allow string `description:"Comma-separated list of allowed paths." yaml:"allow"`
Block string `description:"Comma-separated list of blocked paths." yaml:"block"` Block string `description:"Comma-separated list of blocked paths." yaml:"block"`
} }
// API server
var ApiServer = "https://api.tinyauth.app"
+20 -21
View File
@@ -4,7 +4,7 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
@@ -19,7 +19,7 @@ type UserContextResponse struct {
Email string `json:"email"` Email string `json:"email"`
Provider string `json:"provider"` Provider string `json:"provider"`
OAuth bool `json:"oauth"` OAuth bool `json:"oauth"`
TOTPPending bool `json:"totpPending"` TotpPending bool `json:"totpPending"`
OAuthName string `json:"oauthName"` OAuthName string `json:"oauthName"`
} }
@@ -76,29 +76,28 @@ func (controller *ContextController) SetupRoutes() {
} }
func (controller *ContextController) userContextHandler(c *gin.Context) { func (controller *ContextController) userContextHandler(c *gin.Context) {
context, err := new(model.UserContext).NewFromGin(c) context, err := utils.GetContext(c)
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
c.JSON(200, UserContextResponse{
Status: 401,
Message: "Unauthorized",
IsLoggedIn: false,
})
return
}
userContext := UserContextResponse{ userContext := UserContextResponse{
Status: 200, Status: 200,
Message: "Success", Message: "Success",
IsLoggedIn: context.Authenticated, IsLoggedIn: context.IsLoggedIn,
Username: context.GetUsername(), Username: context.Username,
Name: context.GetName(), Name: context.Name,
Email: context.GetEmail(), Email: context.Email,
Provider: context.ProviderName(), Provider: context.Provider,
OAuth: context.IsOAuth(), OAuth: context.OAuth,
TOTPPending: context.TOTPPending(), TotpPending: context.TotpPending,
OAuthName: context.OAuthName(), OAuthName: context.OAuthName,
}
if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request")
userContext.Status = 401
userContext.Message = "Unauthorized"
userContext.IsLoggedIn = false
c.JSON(200, userContext)
return
} }
c.JSON(200, userContext) c.JSON(200, userContext)
-12
View File
@@ -1,12 +0,0 @@
package controller
type UnauthorizedQuery struct {
Username string `url:"username"`
Resource string `url:"resource"`
GroupErr bool `url:"groupErr"`
IP string `url:"ip"`
}
type RedirectQuery struct {
RedirectURI string `url:"redirect_uri"`
}
+4 -5
View File
@@ -6,6 +6,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -175,7 +176,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted") tlog.App.Warn().Str("email", user.Email).Msg("Email not whitelisted")
tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted") tlog.AuditLoginFailure(c, user.Email, req.Provider, "email not whitelisted")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Username: user.Email, Username: user.Email,
}) })
@@ -235,7 +236,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -243,8 +244,6 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider) tlog.AuditLoginSuccess(c, sessionCookie.Username, sessionCookie.Provider)
if controller.isOidcRequest(oauthPendingSession.CallbackParams) { if controller.isOidcRequest(oauthPendingSession.CallbackParams) {
@@ -260,7 +259,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
} }
if oauthPendingSession.CallbackParams.RedirectURI != "" { if oauthPendingSession.CallbackParams.RedirectURI != "" {
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: oauthPendingSession.CallbackParams.RedirectURI, RedirectURI: oauthPendingSession.CallbackParams.RedirectURI,
}) })
+4 -5
View File
@@ -10,7 +10,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/google/go-querystring/query" "github.com/google/go-querystring/query"
"github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -112,14 +111,14 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) userContext, err := utils.GetContext(c)
if err != nil { if err != nil {
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "") controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
return return
} }
if !userContext.Authenticated { if !userContext.IsLoggedIn {
controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "") controller.authorizeError(c, errors.New("err user not logged in"), "User not logged in", "The user is not logged in", "", "", "")
return return
} }
@@ -152,7 +151,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
} }
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too. // WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username and client name which remains stable, but if username or client name changes then sub changes too.
sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.GetUsername(), client.ID)) sub := utils.GenerateUUID(fmt.Sprintf("%s:%s", userContext.Username, client.ID))
code := utils.GenerateString(32) code := utils.GenerateString(32)
// Before storing the code, delete old session // Before storing the code, delete old session
@@ -171,7 +170,7 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
// We also need a snapshot of the user that authorized this (skip if no openid scope) // We also need a snapshot of the user that authorized this (skip if no openid scope)
if slices.Contains(strings.Fields(req.Scope), "openid") { if slices.Contains(strings.Fields(req.Scope), "openid") {
err = controller.oidc.StoreUserinfo(c, sub, *userContext, req) err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to insert user info into database") tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
+42 -43
View File
@@ -8,7 +8,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -99,16 +99,12 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if acls == nil {
acls = &model.App{}
}
tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource") tlog.App.Trace().Interface("acls", acls).Msg("ACLs for resource")
clientIP := c.ClientIP() clientIP := c.ClientIP()
if controller.auth.IsBypassedIP(&acls.IP, clientIP) { if controller.auth.IsBypassedIP(acls.IP, clientIP) {
controller.setHeaders(c, *acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
@@ -116,7 +112,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, &acls.Path) authEnabled, err := controller.auth.IsAuthEnabled(proxyCtx.Path, acls.Path)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource") tlog.App.Error().Err(err).Msg("Failed to check if auth is enabled for resource")
@@ -126,7 +122,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
if !authEnabled { if !authEnabled {
tlog.App.Debug().Msg("Authentication disabled for resource, allowing access") tlog.App.Debug().Msg("Authentication disabled for resource, allowing access")
controller.setHeaders(c, *acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Authenticated", "message": "Authenticated",
@@ -134,8 +130,8 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if !controller.auth.CheckIP(&acls.IP, clientIP) { if !controller.auth.CheckIP(acls.IP, clientIP) {
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
IP: clientIP, IP: clientIP,
}) })
@@ -161,24 +157,28 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
userContext, err := new(model.UserContext).NewFromGin(c) var userContext config.UserContext
context, err := utils.GetContext(c)
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("No user context found in request, treating as unauthenticated") tlog.App.Debug().Msg("No user context found in request, treating as not logged in")
userContext = &model.UserContext{ userContext = config.UserContext{
Authenticated: false, IsLoggedIn: false,
} }
} else {
userContext = context
} }
tlog.App.Trace().Interface("context", userContext).Msg("User context from request") tlog.App.Trace().Interface("context", userContext).Msg("User context from request")
if userContext.Authenticated { if userContext.IsLoggedIn {
userAllowed := controller.auth.IsUserAllowed(c, *userContext, acls) userAllowed := controller.auth.IsUserAllowed(c, userContext, acls)
if !userAllowed { if !userAllowed {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User not allowed to access resource")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
}) })
@@ -188,10 +188,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if userContext.IsOAuth() { if userContext.OAuth {
queries.Set("username", userContext.GetEmail()) queries.Set("username", userContext.Email)
} else { } else {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.Username)
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -209,19 +209,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if userContext.IsOAuth() || userContext.IsLDAP() { if userContext.OAuth || userContext.Provider == "ldap" {
var groupOK bool var groupOK bool
if userContext.IsOAuth() { if userContext.OAuth {
groupOK = controller.auth.IsInOAuthGroup(c, *userContext, acls.OAuth.Groups) groupOK = controller.auth.IsInOAuthGroup(c, userContext, acls.OAuth.Groups)
} else { } else {
groupOK = controller.auth.IsInLDAPGroup(c, *userContext, acls.LDAP.Groups) groupOK = controller.auth.IsInLdapGroup(c, userContext, acls.LDAP.Groups)
} }
if !groupOK { if !groupOK {
tlog.App.Warn().Str("user", userContext.GetUsername()).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements") tlog.App.Warn().Str("user", userContext.Username).Str("resource", strings.Split(proxyCtx.Host, ".")[0]).Msg("User groups do not match resource requirements")
queries, err := query.Values(UnauthorizedQuery{ queries, err := query.Values(config.UnauthorizedQuery{
Resource: strings.Split(proxyCtx.Host, ".")[0], Resource: strings.Split(proxyCtx.Host, ".")[0],
GroupErr: true, GroupErr: true,
}) })
@@ -232,10 +232,10 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
if userContext.IsOAuth() { if userContext.OAuth {
queries.Set("username", userContext.GetEmail()) queries.Set("username", userContext.Email)
} else { } else {
queries.Set("username", userContext.GetUsername()) queries.Set("username", userContext.Username)
} }
redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()) redirectURL := fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())
@@ -254,20 +254,19 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
} }
} }
c.Header("Remote-User", utils.SanitizeHeader(userContext.GetUsername())) c.Header("Remote-User", utils.SanitizeHeader(userContext.Username))
c.Header("Remote-Name", utils.SanitizeHeader(userContext.GetName())) c.Header("Remote-Name", utils.SanitizeHeader(userContext.Name))
c.Header("Remote-Email", utils.SanitizeHeader(userContext.GetEmail())) c.Header("Remote-Email", utils.SanitizeHeader(userContext.Email))
if userContext.IsLDAP() { if userContext.Provider == "ldap" {
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.LDAP.Groups, ","))) c.Header("Remote-Groups", utils.SanitizeHeader(userContext.LdapGroups))
} else if userContext.Provider != "local" {
c.Header("Remote-Groups", utils.SanitizeHeader(userContext.OAuthGroups))
} }
if userContext.IsOAuth() { c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuthSub))
c.Header("Remote-Groups", utils.SanitizeHeader(strings.Join(userContext.OAuth.Groups, ",")))
c.Header("Remote-Sub", utils.SanitizeHeader(userContext.OAuth.Sub))
}
controller.setHeaders(c, *acls) controller.setHeaders(c, acls)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -276,7 +275,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
return return
} }
queries, err := query.Values(RedirectQuery{ queries, err := query.Values(config.RedirectQuery{
RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path), RedirectURI: fmt.Sprintf("%s://%s%s", proxyCtx.Proto, proxyCtx.Host, proxyCtx.Path),
}) })
@@ -300,7 +299,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
c.Redirect(http.StatusTemporaryRedirect, redirectURL) c.Redirect(http.StatusTemporaryRedirect, redirectURL)
} }
func (controller *ProxyController) setHeaders(c *gin.Context, acls model.App) { func (controller *ProxyController) setHeaders(c *gin.Context, acls config.App) {
c.Header("Authorization", c.Request.Header.Get("Authorization")) c.Header("Authorization", c.Request.Header.Get("Authorization"))
headers := utils.ParseHeaders(acls.Response.Headers) headers := utils.ParseHeaders(acls.Response.Headers)
+39 -90
View File
@@ -1,12 +1,10 @@
package controller package controller
import ( import (
"errors"
"fmt" "fmt"
"net/http"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -27,7 +25,6 @@ type TotpRequest struct {
type UserControllerConfig struct { type UserControllerConfig struct {
CookieDomain string CookieDomain string
SessionCookieName string
} }
type UserController struct { type UserController struct {
@@ -80,10 +77,9 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
search, err := controller.auth.SearchUser(req.Username) userSearch := controller.auth.SearchUser(req.Username)
if err != nil { if userSearch.Type == "unknown" {
if errors.Is(err, service.ErrUserNotFound) {
tlog.App.Warn().Str("username", req.Username).Msg("User not found") tlog.App.Warn().Str("username", req.Username).Msg("User not found")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "user not found") tlog.AuditLoginFailure(c, req.Username, "username", "user not found")
@@ -93,15 +89,8 @@ func (controller *UserController) loginHandler(c *gin.Context) {
}) })
return return
} }
tlog.App.Error().Err(err).Str("username", req.Username).Msg("Error searching for user")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
if err := controller.auth.CheckUserPassword(*search, req.Password); err != nil { if !controller.auth.VerifyUser(userSearch, req.Password) {
tlog.App.Warn().Str("username", req.Username).Msg("Invalid password") tlog.App.Warn().Str("username", req.Username).Msg("Invalid password")
controller.auth.RecordLoginAttempt(req.Username, false) controller.auth.RecordLoginAttempt(req.Username, false)
tlog.AuditLoginFailure(c, req.Username, "username", "invalid password") tlog.AuditLoginFailure(c, req.Username, "username", "invalid password")
@@ -117,26 +106,30 @@ func (controller *UserController) loginHandler(c *gin.Context) {
controller.auth.RecordLoginAttempt(req.Username, true) controller.auth.RecordLoginAttempt(req.Username, true)
var localUser *model.LocalUser var localUser *config.User
if userSearch.Type == "local" {
user := controller.auth.GetLocalUser(userSearch.Username)
localUser = &user
}
if search.Type == model.UserLocal { if userSearch.Type == "local" && localUser != nil {
localUser = controller.auth.GetLocalUser(req.Username) user := *localUser
if localUser.TOTPSecret != "" { if user.TotpSecret != "" {
tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification") tlog.App.Debug().Str("username", req.Username).Msg("User has TOTP enabled, requiring TOTP verification")
name := localUser.Attributes.Name name := user.Attributes.Name
if name == "" { if name == "" {
name = utils.Capitalize(localUser.Username) name = utils.Capitalize(user.Username)
} }
email := localUser.Attributes.Email email := user.Attributes.Email
if email == "" { if email == "" {
email = utils.CompileUserEmail(localUser.Username, controller.config.CookieDomain) email = utils.CompileUserEmail(user.Username, controller.config.CookieDomain)
} }
cookie, err := controller.auth.CreateSession(c, repository.Session{ err := controller.auth.CreateSessionCookie(c, &repository.Session{
Username: localUser.Username, Username: user.Username,
Name: name, Name: name,
Email: email, Email: email,
Provider: "local", Provider: "local",
@@ -152,8 +145,6 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "TOTP required", "message": "TOTP required",
@@ -170,7 +161,7 @@ func (controller *UserController) loginHandler(c *gin.Context) {
Provider: "local", Provider: "local",
} }
if search.Type == model.UserLocal { if userSearch.Type == "local" && localUser != nil {
if localUser.Attributes.Name != "" { if localUser.Attributes.Name != "" {
sessionCookie.Name = localUser.Attributes.Name sessionCookie.Name = localUser.Attributes.Name
} }
@@ -179,13 +170,13 @@ func (controller *UserController) loginHandler(c *gin.Context) {
} }
} }
if search.Type == model.UserLDAP { if userSearch.Type == "ldap" {
sessionCookie.Provider = "ldap" sessionCookie.Provider = "ldap"
} }
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -196,8 +187,6 @@ func (controller *UserController) loginHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
@@ -207,50 +196,12 @@ func (controller *UserController) loginHandler(c *gin.Context) {
func (controller *UserController) logoutHandler(c *gin.Context) { func (controller *UserController) logoutHandler(c *gin.Context) {
tlog.App.Debug().Msg("Logout request received") tlog.App.Debug().Msg("Logout request received")
uuid, err := c.Cookie(controller.config.SessionCookieName) controller.auth.DeleteSessionCookie(c)
if err != nil { context, err := utils.GetContext(c)
if errors.Is(err, http.ErrNoCookie) { if err == nil && context.IsLoggedIn {
tlog.App.Warn().Msg("No session cookie found on logout request") tlog.AuditLogout(c, context.Username, context.Provider)
c.JSON(200, gin.H{
"status": 200,
"message": "Logout successful",
})
return
} }
tlog.App.Error().Err(err).Msg("Error retrieving session cookie on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
context, err := new(model.UserContext).NewFromGin(c)
if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
cookie, err := controller.auth.DeleteSession(c, uuid)
if err != nil {
tlog.App.Error().Err(err).Msg("Error deleting session on logout")
c.JSON(500, gin.H{
"status": 500,
"message": "Internal Server Error",
})
return
}
tlog.AuditLogout(c, context.GetUsername(), context.ProviderName())
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
@@ -271,7 +222,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
context, err := new(model.UserContext).NewFromGin(c) context, err := utils.GetContext(c)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to get user context") tlog.App.Error().Err(err).Msg("Failed to get user context")
@@ -282,7 +233,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
if !context.TOTPPending() { if !context.TotpPending {
tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session") tlog.App.Warn().Msg("TOTP attempt without a pending TOTP session")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
@@ -291,12 +242,12 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Debug().Str("username", context.GetUsername()).Msg("TOTP verification attempt") tlog.App.Debug().Str("username", context.Username).Msg("TOTP verification attempt")
isLocked, remaining := controller.auth.IsAccountLocked(context.GetUsername()) isLocked, remaining := controller.auth.IsAccountLocked(context.Username)
if isLocked { if isLocked {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Account is locked due to too many failed TOTP attempts") tlog.App.Warn().Str("username", context.Username).Msg("Account is locked due to too many failed TOTP attempts")
c.Writer.Header().Add("x-tinyauth-lock-locked", "true") c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339)) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
c.JSON(429, gin.H{ c.JSON(429, gin.H{
@@ -306,14 +257,14 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
user := controller.auth.GetLocalUser(context.GetUsername()) user := controller.auth.GetLocalUser(context.Username)
ok := totp.Validate(req.Code, user.TOTPSecret) ok := totp.Validate(req.Code, user.TotpSecret)
if !ok { if !ok {
tlog.App.Warn().Str("username", context.GetUsername()).Msg("Invalid TOTP code") tlog.App.Warn().Str("username", context.Username).Msg("Invalid TOTP code")
controller.auth.RecordLoginAttempt(context.GetUsername(), false) controller.auth.RecordLoginAttempt(context.Username, false)
tlog.AuditLoginFailure(c, context.GetUsername(), "totp", "invalid totp code") tlog.AuditLoginFailure(c, context.Username, "totp", "invalid totp code")
c.JSON(401, gin.H{ c.JSON(401, gin.H{
"status": 401, "status": 401,
"message": "Unauthorized", "message": "Unauthorized",
@@ -321,10 +272,10 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
tlog.App.Info().Str("username", context.GetUsername()).Msg("TOTP verification successful") tlog.App.Info().Str("username", context.Username).Msg("TOTP verification successful")
tlog.AuditLoginSuccess(c, context.GetUsername(), "totp") tlog.AuditLoginSuccess(c, context.Username, "totp")
controller.auth.RecordLoginAttempt(context.GetUsername(), true) controller.auth.RecordLoginAttempt(context.Username, true)
sessionCookie := repository.Session{ sessionCookie := repository.Session{
Username: user.Username, Username: user.Username,
@@ -342,7 +293,7 @@ func (controller *UserController) totpHandler(c *gin.Context) {
tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie") tlog.App.Trace().Interface("session_cookie", sessionCookie).Msg("Creating session cookie")
cookie, err := controller.auth.CreateSession(c, sessionCookie) err = controller.auth.CreateSessionCookie(c, &sessionCookie)
if err != nil { if err != nil {
tlog.App.Error().Err(err).Msg("Failed to create session cookie") tlog.App.Error().Err(err).Msg("Failed to create session cookie")
@@ -353,8 +304,6 @@ func (controller *UserController) totpHandler(c *gin.Context) {
return return
} }
http.SetCookie(c.Writer, cookie)
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"message": "Login successful", "message": "Login successful",
+137 -134
View File
@@ -1,13 +1,10 @@
package middleware package middleware
import ( import (
"context"
"fmt"
"net/http"
"strings" "strings"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -37,7 +34,6 @@ var (
type ContextMiddlewareConfig struct { type ContextMiddlewareConfig struct {
CookieDomain string CookieDomain string
SessionCookieName string
} }
type ContextMiddleware struct { type ContextMiddleware struct {
@@ -65,193 +61,200 @@ func (m *ContextMiddleware) Middleware() gin.HandlerFunc {
return return
} }
uuid, err := c.Cookie(m.config.SessionCookieName) cookie, err := m.auth.GetSessionCookie(c)
if err == nil {
userContext, cookie, err := m.cookieAuth(c.Request.Context(), uuid)
if err != nil { if err != nil {
tlog.App.Error().Msgf("Error authenticating session cookie: %v", err) tlog.App.Debug().Err(err).Msg("No valid session cookie found")
goto basic
}
if cookie.TotpPending {
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: "local",
TotpPending: true,
TotpEnabled: true,
})
c.Next() c.Next()
return return
} }
if cookie != nil { switch cookie.Provider {
http.SetCookie(c.Writer, cookie) case "local", "ldap":
userSearch := m.auth.SearchUser(cookie.Username)
if userSearch.Type == "unknown" {
tlog.App.Debug().Msg("User from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
} }
tlog.App.Trace().Msgf("Authenticated user from session cookie: %s", userContext.GetUsername()) if userSearch.Type != cookie.Provider {
c.Set("context", userContext) tlog.App.Warn().Msg("User type from session cookie does not match user search type")
m.auth.DeleteSessionCookie(c)
c.Next() c.Next()
return return
} }
basic, err := m.auth.GetBasicAuth(c.Request) var ldapGroups []string
var localAttributes config.UserAttributes
if err == nil { if cookie.Provider == "ldap" {
userContext, headers, err := m.basicAuth(c.Request.Context(), basic) ldapUser, err := m.auth.GetLdapUser(userSearch.Username)
if err != nil { if err != nil {
tlog.App.Error().Msgf("Error authenticating basic auth: %v", err) tlog.App.Error().Err(err).Msg("Error retrieving LDAP user details")
c.Next() c.Next()
return return
} }
for k, v := range headers { ldapGroups = ldapUser.Groups
c.Header(k, v)
} }
c.Set("context", userContext) if cookie.Provider == "local" {
localUser := m.auth.GetLocalUser(cookie.Username)
localAttributes = localUser.Attributes
}
m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
Username: cookie.Username,
Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
IsLoggedIn: true,
LdapGroups: strings.Join(ldapGroups, ","),
Attributes: localAttributes,
})
c.Next() c.Next()
return return
} default:
_, exists := m.broker.GetService(cookie.Provider)
c.Next()
}
}
func (m *ContextMiddleware) cookieAuth(ctx context.Context, uuid string) (*model.UserContext, *http.Cookie, error) {
session, err := m.auth.GetSession(ctx, uuid)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving session: %w", err)
}
userContext, err := new(model.UserContext).NewFromSession(session)
if err != nil {
return nil, nil, fmt.Errorf("error creating user context from session: %w", err)
}
if userContext.Provider == model.ProviderLocal &&
userContext.Local.TOTPPending {
userContext.Local.TOTPEnabled = true
return userContext, nil, nil
}
switch userContext.Provider {
case model.ProviderLocal:
user := m.auth.GetLocalUser(userContext.Local.Username)
if user == nil {
return nil, nil, fmt.Errorf("local user not found")
}
userContext.Local.Attributes = user.Attributes
if userContext.Local.Attributes.Name == "" {
userContext.Local.Attributes.Name = utils.Capitalize(user.Username)
}
if userContext.Local.Attributes.Email == "" {
userContext.Local.Attributes.Email = utils.CompileUserEmail(user.Username, m.config.CookieDomain)
}
case model.ProviderLDAP:
search, err := m.auth.SearchUser(userContext.LDAP.Username)
if err != nil {
return nil, nil, fmt.Errorf("error searching for ldap user: %w", err)
}
if search.Type != model.UserLDAP {
return nil, nil, fmt.Errorf("user from session cookie is not ldap")
}
user, err := m.auth.GetLDAPUser(search.Username)
if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err)
}
userContext.LDAP.Groups = user.Groups
userContext.LDAP.Name = utils.Capitalize(userContext.LDAP.Username)
userContext.LDAP.Email = utils.CompileUserEmail(userContext.LDAP.Username, m.config.CookieDomain)
case model.ProviderOAuth:
_, exists := m.broker.GetService(userContext.OAuth.ID)
if !exists { if !exists {
return nil, nil, fmt.Errorf("oauth provider from session cookie not found: %s", userContext.OAuth.ID) tlog.App.Debug().Msg("OAuth provider from session cookie not found")
m.auth.DeleteSessionCookie(c)
goto basic
} }
if !m.auth.IsEmailWhitelisted(userContext.OAuth.Email) { if !m.auth.IsEmailWhitelisted(cookie.Email) {
m.auth.DeleteSession(ctx, uuid) tlog.App.Debug().Msg("Email from session cookie not whitelisted")
return nil, nil, fmt.Errorf("email from session cookie not whitelisted: %s", userContext.OAuth.Email) m.auth.DeleteSessionCookie(c)
} goto basic
} }
cookie, err := m.auth.RefreshSession(ctx, uuid) m.auth.RefreshSessionCookie(c)
c.Set("context", &config.UserContext{
if err != nil { Username: cookie.Username,
return nil, nil, fmt.Errorf("error refreshing session: %w", err) Name: cookie.Name,
Email: cookie.Email,
Provider: cookie.Provider,
OAuthGroups: cookie.OAuthGroups,
OAuthName: cookie.OAuthName,
OAuthSub: cookie.OAuthSub,
IsLoggedIn: true,
OAuth: true,
})
c.Next()
return
} }
return userContext, cookie, nil basic:
} basic := m.auth.GetBasicAuth(c)
if basic == nil {
tlog.App.Debug().Msg("No basic auth provided")
c.Next()
return
}
func (m *ContextMiddleware) basicAuth(ctx context.Context, basic *model.LocalUser) (*model.UserContext, map[string]string, error) {
headers := make(map[string]string)
userContext := new(model.UserContext)
locked, remaining := m.auth.IsAccountLocked(basic.Username) locked, remaining := m.auth.IsAccountLocked(basic.Username)
if locked { if locked {
tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining) tlog.App.Debug().Msgf("Account for user %s is locked for %d seconds, denying auth", basic.Username, remaining)
headers["x-tinyauth-lock-locked"] = "true" c.Writer.Header().Add("x-tinyauth-lock-locked", "true")
headers["x-tinyauth-lock-reset"] = time.Now().Add(time.Duration(remaining) * time.Second).Format(time.RFC3339) c.Writer.Header().Add("x-tinyauth-lock-reset", time.Now().Add(time.Duration(remaining)*time.Second).Format(time.RFC3339))
return nil, headers, nil c.Next()
return
} }
search, err := m.auth.SearchUser(basic.Username) userSearch := m.auth.SearchUser(basic.Username)
if err != nil { if userSearch.Type == "unknown" || userSearch.Type == "error" {
return nil, nil, fmt.Errorf("error searching for user: %w", err)
}
err = m.auth.CheckUserPassword(*search, basic.Password)
if err != nil {
m.auth.RecordLoginAttempt(basic.Username, false) m.auth.RecordLoginAttempt(basic.Username, false)
return nil, nil, fmt.Errorf("invalid password for basic auth user: %w", err) tlog.App.Debug().Msg("User from basic auth not found")
c.Next()
return
}
if !m.auth.VerifyUser(userSearch, basic.Password) {
m.auth.RecordLoginAttempt(basic.Username, false)
tlog.App.Debug().Msg("Invalid password for basic auth user")
c.Next()
return
} }
m.auth.RecordLoginAttempt(basic.Username, true) m.auth.RecordLoginAttempt(basic.Username, true)
switch search.Type { switch userSearch.Type {
case model.UserLocal: case "local":
tlog.App.Debug().Msg("Basic auth user is local")
user := m.auth.GetLocalUser(basic.Username) user := m.auth.GetLocalUser(basic.Username)
if user.TOTPSecret != "" { if user.TotpSecret != "" {
return nil, nil, fmt.Errorf("user with totp not allowed to login via basic auth: %s", basic.Username) tlog.App.Debug().Msg("User with TOTP not allowed to login via basic auth")
return
} }
userContext.Local = &model.LocalContext{ name := utils.Capitalize(user.Username)
BaseContext: model.BaseContext{ if user.Attributes.Name != "" {
Username: user.Username, name = user.Attributes.Name
Name: utils.Capitalize(user.Username),
Email: utils.CompileUserEmail(user.Username, m.config.CookieDomain),
},
Attributes: user.Attributes,
} }
userContext.Provider = model.ProviderLocal email := utils.CompileUserEmail(user.Username, m.config.CookieDomain)
case model.UserLDAP: if user.Attributes.Email != "" {
user, err := m.auth.GetLDAPUser(basic.Username) email = user.Attributes.Email
}
c.Set("context", &config.UserContext{
Username: user.Username,
Name: name,
Email: email,
Provider: "local",
IsLoggedIn: true,
IsBasicAuth: true,
Attributes: user.Attributes,
})
c.Next()
return
case "ldap":
tlog.App.Debug().Msg("Basic auth user is LDAP")
ldapUser, err := m.auth.GetLdapUser(basic.Username)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("error retrieving ldap user details: %w", err) tlog.App.Debug().Err(err).Msg("Error retrieving LDAP user details")
c.Next()
return
} }
userContext.LDAP = &model.LDAPContext{ c.Set("context", &config.UserContext{
BaseContext: model.BaseContext{
Username: basic.Username, Username: basic.Username,
Name: utils.Capitalize(basic.Username), Name: utils.Capitalize(basic.Username),
Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain), Email: utils.CompileUserEmail(basic.Username, m.config.CookieDomain),
}, Provider: "ldap",
Groups: user.Groups, IsLoggedIn: true,
} LdapGroups: strings.Join(ldapUser.Groups, ","),
userContext.Provider = model.ProviderLDAP IsBasicAuth: true,
})
c.Next()
return
} }
userContext.Authenticated = true c.Next()
return userContext, nil, nil }
} }
func (m *ContextMiddleware) isIgnorePath(path string) bool { func (m *ContextMiddleware) isIgnorePath(path string) bool {
-23
View File
@@ -1,23 +0,0 @@
package model
const DefaultNamePrefix = "TINYAUTH_"
const APIServer = "https://api.tinyauth.app"
type Claims struct {
Sub string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
PreferredUsername string `json:"preferred_username"`
Groups any `json:"groups"`
}
var OverrideProviders = map[string]string{
"google": "Google",
"github": "GitHub",
}
const SessionCookieName = "tinyauth-session"
const CSRFCookieName = "tinyauth-csrf"
const RedirectCookieName = "tinyauth-redirect"
const OAuthSessionCookieName = "tinyauth-oauth"
-206
View File
@@ -1,206 +0,0 @@
package model
import (
"errors"
"strings"
"github.com/gin-gonic/gin"
"github.com/tinyauthapp/tinyauth/internal/repository"
)
type ProviderType int
const (
ProviderLocal ProviderType = iota
ProviderBasicAuth
ProviderOAuth
ProviderLDAP
)
type UserContext struct {
Authenticated bool
Provider ProviderType
Local *LocalContext
OAuth *OAuthContext
LDAP *LDAPContext
}
type BaseContext struct {
Username string
Name string
Email string
}
type LocalContext struct {
BaseContext
TOTPPending bool
TOTPEnabled bool
Attributes UserAttributes
}
type OAuthContext struct {
BaseContext
Groups []string
Sub string
DisplayName string
ID string
}
type LDAPContext struct {
BaseContext
Groups []string
}
func (c *UserContext) IsAuthenticated() bool {
return c.Authenticated
}
func (c *UserContext) IsLocal() bool {
return c.Provider == ProviderLocal
}
func (c *UserContext) IsOAuth() bool {
return c.Provider == ProviderOAuth
}
func (c *UserContext) IsLDAP() bool {
return c.Provider == ProviderLDAP
}
func (c *UserContext) IsBasicAuth() bool {
return c.Provider == ProviderBasicAuth
}
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
userContextValue, exists := ginctx.Get("context")
if !exists {
return nil, errors.New("failed to get user context")
}
userContext, ok := userContextValue.(*UserContext)
if !ok {
return nil, errors.New("invalid user context type")
}
*c = *userContext
return c, nil
}
// Compatability layer until we get an excuse to drop in database migrations
func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext, error) {
switch session.Provider {
case "local":
c.Provider = ProviderLocal
c.Local = &LocalContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
TOTPPending: session.TotpPending,
}
case "ldap":
c.Provider = ProviderLDAP
c.LDAP = &LDAPContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
}
// By default we assume an unkown name which is oauth
default:
c.Provider = ProviderOAuth
c.OAuth = &OAuthContext{
BaseContext: BaseContext{
Username: session.Username,
Name: session.Name,
Email: session.Email,
},
Groups: strings.Split(session.OAuthGroups, ","),
Sub: session.OAuthSub,
DisplayName: session.OAuthName,
ID: session.Provider,
}
}
if !session.TotpPending {
c.Authenticated = true
}
return c, nil
}
func (c *UserContext) GetUsername() string {
switch c.Provider {
case ProviderLocal:
return c.Local.Username
case ProviderLDAP:
return c.LDAP.Username
case ProviderBasicAuth:
return c.Local.Username
case ProviderOAuth:
return c.OAuth.Username
default:
return ""
}
}
func (c *UserContext) GetEmail() string {
switch c.Provider {
case ProviderLocal:
return c.Local.Email
case ProviderLDAP:
return c.LDAP.Email
case ProviderBasicAuth:
return c.Local.Email
case ProviderOAuth:
return c.OAuth.Email
default:
return ""
}
}
func (c *UserContext) GetName() string {
switch c.Provider {
case ProviderLocal:
return c.Local.Name
case ProviderLDAP:
return c.LDAP.Name
case ProviderBasicAuth:
return c.Local.Name
case ProviderOAuth:
return c.OAuth.Name
default:
return ""
}
}
func (c *UserContext) ProviderName() string {
switch c.Provider {
case ProviderBasicAuth, ProviderLocal:
return "local"
case ProviderLDAP:
return "ldap"
case ProviderOAuth:
return c.OAuth.DisplayName // compatability
default:
return "unknown"
}
}
func (c *UserContext) TOTPPending() bool {
if c.Provider == ProviderLocal {
return c.Local.TOTPPending
}
return false
}
func (c *UserContext) OAuthName() string {
if c.Provider == ProviderOAuth {
return c.OAuth.DisplayName
}
return ""
}
-25
View File
@@ -1,25 +0,0 @@
package model
type UserSearchType int
const (
UserLocal UserSearchType = iota
UserLDAP
)
type LDAPUser struct {
DN string
Groups []string
}
type LocalUser struct {
Username string
Password string
TOTPSecret string
Attributes UserAttributes
}
type UserSearch struct {
Username string
Type UserSearchType
}
-5
View File
@@ -1,5 +0,0 @@
package model
var Version = "development"
var CommitHash = "development"
var BuildTimestamp = "0000-00-00T00:00:00Z"
+9 -9
View File
@@ -4,20 +4,20 @@ import (
"errors" "errors"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
) )
type LabelProvider interface { type LabelProvider interface {
GetLabels(appDomain string) (*model.App, error) GetLabels(appDomain string) (config.App, error)
} }
type AccessControlsService struct { type AccessControlsService struct {
labelProvider LabelProvider labelProvider LabelProvider
static map[string]model.App static map[string]config.App
} }
func NewAccessControlsService(labelProvider LabelProvider, static map[string]model.App) *AccessControlsService { func NewAccessControlsService(labelProvider LabelProvider, static map[string]config.App) *AccessControlsService {
return &AccessControlsService{ return &AccessControlsService{
labelProvider: labelProvider, labelProvider: labelProvider,
static: static, static: static,
@@ -28,22 +28,22 @@ func (acls *AccessControlsService) Init() error {
return nil // No initialization needed return nil // No initialization needed
} }
func (acls *AccessControlsService) lookupStaticACLs(domain string) (*model.App, error) { func (acls *AccessControlsService) lookupStaticACLs(domain string) (config.App, error) {
for app, config := range acls.static { for app, config := range acls.static {
if config.Config.Domain == domain { if config.Config.Domain == domain {
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain") tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
return &config, nil return config, nil
} }
if strings.SplitN(domain, ".", 2)[0] == app { if strings.SplitN(domain, ".", 2)[0] == app {
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name") tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
return &config, nil return config, nil
} }
} }
return nil, errors.New("no results") return config.App{}, errors.New("no results")
} }
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) { func (acls *AccessControlsService) GetAccessControls(domain string) (config.App, error) {
// First check in the static config // First check in the static config
app, err := acls.lookupStaticACLs(domain) app, err := acls.lookupStaticACLs(domain)
+136 -148
View File
@@ -5,13 +5,12 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"net/http"
"regexp" "regexp"
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -30,10 +29,6 @@ const MaxOAuthPendingSessions = 256
const OAuthCleanupCount = 16 const OAuthCleanupCount = 16
const MaxLoginAttemptRecords = 256 const MaxLoginAttemptRecords = 256
var (
ErrUserNotFound = errors.New("user not found")
)
// slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all // slightly modified version of the AuthorizeRequest from the OIDC service to basically accept all
// parameters and pass them to the authorize page if needed // parameters and pass them to the authorize page if needed
type OAuthURLParams struct { type OAuthURLParams struct {
@@ -73,7 +68,7 @@ type Lockdown struct {
} }
type AuthServiceConfig struct { type AuthServiceConfig struct {
LocalUsers []model.LocalUser Users []config.User
OauthWhitelist []string OauthWhitelist []string
SessionExpiry int SessionExpiry int
SessionMaxLifetime int SessionMaxLifetime int
@@ -82,7 +77,7 @@ type AuthServiceConfig struct {
LoginTimeout int LoginTimeout int
LoginMaxRetries int LoginMaxRetries int
SessionCookieName string SessionCookieName string
IP model.IPConfig IP config.IPConfig
LDAPGroupsCacheTTL int LDAPGroupsCacheTTL int
} }
@@ -111,7 +106,7 @@ func NewAuthService(config AuthServiceConfig, ldap *LdapService, queries *reposi
ldap: ldap, ldap: ldap,
queries: queries, queries: queries,
oauthBroker: oauthBroker, oauthBroker: oauthBroker,
} }
} }
func (auth *AuthService) Init() error { func (auth *AuthService) Init() error {
@@ -119,67 +114,79 @@ func (auth *AuthService) Init() error {
return nil return nil
} }
func (auth *AuthService) SearchUser(username string) (*model.UserSearch, error) { func (auth *AuthService) SearchUser(username string) config.UserSearch {
if auth.GetLocalUser(username).Username != "" { if auth.GetLocalUser(username).Username != "" {
return &model.UserSearch{ return config.UserSearch{
Username: username, Username: username,
Type: model.UserLocal, Type: "local",
}, nil }
} }
if auth.ldap.IsConfigured() { if auth.ldap.IsConfigured() {
userDN, err := auth.ldap.GetUserDN(username) userDN, err := auth.ldap.GetUserDN(username)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get ldap user: %w", err) tlog.App.Warn().Err(err).Str("username", username).Msg("Failed to search for user in LDAP")
return config.UserSearch{
Type: "unknown",
}
} }
return &model.UserSearch{ return config.UserSearch{
Username: userDN, Username: userDN,
Type: model.UserLDAP, Type: "ldap",
}, nil }
} }
return nil, ErrUserNotFound return config.UserSearch{
Type: "unknown",
}
} }
func (auth *AuthService) CheckUserPassword(search model.UserSearch, password string) error { func (auth *AuthService) VerifyUser(search config.UserSearch, password string) bool {
switch search.Type { switch search.Type {
case model.UserLocal: case "local":
user := auth.GetLocalUser(search.Username) user := auth.GetLocalUser(search.Username)
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) return auth.CheckPassword(user, password)
case model.UserLDAP: case "ldap":
if auth.ldap.IsConfigured() { if auth.ldap.IsConfigured() {
err := auth.ldap.Bind(search.Username, password) err := auth.ldap.Bind(search.Username, password)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap user: %w", err) tlog.App.Warn().Err(err).Str("username", search.Username).Msg("Failed to bind to LDAP")
return false
} }
err = auth.ldap.BindService(true) err = auth.ldap.BindService(true)
if err != nil { if err != nil {
return fmt.Errorf("failed to bind to ldap service account: %w", err) tlog.App.Error().Err(err).Msg("Failed to rebind with service account after user authentication")
return false
} }
return nil return true
} }
default: default:
return errors.New("unknown user search type") tlog.App.Debug().Str("type", search.Type).Msg("Unknown user type for authentication")
return false
} }
return errors.New("user authentication failed")
tlog.App.Warn().Str("username", search.Username).Msg("User authentication failed")
return false
} }
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser { func (auth *AuthService) GetLocalUser(username string) config.User {
for _, user := range auth.config.LocalUsers { for _, user := range auth.config.Users {
if user.Username == username { if user.Username == username {
return &user return user
} }
} }
return nil
tlog.App.Warn().Str("username", username).Msg("Local user not found")
return config.User{}
} }
func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) { func (auth *AuthService) GetLdapUser(userDN string) (config.LdapUser, error) {
if !auth.ldap.IsConfigured() { if !auth.ldap.IsConfigured() {
return nil, errors.New("ldap service not configured") return config.LdapUser{}, errors.New("LDAP service not initialized")
} }
auth.ldapGroupsMutex.RLock() auth.ldapGroupsMutex.RLock()
@@ -187,7 +194,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
auth.ldapGroupsMutex.RUnlock() auth.ldapGroupsMutex.RUnlock()
if exists && time.Now().Before(entry.Expires) { if exists && time.Now().Before(entry.Expires) {
return &model.LDAPUser{ return config.LdapUser{
DN: userDN, DN: userDN,
Groups: entry.Groups, Groups: entry.Groups,
}, nil }, nil
@@ -196,7 +203,7 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
groups, err := auth.ldap.GetUserGroups(userDN) groups, err := auth.ldap.GetUserGroups(userDN)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get ldap groups: %w", err) return config.LdapUser{}, err
} }
auth.ldapGroupsMutex.Lock() auth.ldapGroupsMutex.Lock()
@@ -206,12 +213,16 @@ func (auth *AuthService) GetLDAPUser(userDN string) (*model.LDAPUser, error) {
} }
auth.ldapGroupsMutex.Unlock() auth.ldapGroupsMutex.Unlock()
return &model.LDAPUser{ return config.LdapUser{
DN: userDN, DN: userDN,
Groups: groups, Groups: groups,
}, nil }, nil
} }
func (auth *AuthService) CheckPassword(user config.User, password string) bool {
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)) == nil
}
func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) { func (auth *AuthService) IsAccountLocked(identifier string) (bool, int) {
auth.loginMutex.RLock() auth.loginMutex.RLock()
defer auth.loginMutex.RUnlock() defer auth.loginMutex.RUnlock()
@@ -280,11 +291,11 @@ func (auth *AuthService) IsEmailWhitelisted(email string) bool {
return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email) return utils.CheckFilter(strings.Join(auth.config.OauthWhitelist, ","), email)
} }
func (auth *AuthService) CreateSession(ctx context.Context, data repository.Session) (*http.Cookie, error) { func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *repository.Session) error {
uuid, err := uuid.NewRandom() uuid, err := uuid.NewRandom()
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to generate session uuid: %w", err) return err
} }
var expiry int var expiry int
@@ -309,30 +320,28 @@ func (auth *AuthService) CreateSession(ctx context.Context, data repository.Sess
OAuthSub: data.OAuthSub, OAuthSub: data.OAuthSub,
} }
_, err = auth.queries.CreateSession(ctx, session) _, err = auth.queries.CreateSession(c, session)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to create session entry: %w", err) return err
} }
return &http.Cookie{ c.SetCookie(auth.config.SessionCookieName, session.UUID, expiry, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Name: auth.config.SessionCookieName,
Value: session.UUID, return nil
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(expiry) * time.Second),
MaxAge: expiry,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
} }
func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) RefreshSessionCookie(c *gin.Context) error {
session, err := auth.queries.GetSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to retrieve session: %w", err) return err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil {
return err
} }
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
@@ -346,12 +355,12 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
} }
if session.Expiry-currentTime > refreshThreshold { if session.Expiry-currentTime > refreshThreshold {
return nil, nil return nil
} }
newExpiry := session.Expiry + refreshThreshold newExpiry := session.Expiry + refreshThreshold
_, err = auth.queries.UpdateSession(ctx, repository.UpdateSessionParams{ _, err = auth.queries.UpdateSession(c, repository.UpdateSessionParams{
Username: session.Username, Username: session.Username,
Email: session.Email, Email: session.Email,
Name: session.Name, Name: session.Name,
@@ -365,121 +374,120 @@ func (auth *AuthService) RefreshSession(ctx context.Context, uuid string) (*http
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to update session expiry: %w", err) return err
} }
return &http.Cookie{ c.SetCookie(auth.config.SessionCookieName, cookie, int(newExpiry-currentTime), "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Name: auth.config.SessionCookieName, tlog.App.Trace().Str("username", session.Username).Msg("Session cookie refreshed")
Value: session.UUID,
Path: "/",
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain),
Expires: time.Now().Add(time.Duration(newExpiry-currentTime) * time.Second),
MaxAge: auth.config.SessionExpiry,
Secure: auth.config.SecureCookie,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
}, nil
return nil
} }
func (auth *AuthService) DeleteSession(ctx context.Context, uuid string) (*http.Cookie, error) { func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
err := auth.queries.DeleteSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil { if err != nil {
tlog.App.Warn().Err(err).Msg("Failed to delete session from database, proceeding to clear cookie anyway") return err
} }
return &http.Cookie{ err = auth.queries.DeleteSession(c, cookie)
Name: auth.config.SessionCookieName,
Value: "", if err != nil {
Path: "/", return err
Domain: fmt.Sprintf(".%s", auth.config.CookieDomain), }
Expires: time.Now(),
MaxAge: -1, c.SetCookie(auth.config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.config.CookieDomain), auth.config.SecureCookie, true)
Secure: auth.config.SecureCookie,
HttpOnly: true, return nil
SameSite: http.SameSiteLaxMode,
}, nil
} }
func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*repository.Session, error) { func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, error) {
session, err := auth.queries.GetSession(ctx, uuid) cookie, err := c.Cookie(auth.config.SessionCookieName)
if err != nil {
return repository.Session{}, err
}
session, err := auth.queries.GetSession(c, cookie)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("session not found") return repository.Session{}, fmt.Errorf("session not found")
} }
return nil, err return repository.Session{}, err
} }
currentTime := time.Now().Unix() currentTime := time.Now().Unix()
if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 { if auth.config.SessionMaxLifetime != 0 && session.CreatedAt != 0 {
if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) { if currentTime-session.CreatedAt > int64(auth.config.SessionMaxLifetime) {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(c, cookie)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) tlog.App.Error().Err(err).Msg("Failed to delete session exceeding max lifetime")
} }
return nil, fmt.Errorf("session max lifetime exceeded") return repository.Session{}, fmt.Errorf("session expired due to max lifetime exceeded")
} }
} }
if currentTime > session.Expiry { if currentTime > session.Expiry {
err = auth.queries.DeleteSession(ctx, uuid) err = auth.queries.DeleteSession(c, cookie)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to delete expired session: %w", err) tlog.App.Error().Err(err).Msg("Failed to delete expired session")
} }
return nil, fmt.Errorf("session expired") return repository.Session{}, fmt.Errorf("session expired")
} }
return &session, nil return repository.Session{
UUID: session.UUID,
Username: session.Username,
Email: session.Email,
Name: session.Name,
Provider: session.Provider,
TotpPending: session.TotpPending,
OAuthGroups: session.OAuthGroups,
OAuthName: session.OAuthName,
OAuthSub: session.OAuthSub,
}, nil
} }
func (auth *AuthService) LocalAuthConfigured() bool { func (auth *AuthService) LocalAuthConfigured() bool {
return len(auth.config.LocalUsers) > 0 return len(auth.config.Users) > 0
} }
func (auth *AuthService) LDAPAuthConfigured() bool { func (auth *AuthService) LdapAuthConfigured() bool {
return auth.ldap.IsConfigured() return auth.ldap.IsConfigured()
} }
func (auth *AuthService) IsUserAllowed(c *gin.Context, context model.UserContext, acls *model.App) bool { func (auth *AuthService) IsUserAllowed(c *gin.Context, context config.UserContext, acls config.App) bool {
if acls == nil { if context.OAuth {
return true
}
if context.Provider == model.ProviderOAuth {
tlog.App.Debug().Msg("Checking OAuth whitelist") tlog.App.Debug().Msg("Checking OAuth whitelist")
return utils.CheckFilter(acls.OAuth.Whitelist, context.OAuth.Email) return utils.CheckFilter(acls.OAuth.Whitelist, context.Email)
} }
if acls.Users.Block != "" { if acls.Users.Block != "" {
tlog.App.Debug().Msg("Checking blocked users") tlog.App.Debug().Msg("Checking blocked users")
if utils.CheckFilter(acls.Users.Block, context.GetUsername()) { if utils.CheckFilter(acls.Users.Block, context.Username) {
return false return false
} }
} }
tlog.App.Debug().Msg("Checking users") tlog.App.Debug().Msg("Checking users")
return utils.CheckFilter(acls.Users.Allow, context.GetUsername()) return utils.CheckFilter(acls.Users.Allow, context.Username)
} }
func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" { if requiredGroups == "" {
return true return true
} }
if !context.IsOAuth() { for id := range config.OverrideProviders {
tlog.App.Debug().Msg("User is not an OAuth user, skipping OAuth group check") if context.Provider == id {
return false tlog.App.Info().Str("provider", id).Msg("OAuth groups not supported for this provider")
}
if _, ok := model.OverrideProviders[context.OAuth.ID]; ok {
tlog.App.Debug().Msg("Provider override for OAuth groups enabled, skipping group check")
return true return true
} }
}
for _, userGroup := range context.OAuth.Groups { for userGroup := range strings.SplitSeq(context.OAuthGroups, ",") {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true return true
@@ -490,17 +498,12 @@ func (auth *AuthService) IsInOAuthGroup(c *gin.Context, context model.UserContex
return false return false
} }
func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext, requiredGroups string) bool { func (auth *AuthService) IsInLdapGroup(c *gin.Context, context config.UserContext, requiredGroups string) bool {
if requiredGroups == "" { if requiredGroups == "" {
return true return true
} }
if !context.IsLDAP() { for userGroup := range strings.SplitSeq(context.LdapGroups, ",") {
tlog.App.Debug().Msg("User is not an LDAP user, skipping LDAP group check")
return false
}
for _, userGroup := range context.LDAP.Groups {
if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) { if utils.CheckFilter(requiredGroups, strings.TrimSpace(userGroup)) {
tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched") tlog.App.Trace().Str("group", userGroup).Str("required", requiredGroups).Msg("User group matched")
return true return true
@@ -511,11 +514,7 @@ func (auth *AuthService) IsInLDAPGroup(c *gin.Context, context model.UserContext
return false return false
} }
func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, error) { func (auth *AuthService) IsAuthEnabled(uri string, path config.AppPath) (bool, error) {
if path == nil {
return true, nil
}
// Check for block list // Check for block list
if path.Block != "" { if path.Block != "" {
regex, err := regexp.Compile(path.Block) regex, err := regexp.Compile(path.Block)
@@ -545,26 +544,19 @@ func (auth *AuthService) IsAuthEnabled(uri string, path *model.AppPath) (bool, e
return true, nil return true, nil
} }
// local user is used only as a medium to pass the basic auth credentials, user can be ldap too func (auth *AuthService) GetBasicAuth(c *gin.Context) *config.User {
func (auth *AuthService) GetBasicAuth(req *http.Request) (*model.LocalUser, error) { username, password, ok := c.Request.BasicAuth()
if req == nil {
return nil, errors.New("request is nil")
}
username, password, ok := req.BasicAuth()
if !ok { if !ok {
return nil, errors.New("no basic auth credentials provided") tlog.App.Debug().Msg("No basic auth provided")
return nil
} }
return &model.LocalUser{ return &config.User{
Username: username, Username: username,
Password: password, Password: password,
}, nil }
} }
func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool { func (auth *AuthService) CheckIP(acls config.AppIP, ip string) bool {
if acls == nil {
acls = &model.AppIP{}
}
// Merge the global and app IP filter // Merge the global and app IP filter
blockedIps := append(auth.config.IP.Block, acls.Block...) blockedIps := append(auth.config.IP.Block, acls.Block...)
allowedIPs := append(auth.config.IP.Allow, acls.Allow...) allowedIPs := append(auth.config.IP.Allow, acls.Allow...)
@@ -602,11 +594,7 @@ func (auth *AuthService) CheckIP(acls *model.AppIP, ip string) bool {
return true return true
} }
func (auth *AuthService) IsBypassedIP(acls *model.AppIP, ip string) bool { func (auth *AuthService) IsBypassedIP(acls config.AppIP, ip string) bool {
if acls == nil {
return false
}
for _, bypassed := range acls.Bypass { for _, bypassed := range acls.Bypass {
res, err := utils.FilterIP(bypassed, ip) res, err := utils.FilterIP(bypassed, ip)
if err != nil { if err != nil {
@@ -686,21 +674,21 @@ func (auth *AuthService) GetOAuthToken(sessionId string, code string) (*oauth2.T
return token, nil return token, nil
} }
func (auth *AuthService) GetOAuthUserinfo(sessionId string) (*model.Claims, error) { func (auth *AuthService) GetOAuthUserinfo(sessionId string) (config.Claims, error) {
session, err := auth.GetOAuthPendingSession(sessionId) session, err := auth.GetOAuthPendingSession(sessionId)
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
if session.Token == nil { if session.Token == nil {
return nil, fmt.Errorf("oauth token not found for session: %s", sessionId) return config.Claims{}, fmt.Errorf("oauth token not found for session: %s", sessionId)
} }
userinfo, err := (*session.Service).GetUserinfo(session.Token) userinfo, err := (*session.Service).GetUserinfo(session.Token)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get userinfo: %w", err) return config.Claims{}, fmt.Errorf("failed to get userinfo: %w", err)
} }
return userinfo, nil return userinfo, nil
+10 -10
View File
@@ -4,7 +4,7 @@ import (
"context" "context"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -66,41 +66,41 @@ func (docker *DockerService) inspectContainer(containerId string) (container.Ins
return inspect, nil return inspect, nil
} }
func (docker *DockerService) GetLabels(appDomain string) (*model.App, error) { func (docker *DockerService) GetLabels(appDomain string) (config.App, error) {
if !docker.isConnected { if !docker.isConnected {
tlog.App.Debug().Msg("Docker not connected, returning empty labels") tlog.App.Debug().Msg("Docker not connected, returning empty labels")
return nil, nil return config.App{}, nil
} }
containers, err := docker.getContainers() containers, err := docker.getContainers()
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
for _, ctr := range containers { for _, ctr := range containers {
inspect, err := docker.inspectContainer(ctr.ID) inspect, err := docker.inspectContainer(ctr.ID)
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
labels, err := decoders.DecodeLabels[model.Apps](inspect.Config.Labels, "apps") labels, err := decoders.DecodeLabels[config.Apps](inspect.Config.Labels, "apps")
if err != nil { if err != nil {
return nil, err return config.App{}, err
} }
for appName, appLabels := range labels.Apps { for appName, appLabels := range labels.Apps {
if appLabels.Config.Domain == appDomain { if appLabels.Config.Domain == appDomain {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by domain")
return &appLabels, nil return appLabels, nil
} }
if strings.SplitN(appDomain, ".", 2)[0] == appName { if strings.SplitN(appDomain, ".", 2)[0] == appName {
tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name") tlog.App.Debug().Str("id", inspect.ID).Str("name", inspect.Name).Msg("Found matching container by app name")
return &appLabels, nil return appLabels, nil
} }
} }
} }
tlog.App.Debug().Msg("No matching container found, returning empty labels") tlog.App.Debug().Msg("No matching container found, returning empty labels")
return nil, nil return config.App{}, nil
} }
+13 -12
View File
@@ -7,7 +7,7 @@ import (
"sync" "sync"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/decoders" "github.com/tinyauthapp/tinyauth/internal/utils/decoders"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -32,7 +32,7 @@ type ingressAppKey struct {
type ingressApp struct { type ingressApp struct {
domain string domain string
appName string appName string
app model.App app config.App
} }
type KubernetesService struct { type KubernetesService struct {
@@ -89,7 +89,7 @@ func (k *KubernetesService) removeIngress(namespace, name string) {
} }
} }
func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) { func (k *KubernetesService) getByDomain(domain string) (config.App, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
@@ -97,15 +97,15 @@ func (k *KubernetesService) getByDomain(domain string) (*model.App, bool) {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps { for _, app := range apps {
if app.domain == domain && app.appName == appKey.appName { if app.domain == domain && app.appName == appKey.appName {
return &app.app, true return app.app, true
} }
} }
} }
} }
return nil, false return config.App{}, false
} }
func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) { func (k *KubernetesService) getByAppName(appName string) (config.App, bool) {
k.mu.RLock() k.mu.RLock()
defer k.mu.RUnlock() defer k.mu.RUnlock()
@@ -113,12 +113,12 @@ func (k *KubernetesService) getByAppName(appName string) (*model.App, bool) {
if apps, ok := k.ingressApps[appKey.ingressKey]; ok { if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
for _, app := range apps { for _, app := range apps {
if app.appName == appName { if app.appName == appName {
return &app.app, true return app.app, true
} }
} }
} }
} }
return nil, false return config.App{}, false
} }
func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) { func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
@@ -129,7 +129,7 @@ func (k *KubernetesService) updateFromItem(item *unstructured.Unstructured) {
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
return return
} }
labels, err := decoders.DecodeLabels[model.Apps](annotations, "apps") labels, err := decoders.DecodeLabels[config.Apps](annotations, "apps")
if err != nil { if err != nil {
tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations") tlog.App.Debug().Err(err).Msg("Failed to decode labels from annotations")
k.removeIngress(namespace, name) k.removeIngress(namespace, name)
@@ -280,10 +280,10 @@ func (k *KubernetesService) Init() error {
return nil return nil
} }
func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) { func (k *KubernetesService) GetLabels(appDomain string) (config.App, error) {
if !k.started { if !k.started {
tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels") tlog.App.Debug().Msg("Kubernetes not connected, returning empty labels")
return nil, nil return config.App{}, nil
} }
// First check cache // First check cache
@@ -298,5 +298,6 @@ func (k *KubernetesService) GetLabels(appDomain string) (*model.App, error) {
} }
tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found") tlog.App.Debug().Str("domain", appDomain).Msg("Cache miss, no matching ingress found")
return nil, nil return config.App{}, nil
} }
+5 -5
View File
@@ -1,7 +1,7 @@
package service package service
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"slices" "slices"
@@ -15,20 +15,20 @@ type OAuthServiceImpl interface {
NewRandom() string NewRandom() string
GetAuthURL(state string, verifier string) string GetAuthURL(state string, verifier string) string
GetToken(code string, verifier string) (*oauth2.Token, error) GetToken(code string, verifier string) (*oauth2.Token, error)
GetUserinfo(token *oauth2.Token) (*model.Claims, error) GetUserinfo(token *oauth2.Token) (config.Claims, error)
} }
type OAuthBrokerService struct { type OAuthBrokerService struct {
services map[string]OAuthServiceImpl services map[string]OAuthServiceImpl
configs map[string]model.OAuthServiceConfig configs map[string]config.OAuthServiceConfig
} }
var presets = map[string]func(config model.OAuthServiceConfig) *OAuthService{ var presets = map[string]func(config config.OAuthServiceConfig) *OAuthService{
"github": newGitHubOAuthService, "github": newGitHubOAuthService,
"google": newGoogleOAuthService, "google": newGoogleOAuthService,
} }
func NewOAuthBrokerService(configs map[string]model.OAuthServiceConfig) *OAuthBrokerService { func NewOAuthBrokerService(configs map[string]config.OAuthServiceConfig) *OAuthBrokerService {
return &OAuthBrokerService{ return &OAuthBrokerService{
services: make(map[string]OAuthServiceImpl), services: make(map[string]OAuthServiceImpl),
configs: configs, configs: configs,
+19 -19
View File
@@ -8,7 +8,7 @@ import (
"net/http" "net/http"
"strconv" "strconv"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
type GithubEmailResponse []struct { type GithubEmailResponse []struct {
@@ -22,32 +22,32 @@ type GithubUserInfoResponse struct {
ID int `json:"id"` ID int `json:"id"`
} }
func defaultExtractor(client *http.Client, url string) (*model.Claims, error) { func defaultExtractor(client *http.Client, url string) (config.Claims, error) {
return simpleReq[model.Claims](client, url, nil) return simpleReq[config.Claims](client, url, nil)
} }
func githubExtractor(client *http.Client, url string) (*model.Claims, error) { func githubExtractor(client *http.Client, url string) (config.Claims, error) {
var user model.Claims var user config.Claims
userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{ userInfo, err := simpleReq[GithubUserInfoResponse](client, "https://api.github.com/user", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{ userEmails, err := simpleReq[GithubEmailResponse](client, "https://api.github.com/user/emails", map[string]string{
"accept": "application/vnd.github+json", "accept": "application/vnd.github+json",
}) })
if err != nil { if err != nil {
return nil, err return config.Claims{}, err
} }
if len(*userEmails) == 0 { if len(userEmails) == 0 {
return nil, errors.New("no emails found") return user, errors.New("no emails found")
} }
for _, email := range *userEmails { for _, email := range userEmails {
if email.Primary { if email.Primary {
user.Email = email.Email user.Email = email.Email
break break
@@ -56,22 +56,22 @@ func githubExtractor(client *http.Client, url string) (*model.Claims, error) {
// Use first available email if no primary email was found // Use first available email if no primary email was found
if user.Email == "" { if user.Email == "" {
user.Email = (*userEmails)[0].Email user.Email = userEmails[0].Email
} }
user.PreferredUsername = userInfo.Login user.PreferredUsername = userInfo.Login
user.Name = userInfo.Name user.Name = userInfo.Name
user.Sub = strconv.Itoa(userInfo.ID) user.Sub = strconv.Itoa(userInfo.ID)
return &user, nil return user, nil
} }
func simpleReq[T any](client *http.Client, url string, headers map[string]string) (*T, error) { func simpleReq[T any](client *http.Client, url string, headers map[string]string) (T, error) {
var decodedRes T var decodedRes T
req, err := http.NewRequest("GET", url, nil) req, err := http.NewRequest("GET", url, nil)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
for key, value := range headers { for key, value := range headers {
@@ -80,23 +80,23 @@ func simpleReq[T any](client *http.Client, url string, headers map[string]string
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode < 200 || res.StatusCode >= 300 { if res.StatusCode < 200 || res.StatusCode >= 300 {
return nil, fmt.Errorf("request failed with status: %s", res.Status) return decodedRes, fmt.Errorf("request failed with status: %s", res.Status)
} }
body, err := io.ReadAll(res.Body) body, err := io.ReadAll(res.Body)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
err = json.Unmarshal(body, &decodedRes) err = json.Unmarshal(body, &decodedRes)
if err != nil { if err != nil {
return nil, err return decodedRes, err
} }
return &decodedRes, nil return decodedRes, nil
} }
+3 -3
View File
@@ -1,11 +1,11 @@
package service package service
import ( import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGoogleOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"openid", "email", "profile"} scopes := []string{"openid", "email", "profile"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.Google.AuthURL config.AuthURL = endpoints.Google.AuthURL
@@ -14,7 +14,7 @@ func newGoogleOAuthService(config model.OAuthServiceConfig) *OAuthService {
return NewOAuthService(config, "google") return NewOAuthService(config, "google")
} }
func newGitHubOAuthService(config model.OAuthServiceConfig) *OAuthService { func newGitHubOAuthService(config config.OAuthServiceConfig) *OAuthService {
scopes := []string{"read:user", "user:email"} scopes := []string{"read:user", "user:email"}
config.Scopes = scopes config.Scopes = scopes
config.AuthURL = endpoints.GitHub.AuthURL config.AuthURL = endpoints.GitHub.AuthURL
+5 -5
View File
@@ -6,21 +6,21 @@ import (
"net/http" "net/http"
"time" "time"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type UserinfoExtractor func(client *http.Client, url string) (*model.Claims, error) type UserinfoExtractor func(client *http.Client, url string) (config.Claims, error)
type OAuthService struct { type OAuthService struct {
serviceCfg model.OAuthServiceConfig serviceCfg config.OAuthServiceConfig
config *oauth2.Config config *oauth2.Config
ctx context.Context ctx context.Context
userinfoExtractor UserinfoExtractor userinfoExtractor UserinfoExtractor
id string id string
} }
func NewOAuthService(config model.OAuthServiceConfig, id string) *OAuthService { func NewOAuthService(config config.OAuthServiceConfig, id string) *OAuthService {
httpClient := &http.Client{ httpClient := &http.Client{
Timeout: 30 * time.Second, Timeout: 30 * time.Second,
Transport: &http.Transport{ Transport: &http.Transport{
@@ -78,7 +78,7 @@ func (s *OAuthService) GetToken(code string, verifier string) (*oauth2.Token, er
return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier)) return s.config.Exchange(s.ctx, code, oauth2.VerifierOption(verifier))
} }
func (s *OAuthService) GetUserinfo(token *oauth2.Token) (*model.Claims, error) { func (s *OAuthService) GetUserinfo(token *oauth2.Token) (config.Claims, error) {
client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token)) client := oauth2.NewClient(s.ctx, oauth2.StaticTokenSource(token))
return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL) return s.userinfoExtractor(client, s.serviceCfg.UserinfoURL)
} }
+37 -39
View File
@@ -22,7 +22,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/go-jose/go-jose/v4" "github.com/go-jose/go-jose/v4"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
@@ -87,7 +87,7 @@ type UserinfoResponse struct {
EmailVerified bool `json:"email_verified,omitempty"` EmailVerified bool `json:"email_verified,omitempty"`
PhoneNumber string `json:"phone_number,omitempty"` PhoneNumber string `json:"phone_number,omitempty"`
PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"` PhoneNumberVerified *bool `json:"phone_number_verified,omitempty"`
Address *model.AddressClaim `json:"address,omitempty"` Address *config.AddressClaim `json:"address,omitempty"`
UpdatedAt int64 `json:"updated_at"` UpdatedAt int64 `json:"updated_at"`
} }
@@ -112,7 +112,7 @@ type AuthorizeRequest struct {
} }
type OIDCServiceConfig struct { type OIDCServiceConfig struct {
Clients map[string]model.OIDCClientConfig Clients map[string]config.OIDCClientConfig
PrivateKeyPath string PrivateKeyPath string
PublicKeyPath string PublicKeyPath string
Issuer string Issuer string
@@ -122,7 +122,7 @@ type OIDCServiceConfig struct {
type OIDCService struct { type OIDCService struct {
config OIDCServiceConfig config OIDCServiceConfig
queries *repository.Queries queries *repository.Queries
clients map[string]model.OIDCClientConfig clients map[string]config.OIDCClientConfig
privateKey *rsa.PrivateKey privateKey *rsa.PrivateKey
publicKey crypto.PublicKey publicKey crypto.PublicKey
issuer string issuer string
@@ -255,7 +255,7 @@ func (service *OIDCService) Init() error {
} }
// We will reorganize the client into a map with the client ID as the key // We will reorganize the client into a map with the client ID as the key
service.clients = make(map[string]model.OIDCClientConfig) service.clients = make(map[string]config.OIDCClientConfig)
for id, client := range service.config.Clients { for id, client := range service.config.Clients {
client.ID = id client.ID = id
@@ -283,7 +283,7 @@ func (service *OIDCService) GetIssuer() string {
return service.issuer return service.issuer
} }
func (service *OIDCService) GetClient(id string) (model.OIDCClientConfig, bool) { func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
client, ok := service.clients[id] client, ok := service.clients[id]
return client, ok return client, ok
} }
@@ -367,45 +367,43 @@ func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, r
return err return err
} }
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext model.UserContext, req AuthorizeRequest) error { func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
userInfoParams := repository.CreateOidcUserInfoParams{ addressJSON, err := json.Marshal(userContext.Attributes.Address)
Sub: sub,
Name: userContext.GetName(),
Email: userContext.GetEmail(),
PreferredUsername: userContext.GetUsername(),
UpdatedAt: time.Now().Unix(),
}
if userContext.IsLocal() {
addressJSON, err := json.Marshal(userContext.Local.Attributes.Address)
if err != nil { if err != nil {
return err return err
} }
userInfoParams.GivenName = userContext.Local.Attributes.GivenName
userInfoParams.FamilyName = userContext.Local.Attributes.FamilyName userInfoParams := repository.CreateOidcUserInfoParams{
userInfoParams.MiddleName = userContext.Local.Attributes.MiddleName Sub: sub,
userInfoParams.Nickname = userContext.Local.Attributes.Nickname Name: userContext.Name,
userInfoParams.Profile = userContext.Local.Attributes.Profile Email: userContext.Email,
userInfoParams.Picture = userContext.Local.Attributes.Picture PreferredUsername: userContext.Username,
userInfoParams.Website = userContext.Local.Attributes.Website UpdatedAt: time.Now().Unix(),
userInfoParams.Gender = userContext.Local.Attributes.Gender GivenName: userContext.Attributes.GivenName,
userInfoParams.Birthdate = userContext.Local.Attributes.Birthdate FamilyName: userContext.Attributes.FamilyName,
userInfoParams.Zoneinfo = userContext.Local.Attributes.Zoneinfo MiddleName: userContext.Attributes.MiddleName,
userInfoParams.Locale = userContext.Local.Attributes.Locale Nickname: userContext.Attributes.Nickname,
userInfoParams.PhoneNumber = userContext.Local.Attributes.PhoneNumber Profile: userContext.Attributes.Profile,
userInfoParams.Address = string(addressJSON) Picture: userContext.Attributes.Picture,
Website: userContext.Attributes.Website,
Gender: userContext.Attributes.Gender,
Birthdate: userContext.Attributes.Birthdate,
Zoneinfo: userContext.Attributes.Zoneinfo,
Locale: userContext.Attributes.Locale,
PhoneNumber: userContext.Attributes.PhoneNumber,
Address: string(addressJSON),
} }
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server // Tinyauth will pass through the groups it got from an LDAP or an OIDC server
if userContext.IsLDAP() { if userContext.Provider == "ldap" {
userInfoParams.Groups = strings.Join(userContext.LDAP.Groups, ",") userInfoParams.Groups = userContext.LdapGroups
} }
if userContext.IsOAuth() { if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
userInfoParams.Groups = strings.Join(userContext.OAuth.Groups, ",") userInfoParams.Groups = userContext.OAuthGroups
} }
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams) _, err = service.queries.CreateOidcUserInfo(c, userInfoParams)
return err return err
} }
@@ -447,7 +445,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
return oidcCode, nil return oidcCode, nil
} }
func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) { func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, user repository.OidcUserinfo, scope string, nonce string) (string, error) {
createdAt := time.Now().Unix() createdAt := time.Now().Unix()
expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix() expiresAt := time.Now().Add(time.Duration(service.config.SessionExpiry) * time.Second).Unix()
@@ -513,7 +511,7 @@ func (service *OIDCService) generateIDToken(client model.OIDCClientConfig, user
return token, nil return token, nil
} }
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client model.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) { func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, codeEntry repository.OidcCode) (TokenResponse, error) {
user, err := service.GetUserinfo(c, codeEntry.Sub) user, err := service.GetUserinfo(c, codeEntry.Sub)
if err != nil { if err != nil {
@@ -587,7 +585,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
return TokenResponse{}, err return TokenResponse{}, err
} }
idToken, err := service.generateIDToken(model.OIDCClientConfig{ idToken, err := service.generateIDToken(config.OIDCClientConfig{
ClientID: entry.ClientID, ClientID: entry.ClientID,
}, user, entry.Scope, entry.Nonce) }, user, entry.Scope, entry.Nonce)
@@ -716,7 +714,7 @@ func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope
} }
if slices.Contains(scopes, "address") { if slices.Contains(scopes, "address") {
var addr model.AddressClaim var addr config.AddressClaim
if err := json.Unmarshal([]byte(user.Address), &addr); err == nil { if err := json.Unmarshal([]byte(user.Address), &addr); err == nil {
userInfo.Address = &addr userInfo.Address = &addr
} }
+18
View File
@@ -7,8 +7,10 @@ import (
"net/url" "net/url"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils/tlog" "github.com/tinyauthapp/tinyauth/internal/utils/tlog"
"github.com/gin-gonic/gin"
"github.com/weppos/publicsuffix-go/publicsuffix" "github.com/weppos/publicsuffix-go/publicsuffix"
) )
@@ -71,6 +73,22 @@ func Filter[T any](slice []T, test func(T) bool) (res []T) {
return res return res
} }
func GetContext(c *gin.Context) (config.UserContext, error) {
userContextValue, exists := c.Get("context")
if !exists {
return config.UserContext{}, errors.New("no user context in request")
}
userContext, ok := userContextValue.(*config.UserContext)
if !ok {
return config.UserContext{}, errors.New("invalid user context in request")
}
return *userContext, nil
}
func IsRedirectSafe(redirectURL string, domain string) bool { func IsRedirectSafe(redirectURL string, domain string) bool {
if redirectURL == "" { if redirectURL == "" {
return false return false
+24
View File
@@ -3,8 +3,10 @@ package utils_test
import ( import (
"testing" "testing"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/gin-gonic/gin"
"gotest.tools/v3/assert" "gotest.tools/v3/assert"
) )
@@ -127,6 +129,28 @@ func TestFilter(t *testing.T) {
assert.DeepEqual(t, expectedStr, resultStr) assert.DeepEqual(t, expectedStr, resultStr)
} }
func TestGetContext(t *testing.T) {
// Setup
gin.SetMode(gin.TestMode)
c, _ := gin.CreateTestContext(nil)
// Normal case
c.Set("context", &config.UserContext{Username: "testuser"})
result, err := utils.GetContext(c)
assert.NilError(t, err)
assert.Equal(t, "testuser", result.Username)
// Case with no context
c.Set("context", nil)
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
// Case with invalid context type
c.Set("context", "invalid type")
_, err = utils.GetContext(c)
assert.Error(t, err, "invalid user context in request")
}
func TestIsRedirectSafe(t *testing.T) { func TestIsRedirectSafe(t *testing.T) {
// Setup // Setup
domain := "example.com" domain := "example.com"
+4 -3
View File
@@ -4,20 +4,21 @@ import (
"fmt" "fmt"
"os" "os"
"github.com/tinyauthapp/tinyauth/internal/config"
"github.com/tinyauthapp/paerser/cli" "github.com/tinyauthapp/paerser/cli"
"github.com/tinyauthapp/paerser/env" "github.com/tinyauthapp/paerser/env"
"github.com/tinyauthapp/tinyauth/internal/model"
) )
type EnvLoader struct{} type EnvLoader struct{}
func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) { func (e *EnvLoader) Load(_ []string, cmd *cli.Command) (bool, error) {
vars := env.FindPrefixedEnvVars(os.Environ(), model.DefaultNamePrefix, cmd.Configuration) vars := env.FindPrefixedEnvVars(os.Environ(), config.DefaultNamePrefix, cmd.Configuration)
if len(vars) == 0 { if len(vars) == 0 {
return false, nil return false, nil
} }
if err := env.Decode(vars, model.DefaultNamePrefix, cmd.Configuration); err != nil { if err := env.Decode(vars, config.DefaultNamePrefix, cmd.Configuration); err != nil {
return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err) return false, fmt.Errorf("failed to decode configuration from environment variables: %w", err)
} }
+13 -13
View File
@@ -7,7 +7,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
type Logger struct { type Logger struct {
@@ -22,7 +22,7 @@ var (
App zerolog.Logger App zerolog.Logger
) )
func NewLogger(cfg model.LogConfig) *Logger { func NewLogger(cfg config.LogConfig) *Logger {
baseLogger := log.With(). baseLogger := log.With().
Timestamp(). Timestamp().
Caller(). Caller().
@@ -44,24 +44,24 @@ func NewLogger(cfg model.LogConfig) *Logger {
} }
func NewSimpleLogger() *Logger { func NewSimpleLogger() *Logger {
return NewLogger(model.LogConfig{ return NewLogger(config.LogConfig{
Level: "info", Level: "info",
Json: false, Json: false,
Streams: model.LogStreams{ Streams: config.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true}, HTTP: config.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true}, App: config.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: false}, Audit: config.LogStreamConfig{Enabled: false},
}, },
}) })
} }
func NewTestLogger() *Logger { func NewTestLogger() *Logger {
return NewLogger(model.LogConfig{ return NewLogger(config.LogConfig{
Level: "trace", Level: "trace",
Streams: model.LogStreams{ Streams: config.LogStreams{
HTTP: model.LogStreamConfig{Enabled: true}, HTTP: config.LogStreamConfig{Enabled: true},
App: model.LogStreamConfig{Enabled: true}, App: config.LogStreamConfig{Enabled: true},
Audit: model.LogStreamConfig{Enabled: true}, Audit: config.LogStreamConfig{Enabled: true},
}, },
}) })
} }
@@ -72,7 +72,7 @@ func (l *Logger) Init() {
App = l.App App = l.App
} }
func createLogger(component string, streamCfg model.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger { func createLogger(component string, streamCfg config.LogStreamConfig, baseLogger zerolog.Logger) zerolog.Logger {
if !streamCfg.Enabled { if !streamCfg.Enabled {
return zerolog.Nop() return zerolog.Nop()
} }
+16 -16
View File
@@ -6,14 +6,14 @@ import (
"net/mail" "net/mail"
"strings" "strings"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/config"
) )
func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { func ParseUsers(usersStr []string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
var users []model.LocalUser var users []config.User
if len(usersStr) == 0 { if len(usersStr) == 0 {
return &users, nil return []config.User{}, nil
} }
for _, user := range usersStr { for _, user := range usersStr {
@@ -22,22 +22,22 @@ func ParseUsers(usersStr []string, userAttributes map[string]model.UserAttribute
} }
parsed, err := ParseUser(strings.TrimSpace(user)) parsed, err := ParseUser(strings.TrimSpace(user))
if err != nil { if err != nil {
return nil, err return []config.User{}, err
} }
if attrs, ok := userAttributes[parsed.Username]; ok { if attrs, ok := userAttributes[parsed.Username]; ok {
parsed.Attributes = attrs parsed.Attributes = attrs
} }
users = append(users, *parsed) users = append(users, parsed)
} }
return &users, nil return users, nil
} }
func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]model.UserAttributes) (*[]model.LocalUser, error) { func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]config.UserAttributes) ([]config.User, error) {
var usersStr []string var usersStr []string
if len(usersCfg) == 0 && usersPath == "" { if len(usersCfg) == 0 && usersPath == "" {
return &[]model.LocalUser{}, nil return []config.User{}, nil
} }
if len(usersCfg) > 0 { if len(usersCfg) > 0 {
@@ -48,7 +48,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]mod
contents, err := ReadFile(usersPath) contents, err := ReadFile(usersPath)
if err != nil { if err != nil {
return nil, err return []config.User{}, err
} }
lines := strings.SplitSeq(contents, "\n") lines := strings.SplitSeq(contents, "\n")
@@ -65,7 +65,7 @@ func GetUsers(usersCfg []string, usersPath string, userAttributes map[string]mod
return ParseUsers(usersStr, userAttributes) return ParseUsers(usersStr, userAttributes)
} }
func ParseUser(userStr string) (*model.LocalUser, error) { func ParseUser(userStr string) (config.User, error) {
if strings.Contains(userStr, "$$") { if strings.Contains(userStr, "$$") {
userStr = strings.ReplaceAll(userStr, "$$", "$") userStr = strings.ReplaceAll(userStr, "$$", "$")
} }
@@ -73,27 +73,27 @@ func ParseUser(userStr string) (*model.LocalUser, error) {
parts := strings.SplitN(userStr, ":", 4) parts := strings.SplitN(userStr, ":", 4)
if len(parts) < 2 || len(parts) > 3 { if len(parts) < 2 || len(parts) > 3 {
return nil, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
for i, part := range parts { for i, part := range parts {
trimmed := strings.TrimSpace(part) trimmed := strings.TrimSpace(part)
if trimmed == "" { if trimmed == "" {
return nil, errors.New("invalid user format") return config.User{}, errors.New("invalid user format")
} }
parts[i] = trimmed parts[i] = trimmed
} }
user := model.LocalUser{ user := config.User{
Username: parts[0], Username: parts[0],
Password: parts[1], Password: parts[1],
} }
if len(parts) == 3 { if len(parts) == 3 {
user.TOTPSecret = parts[2] user.TotpSecret = parts[2]
} }
return &user, nil return user, nil
} }
func CompileUserEmail(username string, domain string) string { func CompileUserEmail(username string, domain string) string {