mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-05-04 19:38:12 +00:00
feat(db): add memory storage driver
removes the sqlite dependency for tests, also brings back the option for users to run zero persistence instances of tinyauth. adds new mapErr fn for sqlc wrapper gen to prevent sql errors from leaking out of the store implementation.
This commit is contained in:
@@ -449,18 +449,18 @@ func buildBody(m methodInfo) string {
|
||||
|
||||
// no repo-typed result → direct return
|
||||
if len(m.Results) == 0 || m.Results[0].RepoType == "" {
|
||||
return "\treturn " + call + "\n"
|
||||
return "\treturn mapErr(" + call + ")\n"
|
||||
}
|
||||
|
||||
r := m.Results[0]
|
||||
if r.IsSlice {
|
||||
return fmt.Sprintf(
|
||||
"\trows, err := %s\n\tif err != nil {\n\t\treturn nil, err\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n",
|
||||
"\trows, err := %s\n\tif err != nil {\n\t\treturn nil, mapErr(err)\n\t}\n\tout := make([]%s, len(rows))\n\tfor i, row := range rows {\n\t\tout[i] = %s(row)\n\t}\n\treturn out, nil\n",
|
||||
call, r.RepoType, converterFn(r.TypeStr),
|
||||
)
|
||||
}
|
||||
return fmt.Sprintf(
|
||||
"\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, err\n\t}\n\treturn %s(r), nil\n",
|
||||
"\tr, err := %s\n\tif err != nil {\n\t\treturn %s{}, mapErr(err)\n\t}\n\treturn %s(r), nil\n",
|
||||
call, r.RepoType, converterFn(r.TypeStr),
|
||||
)
|
||||
}
|
||||
@@ -477,6 +477,8 @@ package {{.PkgName}}
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"{{.RepoPkg}}"
|
||||
)
|
||||
@@ -491,6 +493,22 @@ func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errMap = []struct {
|
||||
from error
|
||||
to error
|
||||
}{
|
||||
{sql.ErrNoRows, repository.ErrNotFound},
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for _, e := range errMap {
|
||||
if errors.Is(err, e.from) {
|
||||
return e.to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
{{range .ModelTypes -}}
|
||||
func {{converterFn .}}(v {{.}}) repository.{{.}} {
|
||||
return repository.{{.}}(v)
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/assets"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/sqlite"
|
||||
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
@@ -17,14 +18,14 @@ import (
|
||||
)
|
||||
|
||||
func (app *BootstrapApp) SetupStore() (repository.Store, error) {
|
||||
return app.setupSQLite(app.config.Database.Path)
|
||||
}
|
||||
|
||||
// NewSQLiteStore opens a SQLite database at the given path, runs migrations, and returns a Store.
|
||||
// Useful for testing or when constructing a store outside of a BootstrapApp.
|
||||
func NewSQLiteStore(databasePath string) (repository.Store, error) {
|
||||
app := &BootstrapApp{}
|
||||
return app.setupSQLite(databasePath)
|
||||
switch app.config.Database.Driver {
|
||||
case "memory":
|
||||
return memory.New(), nil
|
||||
case "sqlite", "":
|
||||
return app.setupSQLite(app.config.Database.Path)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown database driver %q: valid values are sqlite, memory", app.config.Database.Driver)
|
||||
}
|
||||
}
|
||||
|
||||
func (app *BootstrapApp) setupSQLite(databasePath string) (repository.Store, error) {
|
||||
|
||||
@@ -4,7 +4,8 @@ package config
|
||||
func NewDefaultConfiguration() *Config {
|
||||
return &Config{
|
||||
Database: DatabaseConfig{
|
||||
Path: "./tinyauth.db",
|
||||
Driver: "sqlite",
|
||||
Path: "./tinyauth.db",
|
||||
},
|
||||
Analytics: AnalyticsConfig{
|
||||
Enabled: true,
|
||||
@@ -95,7 +96,8 @@ type Config struct {
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Path string `description:"The path to the SQLite database, including file name." yaml:"path"`
|
||||
Driver string `description:"The database driver to use. Valid values: sqlite, memory." yaml:"driver"`
|
||||
Path string `description:"The path to the SQLite database, including file name. Only used when driver is sqlite." yaml:"path"`
|
||||
}
|
||||
|
||||
type AnalyticsConfig struct {
|
||||
|
||||
@@ -12,9 +12,9 @@ import (
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/go-querystring/query"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -847,11 +847,10 @@ func TestOIDCController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
store := memory.New()
|
||||
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, store)
|
||||
err = oidcService.Init()
|
||||
err := oidcService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -2,13 +2,12 @@ package controller_test
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -17,7 +16,6 @@ import (
|
||||
|
||||
func TestProxyController(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
Users: []config.User{
|
||||
@@ -392,11 +390,10 @@ func TestProxyController(t *testing.T) {
|
||||
|
||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
store := memory.New()
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
err := docker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
|
||||
@@ -3,16 +3,15 @@ package controller_test
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http/httptest"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/pquerna/otp/totp"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -21,7 +20,6 @@ import (
|
||||
|
||||
func TestUserController(t *testing.T) {
|
||||
tlog.NewTestLogger().Init()
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authServiceCfg := service.AuthServiceConfig{
|
||||
Users: []config.User{
|
||||
@@ -350,11 +348,10 @@ func TestUserController(t *testing.T) {
|
||||
|
||||
oauthBrokerCfgs := make(map[string]config.OAuthServiceConfig)
|
||||
|
||||
store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
store := memory.New()
|
||||
|
||||
docker := service.NewDockerService()
|
||||
err = docker.Init()
|
||||
err := docker.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
ldap := service.NewLdapService(service.LdapServiceConfig{})
|
||||
|
||||
@@ -8,9 +8,9 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tinyauthapp/tinyauth/internal/bootstrap"
|
||||
"github.com/tinyauthapp/tinyauth/internal/config"
|
||||
"github.com/tinyauthapp/tinyauth/internal/controller"
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository/memory"
|
||||
"github.com/tinyauthapp/tinyauth/internal/service"
|
||||
"github.com/tinyauthapp/tinyauth/internal/utils/tlog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@@ -100,11 +100,10 @@ func TestWellKnownController(t *testing.T) {
|
||||
},
|
||||
}
|
||||
|
||||
store, err := bootstrap.NewSQLiteStore(path.Join(tempDir, "tinyauth.db"))
|
||||
require.NoError(t, err)
|
||||
store := memory.New()
|
||||
|
||||
oidcService := service.NewOIDCService(oidcServiceCfg, store)
|
||||
err = oidcService.Init()
|
||||
err := oidcService.Init()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, test := range tests {
|
||||
|
||||
@@ -0,0 +1,241 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateOidcCode(_ context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == arg.Sub {
|
||||
return repository.OidcCode{}, fmt.Errorf("UNIQUE constraint failed: oidc_codes.sub")
|
||||
}
|
||||
}
|
||||
code := repository.OidcCode(arg)
|
||||
s.oidcCodes[arg.CodeHash] = code
|
||||
return code, nil
|
||||
}
|
||||
|
||||
// GetOidcCode is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCode(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySub is a destructive read: it deletes and returns the code (mirrors SQLite's DELETE...RETURNING).
|
||||
func (s *Store) GetOidcCodeBySub(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
// GetOidcCodeUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeUnsafe(_ context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
c, ok := s.oidcCodes[codeHash]
|
||||
if !ok {
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// GetOidcCodeBySubUnsafe is a non-destructive read (mirrors SQLite's SELECT).
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(_ context.Context, sub string) (repository.OidcCode, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
return c, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcCode{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcCodes, codeHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.Sub == sub {
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcCodes(_ context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcCode
|
||||
for k, c := range s.oidcCodes {
|
||||
if c.ExpiresAt < expiresAt {
|
||||
deleted = append(deleted, c)
|
||||
delete(s.oidcCodes, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcToken(_ context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
// Enforce sub UNIQUE constraint
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == arg.Sub {
|
||||
return repository.OidcToken{}, fmt.Errorf("UNIQUE constraint failed: oidc_tokens.sub")
|
||||
}
|
||||
}
|
||||
tok := repository.OidcToken{
|
||||
Sub: arg.Sub,
|
||||
AccessTokenHash: arg.AccessTokenHash,
|
||||
RefreshTokenHash: arg.RefreshTokenHash,
|
||||
CodeHash: arg.CodeHash,
|
||||
Scope: arg.Scope,
|
||||
ClientID: arg.ClientID,
|
||||
TokenExpiresAt: arg.TokenExpiresAt,
|
||||
RefreshTokenExpiresAt: arg.RefreshTokenExpiresAt,
|
||||
Nonce: arg.Nonce,
|
||||
}
|
||||
s.oidcTokens[arg.AccessTokenHash] = tok
|
||||
return tok, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcToken(_ context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
t, ok := s.oidcTokens[accessTokenHash]
|
||||
if !ok {
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenByRefreshToken(_ context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == refreshTokenHash {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcTokenBySub(_ context.Context, sub string) (repository.OidcToken, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
for _, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(_ context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.RefreshTokenHash == arg.RefreshTokenHash_2 {
|
||||
delete(s.oidcTokens, k)
|
||||
t.AccessTokenHash = arg.AccessTokenHash
|
||||
t.RefreshTokenHash = arg.RefreshTokenHash
|
||||
t.TokenExpiresAt = arg.TokenExpiresAt
|
||||
t.RefreshTokenExpiresAt = arg.RefreshTokenExpiresAt
|
||||
s.oidcTokens[arg.AccessTokenHash] = t
|
||||
return t, nil
|
||||
}
|
||||
}
|
||||
return repository.OidcToken{}, repository.ErrNotFound
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(_ context.Context, accessTokenHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcTokens, accessTokenHash)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.Sub == sub {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(_ context.Context, codeHash string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.CodeHash == codeHash {
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredOidcTokens(_ context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
var deleted []repository.OidcToken
|
||||
for k, t := range s.oidcTokens {
|
||||
if t.TokenExpiresAt < arg.TokenExpiresAt || t.RefreshTokenExpiresAt < arg.RefreshTokenExpiresAt {
|
||||
deleted = append(deleted, t)
|
||||
delete(s.oidcTokens, k)
|
||||
}
|
||||
}
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func (s *Store) CreateOidcUserInfo(_ context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
u := repository.OidcUserinfo(arg)
|
||||
s.oidcUsers[arg.Sub] = u
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcUserInfo(_ context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
u, ok := s.oidcUsers[sub]
|
||||
if !ok {
|
||||
return repository.OidcUserinfo{}, repository.ErrNotFound
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(_ context.Context, sub string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.oidcUsers, sub)
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,63 @@
|
||||
package memory
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
func (s *Store) CreateSession(_ context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess := repository.Session(arg)
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) GetSession(_ context.Context, uuid string) (repository.Session, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
sess, ok := s.sessions[uuid]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateSession(_ context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
sess, ok := s.sessions[arg.UUID]
|
||||
if !ok {
|
||||
return repository.Session{}, repository.ErrNotFound
|
||||
}
|
||||
sess.Username = arg.Username
|
||||
sess.Email = arg.Email
|
||||
sess.Name = arg.Name
|
||||
sess.Provider = arg.Provider
|
||||
sess.TotpPending = arg.TotpPending
|
||||
sess.OAuthGroups = arg.OAuthGroups
|
||||
sess.Expiry = arg.Expiry
|
||||
sess.OAuthName = arg.OAuthName
|
||||
sess.OAuthSub = arg.OAuthSub
|
||||
s.sessions[arg.UUID] = sess
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(_ context.Context, uuid string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
delete(s.sessions, uuid)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(_ context.Context, expiry int64) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
for k, v := range s.sessions {
|
||||
if v.Expiry < expiry {
|
||||
delete(s.sessions, k)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -0,0 +1,27 @@
|
||||
// Package memory provides an in-memory implementation of repository.Store for use in tests.
|
||||
package memory
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
|
||||
// Store is a thread-safe in-memory implementation of repository.Store.
|
||||
type Store struct {
|
||||
mu sync.RWMutex
|
||||
sessions map[string]repository.Session
|
||||
oidcCodes map[string]repository.OidcCode
|
||||
oidcTokens map[string]repository.OidcToken
|
||||
oidcUsers map[string]repository.OidcUserinfo
|
||||
}
|
||||
|
||||
// New returns a new empty in-memory Store.
|
||||
func New() repository.Store {
|
||||
return &Store{
|
||||
sessions: make(map[string]repository.Session),
|
||||
oidcCodes: make(map[string]repository.OidcCode),
|
||||
oidcTokens: make(map[string]repository.OidcToken),
|
||||
oidcUsers: make(map[string]repository.OidcUserinfo),
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,8 @@ package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/tinyauthapp/tinyauth/internal/repository"
|
||||
)
|
||||
@@ -17,6 +19,22 @@ func NewStore(q *Queries) repository.Store {
|
||||
return &Store{q: q}
|
||||
}
|
||||
|
||||
var errMap = []struct {
|
||||
from error
|
||||
to error
|
||||
}{
|
||||
{sql.ErrNoRows, repository.ErrNotFound},
|
||||
}
|
||||
|
||||
func mapErr(err error) error {
|
||||
for _, e := range errMap {
|
||||
if errors.Is(err, e.from) {
|
||||
return e.to
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func oidcCodeToRepo(v OidcCode) repository.OidcCode {
|
||||
return repository.OidcCode(v)
|
||||
}
|
||||
@@ -32,7 +50,7 @@ func sessionToRepo(v Session) repository.Session {
|
||||
func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCodeParams) (repository.OidcCode, error) {
|
||||
r, err := s.q.CreateOidcCode(ctx, CreateOidcCodeParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
@@ -40,7 +58,7 @@ func (s *Store) CreateOidcCode(ctx context.Context, arg repository.CreateOidcCod
|
||||
func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.CreateOidcToken(ctx, CreateOidcTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
@@ -48,7 +66,7 @@ func (s *Store) CreateOidcToken(ctx context.Context, arg repository.CreateOidcTo
|
||||
func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOidcUserInfoParams) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.CreateOidcUserInfo(ctx, CreateOidcUserInfoParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, err
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return oidcUserinfoToRepo(r), nil
|
||||
}
|
||||
@@ -56,7 +74,7 @@ func (s *Store) CreateOidcUserInfo(ctx context.Context, arg repository.CreateOid
|
||||
func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.CreateSession(ctx, CreateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, err
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
@@ -64,7 +82,7 @@ func (s *Store) CreateSession(ctx context.Context, arg repository.CreateSessionP
|
||||
func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]repository.OidcCode, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcCodes(ctx, expiresAt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcCode, len(rows))
|
||||
for i, row := range rows {
|
||||
@@ -76,7 +94,7 @@ func (s *Store) DeleteExpiredOidcCodes(ctx context.Context, expiresAt int64) ([]
|
||||
func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.DeleteExpiredOidcTokensParams) ([]repository.OidcToken, error) {
|
||||
rows, err := s.q.DeleteExpiredOidcTokens(ctx, DeleteExpiredOidcTokensParams(arg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, mapErr(err)
|
||||
}
|
||||
out := make([]repository.OidcToken, len(rows))
|
||||
for i, row := range rows {
|
||||
@@ -86,41 +104,41 @@ func (s *Store) DeleteExpiredOidcTokens(ctx context.Context, arg repository.Dele
|
||||
}
|
||||
|
||||
func (s *Store) DeleteExpiredSessions(ctx context.Context, expiry int64) error {
|
||||
return s.q.DeleteExpiredSessions(ctx, expiry)
|
||||
return mapErr(s.q.DeleteExpiredSessions(ctx, expiry))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCode(ctx context.Context, codeHash string) error {
|
||||
return s.q.DeleteOidcCode(ctx, codeHash)
|
||||
return mapErr(s.q.DeleteOidcCode(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcCodeBySub(ctx context.Context, sub string) error {
|
||||
return s.q.DeleteOidcCodeBySub(ctx, sub)
|
||||
return mapErr(s.q.DeleteOidcCodeBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcToken(ctx context.Context, accessTokenHash string) error {
|
||||
return s.q.DeleteOidcToken(ctx, accessTokenHash)
|
||||
return mapErr(s.q.DeleteOidcToken(ctx, accessTokenHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenByCodeHash(ctx context.Context, codeHash string) error {
|
||||
return s.q.DeleteOidcTokenByCodeHash(ctx, codeHash)
|
||||
return mapErr(s.q.DeleteOidcTokenByCodeHash(ctx, codeHash))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcTokenBySub(ctx context.Context, sub string) error {
|
||||
return s.q.DeleteOidcTokenBySub(ctx, sub)
|
||||
return mapErr(s.q.DeleteOidcTokenBySub(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteOidcUserInfo(ctx context.Context, sub string) error {
|
||||
return s.q.DeleteOidcUserInfo(ctx, sub)
|
||||
return mapErr(s.q.DeleteOidcUserInfo(ctx, sub))
|
||||
}
|
||||
|
||||
func (s *Store) DeleteSession(ctx context.Context, uuid string) error {
|
||||
return s.q.DeleteSession(ctx, uuid)
|
||||
return mapErr(s.q.DeleteSession(ctx, uuid))
|
||||
}
|
||||
|
||||
func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCode(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
@@ -128,7 +146,7 @@ func (s *Store) GetOidcCode(ctx context.Context, codeHash string) (repository.Oi
|
||||
func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
@@ -136,7 +154,7 @@ func (s *Store) GetOidcCodeBySub(ctx context.Context, sub string) (repository.Oi
|
||||
func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeBySubUnsafe(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
@@ -144,7 +162,7 @@ func (s *Store) GetOidcCodeBySubUnsafe(ctx context.Context, sub string) (reposit
|
||||
func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (repository.OidcCode, error) {
|
||||
r, err := s.q.GetOidcCodeUnsafe(ctx, codeHash)
|
||||
if err != nil {
|
||||
return repository.OidcCode{}, err
|
||||
return repository.OidcCode{}, mapErr(err)
|
||||
}
|
||||
return oidcCodeToRepo(r), nil
|
||||
}
|
||||
@@ -152,7 +170,7 @@ func (s *Store) GetOidcCodeUnsafe(ctx context.Context, codeHash string) (reposit
|
||||
func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcToken(ctx, accessTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
@@ -160,7 +178,7 @@ func (s *Store) GetOidcToken(ctx context.Context, accessTokenHash string) (repos
|
||||
func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenByRefreshToken(ctx, refreshTokenHash)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
@@ -168,7 +186,7 @@ func (s *Store) GetOidcTokenByRefreshToken(ctx context.Context, refreshTokenHash
|
||||
func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.OidcToken, error) {
|
||||
r, err := s.q.GetOidcTokenBySub(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
@@ -176,7 +194,7 @@ func (s *Store) GetOidcTokenBySub(ctx context.Context, sub string) (repository.O
|
||||
func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.OidcUserinfo, error) {
|
||||
r, err := s.q.GetOidcUserInfo(ctx, sub)
|
||||
if err != nil {
|
||||
return repository.OidcUserinfo{}, err
|
||||
return repository.OidcUserinfo{}, mapErr(err)
|
||||
}
|
||||
return oidcUserinfoToRepo(r), nil
|
||||
}
|
||||
@@ -184,7 +202,7 @@ func (s *Store) GetOidcUserInfo(ctx context.Context, sub string) (repository.Oid
|
||||
func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session, error) {
|
||||
r, err := s.q.GetSession(ctx, uuid)
|
||||
if err != nil {
|
||||
return repository.Session{}, err
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
@@ -192,7 +210,7 @@ func (s *Store) GetSession(ctx context.Context, uuid string) (repository.Session
|
||||
func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repository.UpdateOidcTokenByRefreshTokenParams) (repository.OidcToken, error) {
|
||||
r, err := s.q.UpdateOidcTokenByRefreshToken(ctx, UpdateOidcTokenByRefreshTokenParams(arg))
|
||||
if err != nil {
|
||||
return repository.OidcToken{}, err
|
||||
return repository.OidcToken{}, mapErr(err)
|
||||
}
|
||||
return oidcTokenToRepo(r), nil
|
||||
}
|
||||
@@ -200,7 +218,7 @@ func (s *Store) UpdateOidcTokenByRefreshToken(ctx context.Context, arg repositor
|
||||
func (s *Store) UpdateSession(ctx context.Context, arg repository.UpdateSessionParams) (repository.Session, error) {
|
||||
r, err := s.q.UpdateSession(ctx, UpdateSessionParams(arg))
|
||||
if err != nil {
|
||||
return repository.Session{}, err
|
||||
return repository.Session{}, mapErr(err)
|
||||
}
|
||||
return sessionToRepo(r), nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
package repository
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
)
|
||||
|
||||
// ErrNotFound is returned by Store methods when the requested record does not exist.
|
||||
var ErrNotFound = errors.New("not found")
|
||||
|
||||
// Store is the interface that all storage drivers must implement.
|
||||
// The sqlc-generated *Queries struct satisfies this interface for SQLite.
|
||||
|
||||
@@ -2,7 +2,6 @@ package service
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"regexp"
|
||||
@@ -411,7 +410,7 @@ func (auth *AuthService) GetSessionCookie(c *gin.Context) (repository.Session, e
|
||||
session, err := auth.queries.GetSession(c, cookie)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.Session{}, fmt.Errorf("session not found")
|
||||
}
|
||||
return repository.Session{}, err
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/x509"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
@@ -420,7 +419,7 @@ func (service *OIDCService) GetCodeEntry(c *gin.Context, codeHash string, client
|
||||
oidcCode, err := service.queries.GetOidcCode(c, codeHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcCode{}, ErrCodeNotFound
|
||||
}
|
||||
return repository.OidcCode{}, err
|
||||
@@ -564,7 +563,7 @@ func (service *OIDCService) RefreshAccessToken(c *gin.Context, refreshToken stri
|
||||
entry, err := service.queries.GetOidcTokenByRefreshToken(c, service.Hash(refreshToken))
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return TokenResponse{}, ErrTokenNotFound
|
||||
}
|
||||
return TokenResponse{}, err
|
||||
@@ -643,7 +642,7 @@ func (service *OIDCService) GetAccessToken(c *gin.Context, tokenHash string) (re
|
||||
entry, err := service.queries.GetOidcToken(c, tokenHash)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
return repository.OidcToken{}, ErrTokenNotFound
|
||||
}
|
||||
return repository.OidcToken{}, err
|
||||
@@ -731,15 +730,15 @@ func (service *OIDCService) Hash(token string) string {
|
||||
|
||||
func (service *OIDCService) DeleteOldSession(ctx context.Context, sub string) error {
|
||||
err := service.queries.DeleteOidcCodeBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcTokenBySub(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
err = service.queries.DeleteOidcUserInfo(ctx, sub)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
if err != nil && !errors.Is(err, repository.ErrNotFound) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
@@ -784,7 +783,7 @@ func (service *OIDCService) Cleanup() {
|
||||
token, err := service.queries.GetOidcTokenBySub(ctx, expiredCode.Sub)
|
||||
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
if errors.Is(err, repository.ErrNotFound) {
|
||||
continue
|
||||
}
|
||||
tlog.App.Warn().Err(err).Msg("Failed to get OIDC token by sub")
|
||||
|
||||
Reference in New Issue
Block a user