Security: Refactor rate limits for failed authentication request #808

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2024-03-31 14:45:17 +02:00
parent 7336304828
commit 31d1f06ffa
16 changed files with 328 additions and 69 deletions

View File

@@ -19,19 +19,22 @@ func Session(clientIp, authToken string) *entity.Session {
return nil return nil
} }
// Fail if authentication error rate limit is exceeded. // Check request rate limit.
if clientIp != "" && limiter.Auth.Reject(clientIp) { r := limiter.Auth.Request(clientIp)
if r.Reject() {
return nil return nil
} }
// Find the session based on the hashed auth token, or return nil otherwise. // Find the session based on the hashed auth token, or return nil otherwise.
if s, err := entity.FindSession(rnd.SessionID(authToken)); err != nil { s, err := entity.FindSession(rnd.SessionID(authToken))
if clientIp != "" {
limiter.Auth.Reserve(clientIp)
}
if err != nil {
return nil return nil
} else {
return s
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
return s
} }

View File

@@ -56,8 +56,16 @@ func CreateSession(router *gin.RouterGroup) {
return return
} }
// Fail if authentication error rate limit is exceeded. // Check request rate limit.
if clientIp != "" && (limiter.Login.Reject(clientIp) || limiter.Auth.Reject(clientIp)) { var r *limiter.Request
if f.Passcode == "" {
r = limiter.Login.Request(clientIp)
} else {
r = limiter.Login.RequestN(clientIp, 3)
}
// Abort if failure rate limit is exceeded.
if r.Reject() || limiter.Auth.Reject(clientIp) {
limiter.AbortJSON(c) limiter.AbortJSON(c)
return return
} }
@@ -81,6 +89,8 @@ func CreateSession(router *gin.RouterGroup) {
c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
} else if errors.Is(err, authn.ErrPasscodeRequired) { } else if errors.Is(err, authn.ErrPasscodeRequired) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrPasscodeRequired, "message": i18n.Msg(i18n.ErrPasscodeRequired)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrPasscodeRequired, "message": i18n.Msg(i18n.ErrPasscodeRequired)})
// Return the reserved request rate limit tokens if password is correct, even if the verification code is missing.
r.Success()
} else { } else {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrInvalidPasscode, "message": i18n.Msg(i18n.ErrInvalidPasscode)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": err.Error(), "code": i18n.ErrInvalidPasscode, "message": i18n.Msg(i18n.ErrInvalidPasscode)})
} }
@@ -98,6 +108,9 @@ func CreateSession(router *gin.RouterGroup) {
event.AuditInfo([]string{clientIp, "session %s", "updated"}, sess.RefID) event.AuditInfo([]string{clientIp, "session %s", "updated"}, sess.RefID)
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Response includes user data, session data, and client config values. // Response includes user data, session data, and client config values.
response := CreateSessionResponse(sess.AuthToken(), sess, conf.ClientSession(sess)) response := CreateSessionResponse(sess.AuthToken(), sess, conf.ClientSession(sess))

View File

@@ -35,7 +35,7 @@ func CreateOAuthToken(router *gin.RouterGroup) {
if get.Config().Public() { if get.Config().Public() {
// Abort if running in public mode. // Abort if running in public mode.
event.AuditErr([]string{clientIp, "client", "create session", "oauth2", "disabled in public mode"}) event.AuditErr([]string{clientIp, "client", "create session", "oauth2", authn.ErrDisabledInPublicMode.Error()})
AbortForbidden(c) AbortForbidden(c)
return return
} }
@@ -65,8 +65,11 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Disable caching of responses. // Disable caching of responses.
c.Header(header.CacheControl, header.CacheControlNoStore) c.Header(header.CacheControl, header.CacheControlNoStore)
// Fail if authentication error rate limit is exceeded. // Check request rate limit.
if clientIp != "" && (limiter.Login.Reject(clientIp) || limiter.Auth.Reject(clientIp)) { r := limiter.Login.Request(clientIp)
// Abort if request rate limit is exceeded.
if r.Reject() || limiter.Auth.Reject(clientIp) {
limiter.AbortJSON(c) limiter.AbortJSON(c)
return return
} }
@@ -76,12 +79,11 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Abort if the client ID or secret are invalid. // Abort if the client ID or secret are invalid.
if client == nil { if client == nil {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "invalid client id"}, f.ClientID) event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientID.Error()}, f.ClientID)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
limiter.Login.Reserve(clientIp)
return return
} else if !client.AuthEnabled { } else if !client.AuthEnabled {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "authentication disabled"}, f.ClientID) event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrAuthenticationDisabled.Error()}, f.ClientID)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
return return
} else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 { } else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 {
@@ -89,16 +91,18 @@ func CreateOAuthToken(router *gin.RouterGroup) {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
return return
} else if client.WrongSecret(f.ClientSecret) { } else if client.WrongSecret(f.ClientSecret) {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "invalid client secret"}, f.ClientID) event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientSecret.Error()}, f.ClientID)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
limiter.Login.Reserve(clientIp)
return return
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Create new client session. // Create new client session.
sess := client.NewSession(c) sess := client.NewSession(c)
// Try to log in and save session if successful. // Save new client session.
if sess, err = get.Session().Save(sess); err != nil { if sess, err = get.Session().Save(sess); err != nil {
event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", "%s"}, f.ClientID, err) event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", "%s"}, f.ClientID, err)
c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})

View File

@@ -32,13 +32,23 @@ func CreateUserPasscode(router *gin.RouterGroup) {
return return
} }
// Check request rate limit.
r := limiter.Login.Request(ClientIP(c))
if r.Reject() {
limiter.AbortJSON(c)
return
}
// Check if the account password is correct. // Check if the account password is correct.
if user.WrongPassword(frm.Password) { if user.WrongPassword(frm.Password) {
limiter.Login.Reserve(ClientIP(c))
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword)
return return
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Get config. // Get config.
conf := get.Config() conf := get.Config()
@@ -76,7 +86,15 @@ func ConfirmUserPasscode(router *gin.RouterGroup) {
return return
} }
// Verify new passcode. // Check request rate limit.
r := limiter.Login.RequestN(ClientIP(c), 3)
if r.Reject() {
limiter.AbortJSON(c)
return
}
// Verify passcode.
valid, passcode, err := user.VerifyPasscode(frm.Passcode) valid, passcode, err := user.VerifyPasscode(frm.Passcode)
if err != nil { if err != nil {
@@ -84,12 +102,14 @@ func ConfirmUserPasscode(router *gin.RouterGroup) {
Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode)
return return
} else if !valid { } else if !valid {
event.AuditWarn([]string{ClientIP(c), "session %s", "users", user.UserName, "incorrect passcode"}, s.RefID) event.AuditWarn([]string{ClientIP(c), "session %s", "users", user.UserName, authn.ErrInvalidPasscode.Error()}, s.RefID)
limiter.Login.ReserveN(ClientIP(c), 3)
Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode)
return return
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "verified"}, s.RefID) event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "verified"}, s.RefID)
// Clear session cache. // Clear session cache.
@@ -147,13 +167,23 @@ func DeactivateUserPasscode(router *gin.RouterGroup) {
return return
} }
// Check request rate limit.
r := limiter.Login.Request(ClientIP(c))
if r.Reject() {
limiter.AbortJSON(c)
return
}
// Check if the account password is correct. // Check if the account password is correct.
if user.WrongPassword(frm.Password) { if user.WrongPassword(frm.Password) {
limiter.Login.Reserve(ClientIP(c))
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword)
return return
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Delete passcode. // Delete passcode.
if _, err := user.DeactivatePasscode(); err != nil { if _, err := user.DeactivatePasscode(); err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to deactivate passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to deactivate passcode", clean.Error(err)}, s.RefID)

View File

@@ -29,12 +29,6 @@ func UpdateUserPassword(router *gin.RouterGroup) {
return return
} }
// Check limit for failed auth requests (max. 10 per minute).
if limiter.Login.Reject(ClientIP(c)) {
limiter.AbortJSON(c)
return
}
// Get session. // Get session.
s := Auth(c, acl.ResourcePassword, acl.ActionUpdate) s := Auth(c, acl.ResourcePassword, acl.ActionUpdate)
@@ -42,6 +36,17 @@ func UpdateUserPassword(router *gin.RouterGroup) {
return return
} }
// Get client IP address.
clientIp := ClientIP(c)
// Check request rate limit.
r := limiter.Login.Request(clientIp)
if r.Reject() {
limiter.AbortJSON(c)
return
}
// Check if the current user has management privileges. // Check if the current user has management privileges.
isAdmin := acl.Rules.AllowAll(acl.ResourceUsers, s.UserRole(), acl.Permissions{acl.AccessAll, acl.ActionManage}) isAdmin := acl.Rules.AllowAll(acl.ResourceUsers, s.UserRole(), acl.Permissions{acl.AccessAll, acl.ActionManage})
isSuperAdmin := isAdmin && s.User().IsSuperAdmin() isSuperAdmin := isAdmin && s.User().IsSuperAdmin()
@@ -73,11 +78,13 @@ func UpdateUserPassword(router *gin.RouterGroup) {
if isSuperAdmin && f.OldPassword == "" { if isSuperAdmin && f.OldPassword == "" {
// Do nothing. // Do nothing.
} else if u.WrongPassword(f.OldPassword) { } else if u.WrongPassword(f.OldPassword) {
limiter.Login.Reserve(ClientIP(c))
Abort(c, http.StatusBadRequest, i18n.ErrInvalidPassword) Abort(c, http.StatusBadRequest, i18n.ErrInvalidPassword)
return return
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Set new password. // Set new password.
if err := u.SetPassword(f.NewPassword); err != nil { if err := u.SetPassword(f.NewPassword); err != nil {
Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword) Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword)

View File

@@ -5,6 +5,7 @@ import (
"regexp" "regexp"
"strings" "strings"
"github.com/photoprism/photoprism/internal/server/limiter"
"github.com/photoprism/photoprism/internal/ttl" "github.com/photoprism/photoprism/internal/ttl"
"github.com/photoprism/photoprism/pkg/fs" "github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/header"
@@ -120,9 +121,9 @@ func (c *Config) HttpVideoMaxAge() ttl.Duration {
// HttpHost returns the built-in HTTP server host name or IP address (empty for all interfaces). // HttpHost returns the built-in HTTP server host name or IP address (empty for all interfaces).
func (c *Config) HttpHost() string { func (c *Config) HttpHost() string {
// when unix socket used as host, make host as default value. or http client will act weirdly. // Set http host to "0.0.0.0" if unix socket is used to serve requests.
if c.options.HttpHost == "" { if c.options.HttpHost == "" {
return "0.0.0.0" return limiter.DefaultIP
} }
return c.options.HttpHost return c.options.HttpHost

View File

@@ -14,6 +14,7 @@ import (
"github.com/photoprism/photoprism/internal/acl" "github.com/photoprism/photoprism/internal/acl"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/server/limiter"
"github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/header"
@@ -28,7 +29,7 @@ import (
// SessionPrefix for RefID. // SessionPrefix for RefID.
const ( const (
SessionPrefix = "sess" SessionPrefix = "sess"
UnknownIP = "0.0.0.0" UnknownIP = limiter.DefaultIP
) )
// Sessions represents a list of sessions. // Sessions represents a list of sessions.
@@ -976,7 +977,7 @@ func (m *Session) IP() string {
if m.ClientIP != "" { if m.ClientIP != "" {
return m.ClientIP return m.ClientIP
} else { } else {
return "0.0.0.0" return UnknownIP
} }
} }

View File

@@ -10,7 +10,6 @@ import (
"github.com/photoprism/photoprism/internal/acl" "github.com/photoprism/photoprism/internal/acl"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/form" "github.com/photoprism/photoprism/internal/form"
"github.com/photoprism/photoprism/internal/server/limiter"
"github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/header"
@@ -84,7 +83,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
// Check if user account exists. // Check if user account exists.
if user == nil { if user == nil {
message := authn.ErrAccountNotFound.Error() message := authn.ErrAccountNotFound.Error()
limiter.Login.Reserve(clientIp)
if m != nil { if m != nil {
event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName)) event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName))
@@ -122,7 +120,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
if authSess, authUser, authErr := AuthSession(f, c); authSess != nil && authUser != nil && authErr == nil { if authSess, authUser, authErr := AuthSession(f, c); authSess != nil && authUser != nil && authErr == nil {
if !authUser.IsRegistered() || authUser.UserUID != user.UserUID { if !authUser.IsRegistered() || authUser.UserUID != user.UserUID {
message := authn.ErrInvalidUsername.Error() message := authn.ErrInvalidUsername.Error()
limiter.Login.Reserve(clientIp)
event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName)) event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName))
event.LoginError(clientIp, "api", userName, m.UserAgent, message) event.LoginError(clientIp, "api", userName, m.UserAgent, message)
m.Status = http.StatusUnauthorized m.Status = http.StatusUnauthorized
@@ -134,7 +131,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
} else { } else {
message = authn.ErrUnauthorized.Error() message = authn.ErrUnauthorized.Error()
} }
limiter.Login.Reserve(clientIp)
event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName)) event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(userName))
event.LoginError(clientIp, "api", userName, m.UserAgent, message) event.LoginError(clientIp, "api", userName, m.UserAgent, message)
m.Status = http.StatusUnauthorized m.Status = http.StatusUnauthorized
@@ -155,7 +151,6 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
// Otherwise, check account password. // Otherwise, check account password.
if user.WrongPassword(f.Password) { if user.WrongPassword(f.Password) {
message := authn.ErrInvalidPassword.Error() message := authn.ErrInvalidPassword.Error()
limiter.Login.Reserve(clientIp)
if m != nil { if m != nil {
event.AuditErr([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName)) event.AuditErr([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(userName))
@@ -171,10 +166,8 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
// Perform two-factor authentication check, if required. // Perform two-factor authentication check, if required.
if method = user.Method(); method.Is(authn.Method2FA) { if method = user.Method(); method.Is(authn.Method2FA) {
if valid, _, passcodeErr := user.VerifyPasscode(f.Passcode); passcodeErr != nil { if valid, _, passcodeErr := user.VerifyPasscode(f.Passcode); passcodeErr != nil {
limiter.Login.Reserve(clientIp)
return provider, method, passcodeErr return provider, method, passcodeErr
} else if !valid { } else if !valid {
limiter.Login.ReserveN(clientIp, 3)
return provider, method, authn.ErrInvalidPasscode return provider, method, authn.ErrInvalidPasscode
} }
} else if method == authn.MethodUndefined { } else if method == authn.MethodUndefined {
@@ -195,6 +188,9 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) {
m.SetContext(c) m.SetContext(c)
} }
// r := limiter.Login.Reserve(m.IP())
// r.Cancel()
var user *User var user *User
var provider authn.ProviderType var provider authn.ProviderType
var method authn.MethodType var method authn.MethodType
@@ -225,7 +221,6 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) {
if user.IsRegistered() { if user.IsRegistered() {
if shares := user.RedeemToken(f.ShareToken); shares == 0 { if shares := user.RedeemToken(f.ShareToken); shares == 0 {
message := authn.ErrInvalidShareToken.Error() message := authn.ErrInvalidShareToken.Error()
limiter.Login.Reserve(m.IP())
event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID) event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID)
m.Status = http.StatusNotFound m.Status = http.StatusNotFound
return i18n.Error(i18n.ErrInvalidLink) return i18n.Error(i18n.ErrInvalidLink)
@@ -237,7 +232,6 @@ func (m *Session) LogIn(f form.Login, c *gin.Context) (err error) {
return i18n.Error(i18n.ErrUnexpected) return i18n.Error(i18n.ErrUnexpected)
} else if shares := data.RedeemToken(f.ShareToken); shares == 0 { } else if shares := data.RedeemToken(f.ShareToken); shares == 0 {
message := authn.ErrInvalidShareToken.Error() message := authn.ErrInvalidShareToken.Error()
limiter.Login.Reserve(m.IP())
event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID) event.AuditWarn([]string{m.IP(), "session %s", message}, m.RefID)
event.LoginError(m.IP(), "api", "", m.UserAgent, message) event.LoginError(m.IP(), "api", "", m.UserAgent, message)
m.Status = http.StatusNotFound m.Status = http.StatusNotFound

View File

@@ -0,0 +1,5 @@
package limiter
const (
DefaultIP = "0.0.0.0"
)

View File

@@ -29,6 +29,10 @@ func NewLimit(r rate.Limit, b int) *Limit {
// AddIP adds a new rate limiter for the specified IP address. // AddIP adds a new rate limiter for the specified IP address.
func (i *Limit) AddIP(ip string) *rate.Limiter { func (i *Limit) AddIP(ip string) *rate.Limiter {
if ip == "" {
ip = DefaultIP
}
i.mu.Lock() i.mu.Lock()
defer i.mu.Unlock() defer i.mu.Unlock()
@@ -41,24 +45,43 @@ func (i *Limit) AddIP(ip string) *rate.Limiter {
// IP returns the rate limiter for the specified IP address. // IP returns the rate limiter for the specified IP address.
func (i *Limit) IP(ip string) *rate.Limiter { func (i *Limit) IP(ip string) *rate.Limiter {
i.mu.Lock() if ip == "" {
ip = DefaultIP
}
i.mu.RLock()
limiter, exists := i.limiters[ip] limiter, exists := i.limiters[ip]
if !exists { if !exists {
i.mu.Unlock() i.mu.RUnlock()
return i.AddIP(ip) return i.AddIP(ip)
} }
i.mu.Unlock() i.mu.RUnlock()
return limiter return limiter
} }
// Allow reports whether the request is allowed at this time and increments the request counter. // Allow checks if a new request is allowed at this time and increments the request counter by 1.
func (i *Limit) Allow(ip string) bool { func (i *Limit) Allow(ip string) bool {
return i.IP(ip).Allow() return i.IP(ip).Allow()
} }
// AllowN checks if a new request is allowed at this time and increments the request counter by n.
func (i *Limit) AllowN(ip string, n int) bool {
return i.IP(ip).AllowN(time.Now(), n)
}
// Request tries to increment the request counter and returns the result as new *Request.
func (i *Limit) Request(ip string) *Request {
return NewRequest(i.IP(ip), 1)
}
// RequestN tries to increment the request counter by n and returns the result as new *Request.
func (i *Limit) RequestN(ip string, n int) *Request {
return NewRequest(i.IP(ip), n)
}
// Reserve increments the request counter and returns a rate.Reservation. // Reserve increments the request counter and returns a rate.Reservation.
func (i *Limit) Reserve(ip string) *rate.Reservation { func (i *Limit) Reserve(ip string) *rate.Reservation {
return i.IP(ip).Reserve() return i.IP(ip).Reserve()
@@ -69,7 +92,7 @@ func (i *Limit) ReserveN(ip string, n int) *rate.Reservation {
return i.IP(ip).ReserveN(time.Now(), n) return i.IP(ip).ReserveN(time.Now(), n)
} }
// Reject reports whether the request limit has been exceeded, but does not change the request counter. // Reject checks if the request rate limit has been exceeded, but does not modify the counter.
func (i *Limit) Reject(ip string) bool { func (i *Limit) Reject(ip string) bool {
return i.IP(ip).Tokens() < 1 return i.IP(ip).Tokens() < 1
} }

View File

@@ -73,4 +73,47 @@ func TestNewLimit(t *testing.T) {
assert.True(t, l.Reject(clientIp)) assert.True(t, l.Reject(clientIp))
} }
}) })
t.Run("Request", func(t *testing.T) {
// 10 per minute.
l := NewLimit(0.166, 10)
// Request counter not increased.
for i := 0; i < 20; i++ {
assert.False(t, l.Reject(clientIp))
}
// Request not exceeded and tokens returned by calling Success().
for i := 1; i <= 20; i++ {
reject := l.Reject(clientIp)
r := l.Request(clientIp)
allow := r.Allow()
r.Success()
t.Logf("(1.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens)
assert.False(t, reject)
assert.True(t, allow)
assert.False(t, r.Reject())
}
// Limit not exceeded, but tokens not returned.
for i := 1; i <= 10; i++ {
reject := l.Reject(clientIp)
r := l.Request(clientIp)
allow := r.Allow()
t.Logf("(2.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens)
assert.False(t, reject)
assert.True(t, allow)
assert.False(t, r.Reject())
}
// Limit exceeded and tokens not returned.
for i := 1; i <= 20; i++ {
reject := l.Reject(clientIp)
r := l.Request(clientIp)
allow := r.Allow()
t.Logf("(3.%d) Reject: %t, Allow: %t, Tokens: %d", i, reject, allow, r.Tokens)
assert.True(t, reject)
assert.False(t, allow)
assert.True(t, r.Reject())
}
})
} }

View File

@@ -0,0 +1,50 @@
package limiter
import (
"time"
"golang.org/x/time/rate"
)
// Request represents a request for the specified number of limiter tokens.
type Request struct {
allow bool
limiter *rate.Limiter
Tokens int
}
// NewRequest checks if a request is allowed, reserves the required tokens,
// and returns a new Request to revert the reservation if successful.
func NewRequest(l *rate.Limiter, n int) *Request {
if l.AllowN(time.Now(), n) {
return &Request{
allow: true,
limiter: l,
Tokens: n,
}
} else {
return &Request{
allow: false,
limiter: l,
Tokens: 0,
}
}
}
// Allow checks if the request is allowed.
func (r *Request) Allow() bool {
return r.allow
}
// Reject returns true if the request should be rejected.
func (r *Request) Reject() bool {
return !r.allow
}
// Success returns the rate limit tokens that have been reserved for this request, if any.
func (r *Request) Success() {
if r.Tokens != 0 && r.limiter != nil {
r.limiter.ReserveN(time.Now(), -1*r.Tokens)
r.Tokens = 0
}
}

View File

@@ -0,0 +1,57 @@
package limiter
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestNewRequest(t *testing.T) {
clientIp := "192.0.2.1"
l := NewLimit(0.166, 10).IP(clientIp)
r := NewRequest(l, 9)
assert.True(t, r.Allow())
assert.False(t, r.Reject())
assert.Equal(t, 9, r.Tokens)
r = NewRequest(l, 1)
assert.True(t, r.Allow())
assert.False(t, r.Reject())
assert.Equal(t, 1, r.Tokens)
r = NewRequest(l, 1)
assert.False(t, r.Allow())
assert.True(t, r.Reject())
assert.Equal(t, 0, r.Tokens)
}
func TestRequest(t *testing.T) {
clientIp := "192.0.2.1"
t.Run("Allow", func(t *testing.T) {
r := Request{allow: true}
assert.True(t, r.Allow())
assert.False(t, r.Reject())
})
t.Run("Reject", func(t *testing.T) {
r := Request{allow: false}
assert.False(t, r.Allow())
assert.True(t, r.Reject())
})
t.Run("Success", func(t *testing.T) {
l := NewLimit(0.166, 10).IP(clientIp)
r1 := NewRequest(l, 10)
assert.True(t, r1.Allow())
assert.False(t, r1.Reject())
assert.Equal(t, 10, r1.Tokens)
r2 := NewRequest(l, 10)
assert.False(t, r2.Allow())
assert.True(t, r2.Reject())
assert.Equal(t, 0, r2.Tokens)
r1.Success()
r3 := NewRequest(l, 10)
assert.True(t, r3.Allow())
assert.False(t, r3.Reject())
assert.Equal(t, 10, r3.Tokens)
})
}

View File

@@ -93,11 +93,7 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
} }
// Check webdav access authorization using an auth token or app password, if provided. // Check webdav access authorization using an auth token or app password, if provided.
if limiter.Auth.Reject(clientIp) { if sess, user, sid, cached := WebDAVAuthSession(c, authToken); user != nil && cached {
c.Header("WWW-Authenticate", BasicAuthRealm)
limiter.Abort(c)
return
} else if sess, user, sid, cached := WebDAVAuthSession(c, authToken); user != nil && cached {
// Add user to request context to signal successful authentication if username is empty or matches. // Add user to request context to signal successful authentication if username is empty or matches.
if username == "" || strings.EqualFold(clean.Username(username), user.Username()) { if username == "" || strings.EqualFold(clean.Username(username), user.Username()) {
c.Set(gin.AuthUserKey, user) c.Set(gin.AuthUserKey, user)
@@ -105,7 +101,6 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
} }
event.AuditErr([]string{clientIp, "access webdav as %s with authorization granted to %s", authn.Denied}, clean.Log(username), clean.Log(user.Username())) event.AuditErr([]string{clientIp, "access webdav as %s with authorization granted to %s", authn.Denied}, clean.Log(username), clean.Log(user.Username()))
limiter.Auth.Reserve(clientIp)
WebDAVAbortUnauthorized(c) WebDAVAbortUnauthorized(c)
return return
} else if sess == nil { } else if sess == nil {
@@ -131,7 +126,6 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
// Log warning if WebDAV is disabled for this account. // Log warning if WebDAV is disabled for this account.
message := authn.ErrBasicAuthDoesNotMatch.Error() message := authn.ErrBasicAuthDoesNotMatch.Error()
event.AuditWarn([]string{clientIp, "client %s", "session %s", "access webdav as %s", message}, clean.Log(sess.ClientInfo()), sess.RefID, clean.LogQuote(user.Username())) event.AuditWarn([]string{clientIp, "client %s", "session %s", "access webdav as %s", message}, clean.Log(sess.ClientInfo()), sess.RefID, clean.LogQuote(user.Username()))
limiter.Auth.Reserve(clientIp)
WebDAVAbortUnauthorized(c) WebDAVAbortUnauthorized(c)
return return
} else if err := fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil { } else if err := fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil {
@@ -158,9 +152,11 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
return return
} }
// Check the authentication request rate to block the client after // Check request rate limit.
// too many failed attempts (10/req per minute by default). r := limiter.Login.Request(clientIp)
if limiter.Login.Reject(clientIp) {
// Abort if request rate limit is exceeded.
if r.Reject() || limiter.Auth.Reject(clientIp) {
c.Header("WWW-Authenticate", BasicAuthRealm) c.Header("WWW-Authenticate", BasicAuthRealm)
limiter.Abort(c) limiter.Abort(c)
return return
@@ -179,21 +175,25 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
if user, _, _, err := entity.Auth(f, nil, c); err != nil { if user, _, _, err := entity.Auth(f, nil, c); err != nil {
// Abort if authentication has failed. // Abort if authentication has failed.
message := authn.ErrInvalidCredentials.Error() message := authn.ErrInvalidCredentials.Error()
limiter.Login.Reserve(clientIp)
event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username))
event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message)
} else if user == nil { } else if user == nil {
// Abort if account was not found. // Abort if account was not found.
message := authn.ErrAccountNotFound.Error() message := authn.ErrAccountNotFound.Error()
limiter.Login.Reserve(clientIp)
event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.AuditErr([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username))
event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message)
} else if !user.CanUseWebDAV() { } else if !user.CanUseWebDAV() {
// Return the reserved request rate limit tokens, even if account isn't allowed to use WebDAV.
r.Success()
// Abort if WebDAV is disabled for this account. // Abort if WebDAV is disabled for this account.
message := authn.ErrWebDAVAccessDisabled.Error() message := authn.ErrWebDAVAccessDisabled.Error()
event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username))
event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message) event.LoginError(clientIp, "webdav", username, api.UserAgent(c), message)
} else if err = fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil { } else if err = fs.MkdirAll(filepath.Join(conf.OriginalsPath(), user.GetUploadPath())); err != nil {
// Return the reserved request rate limit tokens, even if path could not be created.
r.Success()
// Abort if upload path could not be created. // Abort if upload path could not be created.
message := authn.ErrFailedToCreateUploadPath.Error() message := authn.ErrFailedToCreateUploadPath.Error()
event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username)) event.AuditWarn([]string{clientIp, "webdav login as %s", message}, clean.LogQuote(username))
@@ -201,6 +201,9 @@ func WebDAVAuth(conf *config.Config) gin.HandlerFunc {
WebDAVAbortServerError(c) WebDAVAbortServerError(c)
return return
} else { } else {
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Log successful authentication. // Log successful authentication.
event.AuditInfo([]string{clientIp, "webdav login as %s", "succeeded"}, clean.LogQuote(username)) event.AuditInfo([]string{clientIp, "webdav login as %s", "succeeded"}, clean.LogQuote(username))
event.LoginInfo(clientIp, "webdav", username, api.UserAgent(c)) event.LoginInfo(clientIp, "webdav", username, api.UserAgent(c))

View File

@@ -5,6 +5,7 @@ import (
"github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/server/limiter"
"github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/header"
"github.com/photoprism/photoprism/pkg/rnd" "github.com/photoprism/photoprism/pkg/rnd"
) )
@@ -19,11 +20,25 @@ func WebDAVAuthSession(c *gin.Context, authToken string) (sess *entity.Session,
return nil, nil, "", false return nil, nil, "", false
} }
// Get client IP address.
clientIp := header.ClientIP(c)
// Check request rate limit.
r := limiter.Auth.Request(clientIp)
// Abort if failure rate limit is exceeded.
if r.Reject() {
return nil, nil, "", false
}
// Get session ID for the auth token provided. // Get session ID for the auth token provided.
sid = rnd.SessionID(authToken) sid = rnd.SessionID(authToken)
// Check if client authorization has been cached to improve performance. // Check if client authorization has been cached to improve performance.
if cacheData, found := webdavAuthCache.Get(sid); found && cacheData != nil { if cacheData, found := webdavAuthCache.Get(sid); found && cacheData != nil {
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Add cached user information to the request context. // Add cached user information to the request context.
user = cacheData.(*entity.User) user = cacheData.(*entity.User)
return nil, user, sid, true return nil, user, sid, true
@@ -40,6 +55,9 @@ func WebDAVAuthSession(c *gin.Context, authToken string) (sess *entity.Session,
return nil, nil, sid, false return nil, nil, sid, false
} }
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
// Update the client IP and the user agent from // Update the client IP and the user agent from
// the request context if they have changed. // the request context if they have changed.
sess.UpdateContext(c) sess.UpdateContext(c)

View File

@@ -9,14 +9,21 @@ import (
// Generic error messages for authentication and authorization: // Generic error messages for authentication and authorization:
var ( var (
ErrUnauthorized = errors.New("unauthorized") ErrUnauthorized = errors.New("unauthorized")
ErrAccountAlreadyExists = errors.New("account already exists") ErrAccountAlreadyExists = errors.New("account already exists")
ErrAccountNotFound = errors.New("account not found") ErrAccountNotFound = errors.New("account not found")
ErrAccountDisabled = errors.New("account disabled") ErrAccountDisabled = errors.New("account disabled")
ErrInvalidCredentials = errors.New("invalid credentials") ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidShareToken = errors.New("invalid share token") ErrInvalidShareToken = errors.New("invalid share token")
ErrInsufficientScope = errors.New("insufficient scope") ErrInsufficientScope = errors.New("insufficient scope")
ErrDisabledInPublicMode = errors.New("disabled in public mode") ErrDisabledInPublicMode = errors.New("disabled in public mode")
ErrAuthenticationDisabled = errors.New("authentication disabled")
)
// OAuth2-related error messages:
var (
ErrInvalidClientID = errors.New("invalid client id")
ErrInvalidClientSecret = errors.New("invalid client secret")
) )
// Username-related error messages: // Username-related error messages: