mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2025-12-25 17:42:30 +00:00
Compare commits
3 Commits
pushpinder
...
refactor/s
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c8e86d8536 | ||
|
|
7269fa1b95 | ||
|
|
ef25872fc3 |
3
.coderabbit.yaml
Normal file
3
.coderabbit.yaml
Normal file
@@ -0,0 +1,3 @@
|
||||
issue_enrichment:
|
||||
auto_enrich:
|
||||
enabled: false
|
||||
@@ -23,7 +23,7 @@ cd tinyauth
|
||||
Although you will not need the requirements in your machine since the development will happen in docker, I still recommend to install them because this way you will not have import errors. To install the go requirements run:
|
||||
|
||||
```sh
|
||||
go mod tidy
|
||||
go mod download
|
||||
```
|
||||
|
||||
You also need to download the frontend dependencies, this can be done like so:
|
||||
|
||||
@@ -42,7 +42,6 @@ services:
|
||||
volumes:
|
||||
- ./internal:/tinyauth/internal
|
||||
- ./cmd:/tinyauth/cmd
|
||||
- ./main.go:/tinyauth/main.go
|
||||
- /var/run/docker.sock:/var/run/docker.sock
|
||||
- ./data:/data
|
||||
ports:
|
||||
|
||||
7
go.mod
7
go.mod
@@ -7,7 +7,6 @@ toolchain go1.24.3
|
||||
require (
|
||||
github.com/cenkalti/backoff/v5 v5.0.3
|
||||
github.com/gin-gonic/gin v1.11.0
|
||||
github.com/glebarez/sqlite v1.11.0
|
||||
github.com/golang-migrate/migrate/v4 v4.19.1
|
||||
github.com/google/go-querystring v1.1.0
|
||||
github.com/google/uuid v1.6.0
|
||||
@@ -17,8 +16,8 @@ require (
|
||||
github.com/weppos/publicsuffix-go v0.50.1
|
||||
golang.org/x/crypto v0.46.0
|
||||
golang.org/x/exp v0.0.0-20250620022241-b7579e27df2b
|
||||
gorm.io/gorm v1.31.1
|
||||
gotest.tools/v3 v3.5.2
|
||||
modernc.org/sqlite v1.38.2
|
||||
)
|
||||
|
||||
require (
|
||||
@@ -32,15 +31,12 @@ require (
|
||||
github.com/containerd/errdefs v1.0.0 // indirect
|
||||
github.com/containerd/errdefs/pkg v0.3.0 // indirect
|
||||
github.com/containerd/log v0.1.0 // indirect
|
||||
github.com/glebarez/go-sqlite v1.21.2 // indirect
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 // indirect
|
||||
github.com/go-playground/validator/v10 v10.28.0 // indirect
|
||||
github.com/goccy/go-yaml v1.18.0 // indirect
|
||||
github.com/google/go-cmp v0.7.0 // indirect
|
||||
github.com/huandu/xstrings v1.5.0 // indirect
|
||||
github.com/imdario/mergo v0.3.11 // indirect
|
||||
github.com/jinzhu/inflection v1.0.0 // indirect
|
||||
github.com/jinzhu/now v1.1.5 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.32 // indirect
|
||||
github.com/mitchellh/copystructure v1.2.0 // indirect
|
||||
github.com/mitchellh/reflectwalk v1.0.2 // indirect
|
||||
@@ -59,7 +55,6 @@ require (
|
||||
modernc.org/libc v1.66.3 // indirect
|
||||
modernc.org/mathutil v1.7.1 // indirect
|
||||
modernc.org/memory v1.11.0 // indirect
|
||||
modernc.org/sqlite v1.38.2 // indirect
|
||||
rsc.io/qr v0.2.0 // indirect
|
||||
)
|
||||
|
||||
|
||||
10
go.sum
10
go.sum
@@ -101,10 +101,6 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w
|
||||
github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM=
|
||||
github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk=
|
||||
github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls=
|
||||
github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo=
|
||||
github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k=
|
||||
github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw=
|
||||
github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
|
||||
github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
|
||||
github.com/go-ldap/ldap/v3 v3.4.12 h1:1b81mv7MagXZ7+1r7cLTWmyuTqVqdwbtJSjC0DAp9s4=
|
||||
@@ -161,10 +157,6 @@ github.com/jcmturner/gokrb5/v8 v8.4.4 h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh6
|
||||
github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3 h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
|
||||
github.com/jcmturner/rpc/v2 v2.0.3/go.mod h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
|
||||
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||
github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y=
|
||||
@@ -378,8 +370,6 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg=
|
||||
gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs=
|
||||
gotest.tools/v3 v3.5.2 h1:7koQfIKdy+I8UTetycgUqXWSDwpgv193Ka+qRsmBY8Q=
|
||||
gotest.tools/v3 v3.5.2/go.mod h1:LtdLGcnqToBH83WByAAi/wiwSFCArdFIUV/xxN4pcjA=
|
||||
modernc.org/cc/v4 v4.26.2 h1:991HMkLjJzYBIfha6ECZdjrIYz2/1ayr+FL8GN+CNzM=
|
||||
|
||||
@@ -13,11 +13,10 @@ import (
|
||||
"time"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/controller"
|
||||
"tinyauth/internal/model"
|
||||
"tinyauth/internal/repository"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type BootstrapApp struct {
|
||||
@@ -107,8 +106,18 @@ func (app *BootstrapApp) Setup() error {
|
||||
log.Trace().Str("csrfCookieName", app.context.csrfCookieName).Msg("CSRF cookie name")
|
||||
log.Trace().Str("redirectCookieName", app.context.redirectCookieName).Msg("Redirect cookie name")
|
||||
|
||||
// Database
|
||||
db, err := app.setupDatabase(app.config.DatabasePath)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to setup database: %w", err)
|
||||
}
|
||||
|
||||
// Queries
|
||||
queries := repository.New(db)
|
||||
|
||||
// Services
|
||||
services, err := app.initServices()
|
||||
services, err := app.initServices(queries)
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to initialize services: %w", err)
|
||||
@@ -154,9 +163,9 @@ func (app *BootstrapApp) Setup() error {
|
||||
return fmt.Errorf("failed to setup routes: %w", err)
|
||||
}
|
||||
|
||||
// Start DB cleanup routine
|
||||
// Start db cleanup routine
|
||||
log.Debug().Msg("Starting database cleanup routine")
|
||||
go app.dbCleanup(services.databaseService.GetDatabase())
|
||||
go app.dbCleanup(queries)
|
||||
|
||||
// If analytics are not disabled, start heartbeat
|
||||
if !app.config.DisableAnalytics {
|
||||
@@ -246,16 +255,16 @@ func (app *BootstrapApp) heartbeat() {
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) dbCleanup(db *gorm.DB) {
|
||||
func (app *BootstrapApp) dbCleanup(queries *repository.Queries) {
|
||||
ticker := time.NewTicker(time.Duration(30) * time.Minute)
|
||||
defer ticker.Stop()
|
||||
ctx := context.Background()
|
||||
|
||||
for ; true; <-ticker.C {
|
||||
log.Debug().Msg("Cleaning up old database sessions")
|
||||
_, err := gorm.G[model.Session](db).Where("expiry < ?", time.Now().Unix()).Delete(ctx)
|
||||
err := queries.DeleteExpiredSessions(ctx, time.Now().Unix())
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to cleanup old sessions")
|
||||
log.Error().Err(err).Msg("Failed to clean up old database sessions")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
52
internal/bootstrap/db_bootstrap.go
Normal file
52
internal/bootstrap/db_bootstrap.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"tinyauth/internal/assets"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
"github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
_ "modernc.org/sqlite"
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) setupDatabase(databasePath string) (*sql.DB, error) {
|
||||
dir := filepath.Dir(databasePath)
|
||||
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return nil, fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite", databasePath)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to open database: %w", err)
|
||||
}
|
||||
|
||||
migrations, err := iofs.New(assets.Migrations, "migrations")
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migrations: %w", err)
|
||||
}
|
||||
|
||||
target, err := sqlite3.WithInstance(db, &sqlite3.Config{})
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create sqlite3 instance: %w", err)
|
||||
}
|
||||
|
||||
migrator, err := migrate.NewWithInstance("iofs", migrations, "sqlite3", target)
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create migrator: %w", err)
|
||||
}
|
||||
|
||||
if err := migrator.Up(); err != nil && err != migrate.ErrNoChange {
|
||||
return nil, fmt.Errorf("failed to migrate database: %w", err)
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package bootstrap
|
||||
|
||||
import (
|
||||
"tinyauth/internal/repository"
|
||||
"tinyauth/internal/service"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -9,27 +10,14 @@ import (
|
||||
type Services struct {
|
||||
accessControlService *service.AccessControlsService
|
||||
authService *service.AuthService
|
||||
databaseService *service.DatabaseService
|
||||
dockerService *service.DockerService
|
||||
ldapService *service.LdapService
|
||||
oauthBrokerService *service.OAuthBrokerService
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) initServices() (Services, error) {
|
||||
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
||||
services := Services{}
|
||||
|
||||
databaseService := service.NewDatabaseService(service.DatabaseServiceConfig{
|
||||
DatabasePath: app.config.DatabasePath,
|
||||
})
|
||||
|
||||
err := databaseService.Init()
|
||||
|
||||
if err != nil {
|
||||
return Services{}, err
|
||||
}
|
||||
|
||||
services.databaseService = databaseService
|
||||
|
||||
ldapService := service.NewLdapService(service.LdapServiceConfig{
|
||||
Address: app.config.Ldap.Address,
|
||||
BindDN: app.config.Ldap.BindDN,
|
||||
@@ -39,7 +27,7 @@ func (app *BootstrapApp) initServices() (Services, error) {
|
||||
SearchFilter: app.config.Ldap.SearchFilter,
|
||||
})
|
||||
|
||||
err = ldapService.Init()
|
||||
err := ldapService.Init()
|
||||
|
||||
if err == nil {
|
||||
services.ldapService = ldapService
|
||||
@@ -76,7 +64,7 @@ func (app *BootstrapApp) initServices() (Services, error) {
|
||||
LoginTimeout: app.config.Auth.LoginTimeout,
|
||||
LoginMaxRetries: app.config.Auth.LoginMaxRetries,
|
||||
SessionCookieName: app.context.sessionCookieName,
|
||||
}, dockerService, ldapService, databaseService.GetDatabase())
|
||||
}, dockerService, ldapService, queries)
|
||||
|
||||
err = authService.Init()
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/service"
|
||||
@@ -13,6 +14,8 @@ import (
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var SupportedProxies = []string{"nginx", "traefik", "caddy", "envoy"}
|
||||
|
||||
type Proxy struct {
|
||||
Proxy string `uri:"proxy" binding:"required"`
|
||||
}
|
||||
@@ -39,7 +42,7 @@ func NewProxyController(config ProxyControllerConfig, router *gin.RouterGroup, a
|
||||
|
||||
func (controller *ProxyController) SetupRoutes() {
|
||||
proxyGroup := controller.router.Group("/auth")
|
||||
proxyGroup.GET("/:proxy", controller.proxyHandler)
|
||||
proxyGroup.Any("/:proxy", controller.proxyHandler)
|
||||
}
|
||||
|
||||
func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
@@ -55,7 +58,7 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Proxy != "nginx" && req.Proxy != "traefik" && req.Proxy != "caddy" {
|
||||
if !slices.Contains(SupportedProxies, req.Proxy) {
|
||||
log.Warn().Str("proxy", req.Proxy).Msg("Invalid proxy")
|
||||
c.JSON(400, gin.H{
|
||||
"status": 400,
|
||||
@@ -64,6 +67,15 @@ func (controller *ProxyController) proxyHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
if req.Proxy != "envoy" && c.Request.Method != http.MethodGet {
|
||||
log.Warn().Str("method", c.Request.Method).Msg("Invalid method for proxy")
|
||||
c.JSON(405, gin.H{
|
||||
"status": 405,
|
||||
"message": "Method Not Allowed",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
isBrowser := strings.Contains(c.Request.Header.Get("Accept"), "text/html")
|
||||
|
||||
if isBrowser {
|
||||
|
||||
@@ -80,6 +80,13 @@ func TestProxyHandler(t *testing.T) {
|
||||
|
||||
assert.Equal(t, 400, recorder.Code)
|
||||
|
||||
// Test invalid method
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/auth/traefik", nil)
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 405, recorder.Code)
|
||||
|
||||
// Test logged out user (traefik/caddy)
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/auth/traefik", nil)
|
||||
@@ -92,6 +99,18 @@ func TestProxyHandler(t *testing.T) {
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
|
||||
|
||||
// Test logged out user (envoy)
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("POST", "/api/auth/envoy", nil)
|
||||
req.Header.Set("X-Forwarded-Proto", "https")
|
||||
req.Header.Set("X-Forwarded-Host", "example.com")
|
||||
req.Header.Set("X-Forwarded-Uri", "/somepath")
|
||||
req.Header.Set("Accept", "text/html")
|
||||
router.ServeHTTP(recorder, req)
|
||||
|
||||
assert.Equal(t, 307, recorder.Code)
|
||||
assert.Equal(t, "http://localhost:8080/login?redirect_uri=https%3A%2F%2Fexample.com%2Fsomepath", recorder.Header().Get("Location"))
|
||||
|
||||
// Test logged out user (nginx)
|
||||
recorder = httptest.NewRecorder()
|
||||
req = httptest.NewRequest("GET", "/api/auth/nginx", nil)
|
||||
|
||||
@@ -1,13 +0,0 @@
|
||||
package model
|
||||
|
||||
type Session struct {
|
||||
UUID string `gorm:"column:uuid;primaryKey"`
|
||||
Username string `gorm:"column:username"`
|
||||
Email string `gorm:"column:email"`
|
||||
Name string `gorm:"column:name"`
|
||||
Provider string `gorm:"column:provider"`
|
||||
TOTPPending bool `gorm:"column:totp_pending"`
|
||||
OAuthGroups string `gorm:"column:oauth_groups"`
|
||||
Expiry int64 `gorm:"column:expiry"`
|
||||
OAuthName string `gorm:"column:oauth_name"`
|
||||
}
|
||||
31
internal/repository/db.go
Normal file
31
internal/repository/db.go
Normal file
@@ -0,0 +1,31 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func New(db DBTX) *Queries {
|
||||
return &Queries{db: db}
|
||||
}
|
||||
|
||||
type Queries struct {
|
||||
db DBTX
|
||||
}
|
||||
|
||||
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
|
||||
return &Queries{
|
||||
db: tx,
|
||||
}
|
||||
}
|
||||
17
internal/repository/models.go
Normal file
17
internal/repository/models.go
Normal file
@@ -0,0 +1,17 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
|
||||
package repository
|
||||
|
||||
type Session struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
}
|
||||
161
internal/repository/query.sql.go
Normal file
161
internal/repository/query.sql.go
Normal file
@@ -0,0 +1,161 @@
|
||||
// Code generated by sqlc. DO NOT EDIT.
|
||||
// versions:
|
||||
// sqlc v1.30.0
|
||||
// source: query.sql
|
||||
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
const createSession = `-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
"uuid",
|
||||
"username",
|
||||
"email",
|
||||
"name",
|
||||
"provider",
|
||||
"totp_pending",
|
||||
"oauth_groups",
|
||||
"expiry",
|
||||
"oauth_name"
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name
|
||||
`
|
||||
|
||||
type CreateSessionParams struct {
|
||||
UUID string
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
}
|
||||
|
||||
func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (Session, error) {
|
||||
row := q.db.QueryRowContext(ctx, createSession,
|
||||
arg.UUID,
|
||||
arg.Username,
|
||||
arg.Email,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.TotpPending,
|
||||
arg.OAuthGroups,
|
||||
arg.Expiry,
|
||||
arg.OAuthName,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.TotpPending,
|
||||
&i.OAuthGroups,
|
||||
&i.Expiry,
|
||||
&i.OAuthName,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const deleteExpiredSessions = `-- name: DeleteExpiredSessions :exec
|
||||
DELETE FROM "sessions"
|
||||
WHERE "expiry" < ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteExpiredSessions, expiry)
|
||||
return err
|
||||
}
|
||||
|
||||
const deleteSession = `-- name: DeleteSession :exec
|
||||
DELETE FROM "sessions"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) DeleteSession(ctx context.Context, uuid string) error {
|
||||
_, err := q.db.ExecContext(ctx, deleteSession, uuid)
|
||||
return err
|
||||
}
|
||||
|
||||
const getSession = `-- name: GetSession :one
|
||||
SELECT uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name FROM "sessions"
|
||||
WHERE "uuid" = ?
|
||||
`
|
||||
|
||||
func (q *Queries) GetSession(ctx context.Context, uuid string) (Session, error) {
|
||||
row := q.db.QueryRowContext(ctx, getSession, uuid)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.TotpPending,
|
||||
&i.OAuthGroups,
|
||||
&i.Expiry,
|
||||
&i.OAuthName,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
|
||||
const updateSession = `-- name: UpdateSession :one
|
||||
UPDATE "sessions" SET
|
||||
"username" = ?,
|
||||
"email" = ?,
|
||||
"name" = ?,
|
||||
"provider" = ?,
|
||||
"totp_pending" = ?,
|
||||
"oauth_groups" = ?,
|
||||
"expiry" = ?,
|
||||
"oauth_name" = ?
|
||||
WHERE "uuid" = ?
|
||||
RETURNING uuid, username, email, name, provider, totp_pending, oauth_groups, expiry, oauth_name
|
||||
`
|
||||
|
||||
type UpdateSessionParams struct {
|
||||
Username string
|
||||
Email string
|
||||
Name string
|
||||
Provider string
|
||||
TotpPending bool
|
||||
OAuthGroups string
|
||||
Expiry int64
|
||||
OAuthName string
|
||||
UUID string
|
||||
}
|
||||
|
||||
func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
|
||||
row := q.db.QueryRowContext(ctx, updateSession,
|
||||
arg.Username,
|
||||
arg.Email,
|
||||
arg.Name,
|
||||
arg.Provider,
|
||||
arg.TotpPending,
|
||||
arg.OAuthGroups,
|
||||
arg.Expiry,
|
||||
arg.OAuthName,
|
||||
arg.UUID,
|
||||
)
|
||||
var i Session
|
||||
err := row.Scan(
|
||||
&i.UUID,
|
||||
&i.Username,
|
||||
&i.Email,
|
||||
&i.Name,
|
||||
&i.Provider,
|
||||
&i.TotpPending,
|
||||
&i.OAuthGroups,
|
||||
&i.Expiry,
|
||||
&i.OAuthName,
|
||||
)
|
||||
return i, err
|
||||
}
|
||||
@@ -2,6 +2,7 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
@@ -9,14 +10,13 @@ import (
|
||||
"sync"
|
||||
"time"
|
||||
"tinyauth/internal/config"
|
||||
"tinyauth/internal/model"
|
||||
"tinyauth/internal/repository"
|
||||
"tinyauth/internal/utils"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type LoginAttempt struct {
|
||||
@@ -42,17 +42,17 @@ type AuthService struct {
|
||||
loginAttempts map[string]*LoginAttempt
|
||||
loginMutex sync.RWMutex
|
||||
ldap *LdapService
|
||||
database *gorm.DB
|
||||
queries *repository.Queries
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, database *gorm.DB) *AuthService {
|
||||
func NewAuthService(config AuthServiceConfig, docker *DockerService, ldap *LdapService, queries *repository.Queries) *AuthService {
|
||||
return &AuthService{
|
||||
config: config,
|
||||
docker: docker,
|
||||
loginAttempts: make(map[string]*LoginAttempt),
|
||||
ldap: ldap,
|
||||
database: database,
|
||||
queries: queries,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,19 +205,19 @@ func (auth *AuthService) CreateSessionCookie(c *gin.Context, data *config.Sessio
|
||||
expiry = auth.config.SessionExpiry
|
||||
}
|
||||
|
||||
session := model.Session{
|
||||
session := repository.CreateSessionParams{
|
||||
UUID: uuid.String(),
|
||||
Username: data.Username,
|
||||
Email: data.Email,
|
||||
Name: data.Name,
|
||||
Provider: data.Provider,
|
||||
TOTPPending: data.TotpPending,
|
||||
TotpPending: data.TotpPending,
|
||||
OAuthGroups: data.OAuthGroups,
|
||||
Expiry: time.Now().Add(time.Duration(expiry) * time.Second).Unix(),
|
||||
OAuthName: data.OAuthName,
|
||||
}
|
||||
|
||||
err = gorm.G[model.Session](auth.database).Create(auth.ctx, &session)
|
||||
_, err = auth.queries.CreateSession(c, session)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -235,7 +235,7 @@ func (auth *AuthService) DeleteSessionCookie(c *gin.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(auth.ctx)
|
||||
err = auth.queries.DeleteSession(auth.ctx, cookie)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -253,20 +253,20 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
|
||||
return config.SessionCookie{}, err
|
||||
}
|
||||
|
||||
session, err := gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).First(auth.ctx)
|
||||
session, err := auth.queries.GetSession(auth.ctx, cookie)
|
||||
|
||||
if err != nil {
|
||||
return config.SessionCookie{}, err
|
||||
}
|
||||
|
||||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return config.SessionCookie{}, fmt.Errorf("session not found")
|
||||
}
|
||||
|
||||
currentTime := time.Now().Unix()
|
||||
|
||||
if currentTime > session.Expiry {
|
||||
_, err = gorm.G[model.Session](auth.database).Where("uuid = ?", cookie).Delete(auth.ctx)
|
||||
err = auth.queries.DeleteSession(auth.ctx, cookie)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to delete expired session")
|
||||
}
|
||||
@@ -279,7 +279,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (config.SessionCookie,
|
||||
Email: session.Email,
|
||||
Name: session.Name,
|
||||
Provider: session.Provider,
|
||||
TotpPending: session.TOTPPending,
|
||||
TotpPending: session.TotpPending,
|
||||
OAuthGroups: session.OAuthGroups,
|
||||
OAuthName: session.OAuthName,
|
||||
}, nil
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
package service
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"tinyauth/internal/assets"
|
||||
|
||||
"github.com/glebarez/sqlite"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
sqliteMigrate "github.com/golang-migrate/migrate/v4/database/sqlite3"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type DatabaseServiceConfig struct {
|
||||
DatabasePath string
|
||||
}
|
||||
|
||||
type DatabaseService struct {
|
||||
config DatabaseServiceConfig
|
||||
database *gorm.DB
|
||||
}
|
||||
|
||||
func NewDatabaseService(config DatabaseServiceConfig) *DatabaseService {
|
||||
return &DatabaseService{
|
||||
config: config,
|
||||
}
|
||||
}
|
||||
|
||||
func (ds *DatabaseService) Init() error {
|
||||
dbPath := ds.config.DatabasePath
|
||||
if dbPath == "" {
|
||||
dbPath = "/data/tinyauth.db"
|
||||
}
|
||||
|
||||
dir := filepath.Dir(dbPath)
|
||||
if err := os.MkdirAll(dir, 0755); err != nil {
|
||||
return fmt.Errorf("failed to create database directory %s: %w", dir, err)
|
||||
}
|
||||
|
||||
gormDB, err := gorm.Open(sqlite.Open(dbPath), &gorm.Config{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB, err := gormDB.DB()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sqlDB.SetMaxOpenConns(1)
|
||||
|
||||
err = ds.migrateDatabase(sqlDB)
|
||||
|
||||
if err != nil && err != migrate.ErrNoChange {
|
||||
return err
|
||||
}
|
||||
|
||||
ds.database = gormDB
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ds *DatabaseService) migrateDatabase(sqlDB *sql.DB) error {
|
||||
data, err := iofs.New(assets.Migrations, "migrations")
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
target, err := sqliteMigrate.WithInstance(sqlDB, &sqliteMigrate.Config{})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
migrator, err := migrate.NewWithInstance("iofs", data, "tinyauth", target)
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return migrator.Up()
|
||||
}
|
||||
|
||||
func (ds *DatabaseService) GetDatabase() *gorm.DB {
|
||||
return ds.database
|
||||
}
|
||||
40
query.sql
Normal file
40
query.sql
Normal file
@@ -0,0 +1,40 @@
|
||||
-- name: CreateSession :one
|
||||
INSERT INTO sessions (
|
||||
"uuid",
|
||||
"username",
|
||||
"email",
|
||||
"name",
|
||||
"provider",
|
||||
"totp_pending",
|
||||
"oauth_groups",
|
||||
"expiry",
|
||||
"oauth_name"
|
||||
) VALUES (
|
||||
?, ?, ?, ?, ?, ?, ?, ?, ?
|
||||
)
|
||||
RETURNING *;
|
||||
|
||||
-- name: GetSession :one
|
||||
SELECT * FROM "sessions"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
-- name: DeleteSession :exec
|
||||
DELETE FROM "sessions"
|
||||
WHERE "uuid" = ?;
|
||||
|
||||
-- name: UpdateSession :one
|
||||
UPDATE "sessions" SET
|
||||
"username" = ?,
|
||||
"email" = ?,
|
||||
"name" = ?,
|
||||
"provider" = ?,
|
||||
"totp_pending" = ?,
|
||||
"oauth_groups" = ?,
|
||||
"expiry" = ?,
|
||||
"oauth_name" = ?
|
||||
WHERE "uuid" = ?
|
||||
RETURNING *;
|
||||
|
||||
-- name: DeleteExpiredSessions :exec
|
||||
DELETE FROM "sessions"
|
||||
WHERE "expiry" < ?;
|
||||
11
schema.sql
Normal file
11
schema.sql
Normal file
@@ -0,0 +1,11 @@
|
||||
CREATE TABLE IF NOT EXISTS "sessions" (
|
||||
"uuid" TEXT NOT NULL PRIMARY KEY UNIQUE,
|
||||
"username" TEXT NOT NULL,
|
||||
"email" TEXT NOT NULL,
|
||||
"name" TEXT NOT NULL,
|
||||
"provider" TEXT NOT NULL,
|
||||
"totp_pending" BOOLEAN NOT NULL,
|
||||
"oauth_groups" TEXT NULL,
|
||||
"expiry" INTEGER NOT NULL,
|
||||
"oauth_name" TEXT NULL
|
||||
);
|
||||
18
sqlc.yml
Normal file
18
sqlc.yml
Normal file
@@ -0,0 +1,18 @@
|
||||
version: "2"
|
||||
sql:
|
||||
- engine: "sqlite"
|
||||
queries: "query.sql"
|
||||
schema: "schema.sql"
|
||||
gen:
|
||||
go:
|
||||
package: "repository"
|
||||
out: "internal/repository"
|
||||
rename:
|
||||
uuid: "UUID"
|
||||
oauth_groups: "OAuthGroups"
|
||||
oauth_name: "OAuthName"
|
||||
overrides:
|
||||
- column: "sessions.oauth_groups"
|
||||
go_type: "string"
|
||||
- column: "sessions.oauth_name"
|
||||
go_type: "string"
|
||||
Reference in New Issue
Block a user