mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-08 21:38:13 +00:00
fix: review comments batch 2
This commit is contained in:
@@ -29,7 +29,7 @@ type BootstrapApp struct {
|
|||||||
csrfCookieName string
|
csrfCookieName string
|
||||||
redirectCookieName string
|
redirectCookieName string
|
||||||
oauthSessionCookieName string
|
oauthSessionCookieName string
|
||||||
localUsers []model.LocalUser
|
localUsers *[]model.LocalUser
|
||||||
oauthProviders map[string]model.OAuthServiceConfig
|
oauthProviders map[string]model.OAuthServiceConfig
|
||||||
configuredProviders []controller.Provider
|
configuredProviders []controller.Provider
|
||||||
oidcClients []model.OIDCClientConfig
|
oidcClients []model.OIDCClientConfig
|
||||||
@@ -69,7 +69,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
app.context.localUsers = *users
|
app.context.localUsers = users
|
||||||
|
|
||||||
// Setup OAuth providers
|
// Setup OAuth providers
|
||||||
app.context.oauthProviders = app.config.OAuth.Providers
|
app.context.oauthProviders = app.config.OAuth.Providers
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func TestProxyController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
authServiceCfg := service.AuthServiceConfig{
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
LocalUsers: []model.LocalUser{
|
LocalUsers: &[]model.LocalUser{
|
||||||
{
|
{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestUserController(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
authServiceCfg := service.AuthServiceConfig{
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
LocalUsers: []model.LocalUser{
|
LocalUsers: &[]model.LocalUser{
|
||||||
{
|
{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func TestContextMiddleware(t *testing.T) {
|
|||||||
tempDir := t.TempDir()
|
tempDir := t.TempDir()
|
||||||
|
|
||||||
authServiceCfg := service.AuthServiceConfig{
|
authServiceCfg := service.AuthServiceConfig{
|
||||||
LocalUsers: []model.LocalUser{
|
LocalUsers: &[]model.LocalUser{
|
||||||
{
|
{
|
||||||
Username: "testuser",
|
Username: "testuser",
|
||||||
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
Password: "$2a$10$ZwVYQH07JX2zq7Fjkt3gU.BjwvvwPeli4OqOno04RQIv0P7usBrXa", // password
|
||||||
|
|||||||
@@ -56,19 +56,19 @@ func (c *UserContext) IsAuthenticated() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) IsLocal() bool {
|
func (c *UserContext) IsLocal() bool {
|
||||||
return c.Provider == ProviderLocal
|
return c.Provider == ProviderLocal && c.Local != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) IsOAuth() bool {
|
func (c *UserContext) IsOAuth() bool {
|
||||||
return c.Provider == ProviderOAuth
|
return c.Provider == ProviderOAuth && c.OAuth != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) IsLDAP() bool {
|
func (c *UserContext) IsLDAP() bool {
|
||||||
return c.Provider == ProviderLDAP
|
return c.Provider == ProviderLDAP && c.LDAP != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) IsBasicAuth() bool {
|
func (c *UserContext) IsBasicAuth() bool {
|
||||||
return c.Provider == ProviderBasicAuth
|
return c.Provider == ProviderBasicAuth && c.Local != nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
func (c *UserContext) NewFromGin(ginctx *gin.Context) (*UserContext, error) {
|
||||||
@@ -145,12 +145,24 @@ func (c *UserContext) NewFromSession(session *repository.Session) (*UserContext,
|
|||||||
func (c *UserContext) GetUsername() string {
|
func (c *UserContext) GetUsername() string {
|
||||||
switch c.Provider {
|
switch c.Provider {
|
||||||
case ProviderLocal:
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Username
|
return c.Local.Username
|
||||||
case ProviderLDAP:
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.LDAP.Username
|
return c.LDAP.Username
|
||||||
case ProviderBasicAuth:
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Username
|
return c.Local.Username
|
||||||
case ProviderOAuth:
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.OAuth.Username
|
return c.OAuth.Username
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
@@ -160,12 +172,24 @@ func (c *UserContext) GetUsername() string {
|
|||||||
func (c *UserContext) GetEmail() string {
|
func (c *UserContext) GetEmail() string {
|
||||||
switch c.Provider {
|
switch c.Provider {
|
||||||
case ProviderLocal:
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Email
|
return c.Local.Email
|
||||||
case ProviderLDAP:
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.LDAP.Email
|
return c.LDAP.Email
|
||||||
case ProviderBasicAuth:
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Email
|
return c.Local.Email
|
||||||
case ProviderOAuth:
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.OAuth.Email
|
return c.OAuth.Email
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
@@ -175,12 +199,24 @@ func (c *UserContext) GetEmail() string {
|
|||||||
func (c *UserContext) GetName() string {
|
func (c *UserContext) GetName() string {
|
||||||
switch c.Provider {
|
switch c.Provider {
|
||||||
case ProviderLocal:
|
case ProviderLocal:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Name
|
return c.Local.Name
|
||||||
case ProviderLDAP:
|
case ProviderLDAP:
|
||||||
|
if c.LDAP == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.LDAP.Name
|
return c.LDAP.Name
|
||||||
case ProviderBasicAuth:
|
case ProviderBasicAuth:
|
||||||
|
if c.Local == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.Local.Name
|
return c.Local.Name
|
||||||
case ProviderOAuth:
|
case ProviderOAuth:
|
||||||
|
if c.OAuth == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
return c.OAuth.Name
|
return c.OAuth.Name
|
||||||
default:
|
default:
|
||||||
return ""
|
return ""
|
||||||
@@ -201,14 +237,14 @@ func (c *UserContext) ProviderName() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) TOTPPending() bool {
|
func (c *UserContext) TOTPPending() bool {
|
||||||
if c.Provider == ProviderLocal {
|
if c.Provider == ProviderLocal && c.Local != nil {
|
||||||
return c.Local.TOTPPending
|
return c.Local.TOTPPending
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *UserContext) OAuthName() string {
|
func (c *UserContext) OAuthName() string {
|
||||||
if c.Provider == ProviderOAuth {
|
if c.Provider == ProviderOAuth && c.OAuth != nil {
|
||||||
return c.OAuth.DisplayName
|
return c.OAuth.DisplayName
|
||||||
}
|
}
|
||||||
return ""
|
return ""
|
||||||
|
|||||||
@@ -34,25 +34,25 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLocal returns true for ProviderLocal",
|
description: "IsLocal returns true for ProviderLocal",
|
||||||
context: &model.UserContext{Provider: model.ProviderLocal},
|
context: &model.UserContext{Provider: model.ProviderLocal, Local: &model.LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLocal() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsOAuth returns true for ProviderOAuth",
|
description: "IsOAuth returns true for ProviderOAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderOAuth},
|
context: &model.UserContext{Provider: model.ProviderOAuth, OAuth: &model.OAuthContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsOAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsLDAP returns true for ProviderLDAP",
|
description: "IsLDAP returns true for ProviderLDAP",
|
||||||
context: &model.UserContext{Provider: model.ProviderLDAP},
|
context: &model.UserContext{Provider: model.ProviderLDAP, LDAP: &model.LDAPContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsLDAP() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
description: "IsBasicAuth returns true for ProviderBasicAuth",
|
||||||
context: &model.UserContext{Provider: model.ProviderBasicAuth},
|
context: &model.UserContext{Provider: model.ProviderBasicAuth, Local: &model.LocalContext{}},
|
||||||
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
run: func(t *testing.T, c *model.UserContext) any { return c.IsBasicAuth() },
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
@@ -258,6 +258,14 @@ func TestContext(t *testing.T) {
|
|||||||
},
|
},
|
||||||
expected: "incomplete user context",
|
expected: "incomplete user context",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
description: "Getters should not panic if provider context is empty",
|
||||||
|
context: &model.UserContext{Provider: model.ProviderLocal},
|
||||||
|
run: func(t *testing.T, c *model.UserContext) any {
|
||||||
|
return [3]string{c.GetUsername(), c.GetEmail(), c.GetName()}
|
||||||
|
},
|
||||||
|
expected: [3]string{"", "", ""},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, test := range tests {
|
for _, test := range tests {
|
||||||
|
|||||||
@@ -28,18 +28,21 @@ func (acls *AccessControlsService) Init() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
func (acls *AccessControlsService) lookupStaticACLs(domain string) *model.App {
|
||||||
|
var appAcls *model.App
|
||||||
for app, config := range acls.static {
|
for app, config := range acls.static {
|
||||||
if config.Config.Domain == domain {
|
if config.Config.Domain == domain {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by domain")
|
||||||
return &config
|
appAcls = &config
|
||||||
|
break // If we find a match by domain, we can stop searching
|
||||||
}
|
}
|
||||||
|
|
||||||
if strings.SplitN(domain, ".", 2)[0] == app {
|
if strings.SplitN(domain, ".", 2)[0] == app {
|
||||||
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
tlog.App.Debug().Str("name", app).Msg("Found matching container by app name")
|
||||||
return &config
|
appAcls = &config
|
||||||
|
break // If we find a match by app name, we can stop searching
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
return appAcls
|
||||||
}
|
}
|
||||||
|
|
||||||
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
func (acls *AccessControlsService) GetAccessControls(domain string) (*model.App, error) {
|
||||||
|
|||||||
@@ -73,7 +73,7 @@ type Lockdown struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type AuthServiceConfig struct {
|
type AuthServiceConfig struct {
|
||||||
LocalUsers []model.LocalUser
|
LocalUsers *[]model.LocalUser
|
||||||
OauthWhitelist []string
|
OauthWhitelist []string
|
||||||
SessionExpiry int
|
SessionExpiry int
|
||||||
SessionMaxLifetime int
|
SessionMaxLifetime int
|
||||||
@@ -147,6 +147,9 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
switch search.Type {
|
switch search.Type {
|
||||||
case model.UserLocal:
|
case model.UserLocal:
|
||||||
user := auth.GetLocalUser(search.Username)
|
user := auth.GetLocalUser(search.Username)
|
||||||
|
if user == nil {
|
||||||
|
return ErrUserNotFound
|
||||||
|
}
|
||||||
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
return bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password))
|
||||||
case model.UserLDAP:
|
case model.UserLDAP:
|
||||||
if auth.ldap.IsConfigured() {
|
if auth.ldap.IsConfigured() {
|
||||||
@@ -169,7 +172,10 @@ func (auth *AuthService) CheckUserPassword(search model.UserSearch, password str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
func (auth *AuthService) GetLocalUser(username string) *model.LocalUser {
|
||||||
for _, user := range auth.config.LocalUsers {
|
if auth.config.LocalUsers == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
for _, user := range *auth.config.LocalUsers {
|
||||||
if user.Username == username {
|
if user.Username == username {
|
||||||
return &user
|
return &user
|
||||||
}
|
}
|
||||||
@@ -438,7 +444,7 @@ func (auth *AuthService) GetSession(ctx context.Context, uuid string) (*reposito
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LocalAuthConfigured() bool {
|
func (auth *AuthService) LocalAuthConfigured() bool {
|
||||||
return len(auth.config.LocalUsers) > 0
|
return auth.config.LocalUsers != nil && len(*auth.config.LocalUsers) > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (auth *AuthService) LDAPAuthConfigured() bool {
|
func (auth *AuthService) LDAPAuthConfigured() bool {
|
||||||
|
|||||||
@@ -95,7 +95,8 @@ func (k *KubernetesService) getByDomain(domain string) *model.App {
|
|||||||
|
|
||||||
if appKey, ok := k.domainIndex[domain]; ok {
|
if appKey, ok := k.domainIndex[domain]; ok {
|
||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for i := range apps {
|
||||||
|
app := &apps[i]
|
||||||
if app.domain == domain && app.appName == appKey.appName {
|
if app.domain == domain && app.appName == appKey.appName {
|
||||||
return &app.app
|
return &app.app
|
||||||
}
|
}
|
||||||
@@ -111,7 +112,8 @@ func (k *KubernetesService) getByAppName(appName string) *model.App {
|
|||||||
|
|
||||||
if appKey, ok := k.appNameIndex[appName]; ok {
|
if appKey, ok := k.appNameIndex[appName]; ok {
|
||||||
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
if apps, ok := k.ingressApps[appKey.ingressKey]; ok {
|
||||||
for _, app := range apps {
|
for i := range apps {
|
||||||
|
app := &apps[i]
|
||||||
if app.appName == appName {
|
if app.appName == appName {
|
||||||
return &app.app
|
return &app.app
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user