mirror of
https://github.com/steveiliop56/tinyauth.git
synced 2026-01-27 09:42:30 +00:00
refactor: implement oidc following tinyauth patterns
This commit is contained in:
@@ -54,6 +54,10 @@ func NewTinyauthCmdConfiguration() *config.Config {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
OIDC: config.OIDCConfig{
|
||||||
|
PrivateKeyPath: "./tinyauth_oidc_key",
|
||||||
|
PublicKeyPath: "./tinyauth_oidc_key.pub",
|
||||||
|
},
|
||||||
Experimental: config.ExperimentalConfig{
|
Experimental: config.ExperimentalConfig{
|
||||||
ConfigFile: "",
|
ConfigFile: "",
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -64,9 +64,7 @@ export const AuthorizePage = () => {
|
|||||||
toast.info("Authorized", {
|
toast.info("Authorized", {
|
||||||
description: "You will be soon redirected to your application",
|
description: "You will be soon redirected to your application",
|
||||||
});
|
});
|
||||||
window.location.replace(
|
window.location.replace(data.data.redirect_uri);
|
||||||
`${data.data.redirect_uri}?code=${encodeURIComponent(data.data.code)}&state=${encodeURIComponent(data.data.state)}`,
|
|
||||||
);
|
|
||||||
},
|
},
|
||||||
onError: (error) => {
|
onError: (error) => {
|
||||||
window.location.replace(
|
window.location.replace(
|
||||||
|
|||||||
@@ -176,7 +176,7 @@ func (app *BootstrapApp) Setup() error {
|
|||||||
app.context.configuredProviders = configuredProviders
|
app.context.configuredProviders = configuredProviders
|
||||||
|
|
||||||
// Setup router
|
// Setup router
|
||||||
router, err := app.setupRouter(queries)
|
router, err := app.setupRouter()
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to setup routes: %w", err)
|
return fmt.Errorf("failed to setup routes: %w", err)
|
||||||
|
|||||||
@@ -7,14 +7,13 @@ import (
|
|||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
"github.com/steveiliop56/tinyauth/internal/controller"
|
"github.com/steveiliop56/tinyauth/internal/controller"
|
||||||
"github.com/steveiliop56/tinyauth/internal/middleware"
|
"github.com/steveiliop56/tinyauth/internal/middleware"
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
)
|
)
|
||||||
|
|
||||||
var DEV_MODES = []string{"main", "test", "development"}
|
var DEV_MODES = []string{"main", "test", "development"}
|
||||||
|
|
||||||
func (app *BootstrapApp) setupRouter(queries *repository.Queries) (*gin.Engine, error) {
|
func (app *BootstrapApp) setupRouter() (*gin.Engine, error) {
|
||||||
if !slices.Contains(DEV_MODES, config.Version) {
|
if !slices.Contains(DEV_MODES, config.Version) {
|
||||||
gin.SetMode(gin.ReleaseMode)
|
gin.SetMode(gin.ReleaseMode)
|
||||||
}
|
}
|
||||||
@@ -87,10 +86,7 @@ func (app *BootstrapApp) setupRouter(queries *repository.Queries) (*gin.Engine,
|
|||||||
|
|
||||||
oauthController.SetupRoutes()
|
oauthController.SetupRoutes()
|
||||||
|
|
||||||
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{
|
oidcController := controller.NewOIDCController(controller.OIDCControllerConfig{}, app.services.oidcService, apiRouter)
|
||||||
Clients: app.context.oidcClients,
|
|
||||||
AppURL: app.config.AppURL,
|
|
||||||
}, apiRouter, queries)
|
|
||||||
|
|
||||||
oidcController.SetupRoutes()
|
oidcController.SetupRoutes()
|
||||||
|
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ type Services struct {
|
|||||||
dockerService *service.DockerService
|
dockerService *service.DockerService
|
||||||
ldapService *service.LdapService
|
ldapService *service.LdapService
|
||||||
oauthBrokerService *service.OAuthBrokerService
|
oauthBrokerService *service.OAuthBrokerService
|
||||||
|
oidcService *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, error) {
|
||||||
@@ -88,5 +89,20 @@ func (app *BootstrapApp) initServices(queries *repository.Queries) (Services, er
|
|||||||
|
|
||||||
services.oauthBrokerService = oauthBrokerService
|
services.oauthBrokerService = oauthBrokerService
|
||||||
|
|
||||||
|
oidcService := service.NewOIDCService(service.OIDCServiceConfig{
|
||||||
|
Clients: app.config.OIDC.Clients,
|
||||||
|
PrivateKeyPath: app.config.OIDC.PrivateKeyPath,
|
||||||
|
PublicKeyPath: app.config.OIDC.PublicKeyPath,
|
||||||
|
Issuer: app.config.AppURL,
|
||||||
|
}, queries)
|
||||||
|
|
||||||
|
err = oidcService.Init()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return Services{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
services.oidcService = oidcService
|
||||||
|
|
||||||
return services, nil
|
return services, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,9 @@ type OAuthConfig struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type OIDCConfig struct {
|
type OIDCConfig struct {
|
||||||
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
|
PrivateKeyPath string `description:"Path to the private key file." yaml:"privateKeyPath"`
|
||||||
|
PublicKeyPath string `description:"Path to the public key file." yaml:"publicKeyPath"`
|
||||||
|
Clients map[string]OIDCClientConfig `description:"OIDC clients configuration." yaml:"clients"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type UIConfig struct {
|
type UIConfig struct {
|
||||||
@@ -136,7 +138,7 @@ type OIDCClientConfig struct {
|
|||||||
ClientID string `description:"OIDC client ID." yaml:"clientId"`
|
ClientID string `description:"OIDC client ID." yaml:"clientId"`
|
||||||
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
|
ClientSecret string `description:"OIDC client secret." yaml:"clientSecret"`
|
||||||
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
|
ClientSecretFile string `description:"Path to the file containing the OIDC client secret." yaml:"clientSecretFile"`
|
||||||
TrustedRedirectURLs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
|
TrustedRedirectURIs []string `description:"List of trusted redirect URLs." yaml:"trustedRedirectUrls"`
|
||||||
Name string `description:"Client name in UI." yaml:"name"`
|
Name string `description:"Client name in UI." yaml:"name"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,43 +1,31 @@
|
|||||||
package controller
|
package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/google/go-querystring/query"
|
"github.com/google/go-querystring/query"
|
||||||
"github.com/steveiliop56/tinyauth/internal/config"
|
"github.com/steveiliop56/tinyauth/internal/service"
|
||||||
"github.com/steveiliop56/tinyauth/internal/repository"
|
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils"
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
type OIDCControllerConfig struct{}
|
||||||
SupportedResponseTypes = []string{"code"}
|
|
||||||
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
|
||||||
SupportedGrantTypes = []string{"authorization_code"}
|
|
||||||
)
|
|
||||||
|
|
||||||
type OIDCControllerConfig struct {
|
|
||||||
Clients []config.OIDCClientConfig
|
|
||||||
AppURL string
|
|
||||||
}
|
|
||||||
|
|
||||||
type OIDCController struct {
|
type OIDCController struct {
|
||||||
config OIDCControllerConfig
|
config OIDCControllerConfig
|
||||||
router *gin.RouterGroup
|
router *gin.RouterGroup
|
||||||
queries *repository.Queries
|
oidc *service.OIDCService
|
||||||
}
|
}
|
||||||
|
|
||||||
type AuthorizeRequest struct {
|
type AuthorizeCallback struct {
|
||||||
Scope string `json:"scope" binding:"required"`
|
Code string `url:"code"`
|
||||||
ResponseType string `json:"response_type" binding:"required"`
|
State string `url:"state"`
|
||||||
ClientID string `json:"client_id" binding:"required"`
|
|
||||||
RedirectURI string `json:"redirect_uri" binding:"required"`
|
|
||||||
State string `json:"state" binding:"required"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TokenRequest struct {
|
type TokenRequest struct {
|
||||||
@@ -52,11 +40,19 @@ type CallbackError struct {
|
|||||||
State string `url:"state"`
|
State string `url:"state"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewOIDCController(config OIDCControllerConfig, router *gin.RouterGroup, queries *repository.Queries) *OIDCController {
|
type ErrorScreen struct {
|
||||||
|
Error string `url:"error"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type ClientRequest struct {
|
||||||
|
ClientID string `uri:"id" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCController(config OIDCControllerConfig, oidcService *service.OIDCService, router *gin.RouterGroup) *OIDCController {
|
||||||
return &OIDCController{
|
return &OIDCController{
|
||||||
config: config,
|
config: config,
|
||||||
router: router,
|
oidc: oidcService,
|
||||||
queries: queries,
|
router: router,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,10 +64,6 @@ func (controller *OIDCController) SetupRoutes() {
|
|||||||
oidcGroup.GET("/userinfo", controller.Userinfo)
|
oidcGroup.GET("/userinfo", controller.Userinfo)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ClientRequest struct {
|
|
||||||
ClientID string `uri:"id" binding:"required"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
||||||
var req ClientRequest
|
var req ClientRequest
|
||||||
|
|
||||||
@@ -85,17 +77,9 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var client *config.OIDCClientConfig
|
client, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
// Inefficient yeah, but it will be good until we have thousands of clients
|
if !ok {
|
||||||
for _, clientCfg := range controller.config.Clients {
|
|
||||||
if clientCfg.ClientID == req.ClientID {
|
|
||||||
client = &clientCfg
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if client == nil {
|
|
||||||
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
||||||
c.JSON(404, gin.H{
|
c.JSON(404, gin.H{
|
||||||
"status": 404,
|
"status": 404,
|
||||||
@@ -106,206 +90,111 @@ func (controller *OIDCController) GetClientInfo(c *gin.Context) {
|
|||||||
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"client": &client.ClientID,
|
"client": client.ClientID,
|
||||||
"name": &client.Name,
|
"name": client.Name,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Authorize(c *gin.Context) {
|
func (controller *OIDCController) Authorize(c *gin.Context) {
|
||||||
// Check if we are logged in
|
|
||||||
userContext, err := utils.GetContext(c)
|
userContext, err := utils.GetContext(c)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to get user context")
|
controller.authorizeError(c, err, "Failed to get user context", "User is not logged in or the session is invalid", "", "", "")
|
||||||
c.JSON(401, gin.H{
|
|
||||||
"status": 401,
|
|
||||||
"message": "Unauthorized",
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// OIDC stuff
|
var req service.AuthorizeRequest
|
||||||
var req AuthorizeRequest
|
|
||||||
|
|
||||||
err = c.BindJSON(&req)
|
err = c.BindJSON(&req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to bind JSON")
|
controller.authorizeError(c, err, "Failed to bind JSON", "The client provided an invalid authorization request", "", "", "")
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"status": 400,
|
|
||||||
"message": "Bad Request",
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: All these errors should redirect to the error page with an explanation
|
_, ok := controller.oidc.GetClient(req.ClientID)
|
||||||
|
|
||||||
// Validate client ID
|
if !ok {
|
||||||
var client *config.OIDCClientConfig
|
controller.authorizeError(c, err, "Client not found", "The client ID is invalid", "", "", "")
|
||||||
|
|
||||||
for _, clientCfg := range controller.config.Clients {
|
|
||||||
if clientCfg.ClientID == req.ClientID {
|
|
||||||
client = &clientCfg
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if client == nil {
|
|
||||||
tlog.App.Warn().Str("client_id", req.ClientID).Msg("Client not found")
|
|
||||||
c.JSON(404, gin.H{
|
|
||||||
"status": 404,
|
|
||||||
"message": "Client not found",
|
|
||||||
})
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate redirect URI
|
err = controller.oidc.ValidateAuthorizeParams(req)
|
||||||
if !slices.Contains(client.TrustedRedirectURLs, req.RedirectURI) {
|
|
||||||
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI not trusted")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"status": 400,
|
|
||||||
"message": "Bad Request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Validate scopes
|
if err != nil {
|
||||||
reqScopes := strings.Split(req.Scope, " ")
|
tlog.App.Error().Err(err).Msg("Failed to validate authorize params")
|
||||||
keptScopes := make([]string, 0)
|
if err.Error() != "invalid_request_uri" {
|
||||||
|
controller.authorizeError(c, err, "Failed validate authorize params", "Invalid request parameters", req.RedirectURI, err.Error(), req.State)
|
||||||
if len(reqScopes) == 0 || strings.TrimSpace(req.Scope) == "" {
|
|
||||||
queries, err := query.Values(CallbackError{
|
|
||||||
Error: "invalid_request",
|
|
||||||
ErrorDescription: "Missing scope parameter",
|
|
||||||
State: req.State,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to build query")
|
|
||||||
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
controller.authorizeError(c, err, "Redirect URI not trusted", "The provided redirect URI is not trusted", "", "", "")
|
||||||
c.Redirect(302, fmt.Sprintf("%s/callback?%s", req.RedirectURI, queries.Encode()))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, scope := range reqScopes {
|
// WARNING: Since Tinyauth is stateless, we cannot have a sub that never changes. We will just create a uuid out of the username which remains stable, but if username changes then sub changes too.
|
||||||
if slices.Contains(SupportedScopes, scope) {
|
sub := utils.GenerateUUID(userContext.Username)
|
||||||
keptScopes = append(keptScopes, scope)
|
code := rand.Text()
|
||||||
continue
|
|
||||||
}
|
|
||||||
tlog.App.Warn().Str("scope", scope).Msg("Scope not supported, ignoring")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate a code and a sub
|
err = controller.oidc.StoreCode(c, sub, code, req)
|
||||||
code, err := utils.GetRandomString(32)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate random string")
|
controller.authorizeError(c, err, "Failed to store code", "Failed to store code", req.RedirectURI, "server_error", req.State)
|
||||||
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
sub, err := utils.GetRandomInt(10)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate random integer")
|
|
||||||
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
|
||||||
|
|
||||||
// Insert the code into the database
|
|
||||||
_, err = controller.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
|
||||||
Code: code,
|
|
||||||
Sub: strconv.Itoa(int(sub)),
|
|
||||||
Scope: strings.Join(keptScopes, ","),
|
|
||||||
RedirectURI: req.RedirectURI,
|
|
||||||
ClientID: client.ClientID,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to insert code into database")
|
|
||||||
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// We also need a snapshot of the user that authorized this
|
// We also need a snapshot of the user that authorized this
|
||||||
userInfoParams := repository.CreateOidcUserInfoParams{
|
err = controller.oidc.StoreUserinfo(c, sub, userContext, req)
|
||||||
Sub: strconv.Itoa(int(sub)),
|
|
||||||
Name: userContext.Name,
|
|
||||||
Email: userContext.Email,
|
|
||||||
PreferredUsername: userContext.Username,
|
|
||||||
UpdatedAt: time.Now().Unix(),
|
|
||||||
}
|
|
||||||
|
|
||||||
if userContext.Provider == "ldap" {
|
|
||||||
userInfoParams.Groups = userContext.LdapGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
|
||||||
userInfoParams.Groups = userContext.OAuthGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = controller.queries.CreateOidcUserInfo(c, userInfoParams)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
tlog.App.Error().Err(err).Msg("Failed to insert user info into database")
|
||||||
c.Redirect(302, fmt.Sprintf("%s/error", controller.config.AppURL))
|
controller.authorizeError(c, err, "Failed to store user info", "Failed to store user info", req.RedirectURI, "server_error", req.State)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(AuthorizeCallback{
|
||||||
|
Code: code,
|
||||||
|
State: req.State,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
controller.authorizeError(c, err, "Failed to build query", "Failed to build query", req.RedirectURI, "server_error", req.State)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Return code and done
|
|
||||||
c.JSON(200, gin.H{
|
c.JSON(200, gin.H{
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"message": "Authorized",
|
"redirect_uri": fmt.Sprintf("%s?%s", req.RedirectURI, queries.Encode()),
|
||||||
"code": code,
|
|
||||||
"state": req.State,
|
|
||||||
"redirect_uri": req.RedirectURI,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Token(c *gin.Context) {
|
func (controller *OIDCController) Token(c *gin.Context) {
|
||||||
// Get basic auth
|
rclientId, rclientSecret, ok := c.Request.BasicAuth()
|
||||||
clientId, clientSecret, ok := c.Request.BasicAuth()
|
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Error().Msg("Missing token verifier")
|
tlog.App.Error().Msg("Missing authorization header")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure client exists
|
client, ok := controller.oidc.GetClient(rclientId)
|
||||||
var client *config.OIDCClientConfig
|
|
||||||
|
|
||||||
for _, clientCfg := range controller.config.Clients {
|
if !ok {
|
||||||
if clientCfg.ClientID == clientId {
|
tlog.App.Warn().Str("client_id", rclientId).Msg("Client not found")
|
||||||
client = &clientCfg
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if client == nil {
|
|
||||||
tlog.App.Warn().Str("client_id", clientId).Msg("Client not found")
|
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "access_denied",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if client.ClientSecret != clientSecret {
|
if client.ClientSecret != rclientSecret {
|
||||||
tlog.App.Warn().Str("client_id", clientId).Msg("Invalid client secret")
|
tlog.App.Warn().Str("client_id", rclientId).Msg("Invalid client secret")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_client",
|
"error": "access_denied",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get token
|
|
||||||
var req TokenRequest
|
var req TokenRequest
|
||||||
|
|
||||||
err := c.Bind(&req)
|
err := c.Bind(&req)
|
||||||
@@ -317,93 +206,73 @@ func (controller *OIDCController) Token(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate grant type
|
err = controller.oidc.ValidateGrantType(req.GrantType)
|
||||||
if !slices.Contains(SupportedGrantTypes, req.GrantType) {
|
if err != nil {
|
||||||
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
tlog.App.Warn().Str("grant_type", req.GrantType).Msg("Unsupported grant type")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "unsupported_grant_type",
|
"error": err.Error(),
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Find pending code entry
|
entry, err := controller.oidc.GetCodeEntry(c, req.Code)
|
||||||
entry, err := controller.queries.GetOidcCode(c, req.Code)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to find code in database")
|
if errors.Is(err, service.ErrCodeExpired) {
|
||||||
|
tlog.App.Warn().Str("code", req.Code).Msg("Code expired")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "access_denied",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if errors.Is(err, service.ErrCodeNotFound) {
|
||||||
|
tlog.App.Warn().Str("code", req.Code).Msg("Code not found")
|
||||||
|
c.JSON(400, gin.H{
|
||||||
|
"error": "access_denied",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
tlog.App.Warn().Err(err).Msg("Failed to get OIDC code entry")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Ensure redirect URIs match
|
|
||||||
if entry.RedirectURI != req.RedirectURI {
|
if entry.RedirectURI != req.RedirectURI {
|
||||||
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
tlog.App.Warn().Str("redirect_uri", req.RedirectURI).Msg("Redirect URI mismatch")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request_uri",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate access token
|
accessToken, err := controller.oidc.GenerateAccessToken(c, client, entry.Sub, entry.Scope)
|
||||||
genToken, err := utils.GetRandomString(29)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
tlog.App.Error().Err(err).Msg("Failed to generate access token")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add tinyauth prefix
|
err = controller.oidc.DeleteCodeEntry(c, entry.Code)
|
||||||
token := fmt.Sprintf("ta-%s", genToken)
|
|
||||||
|
|
||||||
// TODO: either add a refresh token or customize token expiry
|
|
||||||
expiresAt := time.Now().Add(time.Duration(3600) * time.Second).Unix()
|
|
||||||
|
|
||||||
// Create token entry
|
|
||||||
_, err = controller.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
|
||||||
Sub: entry.Sub,
|
|
||||||
AccessToken: token,
|
|
||||||
Scope: entry.Scope,
|
|
||||||
ClientID: client.ClientID,
|
|
||||||
ExpiresAt: expiresAt,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Error().Err(err).Msg("Failed to create token in database")
|
|
||||||
c.JSON(400, gin.H{
|
|
||||||
"error": "invalid_request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete code entry
|
|
||||||
err = controller.queries.DeleteOidcCode(c, entry.Code)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
tlog.App.Error().Err(err).Msg("Failed to delete code in database")
|
||||||
c.JSON(400, gin.H{
|
c.JSON(400, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Respond with token
|
c.JSON(200, accessToken)
|
||||||
c.JSON(200, gin.H{
|
|
||||||
"access_token": token,
|
|
||||||
"token_type": "bearer",
|
|
||||||
"expires_in": 3600,
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
func (controller *OIDCController) Userinfo(c *gin.Context) {
|
||||||
// Get bearer
|
authorization := c.GetHeader("Authorization")
|
||||||
authorizationHeader := c.GetHeader("Authorization")
|
|
||||||
|
|
||||||
tokenType, token, ok := strings.Cut(authorizationHeader, " ")
|
tokenType, token, ok := strings.Cut(authorization, " ")
|
||||||
|
|
||||||
if !ok {
|
if !ok {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without authorization header")
|
||||||
@@ -421,53 +290,36 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get token entry
|
entry, err := controller.oidc.GetAccessToken(c, token)
|
||||||
entry, err := controller.queries.GetOidcToken(c, token)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if err == service.ErrTokenNotFound {
|
||||||
|
tlog.App.Warn().Msg("OIDC userinfo accessed with invalid token")
|
||||||
|
c.JSON(401, gin.H{
|
||||||
|
"error": "invalid_request",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
tlog.App.Err(err).Msg("Failed to get token entry")
|
tlog.App.Err(err).Msg("Failed to get token entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get scopes
|
user, err := controller.oidc.GetUserinfo(c, entry.Sub)
|
||||||
scopes := strings.Split(entry.Scope, ",")
|
|
||||||
|
|
||||||
// Check if token is expired
|
|
||||||
if time.Now().Unix() > entry.ExpiresAt {
|
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed with expired token")
|
|
||||||
|
|
||||||
err = controller.queries.DeleteOidcToken(c, entry.AccessToken)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Err(err).Msg("Failed to delete expired token")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = controller.queries.DeleteOidcUserInfo(c, entry.Sub)
|
|
||||||
if err != nil {
|
|
||||||
tlog.App.Err(err).Msg("Failed to delete oidc user info")
|
|
||||||
}
|
|
||||||
|
|
||||||
c.JSON(401, gin.H{
|
|
||||||
"error": "invalid_request",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get user info
|
|
||||||
user, err := controller.queries.GetOidcUserInfo(c, entry.Sub)
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tlog.App.Err(err).Msg("Failed to get user entry")
|
tlog.App.Err(err).Msg("Failed to get user entry")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "server_error",
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// If we don't have the openid scope, return an error
|
// If we don't have the openid scope, return an error
|
||||||
if !slices.Contains(scopes, "openid") {
|
if !slices.Contains(strings.Split(entry.Scope, ","), "openid") {
|
||||||
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
tlog.App.Warn().Msg("OIDC userinfo accessed without openid scope")
|
||||||
c.JSON(401, gin.H{
|
c.JSON(401, gin.H{
|
||||||
"error": "invalid_request",
|
"error": "invalid_request",
|
||||||
@@ -475,27 +327,52 @@ func (controller *OIDCController) Userinfo(c *gin.Context) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let's build the response
|
c.JSON(200, controller.oidc.CompileUserinfo(user, entry.Scope))
|
||||||
res := map[string]any{
|
}
|
||||||
"sub": user.Sub,
|
|
||||||
"updated_at": user.UpdatedAt,
|
func (controller *OIDCController) authorizeError(c *gin.Context, err error, reason string, reasonUser string, callback string, callbackError string, state string) {
|
||||||
}
|
tlog.App.Error().Err(err).Msg(reason)
|
||||||
|
|
||||||
// If we have the profile scope, add the profile stuff
|
if callback != "" {
|
||||||
if slices.Contains(scopes, "profile") {
|
errorQueries := CallbackError{
|
||||||
res["name"] = user.Name
|
Error: callbackError,
|
||||||
res["preferred_username"] = user.PreferredUsername
|
}
|
||||||
}
|
|
||||||
|
if reasonUser != "" {
|
||||||
// If we have the email scope, add the email stuff
|
errorQueries.ErrorDescription = reasonUser
|
||||||
if slices.Contains(scopes, "email") {
|
}
|
||||||
res["email"] = user.Email
|
|
||||||
}
|
if state != "" {
|
||||||
|
errorQueries.State = state
|
||||||
// If we have the groups scope, add the groups stuff
|
}
|
||||||
if slices.Contains(scopes, "groups") {
|
|
||||||
res["groups"] = user.Groups
|
queries, err := query.Values(errorQueries)
|
||||||
}
|
|
||||||
|
if err != nil {
|
||||||
c.JSON(200, res)
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"redirect_uri": fmt.Sprintf("%s/?%s", callback, queries.Encode()),
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
errorQueries := ErrorScreen{
|
||||||
|
Error: reasonUser,
|
||||||
|
}
|
||||||
|
|
||||||
|
queries, err := query.Values(errorQueries)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
c.AbortWithStatus(http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.JSON(200, gin.H{
|
||||||
|
"status": 200,
|
||||||
|
"redirect_uri": fmt.Sprintf("%s/error?%s", controller.oidc.GetIssuer(), queries.Encode()),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
438
internal/service/oidc_service.go
Normal file
438
internal/service/oidc_service.go
Normal file
@@ -0,0 +1,438 @@
|
|||||||
|
package service
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"database/sql"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/config"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/repository"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/utils"
|
||||||
|
"github.com/steveiliop56/tinyauth/internal/utils/tlog"
|
||||||
|
"golang.org/x/exp/slices"
|
||||||
|
|
||||||
|
// Should probably switch to another package but for now this works
|
||||||
|
"golang.org/x/oauth2/jws"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
SupportedScopes = []string{"openid", "profile", "email", "groups"}
|
||||||
|
SupportedResponseTypes = []string{"code"}
|
||||||
|
SupportedGrantTypes = []string{"authorization_code"}
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrCodeExpired = errors.New("code_expired")
|
||||||
|
ErrCodeNotFound = errors.New("code_not_found")
|
||||||
|
ErrTokenNotFound = errors.New("token_not_found")
|
||||||
|
ErrTokenExpired = errors.New("token_expired")
|
||||||
|
)
|
||||||
|
|
||||||
|
type UserinfoResponse struct {
|
||||||
|
Sub string `json:"sub"`
|
||||||
|
Name string `json:"name"`
|
||||||
|
Email string `json:"email"`
|
||||||
|
PreferredUsername string `json:"preferred_username"`
|
||||||
|
Groups []string `json:"groups"`
|
||||||
|
UpdatedAt int64 `json:"updated_at"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type TokenResponse struct {
|
||||||
|
AccessToken string `json:"access_token"`
|
||||||
|
TokenType string `json:"token_type"`
|
||||||
|
ExpiresIn int64 `json:"expires_in"`
|
||||||
|
IDToken string `json:"id_token"`
|
||||||
|
Scope string `json:"scope"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type AuthorizeRequest struct {
|
||||||
|
Scope string `json:"scope" binding:"required"`
|
||||||
|
ResponseType string `json:"response_type" binding:"required"`
|
||||||
|
ClientID string `json:"client_id" binding:"required"`
|
||||||
|
RedirectURI string `json:"redirect_uri" binding:"required"`
|
||||||
|
State string `json:"state" binding:"required"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCServiceConfig struct {
|
||||||
|
Clients map[string]config.OIDCClientConfig
|
||||||
|
PrivateKeyPath string
|
||||||
|
PublicKeyPath string
|
||||||
|
Issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
type OIDCService struct {
|
||||||
|
config OIDCServiceConfig
|
||||||
|
queries *repository.Queries
|
||||||
|
clients map[string]config.OIDCClientConfig
|
||||||
|
privateKey *rsa.PrivateKey
|
||||||
|
publicKey crypto.PublicKey
|
||||||
|
issuer string
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewOIDCService(config OIDCServiceConfig, queries *repository.Queries) *OIDCService {
|
||||||
|
return &OIDCService{
|
||||||
|
config: config,
|
||||||
|
queries: queries,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: A cleanup routine is needed to clean up expired tokens/code/userinfo
|
||||||
|
|
||||||
|
func (service *OIDCService) Init() error {
|
||||||
|
// Ensure issuer is https
|
||||||
|
uissuer, err := url.Parse(service.config.Issuer)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if uissuer.Scheme != "https" {
|
||||||
|
return errors.New("issuer must be https")
|
||||||
|
}
|
||||||
|
|
||||||
|
service.issuer = fmt.Sprintf("%s://%s", uissuer.Scheme, uissuer.Host)
|
||||||
|
|
||||||
|
// Create/load private and public keys
|
||||||
|
if strings.TrimSpace(service.config.PrivateKeyPath) == "" ||
|
||||||
|
strings.TrimSpace(service.config.PublicKeyPath) == "" {
|
||||||
|
return errors.New("private key path and public key path are required")
|
||||||
|
}
|
||||||
|
|
||||||
|
var privateKey *rsa.PrivateKey
|
||||||
|
|
||||||
|
fprivateKey, err := os.ReadFile(service.config.PrivateKeyPath)
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
privateKey, err = rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
der := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: der,
|
||||||
|
})
|
||||||
|
err = os.WriteFile(service.config.PrivateKeyPath, encoded, 0600)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
|
} else {
|
||||||
|
block, _ := pem.Decode(fprivateKey)
|
||||||
|
privateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service.privateKey = privateKey
|
||||||
|
}
|
||||||
|
|
||||||
|
fpublicKey, err := os.ReadFile(service.config.PublicKeyPath)
|
||||||
|
|
||||||
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if errors.Is(err, os.ErrNotExist) {
|
||||||
|
publicKey := service.privateKey.Public()
|
||||||
|
der := x509.MarshalPKCS1PublicKey(publicKey.(*rsa.PublicKey))
|
||||||
|
encoded := pem.EncodeToMemory(&pem.Block{
|
||||||
|
Type: "RSA PUBLIC KEY",
|
||||||
|
Bytes: der,
|
||||||
|
})
|
||||||
|
err = os.WriteFile(service.config.PublicKeyPath, encoded, 0644)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
|
} else {
|
||||||
|
block, _ := pem.Decode(fpublicKey)
|
||||||
|
publicKey, err := x509.ParsePKCS1PublicKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service.publicKey = publicKey
|
||||||
|
}
|
||||||
|
|
||||||
|
// We will reorganize the client into a map with the client ID as the key
|
||||||
|
service.clients = make(map[string]config.OIDCClientConfig)
|
||||||
|
|
||||||
|
for id, client := range service.config.Clients {
|
||||||
|
client.ID = id
|
||||||
|
service.clients[client.ClientID] = client
|
||||||
|
}
|
||||||
|
|
||||||
|
// Load the client secrets from files if they exist
|
||||||
|
for id, client := range service.clients {
|
||||||
|
secret := utils.GetSecret(client.ClientSecret, client.ClientSecretFile)
|
||||||
|
if secret != "" {
|
||||||
|
client.ClientSecret = secret
|
||||||
|
}
|
||||||
|
client.ClientSecretFile = ""
|
||||||
|
service.clients[id] = client
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetIssuer() string {
|
||||||
|
return service.config.Issuer
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetClient(id string) (config.OIDCClientConfig, bool) {
|
||||||
|
client, ok := service.clients[id]
|
||||||
|
return client, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) ValidateAuthorizeParams(req AuthorizeRequest) error {
|
||||||
|
// Validate client ID
|
||||||
|
client, ok := service.GetClient(req.ClientID)
|
||||||
|
if !ok {
|
||||||
|
return errors.New("access_denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scopes
|
||||||
|
scopes := strings.Split(req.Scope, " ")
|
||||||
|
|
||||||
|
if len(scopes) == 0 || strings.TrimSpace(req.Scope) == "" {
|
||||||
|
return errors.New("invalid_scope")
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, scope := range scopes {
|
||||||
|
if strings.TrimSpace(scope) == "" {
|
||||||
|
return errors.New("invalid_scope")
|
||||||
|
}
|
||||||
|
if !slices.Contains(SupportedScopes, scope) {
|
||||||
|
tlog.App.Warn().Str("scope", scope).Msg("Unsupported OIDC scope, will be ignored")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Response type
|
||||||
|
if !slices.Contains(SupportedResponseTypes, req.ResponseType) {
|
||||||
|
return errors.New("unsupported_response_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Redirect URI
|
||||||
|
if !slices.Contains(client.TrustedRedirectURIs, req.RedirectURI) {
|
||||||
|
return errors.New("invalid_request_uri")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) filterScopes(scopes []string) []string {
|
||||||
|
return utils.Filter(scopes, func(scope string) bool {
|
||||||
|
return slices.Contains(SupportedScopes, scope)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) StoreCode(c *gin.Context, sub string, code string, req AuthorizeRequest) error {
|
||||||
|
// Fixed 10 minutes
|
||||||
|
expiresAt := time.Now().Add(time.Minute * time.Duration(10)).Unix()
|
||||||
|
|
||||||
|
// Insert the code into the database
|
||||||
|
_, err := service.queries.CreateOidcCode(c, repository.CreateOidcCodeParams{
|
||||||
|
Sub: sub,
|
||||||
|
Code: code,
|
||||||
|
// Here it's safe to split and trust the output since, we validated the scopes before
|
||||||
|
Scope: strings.Join(service.filterScopes(strings.Split(req.Scope, " ")), ","),
|
||||||
|
RedirectURI: req.RedirectURI,
|
||||||
|
ClientID: req.ClientID,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) StoreUserinfo(c *gin.Context, sub string, userContext config.UserContext, req AuthorizeRequest) error {
|
||||||
|
userInfoParams := repository.CreateOidcUserInfoParams{
|
||||||
|
Sub: sub,
|
||||||
|
Name: userContext.Name,
|
||||||
|
Email: userContext.Email,
|
||||||
|
PreferredUsername: userContext.Username,
|
||||||
|
UpdatedAt: time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tinyauth will pass through the groups it got from an LDAP or an OIDC server
|
||||||
|
if userContext.Provider == "ldap" {
|
||||||
|
userInfoParams.Groups = userContext.LdapGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
if userContext.OAuth && len(userContext.OAuthGroups) > 0 {
|
||||||
|
userInfoParams.Groups = userContext.OAuthGroups
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err := service.queries.CreateOidcUserInfo(c, userInfoParams)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) ValidateGrantType(grantType string) error {
|
||||||
|
if !slices.Contains(SupportedGrantTypes, grantType) {
|
||||||
|
return errors.New("unsupported_response_type")
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetCodeEntry(c *gin.Context, code string) (repository.OidcCode, error) {
|
||||||
|
oidcCode, err := service.queries.GetOidcCode(c, code)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
|
return repository.OidcCode{}, ErrCodeNotFound
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if time.Now().Unix() > oidcCode.ExpiresAt {
|
||||||
|
err = service.queries.DeleteOidcCode(c, code)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, err
|
||||||
|
}
|
||||||
|
err = service.DeleteUserinfo(c, oidcCode.Sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcCode{}, err
|
||||||
|
}
|
||||||
|
return repository.OidcCode{}, ErrCodeExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
return oidcCode, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) generateIDToken(client config.OIDCClientConfig, sub string) (string, error) {
|
||||||
|
createdAt := time.Now().Unix()
|
||||||
|
|
||||||
|
// TODO: This should probably be user-configured if refresh logic does not exist
|
||||||
|
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
||||||
|
|
||||||
|
claims := jws.ClaimSet{
|
||||||
|
Iss: service.issuer,
|
||||||
|
Aud: client.ClientID,
|
||||||
|
Sub: sub,
|
||||||
|
Iat: createdAt,
|
||||||
|
Exp: expiresAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
header := jws.Header{
|
||||||
|
Algorithm: "RS256",
|
||||||
|
Typ: "JWT",
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jws.Encode(&header, &claims, service.privateKey)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GenerateAccessToken(c *gin.Context, client config.OIDCClientConfig, sub string, scope string) (TokenResponse, error) {
|
||||||
|
idToken, err := service.generateIDToken(client, sub)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
accessToken := rand.Text()
|
||||||
|
expiresAt := time.Now().Add(time.Duration(1) * time.Hour).Unix()
|
||||||
|
|
||||||
|
tokenResponse := TokenResponse{
|
||||||
|
AccessToken: accessToken,
|
||||||
|
TokenType: "Bearer",
|
||||||
|
ExpiresIn: int64(time.Hour.Seconds()),
|
||||||
|
IDToken: idToken,
|
||||||
|
Scope: strings.ReplaceAll(scope, ",", " "),
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = service.queries.CreateOidcToken(c, repository.CreateOidcTokenParams{
|
||||||
|
Sub: sub,
|
||||||
|
AccessToken: accessToken,
|
||||||
|
Scope: scope,
|
||||||
|
ExpiresAt: expiresAt,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return TokenResponse{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tokenResponse, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteCodeEntry(c *gin.Context, code string) error {
|
||||||
|
return service.queries.DeleteOidcCode(c, code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteUserinfo(c *gin.Context, sub string) error {
|
||||||
|
return service.queries.DeleteOidcUserInfo(c, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) DeleteToken(c *gin.Context, token string) error {
|
||||||
|
return service.queries.DeleteOidcToken(c, token)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetAccessToken(c *gin.Context, token string) (repository.OidcToken, error) {
|
||||||
|
entry, err := service.queries.GetOidcToken(c, token)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
if err == sql.ErrNoRows {
|
||||||
|
return repository.OidcToken{}, ErrTokenNotFound
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if entry.ExpiresAt < time.Now().Unix() {
|
||||||
|
err := service.DeleteToken(c, token)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, err
|
||||||
|
}
|
||||||
|
err = service.DeleteUserinfo(c, entry.Sub)
|
||||||
|
if err != nil {
|
||||||
|
return repository.OidcToken{}, err
|
||||||
|
}
|
||||||
|
return repository.OidcToken{}, ErrTokenExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
return entry, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) GetUserinfo(c *gin.Context, sub string) (repository.OidcUserinfo, error) {
|
||||||
|
return service.queries.GetOidcUserInfo(c, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (service *OIDCService) CompileUserinfo(user repository.OidcUserinfo, scope string) UserinfoResponse {
|
||||||
|
scopes := strings.Split(scope, ",") // split by comma since it's a db entry
|
||||||
|
userInfo := UserinfoResponse{
|
||||||
|
Sub: user.Sub,
|
||||||
|
UpdatedAt: user.UpdatedAt,
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(scopes, "profile") {
|
||||||
|
userInfo.Name = user.Name
|
||||||
|
userInfo.PreferredUsername = user.PreferredUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(scopes, "email") {
|
||||||
|
userInfo.Email = user.Email
|
||||||
|
}
|
||||||
|
|
||||||
|
if slices.Contains(scopes, "groups") {
|
||||||
|
userInfo.Groups = strings.Split(user.Groups, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
return userInfo
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user