refactor: remove init functions from methods (#228)

This commit is contained in:
Stavros
2025-07-04 02:35:09 +03:00
committed by GitHub
parent 49c4c7a455
commit 1941de1125
10 changed files with 147 additions and 186 deletions

View File

@@ -8,13 +8,13 @@ import (
"time" "time"
totpCmd "tinyauth/cmd/totp" totpCmd "tinyauth/cmd/totp"
userCmd "tinyauth/cmd/user" userCmd "tinyauth/cmd/user"
"tinyauth/internal/api"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/constants" "tinyauth/internal/constants"
"tinyauth/internal/docker" "tinyauth/internal/docker"
"tinyauth/internal/handlers" "tinyauth/internal/handlers"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/server"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
@@ -114,8 +114,8 @@ var rootCmd = &cobra.Command{
RedirectCookieName: redirectCookieName, RedirectCookieName: redirectCookieName,
} }
// Create api config // Create server config
apiConfig := types.APIConfig{ serverConfig := types.ServerConfig{
Port: config.Port, Port: config.Port,
Address: config.Address, Address: config.Address,
} }
@@ -140,10 +140,7 @@ var rootCmd = &cobra.Command{
} }
// Create docker service // Create docker service
docker := docker.NewDocker() docker, err := docker.NewDocker()
// Initialize docker
err = docker.Init()
HandleError(err, "Failed to initialize docker") HandleError(err, "Failed to initialize docker")
// Create auth service // Create auth service
@@ -152,24 +149,19 @@ var rootCmd = &cobra.Command{
// Create OAuth providers service // Create OAuth providers service
providers := providers.NewProviders(oauthConfig) providers := providers.NewProviders(oauthConfig)
// Initialize providers
providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers // Create handlers
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create API // Create server
api := api.NewAPI(apiConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
HandleError(err, "Failed to create server")
// Setup routes // Start server
api.Init() err = srv.Start()
api.SetupRoutes() HandleError(err, "Failed to start server")
// Start
api.Run()
}, },
} }

View File

@@ -16,36 +16,38 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func NewAuth(config types.AuthConfig, docker *docker.Docker) *Auth {
return &Auth{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
}
}
type Auth struct { type Auth struct {
Config types.AuthConfig Config types.AuthConfig
Docker *docker.Docker Docker *docker.Docker
LoginAttempts map[string]*types.LoginAttempt LoginAttempts map[string]*types.LoginAttempt
LoginMutex sync.RWMutex LoginMutex sync.RWMutex
Store *sessions.CookieStore
} }
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { func NewAuth(config types.AuthConfig, docker *docker.Docker) *Auth {
// Create cookie store // Create cookie store
store := sessions.NewCookieStore([]byte(auth.Config.HMACSecret), []byte(auth.Config.EncryptionSecret)) store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret))
// Configure cookie store // Configure cookie store
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
MaxAge: auth.Config.SessionExpiry, MaxAge: config.SessionExpiry,
Secure: auth.Config.CookieSecure, Secure: config.CookieSecure,
HttpOnly: true, HttpOnly: true,
Domain: fmt.Sprintf(".%s", auth.Config.Domain), Domain: fmt.Sprintf(".%s", config.Domain),
} }
return &Auth{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
Store: store,
}
}
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
// Get session // Get session
session, err := store.Get(c.Request, auth.Config.SessionCookieName) session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying") log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying")
@@ -54,7 +56,7 @@ func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true) c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
// Try to get the session again // Try to get the session again
session, err = store.Get(c.Request, auth.Config.SessionCookieName) session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil { if err != nil {
// If we still can't get the session, log the error and return nil // If we still can't get the session, log the error and return nil

View File

@@ -11,35 +11,30 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewDocker() *Docker {
return &Docker{}
}
type Docker struct { type Docker struct {
Client *client.Client Client *client.Client
Context context.Context Context context.Context
} }
func (docker *Docker) Init() error { func NewDocker() (*Docker, error) {
// Create a new docker client // Create a new docker client
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
// Check if there was an error // Check if there was an error
if err != nil { if err != nil {
return err return nil, err
} }
// Create the context // Create the context
docker.Context = context.Background() ctx := context.Background()
// Negotiate API version // Negotiate API version
client.NegotiateAPIVersion(docker.Context) client.NegotiateAPIVersion(ctx)
// Set client return &Docker{
docker.Client = client Client: client,
Context: ctx,
// Done }, nil
return nil
} }
func (docker *Docker) GetContainers() ([]container.Summary, error) { func (docker *Docker) GetContainers() ([]container.Summary, error) {

View File

@@ -18,6 +18,14 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Handlers struct {
Config types.HandlersConfig
Auth *auth.Auth
Hooks *hooks.Hooks
Providers *providers.Providers
Docker *docker.Docker
}
func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers { func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers {
return &Handlers{ return &Handlers{
Config: config, Config: config,
@@ -28,14 +36,6 @@ func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hook
} }
} }
type Handlers struct {
Config types.HandlersConfig
Auth *auth.Auth
Hooks *hooks.Hooks
Providers *providers.Providers
Docker *docker.Docker
}
func (h *Handlers) AuthHandler(c *gin.Context) { func (h *Handlers) AuthHandler(c *gin.Context) {
// Create struct for proxy // Create struct for proxy
var proxy types.Proxy var proxy types.Proxy

View File

@@ -12,6 +12,12 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{ return &Hooks{
Config: config, Config: config,
@@ -20,12 +26,6 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr
} }
} }
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Get session cookie and basic auth // Get session cookie and basic auth
cookie, err := hooks.Auth.GetSessionCookie(c) cookie, err := hooks.Auth.GetSessionCookie(c)

View File

@@ -10,32 +10,24 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
return &OAuth{
Config: config,
InsecureSkipVerify: insecureSkipVerify,
}
}
type OAuth struct { type OAuth struct {
Config oauth2.Config Config oauth2.Config
Context context.Context Context context.Context
Token *oauth2.Token Token *oauth2.Token
Verifier string Verifier string
InsecureSkipVerify bool
} }
func (oauth *OAuth) Init() { func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
// Create transport with TLS // Create transport with TLS
transport := &http.Transport{ transport := &http.Transport{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: oauth.InsecureSkipVerify, InsecureSkipVerify: insecureSkipVerify,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
} }
// Create a new context // Create a new context
oauth.Context = context.Background() ctx := context.Background()
// Create the HTTP client with the transport // Create the HTTP client with the transport
httpClient := &http.Client{ httpClient := &http.Client{
@@ -43,9 +35,16 @@ func (oauth *OAuth) Init() {
} }
// Set the HTTP client in the context // Set the HTTP client in the context
oauth.Context = context.WithValue(oauth.Context, oauth2.HTTPClient, httpClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// Create the verifier // Create the verifier
oauth.Verifier = oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier()
return &OAuth{
Config: config,
Context: ctx,
Verifier: verifier,
}
} }
func (oauth *OAuth) GetAuthURL(state string) string { func (oauth *OAuth) GetAuthURL(state string) string {

View File

@@ -11,12 +11,6 @@ import (
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func NewProviders(config types.OAuthConfig) *Providers {
return &Providers{
Config: config,
}
}
type Providers struct { type Providers struct {
Config types.OAuthConfig Config types.OAuthConfig
Github *oauth.OAuth Github *oauth.OAuth
@@ -24,60 +18,57 @@ type Providers struct {
Generic *oauth.OAuth Generic *oauth.OAuth
} }
func (providers *Providers) Init() { func NewProviders(config types.OAuthConfig) *Providers {
providers := &Providers{
Config: config,
}
// If we have a client id and secret for github, initialize the oauth provider // If we have a client id and secret for github, initialize the oauth provider
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { if config.GithubClientId != "" && config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth") log.Info().Msg("Initializing Github OAuth")
// Create a new oauth provider with the github config // Create a new oauth provider with the github config
providers.Github = oauth.NewOAuth(oauth2.Config{ providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GithubClientId, ClientID: config.GithubClientId,
ClientSecret: providers.Config.GithubClientSecret, ClientSecret: config.GithubClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL),
Scopes: GithubScopes(), Scopes: GithubScopes(),
Endpoint: endpoints.GitHub, Endpoint: endpoints.GitHub,
}, false) }, false)
// Initialize the oauth provider
providers.Github.Init()
} }
// If we have a client id and secret for google, initialize the oauth provider // If we have a client id and secret for google, initialize the oauth provider
if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { if config.GoogleClientId != "" && config.GoogleClientSecret != "" {
log.Info().Msg("Initializing Google OAuth") log.Info().Msg("Initializing Google OAuth")
// Create a new oauth provider with the google config // Create a new oauth provider with the google config
providers.Google = oauth.NewOAuth(oauth2.Config{ providers.Google = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GoogleClientId, ClientID: config.GoogleClientId,
ClientSecret: providers.Config.GoogleClientSecret, ClientSecret: config.GoogleClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL),
Scopes: GoogleScopes(), Scopes: GoogleScopes(),
Endpoint: endpoints.Google, Endpoint: endpoints.Google,
}, false) }, false)
// Initialize the oauth provider
providers.Google.Init()
} }
// If we have a client id and secret for generic oauth, initialize the oauth provider // If we have a client id and secret for generic oauth, initialize the oauth provider
if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { if config.GenericClientId != "" && config.GenericClientSecret != "" {
log.Info().Msg("Initializing Generic OAuth") log.Info().Msg("Initializing Generic OAuth")
// Create a new oauth provider with the generic config // Create a new oauth provider with the generic config
providers.Generic = oauth.NewOAuth(oauth2.Config{ providers.Generic = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GenericClientId, ClientID: config.GenericClientId,
ClientSecret: providers.Config.GenericClientSecret, ClientSecret: config.GenericClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL),
Scopes: providers.Config.GenericScopes, Scopes: config.GenericScopes,
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: providers.Config.GenericAuthURL, AuthURL: config.GenericAuthURL,
TokenURL: providers.Config.GenericTokenURL, TokenURL: config.GenericTokenURL,
}, },
}, providers.Config.GenericSkipSSL) }, config.GenericSkipSSL)
// Initialize the oauth provider
providers.Generic.Init()
} }
return providers
} }
func (providers *Providers) GetProvider(provider string) *oauth.OAuth { func (providers *Providers) GetProvider(provider string) *oauth.OAuth {

View File

@@ -1,4 +1,4 @@
package api package server
import ( import (
"fmt" "fmt"
@@ -15,20 +15,13 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewAPI(config types.APIConfig, handlers *handlers.Handlers) *API { type Server struct {
return &API{ Config types.ServerConfig
Config: config,
Handlers: handlers,
}
}
type API struct {
Config types.APIConfig
Router *gin.Engine
Handlers *handlers.Handlers Handlers *handlers.Handlers
Router *gin.Engine
} }
func (api *API) Init() { func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) {
// Disable gin logs // Disable gin logs
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@@ -42,7 +35,7 @@ func (api *API) Init() {
dist, err := fs.Sub(assets.Assets, "dist") dist, err := fs.Sub(assets.Assets, "dist")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to get UI assets") return nil, err
} }
// Create file server // Create file server
@@ -69,41 +62,38 @@ func (api *API) Init() {
} }
}) })
// Set router // Proxy routes
api.Router = router router.GET("/api/auth/:proxy", handlers.AuthHandler)
// Auth routes
router.POST("/api/login", handlers.LoginHandler)
router.POST("/api/totp", handlers.TotpHandler)
router.POST("/api/logout", handlers.LogoutHandler)
// Context routes
router.GET("/api/app", handlers.AppHandler)
router.GET("/api/user", handlers.UserHandler)
// OAuth routes
router.GET("/api/oauth/url/:provider", handlers.OauthUrlHandler)
router.GET("/api/oauth/callback/:provider", handlers.OauthCallbackHandler)
// App routes
router.GET("/api/healthcheck", handlers.HealthcheckHandler)
// Return the server
return &Server{
Config: config,
Handlers: handlers,
Router: router,
}, nil
} }
func (api *API) SetupRoutes() { func (s *Server) Start() error {
// Proxy
api.Router.GET("/api/auth/:proxy", api.Handlers.AuthHandler)
// Auth
api.Router.POST("/api/login", api.Handlers.LoginHandler)
api.Router.POST("/api/totp", api.Handlers.TotpHandler)
api.Router.POST("/api/logout", api.Handlers.LogoutHandler)
// Context
api.Router.GET("/api/app", api.Handlers.AppHandler)
api.Router.GET("/api/user", api.Handlers.UserHandler)
// OAuth
api.Router.GET("/api/oauth/url/:provider", api.Handlers.OauthUrlHandler)
api.Router.GET("/api/oauth/callback/:provider", api.Handlers.OauthCallbackHandler)
// App
api.Router.GET("/api/healthcheck", api.Handlers.HealthcheckHandler)
}
func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
// Run server // Run server
err := api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server")
// Check for errors return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port))
if err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
} }
// zerolog is a middleware for gin that logs requests using zerolog // zerolog is a middleware for gin that logs requests using zerolog

View File

@@ -1,4 +1,4 @@
package api_test package server_test
import ( import (
"encoding/json" "encoding/json"
@@ -8,19 +8,19 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"tinyauth/internal/api"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/docker" "tinyauth/internal/docker"
"tinyauth/internal/handlers" "tinyauth/internal/handlers"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/server"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
) )
// Simple API config for tests // Simple server config for tests
var apiConfig = types.APIConfig{ var serverConfig = types.ServerConfig{
Port: 8080, Port: 8080,
Address: "0.0.0.0", Address: "0.0.0.0",
} }
@@ -68,15 +68,11 @@ var user = types.User{
Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass
} }
// We need all this to be able to test the API // We need all this to be able to test the server
func getAPI(t *testing.T) *api.API { func getServer(t *testing.T) *server.Server {
// Create docker service // Create docker service
docker := docker.NewDocker() docker, err := docker.NewDocker()
// Initialize docker
err := docker.Init()
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Failed to initialize docker: %v", err) t.Fatalf("Failed to initialize docker: %v", err)
} }
@@ -93,31 +89,29 @@ func getAPI(t *testing.T) *api.API {
// Create providers service // Create providers service
providers := providers.NewProviders(types.OAuthConfig{}) providers := providers.NewProviders(types.OAuthConfig{})
// Initialize providers
providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers service // Create handlers service
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create API // Create server
api := api.NewAPI(apiConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
// Setup routes if err != nil {
api.Init() t.Fatalf("Failed to create server: %v", err)
api.SetupRoutes() }
return api // Return the server
return srv
} }
// Test login (we will need this for the other tests) // Test login
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
t.Log("Testing login") t.Log("Testing login")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -162,8 +156,8 @@ func TestLogin(t *testing.T) {
func TestAppContext(t *testing.T) { func TestAppContext(t *testing.T) {
t.Log("Testing app context") t.Log("Testing app context")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -230,8 +224,8 @@ func TestAppContext(t *testing.T) {
func TestUserContext(t *testing.T) { func TestUserContext(t *testing.T) {
t.Log("Testing user context") t.Log("Testing user context")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -288,8 +282,8 @@ func TestUserContext(t *testing.T) {
func TestLogout(t *testing.T) { func TestLogout(t *testing.T) {
t.Log("Testing logout") t.Log("Testing logout")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -319,5 +313,3 @@ func TestLogout(t *testing.T) {
t.Fatalf("Cookie not flushed") t.Fatalf("Cookie not flushed")
} }
} }
// TODO: Testing for the oauth stuff

View File

@@ -69,8 +69,8 @@ type OAuthConfig struct {
AppURL string AppURL string
} }
// APIConfig is the configuration for the API // ServerConfig is the configuration for the server
type APIConfig struct { type ServerConfig struct {
Port int Port int
Address string Address string
} }