mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 06:18:11 +00:00
fix: coderabbit comments
This commit is contained in:
@@ -102,7 +102,7 @@ func (app *BootstrapApp) Setup() error {
|
||||
|
||||
app.runtime.OAuthWhitelist = oauthWhitelist
|
||||
|
||||
// Setup oauth providers
|
||||
// setup oauth providers
|
||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||
|
||||
for id, provider := range app.runtime.OAuthProviders {
|
||||
@@ -168,6 +168,14 @@ func (app *BootstrapApp) Setup() error {
|
||||
return fmt.Errorf("failed to setup database: %w", err)
|
||||
}
|
||||
|
||||
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
||||
// to ensure that resources are cleaned up properly in case of an error during initialization
|
||||
defer func() {
|
||||
app.cancel()
|
||||
app.wg.Wait()
|
||||
app.db.Close()
|
||||
}()
|
||||
|
||||
// queries
|
||||
queries := repository.New(app.db)
|
||||
app.queries = queries
|
||||
@@ -279,9 +287,6 @@ func (app *BootstrapApp) Setup() error {
|
||||
for {
|
||||
select {
|
||||
case <-app.ctx.Done():
|
||||
app.wg.Wait()
|
||||
app.log.App.Debug().Msg("Closing database")
|
||||
app.db.Close()
|
||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||
return nil
|
||||
case err := <-errChan:
|
||||
@@ -305,7 +310,7 @@ func (app *BootstrapApp) serveHTTP() error {
|
||||
go func() {
|
||||
<-app.ctx.Done()
|
||||
app.log.App.Debug().Msg("Shutting down http listener")
|
||||
server.Close()
|
||||
server.Shutdown(app.ctx)
|
||||
}()
|
||||
|
||||
err := server.ListenAndServe()
|
||||
@@ -345,21 +350,23 @@ func (app *BootstrapApp) serveUnix() error {
|
||||
Handler: app.router.Handler(),
|
||||
}
|
||||
|
||||
defer server.Close()
|
||||
defer listener.Close()
|
||||
defer os.Remove(app.config.Server.SocketPath)
|
||||
shutdown := func() {
|
||||
server.Shutdown(app.ctx)
|
||||
listener.Close()
|
||||
os.Remove(app.config.Server.SocketPath)
|
||||
}
|
||||
|
||||
defer shutdown()
|
||||
|
||||
go func() {
|
||||
<-app.ctx.Done()
|
||||
app.log.App.Debug().Msg("Shutting down unix socket listener")
|
||||
server.Close()
|
||||
listener.Close()
|
||||
os.Remove(app.config.Server.SocketPath)
|
||||
shutdown()
|
||||
}()
|
||||
|
||||
err = server.Serve(listener)
|
||||
|
||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
||||
}
|
||||
|
||||
|
||||
@@ -27,6 +27,13 @@ func (app *BootstrapApp) SetupDatabase() error {
|
||||
return fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
// Close the database if there is an error during migration
|
||||
defer func() {
|
||||
if err != nil {
|
||||
db.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
||||
// if the sqlite connection starts being a bottleneck
|
||||
db.SetMaxOpenConns(1)
|
||||
|
||||
@@ -43,7 +43,7 @@ func (app *BootstrapApp) setupRouter() error {
|
||||
|
||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||
controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
|
||||
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||
|
||||
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -19,7 +20,7 @@ func TestContextController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
|
||||
@@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if state != oauthPendingSession.State {
|
||||
controller.log.App.Warn().Msg("OAuth state mismatch")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
controller.log.App.Warn().Msg("OAuth provider did not return user info")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if user.Email == "" {
|
||||
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if svc.ID() != req.Provider {
|
||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
||||
}
|
||||
|
||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
)
|
||||
|
||||
type OIDCController struct {
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
runtime model.RuntimeConfig
|
||||
}
|
||||
|
||||
type AuthorizeCallback struct {
|
||||
@@ -58,10 +59,12 @@ type ClientCredentials struct {
|
||||
func NewOIDCController(
|
||||
log *logger.Logger,
|
||||
oidcService *service.OIDCService,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
router *gin.RouterGroup) *OIDCController {
|
||||
controller := &OIDCController{
|
||||
log: log,
|
||||
oidc: oidcService,
|
||||
log: log,
|
||||
oidc: oidcService,
|
||||
runtime: runtimeConfig,
|
||||
}
|
||||
|
||||
oidcGroup := router.Group("/oidc")
|
||||
@@ -75,6 +78,15 @@ func NewOIDCController(
|
||||
}
|
||||
|
||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "OIDC not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ClientRequest
|
||||
|
||||
err := c.BindUri(&req)
|
||||
@@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
func (controller *OIDCController) Token(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
||||
c.JSON(404, gin.H{
|
||||
"error": "not_found",
|
||||
c.JSON(500, gin.H{
|
||||
"error": "server_error",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
||||
c.JSON(404, gin.H{
|
||||
"error": "not_found",
|
||||
c.JSON(500, gin.H{
|
||||
"error": "server_error",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
||||
return
|
||||
}
|
||||
|
||||
redirectUrl := ""
|
||||
|
||||
if controller.oidc != nil {
|
||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
|
||||
} else {
|
||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
||||
"redirect_uri": redirectUrl,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -27,7 +28,7 @@ func TestOIDCController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
@@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) {
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
controller.NewOIDCController(log, oidcService, group)
|
||||
controller.NewOIDCController(log, oidcService, runtime, group)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ func TestProxyController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
acls := map[string]model.App{
|
||||
"app_path_allow": {
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
)
|
||||
|
||||
func TestResourcesController(t *testing.T) {
|
||||
cfg, _ := createTestConfigs(t)
|
||||
cfg, _ := test.CreateTestConfigs(t)
|
||||
|
||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ func TestUserController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
@@ -99,6 +100,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
queries := repository.New(app.GetDB())
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
basicAuthHeader := func(username, password string) string {
|
||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
package middleware_test
|
||||
|
||||
import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Note: This code is duplicated from controller_test.go
|
||||
|
||||
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
||||
|
||||
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := model.Config{
|
||||
UI: model.UIConfig{
|
||||
Title: "Tinyauth Test",
|
||||
ForgotPasswordMessage: "foo",
|
||||
BackgroundImage: "/background.jpg",
|
||||
WarningsEnabled: true,
|
||||
},
|
||||
OAuth: model.OAuthConfig{
|
||||
AutoRedirect: "none",
|
||||
},
|
||||
OIDC: model.OIDCConfig{
|
||||
Clients: map[string]model.OIDCClientConfig{
|
||||
"test": {
|
||||
ClientID: "some-client-id",
|
||||
ClientSecret: "some-client-secret",
|
||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||
Name: "Test Client",
|
||||
},
|
||||
},
|
||||
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||
},
|
||||
Auth: model.AuthConfig{
|
||||
SessionExpiry: 10,
|
||||
LoginTimeout: 10,
|
||||
LoginMaxRetries: 3,
|
||||
},
|
||||
Database: model.DatabaseConfig{
|
||||
Path: path.Join(tempDir, "test.db"),
|
||||
},
|
||||
Resources: model.ResourcesConfig{
|
||||
Enabled: true,
|
||||
Path: path.Join(tempDir, "resources"),
|
||||
},
|
||||
}
|
||||
|
||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
runtime := model.RuntimeConfig{
|
||||
ConfiguredProviders: []model.Provider{
|
||||
{
|
||||
Name: "Local",
|
||||
ID: "local",
|
||||
OAuth: false,
|
||||
},
|
||||
},
|
||||
LocalUsers: []model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: string(passwd),
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
},
|
||||
{
|
||||
Username: "attruser",
|
||||
Password: string(passwd),
|
||||
Attributes: model.UserAttributes{
|
||||
Name: "Alice Smith",
|
||||
Email: "alice@example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
Username: "attrtotpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
Attributes: model.UserAttributes{
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
CookieDomain: "example.com",
|
||||
AppURL: "https://tinyauth.example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
OIDCClients: func() []model.OIDCClientConfig {
|
||||
var clients []model.OIDCClientConfig
|
||||
for id, client := range config.OIDC.Clients {
|
||||
client.ID = id
|
||||
clients = append(clients, client)
|
||||
}
|
||||
return clients
|
||||
}(),
|
||||
}
|
||||
|
||||
return config, runtime
|
||||
}
|
||||
@@ -790,10 +790,8 @@ func (service *OIDCService) cleanupRoutine() {
|
||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
continue
|
||||
}
|
||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||
continue
|
||||
}
|
||||
|
||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
package controller_test
|
||||
package test
|
||||
|
||||
import (
|
||||
"path"
|
||||
@@ -9,9 +9,9 @@ import (
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
||||
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
||||
|
||||
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := model.Config{
|
||||
@@ -69,7 +69,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
TOTPSecret: TestingTOTPSecret,
|
||||
},
|
||||
{
|
||||
Username: "attruser",
|
||||
@@ -82,7 +82,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
{
|
||||
Username: "attrtotpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
TOTPSecret: TestingTOTPSecret,
|
||||
Attributes: model.UserAttributes{
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
@@ -33,7 +33,7 @@ func NewLogger() *Logger {
|
||||
App: model.LogStreamConfig{
|
||||
Enabled: true,
|
||||
},
|
||||
// No reason to enabled audit by default since it will be suppressed by the log level
|
||||
// No reason to enable audit by default since it will be suppressed by the log level
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -162,7 +162,7 @@ func TestLogger(t *testing.T) {
|
||||
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
|
||||
|
||||
assert.NotEmpty(t, buf.String())
|
||||
assert.NotContains(t, "test_nop", buf.String())
|
||||
assert.NotContains(t, buf.String(), "test_nop")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user