Compare commits

...

3 Commits

Author SHA1 Message Date
Stavros
59d2bce189 refactor: remove init functions from methods 2025-07-04 02:29:02 +03:00
Stavros
49c4c7a455 feat: add codeconv to ci 2025-07-04 01:48:38 +03:00
Stavros
c10bff55de fix: encrypt the cookie in sessions (#225)
* fix: encrypt the cookie in sessions

* tests: use new auth config in tests

* fix: coderabbit suggestions
2025-07-04 01:43:36 +03:00
12 changed files with 214 additions and 191 deletions

View File

@@ -39,4 +39,9 @@ jobs:
cp -r frontend/dist internal/assets/dist cp -r frontend/dist internal/assets/dist
- name: Run tests - name: Run tests
run: go test -v ./... run: go test -coverprofile=coverage.txt -v ./...
- name: Upload coverage reports to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}

View File

@@ -8,13 +8,13 @@ import (
"time" "time"
totpCmd "tinyauth/cmd/totp" totpCmd "tinyauth/cmd/totp"
userCmd "tinyauth/cmd/user" userCmd "tinyauth/cmd/user"
"tinyauth/internal/api"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/constants" "tinyauth/internal/constants"
"tinyauth/internal/docker" "tinyauth/internal/docker"
"tinyauth/internal/handlers" "tinyauth/internal/handlers"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/server"
"tinyauth/internal/types" "tinyauth/internal/types"
"tinyauth/internal/utils" "tinyauth/internal/utils"
@@ -74,6 +74,15 @@ var rootCmd = &cobra.Command{
csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId) csrfCookieName := fmt.Sprintf("%s-%s", constants.CsrfCookieName, cookieId)
redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId) redirectCookieName := fmt.Sprintf("%s-%s", constants.RedirectCookieName, cookieId)
// Generate HMAC and encryption secrets
log.Debug().Msg("Deriving HMAC and encryption secrets")
hmacSecret, err := utils.DeriveKey(config.Secret, "hmac")
HandleError(err, "Failed to derive HMAC secret")
encryptionSecret, err := utils.DeriveKey(config.Secret, "encryption")
HandleError(err, "Failed to derive encryption secret")
// Create OAuth config // Create OAuth config
oauthConfig := types.OAuthConfig{ oauthConfig := types.OAuthConfig{
GithubClientId: config.GithubClientId, GithubClientId: config.GithubClientId,
@@ -105,8 +114,8 @@ var rootCmd = &cobra.Command{
RedirectCookieName: redirectCookieName, RedirectCookieName: redirectCookieName,
} }
// Create api config // Create server config
apiConfig := types.APIConfig{ serverConfig := types.ServerConfig{
Port: config.Port, Port: config.Port,
Address: config.Address, Address: config.Address,
} }
@@ -115,13 +124,14 @@ var rootCmd = &cobra.Command{
authConfig := types.AuthConfig{ authConfig := types.AuthConfig{
Users: users, Users: users,
OauthWhitelist: config.OAuthWhitelist, OauthWhitelist: config.OAuthWhitelist,
Secret: config.Secret,
CookieSecure: config.CookieSecure, CookieSecure: config.CookieSecure,
SessionExpiry: config.SessionExpiry, SessionExpiry: config.SessionExpiry,
Domain: domain, Domain: domain,
LoginTimeout: config.LoginTimeout, LoginTimeout: config.LoginTimeout,
LoginMaxRetries: config.LoginMaxRetries, LoginMaxRetries: config.LoginMaxRetries,
SessionCookieName: sessionCookieName, SessionCookieName: sessionCookieName,
HMACSecret: hmacSecret,
EncryptionSecret: encryptionSecret,
} }
// Create hooks config // Create hooks config
@@ -130,10 +140,7 @@ var rootCmd = &cobra.Command{
} }
// Create docker service // Create docker service
docker := docker.NewDocker() docker, err := docker.NewDocker()
// Initialize docker
err = docker.Init()
HandleError(err, "Failed to initialize docker") HandleError(err, "Failed to initialize docker")
// Create auth service // Create auth service
@@ -142,24 +149,19 @@ var rootCmd = &cobra.Command{
// Create OAuth providers service // Create OAuth providers service
providers := providers.NewProviders(oauthConfig) providers := providers.NewProviders(oauthConfig)
// Initialize providers
providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers // Create handlers
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create API // Create server
api := api.NewAPI(apiConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
HandleError(err, "Failed to create server")
// Setup routes // Start server
api.Init() err = srv.Start()
api.SetupRoutes() HandleError(err, "Failed to start server")
// Start
api.Run()
}, },
} }

View File

@@ -16,40 +16,54 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
func NewAuth(config types.AuthConfig, docker *docker.Docker) *Auth {
return &Auth{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
}
}
type Auth struct { type Auth struct {
Config types.AuthConfig Config types.AuthConfig
Docker *docker.Docker Docker *docker.Docker
LoginAttempts map[string]*types.LoginAttempt LoginAttempts map[string]*types.LoginAttempt
LoginMutex sync.RWMutex LoginMutex sync.RWMutex
Store *sessions.CookieStore
} }
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) { func NewAuth(config types.AuthConfig, docker *docker.Docker) *Auth {
// Create cookie store // Create cookie store
store := sessions.NewCookieStore([]byte(auth.Config.Secret)) store := sessions.NewCookieStore([]byte(config.HMACSecret), []byte(config.EncryptionSecret))
// Configure cookie store // Configure cookie store
store.Options = &sessions.Options{ store.Options = &sessions.Options{
Path: "/", Path: "/",
MaxAge: auth.Config.SessionExpiry, MaxAge: config.SessionExpiry,
Secure: auth.Config.CookieSecure, Secure: config.CookieSecure,
HttpOnly: true, HttpOnly: true,
Domain: fmt.Sprintf(".%s", auth.Config.Domain), Domain: fmt.Sprintf(".%s", config.Domain),
} }
return &Auth{
Config: config,
Docker: docker,
LoginAttempts: make(map[string]*types.LoginAttempt),
Store: store,
}
}
func (auth *Auth) GetSession(c *gin.Context) (*sessions.Session, error) {
// Get session // Get session
session, err := store.Get(c.Request, auth.Config.SessionCookieName) session, err := auth.Store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Invalid session, clearing cookie and retrying")
// Delete the session cookie if there is an error
c.SetCookie(auth.Config.SessionCookieName, "", -1, "/", fmt.Sprintf(".%s", auth.Config.Domain), auth.Config.CookieSecure, true)
// Try to get the session again
session, err = auth.Store.Get(c.Request, auth.Config.SessionCookieName)
if err != nil {
// If we still can't get the session, log the error and return nil
log.Error().Err(err).Msg("Failed to get session") log.Error().Err(err).Msg("Failed to get session")
return nil, err return nil, err
} }
}
return session, nil return session, nil
} }

View File

@@ -11,35 +11,30 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewDocker() *Docker {
return &Docker{}
}
type Docker struct { type Docker struct {
Client *client.Client Client *client.Client
Context context.Context Context context.Context
} }
func (docker *Docker) Init() error { func NewDocker() (*Docker, error) {
// Create a new docker client // Create a new docker client
client, err := client.NewClientWithOpts(client.FromEnv) client, err := client.NewClientWithOpts(client.FromEnv)
// Check if there was an error // Check if there was an error
if err != nil { if err != nil {
return err return nil, err
} }
// Create the context // Create the context
docker.Context = context.Background() ctx := context.Background()
// Negotiate API version // Negotiate API version
client.NegotiateAPIVersion(docker.Context) client.NegotiateAPIVersion(ctx)
// Set client return &Docker{
docker.Client = client Client: client,
Context: ctx,
// Done }, nil
return nil
} }
func (docker *Docker) GetContainers() ([]container.Summary, error) { func (docker *Docker) GetContainers() ([]container.Summary, error) {

View File

@@ -18,6 +18,14 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Handlers struct {
Config types.HandlersConfig
Auth *auth.Auth
Hooks *hooks.Hooks
Providers *providers.Providers
Docker *docker.Docker
}
func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers { func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hooks, providers *providers.Providers, docker *docker.Docker) *Handlers {
return &Handlers{ return &Handlers{
Config: config, Config: config,
@@ -28,14 +36,6 @@ func NewHandlers(config types.HandlersConfig, auth *auth.Auth, hooks *hooks.Hook
} }
} }
type Handlers struct {
Config types.HandlersConfig
Auth *auth.Auth
Hooks *hooks.Hooks
Providers *providers.Providers
Docker *docker.Docker
}
func (h *Handlers) AuthHandler(c *gin.Context) { func (h *Handlers) AuthHandler(c *gin.Context) {
// Create struct for proxy // Create struct for proxy
var proxy types.Proxy var proxy types.Proxy

View File

@@ -12,6 +12,12 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks { func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Providers) *Hooks {
return &Hooks{ return &Hooks{
Config: config, Config: config,
@@ -20,12 +26,6 @@ func NewHooks(config types.HooksConfig, auth *auth.Auth, providers *providers.Pr
} }
} }
type Hooks struct {
Config types.HooksConfig
Auth *auth.Auth
Providers *providers.Providers
}
func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext { func (hooks *Hooks) UseUserContext(c *gin.Context) types.UserContext {
// Get session cookie and basic auth // Get session cookie and basic auth
cookie, err := hooks.Auth.GetSessionCookie(c) cookie, err := hooks.Auth.GetSessionCookie(c)

View File

@@ -10,32 +10,24 @@ import (
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
return &OAuth{
Config: config,
InsecureSkipVerify: insecureSkipVerify,
}
}
type OAuth struct { type OAuth struct {
Config oauth2.Config Config oauth2.Config
Context context.Context Context context.Context
Token *oauth2.Token Token *oauth2.Token
Verifier string Verifier string
InsecureSkipVerify bool
} }
func (oauth *OAuth) Init() { func NewOAuth(config oauth2.Config, insecureSkipVerify bool) *OAuth {
// Create transport with TLS // Create transport with TLS
transport := &http.Transport{ transport := &http.Transport{
TLSClientConfig: &tls.Config{ TLSClientConfig: &tls.Config{
InsecureSkipVerify: oauth.InsecureSkipVerify, InsecureSkipVerify: insecureSkipVerify,
MinVersion: tls.VersionTLS12, MinVersion: tls.VersionTLS12,
}, },
} }
// Create a new context // Create a new context
oauth.Context = context.Background() ctx := context.Background()
// Create the HTTP client with the transport // Create the HTTP client with the transport
httpClient := &http.Client{ httpClient := &http.Client{
@@ -43,9 +35,16 @@ func (oauth *OAuth) Init() {
} }
// Set the HTTP client in the context // Set the HTTP client in the context
oauth.Context = context.WithValue(oauth.Context, oauth2.HTTPClient, httpClient) ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
// Create the verifier // Create the verifier
oauth.Verifier = oauth2.GenerateVerifier() verifier := oauth2.GenerateVerifier()
return &OAuth{
Config: config,
Context: ctx,
Verifier: verifier,
}
} }
func (oauth *OAuth) GetAuthURL(state string) string { func (oauth *OAuth) GetAuthURL(state string) string {

View File

@@ -11,12 +11,6 @@ import (
"golang.org/x/oauth2/endpoints" "golang.org/x/oauth2/endpoints"
) )
func NewProviders(config types.OAuthConfig) *Providers {
return &Providers{
Config: config,
}
}
type Providers struct { type Providers struct {
Config types.OAuthConfig Config types.OAuthConfig
Github *oauth.OAuth Github *oauth.OAuth
@@ -24,60 +18,57 @@ type Providers struct {
Generic *oauth.OAuth Generic *oauth.OAuth
} }
func (providers *Providers) Init() { func NewProviders(config types.OAuthConfig) *Providers {
providers := &Providers{
Config: config,
}
// If we have a client id and secret for github, initialize the oauth provider // If we have a client id and secret for github, initialize the oauth provider
if providers.Config.GithubClientId != "" && providers.Config.GithubClientSecret != "" { if config.GithubClientId != "" && config.GithubClientSecret != "" {
log.Info().Msg("Initializing Github OAuth") log.Info().Msg("Initializing Github OAuth")
// Create a new oauth provider with the github config // Create a new oauth provider with the github config
providers.Github = oauth.NewOAuth(oauth2.Config{ providers.Github = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GithubClientId, ClientID: config.GithubClientId,
ClientSecret: providers.Config.GithubClientSecret, ClientSecret: config.GithubClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/github", config.AppURL),
Scopes: GithubScopes(), Scopes: GithubScopes(),
Endpoint: endpoints.GitHub, Endpoint: endpoints.GitHub,
}, false) }, false)
// Initialize the oauth provider
providers.Github.Init()
} }
// If we have a client id and secret for google, initialize the oauth provider // If we have a client id and secret for google, initialize the oauth provider
if providers.Config.GoogleClientId != "" && providers.Config.GoogleClientSecret != "" { if config.GoogleClientId != "" && config.GoogleClientSecret != "" {
log.Info().Msg("Initializing Google OAuth") log.Info().Msg("Initializing Google OAuth")
// Create a new oauth provider with the google config // Create a new oauth provider with the google config
providers.Google = oauth.NewOAuth(oauth2.Config{ providers.Google = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GoogleClientId, ClientID: config.GoogleClientId,
ClientSecret: providers.Config.GoogleClientSecret, ClientSecret: config.GoogleClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/google", config.AppURL),
Scopes: GoogleScopes(), Scopes: GoogleScopes(),
Endpoint: endpoints.Google, Endpoint: endpoints.Google,
}, false) }, false)
// Initialize the oauth provider
providers.Google.Init()
} }
// If we have a client id and secret for generic oauth, initialize the oauth provider // If we have a client id and secret for generic oauth, initialize the oauth provider
if providers.Config.GenericClientId != "" && providers.Config.GenericClientSecret != "" { if config.GenericClientId != "" && config.GenericClientSecret != "" {
log.Info().Msg("Initializing Generic OAuth") log.Info().Msg("Initializing Generic OAuth")
// Create a new oauth provider with the generic config // Create a new oauth provider with the generic config
providers.Generic = oauth.NewOAuth(oauth2.Config{ providers.Generic = oauth.NewOAuth(oauth2.Config{
ClientID: providers.Config.GenericClientId, ClientID: config.GenericClientId,
ClientSecret: providers.Config.GenericClientSecret, ClientSecret: config.GenericClientSecret,
RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", providers.Config.AppURL), RedirectURL: fmt.Sprintf("%s/api/oauth/callback/generic", config.AppURL),
Scopes: providers.Config.GenericScopes, Scopes: config.GenericScopes,
Endpoint: oauth2.Endpoint{ Endpoint: oauth2.Endpoint{
AuthURL: providers.Config.GenericAuthURL, AuthURL: config.GenericAuthURL,
TokenURL: providers.Config.GenericTokenURL, TokenURL: config.GenericTokenURL,
}, },
}, providers.Config.GenericSkipSSL) }, config.GenericSkipSSL)
// Initialize the oauth provider
providers.Generic.Init()
} }
return providers
} }
func (providers *Providers) GetProvider(provider string) *oauth.OAuth { func (providers *Providers) GetProvider(provider string) *oauth.OAuth {

View File

@@ -1,4 +1,4 @@
package api package server
import ( import (
"fmt" "fmt"
@@ -15,20 +15,13 @@ import (
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
) )
func NewAPI(config types.APIConfig, handlers *handlers.Handlers) *API { type Server struct {
return &API{ Config types.ServerConfig
Config: config,
Handlers: handlers,
}
}
type API struct {
Config types.APIConfig
Router *gin.Engine
Handlers *handlers.Handlers Handlers *handlers.Handlers
Router *gin.Engine
} }
func (api *API) Init() { func NewServer(config types.ServerConfig, handlers *handlers.Handlers) (*Server, error) {
// Disable gin logs // Disable gin logs
gin.SetMode(gin.ReleaseMode) gin.SetMode(gin.ReleaseMode)
@@ -42,7 +35,7 @@ func (api *API) Init() {
dist, err := fs.Sub(assets.Assets, "dist") dist, err := fs.Sub(assets.Assets, "dist")
if err != nil { if err != nil {
log.Fatal().Err(err).Msg("Failed to get UI assets") return nil, err
} }
// Create file server // Create file server
@@ -69,41 +62,38 @@ func (api *API) Init() {
} }
}) })
// Set router // Proxy routes
api.Router = router router.GET("/api/auth/:proxy", handlers.AuthHandler)
// Auth routes
router.POST("/api/login", handlers.LoginHandler)
router.POST("/api/totp", handlers.TotpHandler)
router.POST("/api/logout", handlers.LogoutHandler)
// Context routes
router.GET("/api/app", handlers.AppHandler)
router.GET("/api/user", handlers.UserHandler)
// OAuth routes
router.GET("/api/oauth/url/:provider", handlers.OauthUrlHandler)
router.GET("/api/oauth/callback/:provider", handlers.OauthCallbackHandler)
// App routes
router.GET("/api/healthcheck", handlers.HealthcheckHandler)
// Return the server
return &Server{
Config: config,
Handlers: handlers,
Router: router,
}, nil
} }
func (api *API) SetupRoutes() { func (s *Server) Start() error {
// Proxy
api.Router.GET("/api/auth/:proxy", api.Handlers.AuthHandler)
// Auth
api.Router.POST("/api/login", api.Handlers.LoginHandler)
api.Router.POST("/api/totp", api.Handlers.TotpHandler)
api.Router.POST("/api/logout", api.Handlers.LogoutHandler)
// Context
api.Router.GET("/api/app", api.Handlers.AppHandler)
api.Router.GET("/api/user", api.Handlers.UserHandler)
// OAuth
api.Router.GET("/api/oauth/url/:provider", api.Handlers.OauthUrlHandler)
api.Router.GET("/api/oauth/callback/:provider", api.Handlers.OauthCallbackHandler)
// App
api.Router.GET("/api/healthcheck", api.Handlers.HealthcheckHandler)
}
func (api *API) Run() {
log.Info().Str("address", api.Config.Address).Int("port", api.Config.Port).Msg("Starting server")
// Run server // Run server
err := api.Router.Run(fmt.Sprintf("%s:%d", api.Config.Address, api.Config.Port)) log.Info().Str("address", s.Config.Address).Int("port", s.Config.Port).Msg("Starting server")
// Check for errors return s.Router.Run(fmt.Sprintf("%s:%d", s.Config.Address, s.Config.Port))
if err != nil {
log.Fatal().Err(err).Msg("Failed to start server")
}
} }
// zerolog is a middleware for gin that logs requests using zerolog // zerolog is a middleware for gin that logs requests using zerolog

View File

@@ -1,4 +1,4 @@
package api_test package server_test
import ( import (
"encoding/json" "encoding/json"
@@ -8,19 +8,19 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"tinyauth/internal/api"
"tinyauth/internal/auth" "tinyauth/internal/auth"
"tinyauth/internal/docker" "tinyauth/internal/docker"
"tinyauth/internal/handlers" "tinyauth/internal/handlers"
"tinyauth/internal/hooks" "tinyauth/internal/hooks"
"tinyauth/internal/providers" "tinyauth/internal/providers"
"tinyauth/internal/server"
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/magiconair/properties/assert" "github.com/magiconair/properties/assert"
) )
// Simple API config for tests // Simple server config for tests
var apiConfig = types.APIConfig{ var serverConfig = types.ServerConfig{
Port: 8080, Port: 8080,
Address: "0.0.0.0", Address: "0.0.0.0",
} }
@@ -44,7 +44,8 @@ var handlersConfig = types.HandlersConfig{
var authConfig = types.AuthConfig{ var authConfig = types.AuthConfig{
Users: types.Users{}, Users: types.Users{},
OauthWhitelist: "", OauthWhitelist: "",
Secret: "super-secret-api-thing-for-tests", // It is 32 chars long HMACSecret: "super-secret-api-thing-for-test1",
EncryptionSecret: "super-secret-api-thing-for-test2",
CookieSecure: false, CookieSecure: false,
SessionExpiry: 3600, SessionExpiry: 3600,
LoginTimeout: 0, LoginTimeout: 0,
@@ -67,15 +68,11 @@ var user = types.User{
Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass
} }
// We need all this to be able to test the API // We need all this to be able to test the server
func getAPI(t *testing.T) *api.API { func getServer(t *testing.T) *server.Server {
// Create docker service // Create docker service
docker := docker.NewDocker() docker, err := docker.NewDocker()
// Initialize docker
err := docker.Init()
// Check if there was an error
if err != nil { if err != nil {
t.Fatalf("Failed to initialize docker: %v", err) t.Fatalf("Failed to initialize docker: %v", err)
} }
@@ -92,31 +89,29 @@ func getAPI(t *testing.T) *api.API {
// Create providers service // Create providers service
providers := providers.NewProviders(types.OAuthConfig{}) providers := providers.NewProviders(types.OAuthConfig{})
// Initialize providers
providers.Init()
// Create hooks service // Create hooks service
hooks := hooks.NewHooks(hooksConfig, auth, providers) hooks := hooks.NewHooks(hooksConfig, auth, providers)
// Create handlers service // Create handlers service
handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker) handlers := handlers.NewHandlers(handlersConfig, auth, hooks, providers, docker)
// Create API // Create server
api := api.NewAPI(apiConfig, handlers) srv, err := server.NewServer(serverConfig, handlers)
// Setup routes if err != nil {
api.Init() t.Fatalf("Failed to create server: %v", err)
api.SetupRoutes()
return api
} }
// Test login (we will need this for the other tests) // Return the server
return srv
}
// Test login
func TestLogin(t *testing.T) { func TestLogin(t *testing.T) {
t.Log("Testing login") t.Log("Testing login")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -161,8 +156,8 @@ func TestLogin(t *testing.T) {
func TestAppContext(t *testing.T) { func TestAppContext(t *testing.T) {
t.Log("Testing app context") t.Log("Testing app context")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -229,8 +224,8 @@ func TestAppContext(t *testing.T) {
func TestUserContext(t *testing.T) { func TestUserContext(t *testing.T) {
t.Log("Testing user context") t.Log("Testing user context")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -287,8 +282,8 @@ func TestUserContext(t *testing.T) {
func TestLogout(t *testing.T) { func TestLogout(t *testing.T) {
t.Log("Testing logout") t.Log("Testing logout")
// Get API // Get server
api := getAPI(t) api := getServer(t)
// Create recorder // Create recorder
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -318,5 +313,3 @@ func TestLogout(t *testing.T) {
t.Fatalf("Cookie not flushed") t.Fatalf("Cookie not flushed")
} }
} }
// TODO: Testing for the oauth stuff

View File

@@ -69,8 +69,8 @@ type OAuthConfig struct {
AppURL string AppURL string
} }
// APIConfig is the configuration for the API // ServerConfig is the configuration for the server
type APIConfig struct { type ServerConfig struct {
Port int Port int
Address string Address string
} }
@@ -80,12 +80,13 @@ type AuthConfig struct {
Users Users Users Users
OauthWhitelist string OauthWhitelist string
SessionExpiry int SessionExpiry int
Secret string
CookieSecure bool CookieSecure bool
Domain string Domain string
LoginTimeout int LoginTimeout int
LoginMaxRetries int LoginMaxRetries int
SessionCookieName string SessionCookieName string
HMACSecret string
EncryptionSecret string
} }
// HooksConfig is the configuration for the hooks service // HooksConfig is the configuration for the hooks service

View File

@@ -1,8 +1,11 @@
package utils package utils
import ( import (
"bytes"
"crypto/sha256"
"encoding/base64" "encoding/base64"
"errors" "errors"
"io"
"net" "net"
"net/url" "net/url"
"os" "os"
@@ -11,6 +14,7 @@ import (
"tinyauth/internal/types" "tinyauth/internal/types"
"github.com/traefik/paerser/parser" "github.com/traefik/paerser/parser"
"golang.org/x/crypto/hkdf"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@@ -405,3 +409,32 @@ func FilterIP(filter string, ip string) (bool, error) {
// If the filter is not a CIDR range or a single IP, return false // If the filter is not a CIDR range or a single IP, return false
return false, nil return false, nil
} }
func DeriveKey(secret string, info string) (string, error) {
// Create hashing function
hash := sha256.New
// Create a new key using the secret and info
hkdf := hkdf.New(hash, []byte(secret), nil, []byte(info)) // I am not using a salt because I just want two different keys from one secret, maybe bad practice
// Create a new key
key := make([]byte, 24)
// Read the key from the HKDF
_, err := io.ReadFull(hkdf, key)
if err != nil {
return "", err
}
// Verify the key is not empty
if bytes.Equal(key, make([]byte, 24)) {
return "", errors.New("derived key is empty")
}
// Encode the key to base64
encodedKey := base64.StdEncoding.EncodeToString(key)
// Return the key as a base64 encoded string
return encodedKey, nil
}