mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 14:28:12 +00:00
fix: coderabbit comments
This commit is contained in:
@@ -102,7 +102,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
|
|
||||||
app.runtime.OAuthWhitelist = oauthWhitelist
|
app.runtime.OAuthWhitelist = oauthWhitelist
|
||||||
|
|
||||||
// Setup oauth providers
|
// setup oauth providers
|
||||||
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
app.runtime.OAuthProviders = app.config.OAuth.Providers
|
||||||
|
|
||||||
for id, provider := range app.runtime.OAuthProviders {
|
for id, provider := range app.runtime.OAuthProviders {
|
||||||
@@ -168,6 +168,14 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return fmt.Errorf("failed to setup database: %w", err)
|
return fmt.Errorf("failed to setup database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// after this point, we start initializing dependencies so it's a good time to setup a defer
|
||||||
|
// to ensure that resources are cleaned up properly in case of an error during initialization
|
||||||
|
defer func() {
|
||||||
|
app.cancel()
|
||||||
|
app.wg.Wait()
|
||||||
|
app.db.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
// queries
|
// queries
|
||||||
queries := repository.New(app.db)
|
queries := repository.New(app.db)
|
||||||
app.queries = queries
|
app.queries = queries
|
||||||
@@ -279,9 +287,6 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-app.ctx.Done():
|
case <-app.ctx.Done():
|
||||||
app.wg.Wait()
|
|
||||||
app.log.App.Debug().Msg("Closing database")
|
|
||||||
app.db.Close()
|
|
||||||
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
app.log.App.Info().Msg("Oh, it's time for me to go, bye!")
|
||||||
return nil
|
return nil
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
@@ -305,7 +310,7 @@ func (app *BootstrapApp) serveHTTP() error {
|
|||||||
go func() {
|
go func() {
|
||||||
<-app.ctx.Done()
|
<-app.ctx.Done()
|
||||||
app.log.App.Debug().Msg("Shutting down http listener")
|
app.log.App.Debug().Msg("Shutting down http listener")
|
||||||
server.Close()
|
server.Shutdown(app.ctx)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err := server.ListenAndServe()
|
err := server.ListenAndServe()
|
||||||
@@ -345,21 +350,23 @@ func (app *BootstrapApp) serveUnix() error {
|
|||||||
Handler: app.router.Handler(),
|
Handler: app.router.Handler(),
|
||||||
}
|
}
|
||||||
|
|
||||||
defer server.Close()
|
shutdown := func() {
|
||||||
defer listener.Close()
|
server.Shutdown(app.ctx)
|
||||||
defer os.Remove(app.config.Server.SocketPath)
|
listener.Close()
|
||||||
|
os.Remove(app.config.Server.SocketPath)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer shutdown()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-app.ctx.Done()
|
<-app.ctx.Done()
|
||||||
app.log.App.Debug().Msg("Shutting down unix socket listener")
|
app.log.App.Debug().Msg("Shutting down unix socket listener")
|
||||||
server.Close()
|
shutdown()
|
||||||
listener.Close()
|
|
||||||
os.Remove(app.config.Server.SocketPath)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
err = server.Serve(listener)
|
err = server.Serve(listener)
|
||||||
|
|
||||||
if err != nil && !errors.Is(err, net.ErrClosed) && !errors.Is(err, http.ErrServerClosed) {
|
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||||
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
return fmt.Errorf("failed to start unix socket listener: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -27,6 +27,13 @@ func (app *BootstrapApp) SetupDatabase() error {
|
|||||||
return fmt.Errorf("failed to open database: %w", err)
|
return fmt.Errorf("failed to open database: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close the database if there is an error during migration
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
db.Close()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
// Limit to 1 connection to sequence writes, this may need to be revisited in the future
|
||||||
// if the sqlite connection starts being a bottleneck
|
// if the sqlite connection starts being a bottleneck
|
||||||
db.SetMaxOpenConns(1)
|
db.SetMaxOpenConns(1)
|
||||||
|
|||||||
@@ -43,7 +43,7 @@ func (app *BootstrapApp) setupRouter() error {
|
|||||||
|
|
||||||
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
controller.NewContextController(app.log, app.config, app.runtime, apiRouter)
|
||||||
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
controller.NewOAuthController(app.log, app.config, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewOIDCController(app.log, app.services.oidcService, apiRouter)
|
controller.NewOIDCController(app.log, app.services.oidcService, app.runtime, apiRouter)
|
||||||
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
controller.NewProxyController(app.log, app.runtime, apiRouter, app.services.accessControlService, app.services.authService)
|
||||||
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
controller.NewUserController(app.log, app.runtime, apiRouter, app.services.authService)
|
||||||
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
controller.NewResourcesController(app.config, &engine.RouterGroup)
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
@@ -19,7 +20,7 @@ func TestContextController(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
description string
|
description string
|
||||||
|
|||||||
@@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
state := c.Query("state")
|
state := c.Query("state")
|
||||||
if state != oauthPendingSession.State {
|
if state != oauthPendingSession.State {
|
||||||
controller.log.App.Warn().Msg("OAuth state mismatch")
|
controller.log.App.Warn().Msg("OAuth state mismatch")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if user == nil {
|
||||||
|
controller.log.App.Warn().Msg("OAuth provider did not return user info")
|
||||||
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if user.Email == "" {
|
if user.Email == "" {
|
||||||
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if svc.ID() != req.Provider {
|
if svc.ID() != req.Provider {
|
||||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
|||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
|
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
|
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
||||||
|
|||||||
@@ -17,8 +17,9 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
log *logger.Logger
|
log *logger.Logger
|
||||||
oidc *service.OIDCService
|
oidc *service.OIDCService
|
||||||
|
runtime model.RuntimeConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeCallback struct {
|
type AuthorizeCallback struct {
|
||||||
@@ -58,10 +59,12 @@ type ClientCredentials struct {
|
|||||||
func NewOIDCController(
|
func NewOIDCController(
|
||||||
log *logger.Logger,
|
log *logger.Logger,
|
||||||
oidcService *service.OIDCService,
|
oidcService *service.OIDCService,
|
||||||
|
runtimeConfig model.RuntimeConfig,
|
||||||
router *gin.RouterGroup) *OIDCController {
|
router *gin.RouterGroup) *OIDCController {
|
||||||
controller := &OIDCController{
|
controller := &OIDCController{
|
||||||
log: log,
|
log: log,
|
||||||
oidc: oidcService,
|
oidc: oidcService,
|
||||||
|
runtime: runtimeConfig,
|
||||||
}
|
}
|
||||||
|
|
||||||
oidcGroup := router.Group("/oidc")
|
oidcGroup := router.Group("/oidc")
|
||||||
@@ -75,6 +78,15 @@ func NewOIDCController(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||||
|
if controller.oidc == nil {
|
||||||
|
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
||||||
|
c.JSON(500, gin.H{
|
||||||
|
"status": 500,
|
||||||
|
"message": "OIDC not configured",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
var req ClientRequest
|
var req ClientRequest
|
||||||
|
|
||||||
err := c.BindUri(&req)
|
err := c.BindUri(&req)
|
||||||
@@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
|||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"error": "not_found",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
if controller.oidc == nil {
|
if controller.oidc == nil {
|
||||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(500, gin.H{
|
||||||
"error": "not_found",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
redirectUrl := ""
|
||||||
|
|
||||||
|
if controller.oidc != nil {
|
||||||
|
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
|
||||||
|
} else {
|
||||||
|
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||||
|
}
|
||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
"redirect_uri": redirectUrl,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -27,7 +28,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
simpleCtx := func(c *gin.Context) {
|
simpleCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
@@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) {
|
|||||||
group := router.Group("/api")
|
group := router.Group("/api")
|
||||||
gin.SetMode(gin.TestMode)
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
controller.NewOIDCController(log, oidcService, group)
|
controller.NewOIDCController(log, oidcService, runtime, group)
|
||||||
|
|
||||||
recorder := httptest.NewRecorder()
|
recorder := httptest.NewRecorder()
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -21,7 +22,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
acls := map[string]model.App{
|
acls := map[string]model.App{
|
||||||
"app_path_allow": {
|
"app_path_allow": {
|
||||||
|
|||||||
@@ -10,10 +10,11 @@ import (
|
|||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestResourcesController(t *testing.T) {
|
func TestResourcesController(t *testing.T) {
|
||||||
cfg, _ := createTestConfigs(t)
|
cfg, _ := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -26,7 +27,7 @@ func TestUserController(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
totpCtx := func(c *gin.Context) {
|
totpCtx := func(c *gin.Context) {
|
||||||
c.Set("context", &model.UserContext{
|
c.Set("context", &model.UserContext{
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
type testCase struct {
|
type testCase struct {
|
||||||
description string
|
description string
|
||||||
@@ -99,6 +100,7 @@ func TestWellKnownController(t *testing.T) {
|
|||||||
queries := repository.New(app.GetDB())
|
queries := repository.New(app.GetDB())
|
||||||
|
|
||||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
t.Run(test.description, func(t *testing.T) {
|
t.Run(test.description, func(t *testing.T) {
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ import (
|
|||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||||
|
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -24,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
log := logger.NewLogger().WithTestConfig()
|
log := logger.NewLogger().WithTestConfig()
|
||||||
log.Init()
|
log.Init()
|
||||||
|
|
||||||
cfg, runtime := createTestConfigs(t)
|
cfg, runtime := test.CreateTestConfigs(t)
|
||||||
|
|
||||||
basicAuthHeader := func(username, password string) string {
|
basicAuthHeader := func(username, password string) string {
|
||||||
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
return "Basic " + base64.StdEncoding.EncodeToString([]byte(username+":"+password))
|
||||||
|
|||||||
@@ -1,108 +0,0 @@
|
|||||||
package middleware_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"path"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
|
||||||
"golang.org/x/crypto/bcrypt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Note: This code is duplicated from controller_test.go
|
|
||||||
|
|
||||||
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
|
||||||
|
|
||||||
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|
||||||
tempDir := t.TempDir()
|
|
||||||
|
|
||||||
config := model.Config{
|
|
||||||
UI: model.UIConfig{
|
|
||||||
Title: "Tinyauth Test",
|
|
||||||
ForgotPasswordMessage: "foo",
|
|
||||||
BackgroundImage: "/background.jpg",
|
|
||||||
WarningsEnabled: true,
|
|
||||||
},
|
|
||||||
OAuth: model.OAuthConfig{
|
|
||||||
AutoRedirect: "none",
|
|
||||||
},
|
|
||||||
OIDC: model.OIDCConfig{
|
|
||||||
Clients: map[string]model.OIDCClientConfig{
|
|
||||||
"test": {
|
|
||||||
ClientID: "some-client-id",
|
|
||||||
ClientSecret: "some-client-secret",
|
|
||||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
|
||||||
Name: "Test Client",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
|
||||||
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
|
||||||
},
|
|
||||||
Auth: model.AuthConfig{
|
|
||||||
SessionExpiry: 10,
|
|
||||||
LoginTimeout: 10,
|
|
||||||
LoginMaxRetries: 3,
|
|
||||||
},
|
|
||||||
Database: model.DatabaseConfig{
|
|
||||||
Path: path.Join(tempDir, "test.db"),
|
|
||||||
},
|
|
||||||
Resources: model.ResourcesConfig{
|
|
||||||
Enabled: true,
|
|
||||||
Path: path.Join(tempDir, "resources"),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
runtime := model.RuntimeConfig{
|
|
||||||
ConfiguredProviders: []model.Provider{
|
|
||||||
{
|
|
||||||
Name: "Local",
|
|
||||||
ID: "local",
|
|
||||||
OAuth: false,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LocalUsers: []model.LocalUser{
|
|
||||||
{
|
|
||||||
Username: "testuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "totpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: testingTOTPSecret,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attruser",
|
|
||||||
Password: string(passwd),
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Alice Smith",
|
|
||||||
Email: "alice@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Username: "attrtotpuser",
|
|
||||||
Password: string(passwd),
|
|
||||||
TOTPSecret: testingTOTPSecret,
|
|
||||||
Attributes: model.UserAttributes{
|
|
||||||
Name: "Bob Jones",
|
|
||||||
Email: "bob@example.com",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CookieDomain: "example.com",
|
|
||||||
AppURL: "https://tinyauth.example.com",
|
|
||||||
SessionCookieName: "tinyauth-session",
|
|
||||||
OIDCClients: func() []model.OIDCClientConfig {
|
|
||||||
var clients []model.OIDCClientConfig
|
|
||||||
for id, client := range config.OIDC.Clients {
|
|
||||||
client.ID = id
|
|
||||||
clients = append(clients, client)
|
|
||||||
}
|
|
||||||
return clients
|
|
||||||
}(),
|
|
||||||
}
|
|
||||||
|
|
||||||
return config, runtime
|
|
||||||
}
|
|
||||||
@@ -790,10 +790,8 @@ func (service *OIDCService) cleanupRoutine() {
|
|||||||
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
token, err := service.queries.GetOidcTokenBySub(service.context, expiredCode.Sub)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
service.log.App.Warn().Err(err).Msg("Failed to get token by sub for expired code")
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
if token.TokenExpiresAt < currentTime && token.RefreshTokenExpiresAt < currentTime {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
package controller_test
|
package test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"path"
|
"path"
|
||||||
@@ -9,9 +9,9 @@ import (
|
|||||||
"golang.org/x/crypto/bcrypt"
|
"golang.org/x/crypto/bcrypt"
|
||||||
)
|
)
|
||||||
|
|
||||||
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
var TestingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
||||||
|
|
||||||
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
func CreateTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
config := model.Config{
|
config := model.Config{
|
||||||
@@ -69,7 +69,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
{
|
{
|
||||||
Username: "totpuser",
|
Username: "totpuser",
|
||||||
Password: string(passwd),
|
Password: string(passwd),
|
||||||
TOTPSecret: testingTOTPSecret,
|
TOTPSecret: TestingTOTPSecret,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
Username: "attruser",
|
Username: "attruser",
|
||||||
@@ -82,7 +82,7 @@ func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
|||||||
{
|
{
|
||||||
Username: "attrtotpuser",
|
Username: "attrtotpuser",
|
||||||
Password: string(passwd),
|
Password: string(passwd),
|
||||||
TOTPSecret: testingTOTPSecret,
|
TOTPSecret: TestingTOTPSecret,
|
||||||
Attributes: model.UserAttributes{
|
Attributes: model.UserAttributes{
|
||||||
Name: "Bob Jones",
|
Name: "Bob Jones",
|
||||||
Email: "bob@example.com",
|
Email: "bob@example.com",
|
||||||
@@ -33,7 +33,7 @@ func NewLogger() *Logger {
|
|||||||
App: model.LogStreamConfig{
|
App: model.LogStreamConfig{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
},
|
},
|
||||||
// No reason to enabled audit by default since it will be suppressed by the log level
|
// No reason to enable audit by default since it will be suppressed by the log level
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,7 +162,7 @@ func TestLogger(t *testing.T) {
|
|||||||
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
|
l.AuditLoginFailure("test_nop", "test_nop", "test_nop", "test_nop")
|
||||||
|
|
||||||
assert.NotEmpty(t, buf.String())
|
assert.NotEmpty(t, buf.String())
|
||||||
assert.NotContains(t, "test_nop", buf.String())
|
assert.NotContains(t, buf.String(), "test_nop")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user