From 31d1f06ffaf0f547257eab79b5096ae14964aa8b Mon Sep 17 00:00:00 2001 From: Michael Mayer Date: Sun, 31 Mar 2024 14:45:17 +0200 Subject: [PATCH] Security: Refactor rate limits for failed authentication request #808 Signed-off-by: Michael Mayer --- internal/api/session.go | 19 +++++---- internal/api/session_create.go | 17 +++++++- internal/api/session_oauth.go | 22 ++++++---- internal/api/users_passcode.go | 40 ++++++++++++++--- internal/api/users_password.go | 21 ++++++--- internal/config/config_server.go | 5 ++- internal/entity/auth_session.go | 5 ++- internal/entity/auth_session_login.go | 12 ++---- internal/server/limiter/const.go | 5 +++ internal/server/limiter/limit.go | 33 +++++++++++--- internal/server/limiter/limit_test.go | 43 +++++++++++++++++++ internal/server/limiter/request.go | 50 ++++++++++++++++++++++ internal/server/limiter/request_test.go | 57 +++++++++++++++++++++++++ internal/server/webdav_auth.go | 27 ++++++------ internal/server/webdav_auth_session.go | 18 ++++++++ pkg/authn/errors.go | 23 ++++++---- 16 files changed, 328 insertions(+), 69 deletions(-) create mode 100644 internal/server/limiter/const.go create mode 100644 internal/server/limiter/request.go create mode 100644 internal/server/limiter/request_test.go diff --git a/internal/api/session.go b/internal/api/session.go index 471f1a59c..2b7cf2ec2 100644 --- a/internal/api/session.go +++ b/internal/api/session.go @@ -19,19 +19,22 @@ func Session(clientIp, authToken string) *entity.Session { return nil } - // Fail if authentication error rate limit is exceeded. - if clientIp != "" && limiter.Auth.Reject(clientIp) { + // Check request rate limit. + r := limiter.Auth.Request(clientIp) + + if r.Reject() { return nil } // Find the session based on the hashed auth token, or return nil otherwise. - if s, err := entity.FindSession(rnd.SessionID(authToken)); err != nil { - if clientIp != "" { - limiter.Auth.Reserve(clientIp) - } + s, err := entity.FindSession(rnd.SessionID(authToken)) + if err != nil { return nil - } else { - return s } + + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + + return s } diff --git a/internal/api/session_create.go b/internal/api/session_create.go index fdbe1eebb..2d09623e2 100644 --- a/internal/api/session_create.go +++ b/internal/api/session_create.go @@ -56,8 +56,16 @@ func CreateSession(router *gin.RouterGroup) { return } - // Fail if authentication error rate limit is exceeded. - if clientIp != "" && (limiter.Login.Reject(clientIp) || limiter.Auth.Reject(clientIp)) { + // Check request rate limit. + var r *limiter.Request + if f.Passcode == "" { + r = limiter.Login.Request(clientIp) + } else { + r = limiter.Login.RequestN(clientIp, 3) + } + + // Abort if failure rate limit is exceeded. + if r.Reject() || limiter.Auth.Reject(clientIp) { limiter.AbortJSON(c) return } @@ -81,6 +89,8 @@ func CreateSession(router *gin.RouterGroup) { c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) } else if errors.Is(err, authn.ErrPasscodeRequired) { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrPasscodeRequired, "message": i18n.Msg(i18n.ErrPasscodeRequired)}) + // Return the reserved request rate limit tokens if password is correct, even if the verification code is missing. + r.Success() } else { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrInvalidPasscode, "message": i18n.Msg(i18n.ErrInvalidPasscode)}) } @@ -98,6 +108,9 @@ func CreateSession(router *gin.RouterGroup) { event.AuditInfo([]string{clientIp, "session %s", "updated"}, sess.RefID) } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Response includes user data, session data, and client config values. response := CreateSessionResponse(sess.AuthToken(), sess, conf.ClientSession(sess)) diff --git a/internal/api/session_oauth.go b/internal/api/session_oauth.go index 3907b3f47..fcc0c834e 100644 --- a/internal/api/session_oauth.go +++ b/internal/api/session_oauth.go @@ -35,7 +35,7 @@ func CreateOAuthToken(router *gin.RouterGroup) { if get.Config().Public() { // Abort if running in public mode. - event.AuditErr([]string{clientIp, "client", "create session", "oauth2", "disabled in public mode"}) + event.AuditErr([]string{clientIp, "client", "create session", "oauth2", authn.ErrDisabledInPublicMode.Error()}) AbortForbidden(c) return } @@ -65,8 +65,11 @@ func CreateOAuthToken(router *gin.RouterGroup) { // Disable caching of responses. c.Header(header.CacheControl, header.CacheControlNoStore) - // Fail if authentication error rate limit is exceeded. - if clientIp != "" && (limiter.Login.Reject(clientIp) || limiter.Auth.Reject(clientIp)) { + // Check request rate limit. + r := limiter.Login.Request(clientIp) + + // Abort if request rate limit is exceeded. + if r.Reject() || limiter.Auth.Reject(clientIp) { limiter.AbortJSON(c) return } @@ -76,12 +79,11 @@ func CreateOAuthToken(router *gin.RouterGroup) { // Abort if the client ID or secret are invalid. if client == nil { - event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "invalid client id"}, f.ClientID) + event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientID.Error()}, f.ClientID) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) - limiter.Login.Reserve(clientIp) return } else if !client.AuthEnabled { - event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "authentication disabled"}, f.ClientID) + event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrAuthenticationDisabled.Error()}, f.ClientID) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) return } else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 { @@ -89,16 +91,18 @@ func CreateOAuthToken(router *gin.RouterGroup) { c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) return } else if client.WrongSecret(f.ClientSecret) { - event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "invalid client secret"}, f.ClientID) + event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientSecret.Error()}, f.ClientID) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) - limiter.Login.Reserve(clientIp) return } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Create new client session. sess := client.NewSession(c) - // Try to log in and save session if successful. + // Save new client session. if sess, err = get.Session().Save(sess); err != nil { event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", "%s"}, f.ClientID, err) c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) diff --git a/internal/api/users_passcode.go b/internal/api/users_passcode.go index 865dbb31e..1cfd0d993 100644 --- a/internal/api/users_passcode.go +++ b/internal/api/users_passcode.go @@ -32,13 +32,23 @@ func CreateUserPasscode(router *gin.RouterGroup) { return } + // Check request rate limit. + r := limiter.Login.Request(ClientIP(c)) + + if r.Reject() { + limiter.AbortJSON(c) + return + } + // Check if the account password is correct. if user.WrongPassword(frm.Password) { - limiter.Login.Reserve(ClientIP(c)) Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) return } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Get config. conf := get.Config() @@ -76,7 +86,15 @@ func ConfirmUserPasscode(router *gin.RouterGroup) { return } - // Verify new passcode. + // Check request rate limit. + r := limiter.Login.RequestN(ClientIP(c), 3) + + if r.Reject() { + limiter.AbortJSON(c) + return + } + + // Verify passcode. valid, passcode, err := user.VerifyPasscode(frm.Passcode) if err != nil { @@ -84,12 +102,14 @@ func ConfirmUserPasscode(router *gin.RouterGroup) { Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) return } else if !valid { - event.AuditWarn([]string{ClientIP(c), "session %s", "users", user.UserName, "incorrect passcode"}, s.RefID) - limiter.Login.ReserveN(ClientIP(c), 3) + event.AuditWarn([]string{ClientIP(c), "session %s", "users", user.UserName, authn.ErrInvalidPasscode.Error()}, s.RefID) Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) return } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "verified"}, s.RefID) // Clear session cache. @@ -147,13 +167,23 @@ func DeactivateUserPasscode(router *gin.RouterGroup) { return } + // Check request rate limit. + r := limiter.Login.Request(ClientIP(c)) + + if r.Reject() { + limiter.AbortJSON(c) + return + } + // Check if the account password is correct. if user.WrongPassword(frm.Password) { - limiter.Login.Reserve(ClientIP(c)) Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) return } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Delete passcode. if _, err := user.DeactivatePasscode(); err != nil { event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to deactivate passcode", clean.Error(err)}, s.RefID) diff --git a/internal/api/users_password.go b/internal/api/users_password.go index f25907122..dd43bd53f 100644 --- a/internal/api/users_password.go +++ b/internal/api/users_password.go @@ -29,12 +29,6 @@ func UpdateUserPassword(router *gin.RouterGroup) { return } - // Check limit for failed auth requests (max. 10 per minute). - if limiter.Login.Reject(ClientIP(c)) { - limiter.AbortJSON(c) - return - } - // Get session. s := Auth(c, acl.ResourcePassword, acl.ActionUpdate) @@ -42,6 +36,17 @@ func UpdateUserPassword(router *gin.RouterGroup) { return } + // Get client IP address. + clientIp := ClientIP(c) + + // Check request rate limit. + r := limiter.Login.Request(clientIp) + + if r.Reject() { + limiter.AbortJSON(c) + return + } + // Check if the current user has management privileges. isAdmin := acl.Rules.AllowAll(acl.ResourceUsers, s.UserRole(), acl.Permissions{acl.AccessAll, acl.ActionManage}) isSuperAdmin := isAdmin && s.User().IsSuperAdmin() @@ -73,11 +78,13 @@ func UpdateUserPassword(router *gin.RouterGroup) { if isSuperAdmin && f.OldPassword == "" { // Do nothing. } else if u.WrongPassword(f.OldPassword) { - limiter.Login.Reserve(ClientIP(c)) Abort(c, http.StatusBadRequest, i18n.ErrInvalidPassword) return } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Set new password. if err := u.SetPassword(f.NewPassword); err != nil { Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword) diff --git a/internal/config/config_server.go b/internal/config/config_server.go index 257dab3a1..8355bb6c1 100644 --- a/internal/config/config_server.go +++ b/internal/config/config_server.go @@ -5,6 +5,7 @@ import ( "regexp" "strings" + "github.com/photoprism/photoprism/internal/server/limiter" "github.com/photoprism/photoprism/internal/ttl" "github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/header" @@ -120,9 +121,9 @@ func (c *Config) HttpVideoMaxAge() ttl.Duration { // HttpHost returns the built-in HTTP server host name or IP address (empty for all interfaces). func (c *Config) HttpHost() string { - // when unix socket used as host, make host as default value. or http client will act weirdly. + // Set http host to "0.0.0.0" if unix socket is used to serve requests. if c.options.HttpHost == "" { - return "0.0.0.0" + return limiter.DefaultIP } return c.options.HttpHost diff --git a/internal/entity/auth_session.go b/internal/entity/auth_session.go index 5d7440342..9c53cb72c 100644 --- a/internal/entity/auth_session.go +++ b/internal/entity/auth_session.go @@ -14,6 +14,7 @@ import ( "github.com/photoprism/photoprism/internal/acl" "github.com/photoprism/photoprism/internal/event" + "github.com/photoprism/photoprism/internal/server/limiter" "github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/header" @@ -28,7 +29,7 @@ import ( // SessionPrefix for RefID. const ( SessionPrefix = "sess" - UnknownIP = "0.0.0.0" + UnknownIP = limiter.DefaultIP ) // Sessions represents a list of sessions. @@ -976,7 +977,7 @@ func (m *Session) IP() string { if m.ClientIP != "" { return m.ClientIP } else { - return "0.0.0.0" + return UnknownIP } } diff --git a/internal/entity/auth_session_login.go b/internal/entity/auth_session_login.go index dd971f169..a979dfb1d 100644 --- a/internal/entity/auth_session_login.go +++ b/internal/entity/auth_session_login.go @@ -10,7 +10,6 @@ import ( "github.com/photoprism/photoprism/internal/acl" "github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/form" - "github.com/photoprism/photoprism/internal/server/limiter" "github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/header" @@ -84,7 +83,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a // Check if user account exists. if user == nil { message := authn.ErrAccountNotFound.Error() - limiter.Login.Reserve(clientIp) if m != nil { event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName)) @@ -122,7 +120,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a if authSess, authUser, authErr := AuthSession(f, c); authSess != nil && authUser != nil && authErr == nil { if !authUser.IsRegistered() || authUser.UserUID != user.UserUID { message := authn.ErrInvalidUsername.Error() - limiter.Login.Reserve(clientIp) event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName)) event.LoginError(clientIp, "api", userName, m.UserAgent, message) m.Status = http.StatusUnauthorized @@ -134,7 +131,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a } else { message = authn.ErrUnauthorized.Error() } - limiter.Login.Reserve(clientIp) event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName)) event.LoginError(clientIp, "api", userName, m.UserAgent, message) m.Status = http.StatusUnauthorized @@ -155,7 +151,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a // Otherwise, check account password. if user.WrongPassword(f.Password) { message := authn.ErrInvalidPassword.Error() - limiter.Login.Reserve(clientIp) if m != nil { event.AuditErr([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName)) @@ -171,10 +166,8 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a // Perform two-factor authentication check, if required. if method = user.Method(); method.Is(authn.Method2FA) { if valid, _, passcodeErr := user.VerifyPasscode(f.Passcode); passcodeErr != nil { - limiter.Login.Reserve(clientIp) return provider, method, passcodeErr } else if !valid { - limiter.Login.ReserveN(clientIp, 3) return provider, method, authn.ErrInvalidPasscode } } else if method == authn.MethodUndefined { @@ -195,6 +188,9 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) { m.SetContext(c) } + // r := limiter.Login.Reserve(m.IP()) + // r.Cancel() + var user *User var provider authn.ProviderType var method authn.MethodType @@ -225,7 +221,6 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) { if user.IsRegistered() { if shares := user.RedeemToken(f.ShareToken); shares == 0 { message := authn.ErrInvalidShareToken.Error() - limiter.Login.Reserve(m.IP()) event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID) m.Status = http.StatusNotFound return i18n.Error(i18n.ErrInvalidLink) @@ -237,7 +232,6 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) { return i18n.Error(i18n.ErrUnexpected) } else if shares := data.RedeemToken(f.ShareToken); shares == 0 { message := authn.ErrInvalidShareToken.Error() - limiter.Login.Reserve(m.IP()) event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID) event.LoginError(m.IP(), "api", "", m.UserAgent, message) m.Status = http.StatusNotFound diff --git a/internal/server/limiter/const.go b/internal/server/limiter/const.go new file mode 100644 index 000000000..5f85150b4 --- /dev/null +++ b/internal/server/limiter/const.go @@ -0,0 +1,5 @@ +package limiter + +const ( + DefaultIP = "0.0.0.0" +) diff --git a/internal/server/limiter/limit.go b/internal/server/limiter/limit.go index c13a37193..782fad3a5 100644 --- a/internal/server/limiter/limit.go +++ b/internal/server/limiter/limit.go @@ -29,6 +29,10 @@ func NewLimit(r rate.Limit, b int) *Limit { // AddIP adds a new rate limiter for the specified IP address. func (i *Limit) AddIP(ip string) *rate.Limiter { + if ip == "" { + ip = DefaultIP + } + i.mu.Lock() defer i.mu.Unlock() @@ -41,24 +45,43 @@ func (i *Limit) AddIP(ip string) *rate.Limiter { // IP returns the rate limiter for the specified IP address. func (i *Limit) IP(ip string) *rate.Limiter { - i.mu.Lock() + if ip == "" { + ip = DefaultIP + } + + i.mu.RLock() limiter, exists := i.limiters[ip] if !exists { - i.mu.Unlock() + i.mu.RUnlock() return i.AddIP(ip) } - i.mu.Unlock() + i.mu.RUnlock() return limiter } -// Allow reports whether the request is allowed at this time and increments the request counter. +// Allow checks if a new request is allowed at this time and increments the request counter by 1. func (i *Limit) Allow(ip string) bool { return i.IP(ip).Allow() } +// AllowN checks if a new request is allowed at this time and increments the request counter by n. +func (i *Limit) AllowN(ip string, n int) bool { + return i.IP(ip).AllowN(time.Now(), n) +} + +// Request tries to increment the request counter and returns the result as new *Request. +func (i *Limit) Request(ip string) *Request { + return NewRequest(i.IP(ip), 1) +} + +// RequestN tries to increment the request counter by n and returns the result as new *Request. +func (i *Limit) RequestN(ip string, n int) *Request { + return NewRequest(i.IP(ip), n) +} + // Reserve increments the request counter and returns a rate.Reservation. func (i *Limit) Reserve(ip string) *rate.Reservation { return i.IP(ip).Reserve() @@ -69,7 +92,7 @@ func (i *Limit) ReserveN(ip string, n int) *rate.Reservation { return i.IP(ip).ReserveN(time.Now(), n) } -// Reject reports whether the request limit has been exceeded, but does not change the request counter. +// Reject checks if the request rate limit has been exceeded, but does not modify the counter. func (i *Limit) Reject(ip string) bool { return i.IP(ip).Tokens() < 1 } diff --git a/internal/server/limiter/limit_test.go b/internal/server/limiter/limit_test.go index 5e5cd154a..72f97ee3f 100644 --- a/internal/server/limiter/limit_test.go +++ b/internal/server/limiter/limit_test.go @@ -73,4 +73,47 @@ func TestNewLimit(t *testing.T) { assert.True(t, l.Reject(clientIp)) } }) + t.Run("Request", func(t *testing.T) { + // 10 per minute. + l := NewLimit(0.166, 10) + + // Request counter not increased. + for i := 0; i < 20; i++ { + assert.False(t, l.Reject(clientIp)) + } + + // Request not exceeded and tokens returned by calling Success(). + for i := 1; i <= 20; i++ { + reject := l.Reject(clientIp) + r := l.Request(clientIp) + allow := r.Allow() + r.Success() + t.Logf("(1.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens) + assert.False(t, reject) + assert.True(t, allow) + assert.False(t, r.Reject()) + } + + // Limit not exceeded, but tokens not returned. + for i := 1; i <= 10; i++ { + reject := l.Reject(clientIp) + r := l.Request(clientIp) + allow := r.Allow() + t.Logf("(2.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens) + assert.False(t, reject) + assert.True(t, allow) + assert.False(t, r.Reject()) + } + + // Limit exceeded and tokens not returned. + for i := 1; i <= 20; i++ { + reject := l.Reject(clientIp) + r := l.Request(clientIp) + allow := r.Allow() + t.Logf("(3.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens) + assert.True(t, reject) + assert.False(t, allow) + assert.True(t, r.Reject()) + } + }) } diff --git a/internal/server/limiter/request.go b/internal/server/limiter/request.go new file mode 100644 index 000000000..9af3dd6e6 --- /dev/null +++ b/internal/server/limiter/request.go @@ -0,0 +1,50 @@ +package limiter + +import ( + "time" + + "golang.org/x/time/rate" +) + +// Request represents a request for the specified number of limiter tokens. +type Request struct { + allow bool + limiter *rate.Limiter + Tokens int +} + +// NewRequest checks if a request is allowed, reserves the required tokens, +// and returns a new Request to revert the reservation if successful. +func NewRequest(l *rate.Limiter, n int) *Request { + if l.AllowN(time.Now(), n) { + return &Request{ + allow: true, + limiter: l, + Tokens: n, + } + } else { + return &Request{ + allow: false, + limiter: l, + Tokens: 0, + } + } +} + +// Allow checks if the request is allowed. +func (r *Request) Allow() bool { + return r.allow +} + +// Reject returns true if the request should be rejected. +func (r *Request) Reject() bool { + return !r.allow +} + +// Success returns the rate limit tokens that have been reserved for this request, if any. +func (r *Request) Success() { + if r.Tokens != 0 && r.limiter != nil { + r.limiter.ReserveN(time.Now(), -1*r.Tokens) + r.Tokens = 0 + } +} diff --git a/internal/server/limiter/request_test.go b/internal/server/limiter/request_test.go new file mode 100644 index 000000000..2e5083463 --- /dev/null +++ b/internal/server/limiter/request_test.go @@ -0,0 +1,57 @@ +package limiter + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNewRequest(t *testing.T) { + clientIp := "192.0.2.1" + l := NewLimit(0.166, 10).IP(clientIp) + r := NewRequest(l, 9) + assert.True(t, r.Allow()) + assert.False(t, r.Reject()) + assert.Equal(t, 9, r.Tokens) + r = NewRequest(l, 1) + assert.True(t, r.Allow()) + assert.False(t, r.Reject()) + assert.Equal(t, 1, r.Tokens) + r = NewRequest(l, 1) + assert.False(t, r.Allow()) + assert.True(t, r.Reject()) + assert.Equal(t, 0, r.Tokens) +} + +func TestRequest(t *testing.T) { + clientIp := "192.0.2.1" + + t.Run("Allow", func(t *testing.T) { + r := Request{allow: true} + assert.True(t, r.Allow()) + assert.False(t, r.Reject()) + }) + + t.Run("Reject", func(t *testing.T) { + r := Request{allow: false} + assert.False(t, r.Allow()) + assert.True(t, r.Reject()) + }) + + t.Run("Success", func(t *testing.T) { + l := NewLimit(0.166, 10).IP(clientIp) + r1 := NewRequest(l, 10) + assert.True(t, r1.Allow()) + assert.False(t, r1.Reject()) + assert.Equal(t, 10, r1.Tokens) + r2 := NewRequest(l, 10) + assert.False(t, r2.Allow()) + assert.True(t, r2.Reject()) + assert.Equal(t, 0, r2.Tokens) + r1.Success() + r3 := NewRequest(l, 10) + assert.True(t, r3.Allow()) + assert.False(t, r3.Reject()) + assert.Equal(t, 10, r3.Tokens) + }) +} diff --git a/internal/server/webdav_auth.go b/internal/server/webdav_auth.go index f7dda9a93..5157499c9 100644 --- a/internal/server/webdav_auth.go +++ b/internal/server/webdav_auth.go @@ -93,11 +93,7 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { } // Check webdav access authorization using an auth token or app password, if provided. - if limiter.Auth.Reject(clientIp) { - c.Header("WWW-Authenticate", BasicAuthRealm) - limiter.Abort(c) - return - } else if sess, user, sid, cached := WebDAVAuthSession(c, authToken); user != nil && cached { + if sess, user, sid, cached := WebDAVAuthSession(c, authToken); user != nil && cached { // Add user to request context to signal successful authentication if username is empty or matches. if username == "" || strings.EqualFold(clean.Username(username), user.Username()) { c.Set(gin.AuthUserKey, user) @@ -105,7 +101,6 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { } event.AuditErr([]string{clientIp, "access webdav as %s with authorization granted to %s", authn.Denied}, clean.Log(username), clean.Log(user.Username())) - limiter.Auth.Reserve(clientIp) WebDAVAbortUnauthorized(c) return } else if sess == nil { @@ -131,7 +126,6 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { // Log warning if WebDAV is disabled for this account. message := authn.ErrBasicAuthDoesNotMatch.Error() event.AuditWarn([]string{clientIp, "client %s", "session %s", "access webdav as %s", message}, clean.Log(sess.ClientInfo()), sess.RefID, clean.LogQuote(user.Username())) - limiter.Auth.Reserve(clientIp) WebDAVAbortUnauthorized(c) return } else if err := fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil { @@ -158,9 +152,11 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { return } - // Check the authentication request rate to block the client after - // too many failed attempts (10/req per minute by default). - if limiter.Login.Reject(clientIp) { + // Check request rate limit. + r := limiter.Login.Request(clientIp) + + // Abort if request rate limit is exceeded. + if r.Reject() || limiter.Auth.Reject(clientIp) { c.Header("WWW-Authenticate", BasicAuthRealm) limiter.Abort(c) return @@ -179,21 +175,25 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { if user, _, _, err := entity.Auth(f, nil, c); err != nil { // Abort if authentication has failed. message := authn.ErrInvalidCredentials.Error() - limiter.Login.Reserve(clientIp) event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) } else if user == nil { // Abort if account was not found. message := authn.ErrAccountNotFound.Error() - limiter.Login.Reserve(clientIp) event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) } else if !user.CanUseWebDAV() { + // Return the reserved request rate limit tokens, even if account isn't allowed to use WebDAV. + r.Success() + // Abort if WebDAV is disabled for this account. message := authn.ErrWebDAVAccessDisabled.Error() event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) } else if err = fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil { + // Return the reserved request rate limit tokens, even if path could not be created. + r.Success() + // Abort if upload path could not be created. message := authn.ErrFailedToCreateUploadPath.Error() event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) @@ -201,6 +201,9 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc { WebDAVAbortServerError(c) return } else { + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Log successful authentication. event.AuditInfo([]string{clientIp, "webdav login as %s", "succeeded"}, clean.LogQuote(username)) event.LoginInfo(clientIp, "webdav", username, api.UserAgent(c)) diff --git a/internal/server/webdav_auth_session.go b/internal/server/webdav_auth_session.go index b044fdacc..be9864a7d 100644 --- a/internal/server/webdav_auth_session.go +++ b/internal/server/webdav_auth_session.go @@ -5,6 +5,7 @@ import ( "github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/event" + "github.com/photoprism/photoprism/internal/server/limiter" "github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/rnd" ) @@ -19,11 +20,25 @@ func WebDAVAuthSession(c *gin.Context, authToken string) (sess *entity.Session, return nil, nil, "", false } + // Get client IP address. + clientIp := header.ClientIP(c) + + // Check request rate limit. + r := limiter.Auth.Request(clientIp) + + // Abort if failure rate limit is exceeded. + if r.Reject() { + return nil, nil, "", false + } + // Get session ID for the auth token provided. sid = rnd.SessionID(authToken) // Check if client authorization has been cached to improve performance. if cacheData, found := webdavAuthCache.Get(sid); found && cacheData != nil { + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Add cached user information to the request context. user = cacheData.(*entity.User) return nil, user, sid, true @@ -40,6 +55,9 @@ func WebDAVAuthSession(c *gin.Context, authToken string) (sess *entity.Session, return nil, nil, sid, false } + // Return the reserved request rate limit tokens after successful authentication. + r.Success() + // Update the client IP and the user agent from // the request context if they have changed. sess.UpdateContext(c) diff --git a/pkg/authn/errors.go b/pkg/authn/errors.go index 5f024c479..3231c2a40 100644 --- a/pkg/authn/errors.go +++ b/pkg/authn/errors.go @@ -9,14 +9,21 @@ import ( // Generic error messages for authentication and authorization: var ( - ErrUnauthorized = errors.New("unauthorized") - ErrAccountAlreadyExists = errors.New("account already exists") - ErrAccountNotFound = errors.New("account not found") - ErrAccountDisabled = errors.New("account disabled") - ErrInvalidCredentials = errors.New("invalid credentials") - ErrInvalidShareToken = errors.New("invalid share token") - ErrInsufficientScope = errors.New("insufficient scope") - ErrDisabledInPublicMode = errors.New("disabled in public mode") + ErrUnauthorized = errors.New("unauthorized") + ErrAccountAlreadyExists = errors.New("account already exists") + ErrAccountNotFound = errors.New("account not found") + ErrAccountDisabled = errors.New("account disabled") + ErrInvalidCredentials = errors.New("invalid credentials") + ErrInvalidShareToken = errors.New("invalid share token") + ErrInsufficientScope = errors.New("insufficient scope") + ErrDisabledInPublicMode = errors.New("disabled in public mode") + ErrAuthenticationDisabled = errors.New("authentication disabled") +) + +// OAuth2-related error messages: +var ( + ErrInvalidClientID = errors.New("invalid client id") + ErrInvalidClientSecret = errors.New("invalid client secret") ) // Username-related error messages: