diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 67a25b3..094fe66 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -8,6 +8,7 @@ import ( "reflect" "strings" "testing" + "time" "tinyauth/internal/auth" "tinyauth/internal/docker" "tinyauth/internal/handlers" @@ -17,6 +18,7 @@ import ( "tinyauth/internal/types" "github.com/magiconair/properties/assert" + "github.com/pquerna/otp/totp" ) // Simple server config for tests @@ -33,7 +35,7 @@ var handlersConfig = types.HandlersConfig{ CookieSecure: false, Title: "Tinyauth", GenericName: "Generic", - ForgotPasswordMessage: "Some message", + ForgotPasswordMessage: "Message", CsrfCookieName: "tinyauth-csrf", RedirectCookieName: "tinyauth-redirect", BackgroundImage: "https://example.com/image.png", @@ -44,8 +46,8 @@ var handlersConfig = types.HandlersConfig{ var authConfig = types.AuthConfig{ Users: types.Users{}, OauthWhitelist: "", - HMACSecret: "super-secret-api-thing-for-test1", - EncryptionSecret: "super-secret-api-thing-for-test2", + HMACSecret: "4bZ9K.*:;zH=,9zG!meUxu.B5-S[7.V.", // Complex on purpose + EncryptionSecret: "\\:!R(u[Sbv6ZLm.7es)H|OqH4y}0u\\rj", CookieSecure: false, SessionExpiry: 3600, LoginTimeout: 0, @@ -60,7 +62,7 @@ var hooksConfig = types.HooksConfig{ } // Cookie -var cookie string +var cookie = "MTc1MTkyMzM5MnxiME9aTzlGQjZMNEJMdDZMc0lHMk9zcXQyME9SR1ZnUmlaYWZNcWplek5vcVNpdkdHRTZqb09YWkVUYUN6NEt4MkEyOGEyX2hFQWZEUEYtbllDX0h5eDBCb3VyT2phQlRpZWFfRFdTMGw2WUg2VWw4RGdNbEhQclotOUJjblJGaWFQcmhyaWFna0dXRWNud2c1akg5eEpLZ3JzS0pfWktscVZyckZFR1VDX0R5QjFOT0hzMTNKb18ySEMxZlluSWNxa1ByM0VhSzNyMkRtdDNORWJXVGFYSnMzWjFGa0lrZlhSTWduRmttMHhQUXN4UFhNbHFXY0lBWjBnUWpKU0xXMHRubjlKbjV0LXBGdjk0MmpJX0xMX1ZYblVJVW9LWUJoWmpNanVXNkNjamhYWlR2V29rY0RNYWkxY2lMQnpqLUI2cHMyYTZkWWgtWnlFdGN0amh2WURUeUNGT3ZLS1FJVUFIb0NWR1RPMlRtY2c9PXwerwFtb9urOXnwA02qXbLeorMloaK_paQd0in4BAesmg==" // User var user = types.User{ @@ -68,7 +70,7 @@ var user = types.User{ Password: "$2a$10$AvGHLTYv3xiRJ0xV9xs3XeVIlkGTygI9nqIamFYB5Xu.5.0UWF7B6", // pass } -// We need all this to be able to test the server +// Initialize the server for tests func getServer(t *testing.T) *server.Server { // Create docker service docker, err := docker.NewDocker() @@ -80,8 +82,9 @@ func getServer(t *testing.T) *server.Server { // Create auth service authConfig.Users = types.Users{ { - Username: user.Username, - Password: user.Password, + Username: user.Username, + Password: user.Password, + TotpSecret: user.TotpSecret, }, } auth := auth.NewAuth(authConfig, docker, nil) @@ -111,7 +114,7 @@ func TestLogin(t *testing.T) { t.Log("Testing login") // Get server - api := getServer(t) + srv := getServer(t) // Create recorder recorder := httptest.NewRecorder() @@ -138,18 +141,21 @@ func TestLogin(t *testing.T) { } // Serve the request - api.Router.ServeHTTP(recorder, req) + srv.Router.ServeHTTP(recorder, req) // Assert assert.Equal(t, recorder.Code, http.StatusOK) - // Get the cookie - cookie = recorder.Result().Cookies()[0].Value + // Get the result cookie + cookies := recorder.Result().Cookies() // Check if the cookie is set - if cookie == "" { + if len(cookies) == 0 { t.Fatalf("Cookie not set") } + + // Set the cookie for further tests + cookie = cookies[0].Value } // Test app context @@ -157,7 +163,7 @@ func TestAppContext(t *testing.T) { t.Log("Testing app context") // Get server - api := getServer(t) + srv := getServer(t) // Create recorder recorder := httptest.NewRecorder() @@ -177,7 +183,7 @@ func TestAppContext(t *testing.T) { }) // Serve the request - api.Router.ServeHTTP(recorder, req) + srv.Router.ServeHTTP(recorder, req) // Assert assert.Equal(t, recorder.Code, http.StatusOK) @@ -208,7 +214,7 @@ func TestAppContext(t *testing.T) { DisableContinue: false, Title: "Tinyauth", GenericName: "Generic", - ForgotPasswordMessage: "Some message", + ForgotPasswordMessage: "Message", BackgroundImage: "https://example.com/image.png", OAuthAutoRedirect: "none", Domain: "localhost", @@ -222,10 +228,13 @@ func TestAppContext(t *testing.T) { // Test user context func TestUserContext(t *testing.T) { + // Refresh the cookie + TestLogin(t) + t.Log("Testing user context") // Get server - api := getServer(t) + srv := getServer(t) // Create recorder recorder := httptest.NewRecorder() @@ -245,7 +254,7 @@ func TestUserContext(t *testing.T) { }) // Serve the request - api.Router.ServeHTTP(recorder, req) + srv.Router.ServeHTTP(recorder, req) // Assert assert.Equal(t, recorder.Code, http.StatusOK) @@ -280,10 +289,13 @@ func TestUserContext(t *testing.T) { // Test logout func TestLogout(t *testing.T) { + // Refresh the cookie + TestLogin(t) + t.Log("Testing logout") // Get server - api := getServer(t) + srv := getServer(t) // Create recorder recorder := httptest.NewRecorder() @@ -298,18 +310,212 @@ func TestLogout(t *testing.T) { // Set the cookie req.AddCookie(&http.Cookie{ - Name: "tinyauth", + Name: "tinyauth-session", Value: cookie, }) // Serve the request - api.Router.ServeHTTP(recorder, req) + srv.Router.ServeHTTP(recorder, req) // Assert assert.Equal(t, recorder.Code, http.StatusOK) - // Check if the cookie is different (means go sessions flushed it) + // Check if the cookie is different (means the cookie is gone) if recorder.Result().Cookies()[0].Value == cookie { t.Fatalf("Cookie not flushed") } } + +// Test auth endpoint +func TestAuth(t *testing.T) { + // Refresh the cookie + TestLogin(t) + + t.Log("Testing auth endpoint") + + // Get server + srv := getServer(t) + + // Create recorder + recorder := httptest.NewRecorder() + + // Create request + req, err := http.NewRequest("GET", "/api/auth/traefik", nil) + + // Set the accept header + req.Header.Set("Accept", "text/html") + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Serve the request + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusTemporaryRedirect) + + // Recreate recorder + recorder = httptest.NewRecorder() + + // Recreate the request + req, err = http.NewRequest("GET", "/api/auth/traefik", nil) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Test with the cookie + req.AddCookie(&http.Cookie{ + Name: "tinyauth-session", + Value: cookie, + }) + + // Serve the request again + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusOK) + + // Recreate recorder + recorder = httptest.NewRecorder() + + // Recreate the request + req, err = http.NewRequest("GET", "/api/auth/nginx", nil) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Serve the request again + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusUnauthorized) + + // Recreate recorder + recorder = httptest.NewRecorder() + + // Recreate the request + req, err = http.NewRequest("GET", "/api/auth/nginx", nil) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Test with the cookie + req.AddCookie(&http.Cookie{ + Name: "tinyauth-session", + Value: cookie, + }) + + // Serve the request again + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusOK) +} + +func TestTOTP(t *testing.T) { + t.Log("Testing TOTP") + + // Generate totp secret + key, err := totp.Generate(totp.GenerateOpts{ + Issuer: "Tinyauth", + AccountName: user.Username, + }) + + if err != nil { + t.Fatalf("Failed to generate TOTP secret: %v", err) + } + + // Create secret + secret := key.Secret() + + // Set the user's TOTP secret + user.TotpSecret = secret + + // Get server + srv := getServer(t) + + // Create request + user := types.LoginRequest{ + Username: "user", + Password: "pass", + } + + loginJson, err := json.Marshal(user) + + // Check if there was an error + if err != nil { + t.Fatalf("Error marshalling json: %v", err) + } + + // Create recorder + recorder := httptest.NewRecorder() + + // Create request + req, err := http.NewRequest("POST", "/api/login", strings.NewReader(string(loginJson))) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Serve the request + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusOK) + + // Set the cookie for next test + cookie = recorder.Result().Cookies()[0].Value + + // Create TOTP code + code, err := totp.GenerateCode(secret, time.Now()) + + // Check if there was an error + if err != nil { + t.Fatalf("Failed to generate TOTP code: %v", err) + } + + // Create TOTP request + totpRequest := types.TotpRequest{ + Code: code, + } + + // Marshal the TOTP request + totpJson, err := json.Marshal(totpRequest) + + // Check if there was an error + if err != nil { + t.Fatalf("Error marshalling TOTP request: %v", err) + } + + // Create recorder + recorder = httptest.NewRecorder() + + // Create request + req, err = http.NewRequest("POST", "/api/totp", strings.NewReader(string(totpJson))) + + // Check if there was an error + if err != nil { + t.Fatalf("Error creating request: %v", err) + } + + // Set the cookie + req.AddCookie(&http.Cookie{ + Name: "tinyauth-session", + Value: cookie, + }) + + // Serve the request + srv.Router.ServeHTTP(recorder, req) + + // Assert + assert.Equal(t, recorder.Code, http.StatusOK) +} diff --git a/internal/utils/utils_test.go b/internal/utils/utils_test.go index 552e27f..57423ec 100644 --- a/internal/utils/utils_test.go +++ b/internal/utils/utils_test.go @@ -315,25 +315,6 @@ func TestGetLabels(t *testing.T) { } } -// Test the filter function -func TestFilter(t *testing.T) { - t.Log("Testing filter helper") - - // Create variables - data := []string{"", "val1", "", "val2", "", "val3", ""} - expected := []string{"val1", "val2", "val3"} - - // Test the filter function - result := utils.Filter(data, func(val string) bool { - return val != "" - }) - - // Check if the result is equal to the expected - if !reflect.DeepEqual(expected, result) { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - // Test parse user func TestParseUser(t *testing.T) { t.Log("Testing parse user with a valid user") @@ -474,37 +455,6 @@ func TestCheckWhitelist(t *testing.T) { } } -// Test capitalize -func TestCapitalize(t *testing.T) { - t.Log("Testing capitalize with a valid string") - - // Create variables - str := "test" - expected := "Test" - - // Test the capitalize function - result := utils.Capitalize(str) - - // Check if the result is equal to the expected - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } - - t.Log("Testing capitalize with an empty string") - - // Create variables - str = "" - expected = "" - - // Test the capitalize function - result = utils.Capitalize(str) - - // Check if the result is equal to the expected - if result != expected { - t.Fatalf("Expected %v, got %v", expected, result) - } -} - // Test the header sanitizer func TestSanitizeHeader(t *testing.T) { t.Log("Testing sanitize header with a valid string") @@ -535,3 +485,170 @@ func TestSanitizeHeader(t *testing.T) { t.Fatalf("Expected %v, got %v", expected, result) } } + +// Test the parse headers function +func TestParseHeaders(t *testing.T) { + t.Log("Testing parse headers with a valid string") + + // Create variables + headers := []string{"X-Hea\tder1=value1", "X-Header2=value\n2"} + expected := map[string]string{ + "X-Header1": "value1", + "X-Header2": "value2", + } + + // Test the parse headers function + result := utils.ParseHeaders(headers) + + // Check if the result is equal to the expected + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing parse headers with an invalid string") + + // Create variables + headers = []string{"X-Header1=", "X-Header2", "=value", "X-Header3=value3"} + expected = map[string]string{"X-Header3": "value3"} + + // Test the parse headers function + result = utils.ParseHeaders(headers) + + // Check if the result is equal to the expected + if !reflect.DeepEqual(expected, result) { + t.Fatalf("Expected %v, got %v", expected, result) + } +} + +// Test the parse secret file function +func TestParseSecretFile(t *testing.T) { + t.Log("Testing parse secret file with a valid file") + + // Create variables + content := "\n\n \n\n\n secret \n\n \n " + expected := "secret" + + // Test the parse secret file function + result := utils.ParseSecretFile(content) + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } +} + +// Test the filter IP function +func TestFilterIP(t *testing.T) { + t.Log("Testing filter IP with an IP and a valid CIDR") + + // Create variables + ip := "10.10.10.10" + filter := "10.10.10.0/24" + expected := true + + // Test the filter IP function + result, err := utils.FilterIP(filter, ip) + + // Check if there was an error + if err != nil { + t.Fatalf("Error filtering IP: %v", err) + } + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing filter IP with an IP and a valid IP") + + // Create variables + filter = "10.10.10.10" + expected = true + + // Test the filter IP function + result, err = utils.FilterIP(filter, ip) + + // Check if there was an error + if err != nil { + t.Fatalf("Error filtering IP: %v", err) + } + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing filter IP with an IP and an non matching CIDR") + + // Create variables + filter = "10.10.15.0/24" + expected = false + + // Test the filter IP function + result, err = utils.FilterIP(filter, ip) + + // Check if there was an error + if err != nil { + t.Fatalf("Error filtering IP: %v", err) + } + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing filter IP with a non matching IP and a valid CIDR") + + // Create variables + filter = "10.10.10.11" + expected = false + + // Test the filter IP function + result, err = utils.FilterIP(filter, ip) + + // Check if there was an error + if err != nil { + t.Fatalf("Error filtering IP: %v", err) + } + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } + + t.Log("Testing filter IP with an IP and an invalid CIDR") + + // Create variables + filter = "10.../83" + + // Test the filter IP function + _, err = utils.FilterIP(filter, ip) + + // Check if there was an error + if err == nil { + t.Fatalf("Expected error filtering IP") + } +} + +// Test the derive key function +func TestDeriveKey(t *testing.T) { + t.Log("Testing the derive key function") + + // Create variables + master := "master" + info := "info" + expected := "gdrdU/fXzclYjiSXRexEatVgV13qQmKl" + + // Test the derive key function + result, err := utils.DeriveKey(master, info) + + // Check if there was an error + if err != nil { + t.Fatalf("Error deriving key: %v", err) + } + + // Check if the result is equal to the expected + if result != expected { + t.Fatalf("Expected %v, got %v", expected, result) + } +}