mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-10 22:38:10 +00:00
fix: coderabbit comments
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
@@ -19,7 +20,7 @@ func TestContextController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
tests := []struct {
|
||||
description string
|
||||
|
||||
@@ -1,106 +0,0 @@
|
||||
package controller_test
|
||||
|
||||
import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
var testingTOTPSecret = "JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK"
|
||||
|
||||
func createTestConfigs(t *testing.T) (model.Config, model.RuntimeConfig) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
config := model.Config{
|
||||
UI: model.UIConfig{
|
||||
Title: "Tinyauth Test",
|
||||
ForgotPasswordMessage: "foo",
|
||||
BackgroundImage: "/background.jpg",
|
||||
WarningsEnabled: true,
|
||||
},
|
||||
OAuth: model.OAuthConfig{
|
||||
AutoRedirect: "none",
|
||||
},
|
||||
OIDC: model.OIDCConfig{
|
||||
Clients: map[string]model.OIDCClientConfig{
|
||||
"test": {
|
||||
ClientID: "some-client-id",
|
||||
ClientSecret: "some-client-secret",
|
||||
TrustedRedirectURIs: []string{"https://test.example.com/callback"},
|
||||
Name: "Test Client",
|
||||
},
|
||||
},
|
||||
PrivateKeyPath: path.Join(tempDir, "key.pem"),
|
||||
PublicKeyPath: path.Join(tempDir, "key.pub"),
|
||||
},
|
||||
Auth: model.AuthConfig{
|
||||
SessionExpiry: 10,
|
||||
LoginTimeout: 10,
|
||||
LoginMaxRetries: 3,
|
||||
},
|
||||
Database: model.DatabaseConfig{
|
||||
Path: path.Join(tempDir, "test.db"),
|
||||
},
|
||||
Resources: model.ResourcesConfig{
|
||||
Enabled: true,
|
||||
Path: path.Join(tempDir, "resources"),
|
||||
},
|
||||
}
|
||||
|
||||
passwd, err := bcrypt.GenerateFromPassword([]byte("password"), bcrypt.DefaultCost)
|
||||
require.NoError(t, err)
|
||||
|
||||
runtime := model.RuntimeConfig{
|
||||
ConfiguredProviders: []model.Provider{
|
||||
{
|
||||
Name: "Local",
|
||||
ID: "local",
|
||||
OAuth: false,
|
||||
},
|
||||
},
|
||||
LocalUsers: []model.LocalUser{
|
||||
{
|
||||
Username: "testuser",
|
||||
Password: string(passwd),
|
||||
},
|
||||
{
|
||||
Username: "totpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
},
|
||||
{
|
||||
Username: "attruser",
|
||||
Password: string(passwd),
|
||||
Attributes: model.UserAttributes{
|
||||
Name: "Alice Smith",
|
||||
Email: "alice@example.com",
|
||||
},
|
||||
},
|
||||
{
|
||||
Username: "attrtotpuser",
|
||||
Password: string(passwd),
|
||||
TOTPSecret: testingTOTPSecret,
|
||||
Attributes: model.UserAttributes{
|
||||
Name: "Bob Jones",
|
||||
Email: "bob@example.com",
|
||||
},
|
||||
},
|
||||
},
|
||||
CookieDomain: "example.com",
|
||||
AppURL: "https://tinyauth.example.com",
|
||||
SessionCookieName: "tinyauth-session",
|
||||
OIDCClients: func() []model.OIDCClientConfig {
|
||||
var clients []model.OIDCClientConfig
|
||||
for id, client := range config.OIDC.Clients {
|
||||
client.ID = id
|
||||
clients = append(clients, client)
|
||||
}
|
||||
return clients
|
||||
}(),
|
||||
}
|
||||
|
||||
return config, runtime
|
||||
}
|
||||
@@ -131,7 +131,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth session cookie")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -141,7 +141,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get pending OAuth session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -150,7 +150,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if state != oauthPendingSession.State {
|
||||
controller.log.App.Warn().Msg("OAuth state mismatch")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,15 +159,27 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to exchange code for token")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
user, err := controller.auth.GetOAuthUserinfo(sessionIdCookie)
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get user info from OAuth provider")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
controller.log.App.Warn().Msg("OAuth provider did not return user info")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if user.Email == "" {
|
||||
controller.log.App.Warn().Msg("OAuth provider did not return an email")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -181,11 +193,11 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode unauthorized query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/unauthorized?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -213,13 +225,13 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to get OAuth service for session")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
if svc.ID() != req.Provider {
|
||||
controller.log.App.Warn().Msgf("OAuth provider mismatch: expected %s, got %s", req.Provider, svc.ID())
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -239,7 +251,7 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to create session cookie")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -252,10 +264,10 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
queries, err := query.Values(oauthPendingSession.CallbackParams)
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode OIDC callback query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/authorize?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -266,15 +278,15 @@ func (controller *OAuthController) oauthCallbackHandler(c *gin.Context) {
|
||||
|
||||
if err != nil {
|
||||
controller.log.App.Error().Err(err).Msg("Failed to encode redirect query")
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.config.AppURL))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/error", controller.runtime.AppURL))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.config.AppURL, queries.Encode()))
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s/continue?%s", controller.runtime.AppURL, queries.Encode()))
|
||||
return
|
||||
}
|
||||
|
||||
c.Redirect(http.StatusTemporaryRedirect, controller.config.AppURL)
|
||||
c.Redirect(http.StatusTemporaryRedirect, controller.runtime.AppURL)
|
||||
}
|
||||
|
||||
func (controller *OAuthController) isOidcRequest(params service.OAuthURLParams) bool {
|
||||
|
||||
@@ -17,8 +17,9 @@ import (
|
||||
)
|
||||
|
||||
type OIDCController struct {
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
log *logger.Logger
|
||||
oidc *service.OIDCService
|
||||
runtime model.RuntimeConfig
|
||||
}
|
||||
|
||||
type AuthorizeCallback struct {
|
||||
@@ -58,10 +59,12 @@ type ClientCredentials struct {
|
||||
func NewOIDCController(
|
||||
log *logger.Logger,
|
||||
oidcService *service.OIDCService,
|
||||
runtimeConfig model.RuntimeConfig,
|
||||
router *gin.RouterGroup) *OIDCController {
|
||||
controller := &OIDCController{
|
||||
log: log,
|
||||
oidc: oidcService,
|
||||
log: log,
|
||||
oidc: oidcService,
|
||||
runtime: runtimeConfig,
|
||||
}
|
||||
|
||||
oidcGroup := router.Group("/oidc")
|
||||
@@ -75,6 +78,15 @@ func NewOIDCController(
|
||||
}
|
||||
|
||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC client info request but OIDC server is not configured")
|
||||
c.JSON(500, gin.H{
|
||||
"status": 500,
|
||||
"message": "OIDC not configured",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
var req ClientRequest
|
||||
|
||||
err := c.BindUri(&req)
|
||||
@@ -198,8 +210,8 @@ func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||
func (controller *OIDCController) Token(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC request but OIDC server is not configured")
|
||||
c.JSON(404, gin.H{
|
||||
"error": "not_found",
|
||||
c.JSON(500, gin.H{
|
||||
"error": "server_error",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -374,8 +386,8 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||
if controller.oidc == nil {
|
||||
controller.log.App.Warn().Msg("Received OIDC userinfo request but OIDC server is not configured")
|
||||
c.JSON(404, gin.H{
|
||||
"error": "not_found",
|
||||
c.JSON(500, gin.H{
|
||||
"error": "server_error",
|
||||
})
|
||||
return
|
||||
}
|
||||
@@ -507,8 +519,16 @@ func (controller *OIDCController) authorizeError(c *gin.Context, err error, reas
|
||||
return
|
||||
}
|
||||
|
||||
redirectUrl := ""
|
||||
|
||||
if controller.oidc != nil {
|
||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode())
|
||||
} else {
|
||||
redirectUrl = fmt.Sprintf("%s/error?%s", controller.runtime.AppURL, queries.Encode())
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{
|
||||
"status": 200,
|
||||
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
||||
"redirect_uri": redirectUrl,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -20,6 +20,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -27,7 +28,7 @@ func TestOIDCController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
simpleCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
@@ -861,7 +862,7 @@ func TestOIDCController(t *testing.T) {
|
||||
group := router.Group("/api")
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
controller.NewOIDCController(log, oidcService, group)
|
||||
controller.NewOIDCController(log, oidcService, runtime, group)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -21,7 +22,7 @@ func TestProxyController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
acls := map[string]model.App{
|
||||
"app_path_allow": {
|
||||
|
||||
@@ -10,10 +10,11 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
)
|
||||
|
||||
func TestResourcesController(t *testing.T) {
|
||||
cfg, _ := createTestConfigs(t)
|
||||
cfg, _ := test.CreateTestConfigs(t)
|
||||
|
||||
err := os.MkdirAll(cfg.Resources.Path, 0777)
|
||||
require.NoError(t, err)
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/model"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -26,7 +27,7 @@ func TestUserController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
totpCtx := func(c *gin.Context) {
|
||||
c.Set("context", &model.UserContext{
|
||||
|
||||
@@ -15,6 +15,7 @@ import (
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/test"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/logger"
|
||||
)
|
||||
|
||||
@@ -22,7 +23,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
log := logger.NewLogger().WithTestConfig()
|
||||
log.Init()
|
||||
|
||||
cfg, runtime := createTestConfigs(t)
|
||||
cfg, runtime := test.CreateTestConfigs(t)
|
||||
|
||||
type testCase struct {
|
||||
description string
|
||||
@@ -99,6 +100,7 @@ func TestWellKnownController(t *testing.T) {
|
||||
queries := repository.New(app.GetDB())
|
||||
|
||||
oidcService, err := service.NewOIDCService(log, cfg, runtime, queries, ctx, wg)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.description, func(t *testing.T) {
|
||||
|
||||
Reference in New Issue
Block a user