tests: use require instead of assert where previous step is required

This commit is contained in:
Stavros
2026-05-09 13:28:22 +03:00
parent 9fccb63097
commit c7e9fade03
4 changed files with 73 additions and 71 deletions
@@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
"github.com/tinyauthapp/tinyauth/internal/model" "github.com/tinyauthapp/tinyauth/internal/model"
"github.com/tinyauthapp/tinyauth/internal/utils" "github.com/tinyauthapp/tinyauth/internal/utils"
@@ -44,7 +45,7 @@ func TestContextController(t *testing.T) {
WarningsEnabled: cfg.UI.WarningsEnabled, WarningsEnabled: cfg.UI.WarningsEnabled,
} }
bytes, err := json.Marshal(expectedAppContextResponse) bytes, err := json.Marshal(expectedAppContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -58,7 +59,7 @@ func TestContextController(t *testing.T) {
Message: "Unauthorized", Message: "Unauthorized",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -91,7 +92,7 @@ func TestContextController(t *testing.T) {
Provider: "local", Provider: "local",
} }
bytes, err := json.Marshal(expectedUserContextResponse) bytes, err := json.Marshal(expectedUserContextResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -113,7 +114,7 @@ func TestContextController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest("GET", test.path, nil) request, err := http.NewRequest("GET", test.path, nil)
assert.NoError(t, err) require.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
@@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/tinyauthapp/tinyauth/internal/controller" "github.com/tinyauthapp/tinyauth/internal/controller"
) )
@@ -28,7 +29,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -42,7 +43,7 @@ func TestHealthController(t *testing.T) {
"message": "Healthy", "message": "Healthy",
} }
bytes, err := json.Marshal(expectedHealthResponse) bytes, err := json.Marshal(expectedHealthResponse)
assert.NoError(t, err) require.NoError(t, err)
return string(bytes) return string(bytes)
}(), }(),
}, },
@@ -59,7 +60,7 @@ func TestHealthController(t *testing.T) {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
request, err := http.NewRequest(test.method, test.path, nil) request, err := http.NewRequest(test.method, test.path, nil)
assert.NoError(t, err) require.NoError(t, err)
router.ServeHTTP(recorder, request) router.ServeHTTP(recorder, request)
+50 -50
View File
@@ -89,7 +89,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid") assert.Equal(t, res["redirect_uri"], "https://tinyauth.example.com/error?error=User+is+not+logged+in+or+the+session+is+invalid")
}, },
@@ -109,7 +109,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -117,7 +117,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state") assert.Equal(t, res["redirect_uri"], "https://test.example.com/callback?error=unsupported_response_type&error_description=Invalid+request+parameters&state=some-state")
}, },
@@ -137,7 +137,7 @@ func TestOIDCController(t *testing.T) {
Nonce: "some-nonce", Nonce: "some-nonce",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -146,11 +146,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -169,7 +169,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -177,7 +177,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, res["error"], "unsupported_grant_type") assert.Equal(t, res["error"], "unsupported_grant_type")
}, },
@@ -192,7 +192,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -230,7 +230,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -253,11 +253,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -269,7 +269,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -292,7 +292,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := tokenRes["refresh_token"] _, ok := tokenRes["refresh_token"]
assert.True(t, ok, "Expected refresh token in response") assert.True(t, ok, "Expected refresh token in response")
@@ -306,7 +306,7 @@ func TestOIDCController(t *testing.T) {
ClientSecret: "some-client-secret", ClientSecret: "some-client-secret",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -318,7 +318,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
var refreshRes map[string]any var refreshRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes) err = json.Unmarshal(recorder.Body.Bytes(), &refreshRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok = refreshRes["access_token"] _, ok = refreshRes["access_token"]
assert.True(t, ok, "Expected access token in refresh response") assert.True(t, ok, "Expected access token in refresh response")
@@ -339,11 +339,11 @@ func TestOIDCController(t *testing.T) {
var authorizeRes map[string]any var authorizeRes map[string]any
err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes) err := json.Unmarshal(authorizeTestRecorder.Body.Bytes(), &authorizeRes)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := authorizeRes["redirect_uri"].(string) redirectURI := authorizeRes["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -355,7 +355,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -375,7 +375,7 @@ func TestOIDCController(t *testing.T) {
var secondRes map[string]any var secondRes map[string]any
err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes) err = json.Unmarshal(secondRecorder.Body.Bytes(), &secondRes)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", secondRes["error"]) assert.Equal(t, "invalid_grant", secondRes["error"])
}, },
@@ -403,7 +403,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -415,7 +415,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -435,7 +435,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -450,7 +450,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -465,7 +465,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -480,7 +480,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
@@ -495,7 +495,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -510,7 +510,7 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_request", res["error"]) assert.Equal(t, "invalid_request", res["error"])
}, },
}, },
@@ -527,7 +527,7 @@ func TestOIDCController(t *testing.T) {
var tokenRes map[string]any var tokenRes map[string]any
err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes) err := json.Unmarshal(tokenRecorder.Body.Bytes(), &tokenRes)
assert.NoError(t, err) require.NoError(t, err)
accessToken := tokenRes["access_token"].(string) accessToken := tokenRes["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -541,7 +541,7 @@ func TestOIDCController(t *testing.T) {
var userInfoRes map[string]any var userInfoRes map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes) err = json.Unmarshal(recorder.Body.Bytes(), &userInfoRes)
assert.NoError(t, err) require.NoError(t, err)
_, ok := userInfoRes["sub"] _, ok := userInfoRes["sub"]
assert.True(t, ok, "Expected sub claim in userinfo response") assert.True(t, ok, "Expected sub claim in userinfo response")
@@ -565,7 +565,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "", CodeChallengeMethod: "",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -574,11 +574,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -595,7 +595,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -626,7 +626,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -635,11 +635,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -656,7 +656,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge", CodeVerifier: "some-challenge",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -687,7 +687,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "S256", CodeChallengeMethod: "S256",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -696,11 +696,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
assert.Equal(t, queryParams.Get("state"), "some-state") assert.Equal(t, queryParams.Get("state"), "some-state")
@@ -717,7 +717,7 @@ func TestOIDCController(t *testing.T) {
CodeVerifier: "some-challenge-1", CodeVerifier: "some-challenge-1",
} }
reqBodyEncoded, err := query.Values(tokenReqBody) reqBodyEncoded, err := query.Values(tokenReqBody)
assert.NoError(t, err) require.NoError(t, err)
req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req = httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -748,7 +748,7 @@ func TestOIDCController(t *testing.T) {
CodeChallengeMethod: "foo", CodeChallengeMethod: "foo",
} }
reqBodyBytes, err := json.Marshal(reqBody) reqBodyBytes, err := json.Marshal(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes))) req := httptest.NewRequest("POST", "/api/oidc/authorize", strings.NewReader(string(reqBodyBytes)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -757,11 +757,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
error := queryParams.Get("error") error := queryParams.Get("error")
@@ -780,11 +780,11 @@ func TestOIDCController(t *testing.T) {
var res map[string]any var res map[string]any
err := json.Unmarshal(recorder.Body.Bytes(), &res) err := json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
redirectURI := res["redirect_uri"].(string) redirectURI := res["redirect_uri"].(string)
url, err := url.Parse(redirectURI) url, err := url.Parse(redirectURI)
assert.NoError(t, err) require.NoError(t, err)
queryParams := url.Query() queryParams := url.Query()
code := queryParams.Get("code") code := queryParams.Get("code")
@@ -796,7 +796,7 @@ func TestOIDCController(t *testing.T) {
RedirectURI: "https://test.example.com/callback", RedirectURI: "https://test.example.com/callback",
} }
reqBodyEncoded, err := query.Values(reqBody) reqBodyEncoded, err := query.Values(reqBody)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode())) req := httptest.NewRequest("POST", "/api/oidc/token", strings.NewReader(reqBodyEncoded.Encode()))
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -807,7 +807,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
accessToken := res["access_token"].(string) accessToken := res["access_token"].(string)
assert.NotEmpty(t, accessToken) assert.NotEmpty(t, accessToken)
@@ -832,7 +832,7 @@ func TestOIDCController(t *testing.T) {
assert.Equal(t, 401, recorder.Code) assert.Equal(t, 401, recorder.Code)
err = json.Unmarshal(recorder.Body.Bytes(), &res) err = json.Unmarshal(recorder.Body.Bytes(), &res)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "invalid_grant", res["error"]) assert.Equal(t, "invalid_grant", res["error"])
}, },
}, },
+14 -14
View File
@@ -95,7 +95,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -103,7 +103,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -123,7 +123,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -144,7 +144,7 @@ func TestUserController(t *testing.T) {
Password: "wrongpassword", Password: "wrongpassword",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
for range 3 { for range 3 {
recorder := httptest.NewRecorder() recorder := httptest.NewRecorder()
@@ -179,7 +179,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -190,12 +190,12 @@ func TestUserController(t *testing.T) {
decodedBody := make(map[string]any) decodedBody := make(map[string]any)
err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody) err = json.Unmarshal(recorder.Body.Bytes(), &decodedBody)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, decodedBody["totpPending"], true) assert.Equal(t, decodedBody["totpPending"], true)
// should set the session cookie // should set the session cookie
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
cookie := recorder.Result().Cookies()[0] cookie := recorder.Result().Cookies()[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
assert.True(t, cookie.HttpOnly) assert.True(t, cookie.HttpOnly)
@@ -216,7 +216,7 @@ func TestUserController(t *testing.T) {
Password: "password", Password: "password",
} }
loginReqBody, err := json.Marshal(loginReq) loginReqBody, err := json.Marshal(loginReq)
assert.NoError(t, err) require.NoError(t, err)
req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody))) req := httptest.NewRequest("POST", "/api/user/login", strings.NewReader(string(loginReqBody)))
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -225,7 +225,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies := recorder.Result().Cookies() cookies := recorder.Result().Cookies()
assert.Len(t, cookies, 1) require.Len(t, cookies, 1)
cookie := cookies[0] cookie := cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -239,7 +239,7 @@ func TestUserController(t *testing.T) {
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
cookies = recorder.Result().Cookies() cookies = recorder.Result().Cookies()
assert.Len(t, cookies, 1) require.Len(t, cookies, 1)
cookie = cookies[0] cookie = cookies[0]
assert.Equal(t, "tinyauth-session", cookie.Name) assert.Equal(t, "tinyauth-session", cookie.Name)
@@ -266,14 +266,14 @@ func TestUserController(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now()) code, err := totp.GenerateCode("JPIEBDKJH6UGWJMX66RR3S55UFP2SGKK", time.Now())
assert.NoError(t, err) require.NoError(t, err)
totpReq := controller.TotpRequest{ totpReq := controller.TotpRequest{
Code: code, Code: code,
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) require.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))
@@ -288,7 +288,7 @@ func TestUserController(t *testing.T) {
router.ServeHTTP(recorder, req) router.ServeHTTP(recorder, req)
assert.Equal(t, 200, recorder.Code) assert.Equal(t, 200, recorder.Code)
assert.Len(t, recorder.Result().Cookies(), 1) require.Len(t, recorder.Result().Cookies(), 1)
// should set a new session cookie with totp pending removed // should set a new session cookie with totp pending removed
totpCookie := recorder.Result().Cookies()[0] totpCookie := recorder.Result().Cookies()[0]
@@ -311,7 +311,7 @@ func TestUserController(t *testing.T) {
} }
totpReqBody, err := json.Marshal(totpReq) totpReqBody, err := json.Marshal(totpReq)
assert.NoError(t, err) require.NoError(t, err)
recorder = httptest.NewRecorder() recorder = httptest.NewRecorder()
req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody))) req := httptest.NewRequest("POST", "/api/user/totp", strings.NewReader(string(totpReqBody)))