fix: coderabbit comments

This commit is contained in:
Stavros
2026-05-09 17:00:02 +03:00
parent 548d97fa62
commit d5009070e3
17 changed files with 107 additions and 163 deletions
+19 -12
View File
@@ -102,7 +102,7 @@ func (app *BootstrapApp) Setup() error {
app.runtime.OAuthWhitelist = oauthWhitelist app.runtime.OAuthWhitelist = oauthWhitelist
// Setup oauth providers // setup oauth providers
app.runtime.OAuthProviders = app.config.OAuth.Providers app.runtime.OAuthProviders = app.config.OAuth.Providers
for id, provider := range app.runtime.OAuthProviders { for id, provider := range app.runtime.OAuthProviders {
@@ -168,6 +168,14 @@ func (app *BootstrapApp) Setup() error {
return fmt.Errorf("failed to setup database: %w", err) return fmt.Errorf("failed to setup database: %w", err)
} }
// after this point, we start initializing dependencies so it's a good time to setup a defer
// to ensure that resources are cleaned up properly in case of an error during initialization
defer func() {
app.cancel()
app.wg.Wait()
app.db.Close()
}()
// queries // queries
queries := repository.New(app.db) queries := repository.New(app.db)
app.queries = queries app.queries = queries
@@ -279,9 +287,6 @@ func (app *BootstrapApp) Setup() error {
for { for {
select { select {
case <-app.ctx.Done(): case <-app.ctx.Done():
app.wg.Wait()
app.log.App.Debug().Msg("Closing database")
app.db.Close()
app.log.App.Info().Msg("Oh, it's time for me to go, bye!") app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
return nil return nil
case err := <-errChan: case err := <-errChan:
@@ -305,7 +310,7 @@ func (app *BootstrapApp) serveHTTP() error {
go func() { go func() {
<-app.ctx.Done() <-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down http listener") app.log.App.Debug().Msg("Shutting down http listener")
server.Close() server.Shutdown(app.ctx)
}() }()
err := server.ListenAndServe() err := server.ListenAndServe()
@@ -345,21 +350,23 @@ func (app *BootstrapApp) serveUnix() error {
Handler: app.router.Handler(), Handler: app.router.Handler(),
} }
defer server.Close() shutdown := func() {
defer listener.Close() server.Shutdown(app.ctx)
defer os.Remove(app.config.Server.SocketPath) listener.Close()
os.Remove(app.config.Server.SocketPath)
}
defer shutdown()
go func() { go func() {
<-app.ctx.Done() <-app.ctx.Done()
app.log.App.Debug().Msg("Shutting down unix socket listener") app.log.App.Debug().Msg("Shutting down unix socket listener")
server.Close() shutdown()
listener.Close()
os.Remove(app.config.Server.SocketPath)
}() }()
err = server.Serve(listener) err = server.Serve(listener)
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) { if err != nil && !errors.Is(err, http.ErrServerClosed) {
return fmt.Errorf("failed to start unix socket listener: %w", err) return fmt.Errorf("failed to start unix socket listener: %w", err)
} }
+7
View File
@@ -27,6 +27,13 @@ func (app *BootstrapApp) SetupDatabase() error {
return fmt.Errorf("failed to open database: %w", err) return fmt.Errorf("failed to open database: %w", err)
} }
// Close the database if there is an error during migration
defer func() {
if err != nil {
db.Close()
}
}()
// Limit to 1 connection to sequence writes, this may need to be revisited in the future // Limit to 1 connection to sequence writes, this may need to be revisited in the future
// if the sqlite connection starts being a bottleneck // if the sqlite connection starts being a bottleneck
db.SetMaxOpenConns(1) db.SetMaxOpenConns(1)
+1 -1
View File
@@ -43,7 +43,7 @@ func (app *BootstrapApp) setupRouter() error {
controller.NewContextController(app.log, app.config, app.runtime, apiRouter) controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService) controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
controller.NewOIDCController(app.log, app.services.oidcService, apiRouter) controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService) controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService) controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
controller.NewResourcesController(app.config, &engine.RouterGroup) controller.NewResourcesController(app.config, &engine.RouterGroup)
@@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -19,7 +20,7 @@ func TestContextController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
tests := []struct { tests := []struct {
description string description string
+27 -15
View File
@@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie") controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session") controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
state := c.Query("state") state := c.Query("state")
if state != oauthPendingSession.State { if state != oauthPendingSession.State {
controller.log.App.Warn().Msg("OAuth state mismatch") controller.log.App.Warn().Msg("OAuth state mismatch")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token") controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie) user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if user == nil {
controller.log.App.Warn().Msg("OAuth provider did not return user info")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return
}
if user.Email == "" { if user.Email == "" {
controller.log.App.Warn().Msg("OAuth provider did not return an email") controller.log.App.Warn().Msg("OAuth provider did not return an email")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query") controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
@@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session") controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
if svc.ID() != req.Provider { if svc.ID() != req.Provider {
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID()) controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to create session cookie") controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
@@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
queries, err := query.Values(oauthPendingSession.CallbackParams) queries, err := query.Values(oauthPendingSession.CallbackParams)
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query") controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
@@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
if err != nil { if err != nil {
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query") controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL)) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode())) c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
return return
} }
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL) c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
} }
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool { func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
+29 -9
View File
@@ -17,8 +17,9 @@ import (
) )
type OIDCController struct { type OIDCController struct {
log *logger.Logger log *logger.Logger
oidc *service.OIDCService oidc *service.OIDCService
runtime model.RuntimeConfig
} }
type AuthorizeCallback struct { type AuthorizeCallback struct {
@@ -58,10 +59,12 @@ type ClientCredentials struct {
func NewOIDCController( func NewOIDCController(
log *logger.Logger, log *logger.Logger,
oidcService *service.OIDCService, oidcService *service.OIDCService,
runtimeConfig model.RuntimeConfig,
router *gin.RouterGroup) *OIDCController { router *gin.RouterGroup) *OIDCController {
controller := &OIDCController{ controller := &OIDCController{
log: log, log: log,
oidc: oidcService, oidc: oidcService,
runtime: runtimeConfig,
} }
oidcGroup := router.Group("/oidc") oidcGroup := router.Group("/oidc")
@@ -75,6 +78,15 @@ func NewOIDCController(
} }
func (controller *OIDCController) GetClientInfo(c *gin.Context) { func (controller *OIDCController) GetClientInfo(c *gin.Context) {
if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
c.JSON(500, gin.H{
"status": 500,
"message": "OIDC not configured",
})
return
}
var req ClientRequest var req ClientRequest
err := c.BindUri(&req) err := c.BindUri(&req)
@@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
func (controller *OIDCController) Token(c *gin.Context) { func (controller *OIDCController) Token(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured") controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(500, gin.H{
"error": "not_found", "error": "server_error",
}) })
return return
} }
@@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
func (controller *OIDCController) Userinfo(c *gin.Context) { func (controller *OIDCController) Userinfo(c *gin.Context) {
if controller.oidc == nil { if controller.oidc == nil {
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured") controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
c.JSON(404, gin.H{ c.JSON(500, gin.H{
"error": "not_found", "error": "server_error",
}) })
return return
} }
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
return return
} }
redirectUrl := ""
if controller.oidc != nil {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
} else {
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
}
c.JSON(200, gin.H{ c.JSON(200, gin.H{
"status": 200, "status": 200,
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()), "redirect_uri": redirectUrl,
}) })
} }
+3 -2
View File
@@ -20,6 +20,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -27,7 +28,7 @@ func TestOIDCController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
simpleCtx := func(c *gin.Context) { simpleCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
@@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) {
group := router.Group("/api") group := router.Group("/api")
gin.SetMode(gin.TestMode) gin.SetMode(gin.TestMode)
controller.NewOIDCController(log, oidcService, group) controller.NewOIDCController(log, oidcService, runtime, group)
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
+2 -1
View File
@@ -14,6 +14,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -21,7 +22,7 @@ func TestProxyController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
acls := map[string]model.App{ acls := map[string]model.App{
"app_path_allow": { "app_path_allow": {
@@ -10,10 +10,11 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/test"
) )
func TestResourcesController(t *testing.T) { func TestResourcesController(t *testing.T) {
cfg, _ := createTestConfigs(t) cfg, _ := test.CreateTestConfigs(t)
err := os.MkdirAll(cfg.Resources.Path, 0777) err := os.MkdirAll(cfg.Resources.Path, 0777)
require.NoError(t, err) require.NoError(t, err)
+2 -1
View File
@@ -19,6 +19,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -26,7 +27,7 @@ func TestUserController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
totpCtx := func(c *gin.Context) { totpCtx := func(c *gin.Context) {
c.Set("context", &model.UserContext{ c.Set("context", &model.UserContext{
@@ -15,6 +15,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
type testCase struct { type testCase struct {
description string description string
@@ -99,6 +100,7 @@ func TestWellKnownController(t *testing.T) {
queries := repository.New(app.GetDB()) queries := repository.New(app.GetDB())
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg) oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
require.NoError(t, err)
for _, test := range tests { for _, test := range tests {
t.Run(test.description, func(t *testing.T) { t.Run(test.description, func(t *testing.T) {
@@ -17,6 +17,7 @@ import (
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/repository" "github.com/tinyauthapp/tinyauth/internal/repository"
"github.com/tinyauthapp/tinyauth/internal/service" "github.com/tinyauthapp/tinyauth/internal/service"
"github.com/tinyauthapp/tinyauth/internal/test"
"github.com/tinyauthapp/tinyauth/internal/utils/logger" "github.com/tinyauthapp/tinyauth/internal/utils/logger"
) )
@@ -24,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
log := logger.NewLogger().WithTestConfig() log := logger.NewLogger().WithTestConfig()
log.Init() log.Init()
cfg, runtime := createTestConfigs(t) cfg, runtime := test.CreateTestConfigs(t)
basicAuthHeader := func(username, password string) string { basicAuthHeader := func(username, password string) string {
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password)) return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
-108
View File
@@ -1,108 +0,0 @@
package middleware_test
import (
"path"
"testing"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/model"
"golang.org/x/crypto/bcrypt"
)
// Note: This code is duplicated from controller_test.go
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir()
config := model.Config{
UI: model.UIConfig{
Title: "Tinyauth Test",
ForgotPasswordMessage: "foo",
BackgroundImage: "/background.jpg",
WarningsEnabled: true,
},
OAuth: model.OAuthConfig{
AutoRedirect: "none",
},
OIDC: model.OIDCConfig{
Clients: map[string]model.OIDCClientConfig{
"test": {
ClientID: "some-client-id",
ClientSecret: "some-client-secret",
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
Name: "Test Client",
},
},
PrivateKeyPath: path.Join(tempDir, "key.pem"),
PublicKeyPath: path.Join(tempDir, "key.pub"),
},
Auth: model.AuthConfig{
SessionExpiry: 10,
LoginTimeout: 10,
LoginMaxRetries: 3,
},
Database: model.DatabaseConfig{
Path: path.Join(tempDir, "test.db"),
},
Resources: model.ResourcesConfig{
Enabled: true,
Path: path.Join(tempDir, "resources"),
},
}
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
require.NoError(t, err)
runtime := model.RuntimeConfig{
ConfiguredProviders: []model.Provider{
{
Name: "Local",
ID: "local",
OAuth: false,
},
},
LocalUsers: []model.LocalUser{
{
Username: "testuser",
Password: string(passwd),
},
{
Username: "totpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
},
{
Username: "attruser",
Password: string(passwd),
Attributes: model.UserAttributes{
Name: "Alice Smith",
Email: "alice@example.com",
},
},
{
Username: "attrtotpuser",
Password: string(passwd),
TOTPSecret: testingTOTPSecret,
Attributes: model.UserAttributes{
Name: "Bob Jones",
Email: "bob@example.com",
},
},
},
CookieDomain: "example.com",
AppURL: "https://tinyauth.example.com",
SessionCookieName: "tinyauth-session",
OIDCClients: func() []model.OIDCClientConfig {
var clients []model.OIDCClientConfig
for id, client := range config.OIDC.Clients {
client.ID = id
clients = append(clients, client)
}
return clients
}(),
}
return config, runtime
}
+1 -3
View File
@@ -790,10 +790,8 @@ func (service *OIDCService) cleanupRoutine() {
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub) token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) {
continue
}
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code") service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
continue
} }
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime { if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
@@ -1,4 +1,4 @@
package controller_test package test
import ( import (
"path" "path"
@@ -9,9 +9,9 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK" var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) { func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
tempDir := t.TempDir() tempDir := t.TempDir()
config := model.Config{ config := model.Config{
@@ -69,7 +69,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
{ {
Username: "totpuser", Username: "totpuser",
Password: string(passwd), Password: string(passwd),
TOTPSecret: testingTOTPSecret, TOTPSecret: TestingTOTPSecret,
}, },
{ {
Username: "attruser", Username: "attruser",
@@ -82,7 +82,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
{ {
Username: "attrtotpuser", Username: "attrtotpuser",
Password: string(passwd), Password: string(passwd),
TOTPSecret: testingTOTPSecret, TOTPSecret: TestingTOTPSecret,
Attributes: model.UserAttributes{ Attributes: model.UserAttributes{
Name: "Bob Jones", Name: "Bob Jones",
Email: "bob@example.com", Email: "bob@example.com",
+1 -1
View File
@@ -33,7 +33,7 @@ func NewLogger() *Logger {
App: model.LogStreamConfig{ App: model.LogStreamConfig{
Enabled: true, Enabled: true,
}, },
// No reason to enabled audit by default since it will be suppressed by the log level // No reason to enable audit by default since it will be suppressed by the log level
}, },
}, },
} }
+1 -1
View File
@@ -162,7 +162,7 @@ func TestLogger(t *testing.T) {
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop") l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
assert.NotEmpty(t, buf.String()) assert.NotEmpty(t, buf.String())
assert.NotContains(t, "test_nop", buf.String()) assert.NotContains(t, buf.String(), "test_nop")
}, },
}, },
} }