API: Implement creation and revocation of app passwords #808 #4114

Note that these changes are not production ready yet and must be tested
well before releasing them.

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2024-04-07 16:44:30 +02:00
parent 6b0abfded1
commit 33fac8f404
46 changed files with 529 additions and 264 deletions

View File

@@ -18,13 +18,13 @@
<v-card-text class="py-0 px-2"> <v-card-text class="py-0 px-2">
<v-layout wrap align-top> <v-layout wrap align-top>
<v-flex xs12 class="pa-2 body-2"> <v-flex xs12 class="pa-2 body-2">
<translate>To create a new app-specific password, please enter the name and authorization scope of the application and select an expiration date:</translate> <translate>To generate a new app-specific password, please enter the name and authorization scope of the application and choose an expiration date:</translate>
</v-flex> </v-flex>
<v-flex xs12 class="pa-2"> <v-flex xs12 class="pa-2">
<v-text-field <v-text-field
v-model="newApp.Name" v-model="newApp.client_name"
:disabled="busy" :disabled="busy"
name="appname" name="client_name"
type="text" type="text"
:label="$gettext('Name')" :label="$gettext('Name')"
required required
@@ -40,10 +40,10 @@
></v-text-field> ></v-text-field>
</v-flex> </v-flex>
<v-flex xs12 sm6 class="pa-2"> <v-flex xs12 sm6 class="pa-2">
<v-select v-model="newApp.Scope" hide-details box :disabled="busy" :items="auth.ScopeOptions()" :label="$gettext('Scope')" :menu-props="{ maxHeight: 346 }" color="secondary-dark" background-color="secondary-light" class="input-scope"></v-select> <v-select v-model="newApp.scope" hide-details box :disabled="busy" :items="auth.ScopeOptions()" :label="$gettext('Scope')" :menu-props="{ maxHeight: 346 }" color="secondary-dark" background-color="secondary-light" class="input-scope"></v-select>
</v-flex> </v-flex>
<v-flex xs12 sm6 class="pa-2"> <v-flex xs12 sm6 class="pa-2">
<v-select v-model="newApp.Expires" :disabled="busy" :label="$gettext('Expires')" browser-autocomplete="off" hide-details box flat color="secondary-dark" class="input-expires" item-text="text" item-value="value" :items="options.Expires()"></v-select> <v-select v-model="newApp.lifetime" :disabled="busy" :label="$gettext('Expires')" browser-autocomplete="off" hide-details box flat color="secondary-dark" class="input-expires" item-text="text" item-value="value" :items="options.Expires()"></v-select>
</v-flex> </v-flex>
</v-layout> </v-layout>
</v-card-text> </v-card-text>
@@ -53,8 +53,8 @@
<v-btn depressed color="secondary-light" class="action-close ml-0" @click.stop="close"> <v-btn depressed color="secondary-light" class="action-close ml-0" @click.stop="close">
<translate>Close</translate> <translate>Close</translate>
</v-btn> </v-btn>
<v-btn depressed color="primary-button" disabled class="action-create white--text compact mr-0" @click.stop="close"> <v-btn depressed color="primary-button" disabled class="action-generate white--text compact mr-0" @click.stop="close">
<translate>Create</translate> <translate>Generate</translate>
</v-btn> </v-btn>
</v-flex> </v-flex>
</v-layout> </v-layout>
@@ -93,9 +93,11 @@ export default {
passwords: [], passwords: [],
user: this.$session.getUser(), user: this.$session.getUser(),
newApp: { newApp: {
Name: "", grant_type: "session",
Scope: "*", password: "",
Expires: 0, client_name: "",
scope: "*",
lifetime: 0,
}, },
}; };
}, },

View File

@@ -212,23 +212,6 @@ export class User extends RestModel {
return this.AuthProvider && this.AuthProvider === "ldap"; return this.AuthProvider && this.AuthProvider === "ldap";
} }
disable2FA() {
if (!this.Name) {
return true;
}
switch (this.AuthProvider) {
case "default":
return false;
case "local":
return false;
case "ldap":
return false;
default:
return true;
}
}
authInfo() { authInfo() {
if (!this || !this.AuthProvider) { if (!this || !this.AuthProvider) {
return $gettext("Default"); return $gettext("Default");
@@ -289,6 +272,21 @@ export class User extends RestModel {
}).then((response) => Promise.resolve(response.data)); }).then((response) => Promise.resolve(response.data));
} }
disablePasscodeSetup() {
if (!this.Name || !this.CanLogin || this.ID < 1) {
return true;
}
switch (this.AuthProvider) {
case "":
case "default":
case "local":
return false;
default:
return true;
}
}
static getCollectionResource() { static getCollectionResource() {
return "users"; return "users";
} }

View File

@@ -65,11 +65,19 @@ export const ScopeOptions = () => {
}, },
{ {
text: $gettext("Read Only"), text: $gettext("Read Only"),
value: "read", value: "read *",
}, },
{ {
text: $gettext("WebDAV"), text: $gettext("WebDAV"),
value: "webdav", value: "webdav",
}, },
{
text: $gettext("Metrics"),
value: "metrics",
},
{
text: $gettext("Custom"),
value: "~",
},
]; ];
}; };

View File

@@ -189,10 +189,10 @@
</v-btn> </v-btn>
</v-flex> </v-flex>
<v-flex xs12 sm6 class="pa-2"> <v-flex xs12 sm6 class="pa-2">
<v-btn block depressed color="secondary-light" class="action-passcode-dialog compact" :disabled="isPublic || isDemo || user.disable2FA()" @click.stop="showDialog('passcode')"> <v-btn block depressed color="secondary-light" class="action-passcode-dialog compact" :disabled="isPublic || isDemo || user.disablePasscodeSetup()" @click.stop="showDialog('passcode')">
<translate>2-Factor Authentication</translate> <translate>2-Factor Authentication</translate>
<v-icon v-if="user.AuthMethod === '2fa'" :right="!rtl" :left="rtl" dark>gpp_good</v-icon> <v-icon v-if="user.AuthMethod === '2fa'" :right="!rtl" :left="rtl" dark>gpp_good</v-icon>
<v-icon v-else-if="user.disable2FA()" :right="!rtl" :left="rtl" dark>shield</v-icon> <v-icon v-else-if="user.disablePasscodeSetup()" :right="!rtl" :left="rtl" dark>shield</v-icon>
<v-icon v-else :right="!rtl" :left="rtl" dark>gpp_maybe</v-icon> <v-icon v-else :right="!rtl" :left="rtl" dark>gpp_maybe</v-icon>
</v-btn> </v-btn>
</v-flex> </v-flex>

View File

@@ -8,6 +8,7 @@ import (
"github.com/photoprism/photoprism/internal/config" "github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/get" "github.com/photoprism/photoprism/internal/get"
"github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/i18n" "github.com/photoprism/photoprism/pkg/i18n"
) )
@@ -103,3 +104,9 @@ func AbortFeatureDisabled(c *gin.Context) {
func AbortBusy(c *gin.Context) { func AbortBusy(c *gin.Context) {
Abort(c, http.StatusTooManyRequests, i18n.ErrBusy) Abort(c, http.StatusTooManyRequests, i18n.ErrBusy)
} }
func AbortInvalidCredentials(c *gin.Context) {
if c != nil {
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": authn.ErrInvalidCredentials.Error(), "code": i18n.ErrInvalidCredentials, "message": i18n.Msg(i18n.ErrInvalidCredentials)})
}
}

View File

@@ -83,6 +83,7 @@ func CreateAlbum(router *gin.RouterGroup) {
var f form.Album var f form.Album
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return
@@ -163,7 +164,8 @@ func UpdateAlbum(router *gin.RouterGroup) {
return return
} }
if err := c.BindJSON(&f); err != nil { // Assign and validate request form values.
if err = c.BindJSON(&f); err != nil {
log.Error(err) log.Error(err)
AbortBadRequest(c) AbortBadRequest(c)
return return
@@ -328,7 +330,7 @@ func DislikeAlbum(router *gin.RouterGroup) {
return return
} }
if err := a.Update("AlbumFavorite", false); err != nil { if err = a.Update("AlbumFavorite", false); err != nil {
Abort(c, http.StatusInternalServerError, i18n.ErrSaveFailed) Abort(c, http.StatusInternalServerError, i18n.ErrSaveFailed)
return return
} }
@@ -374,7 +376,8 @@ func CloneAlbums(router *gin.RouterGroup) {
var f form.Selection var f form.Selection
if err := c.BindJSON(&f); err != nil { // Assign and validate request form values.
if err = c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return
} }
@@ -425,6 +428,7 @@ func AddPhotosToAlbum(router *gin.RouterGroup) {
var f form.Selection var f form.Selection
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return
@@ -496,6 +500,7 @@ func RemovePhotosFromAlbum(router *gin.RouterGroup) {
var f form.Selection var f form.Selection
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -34,6 +34,7 @@ func BatchPhotosArchive(router *gin.RouterGroup) {
var f form.Selection var f form.Selection
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -28,6 +28,7 @@ func Connect(router *gin.RouterGroup) {
var f form.Connect var f form.Connect
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
log.Warnf("connect: invalid form values (%s)", clean.Log(name)) log.Warnf("connect: invalid form values (%s)", clean.Log(name))
Abort(c, http.StatusBadRequest, i18n.ErrAccountConnect) Abort(c, http.StatusBadRequest, i18n.ErrAccountConnect)

View File

@@ -52,6 +52,7 @@ func UpdateFace(router *gin.RouterGroup) {
var f form.Face var f form.Face
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -34,6 +34,7 @@ func SendFeedback(router *gin.RouterGroup) {
var f form.Feedback var f form.Feedback
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -60,7 +60,7 @@ func ChangeFileOrientation(router *gin.RouterGroup) {
return return
} }
// Update form with values from request // Assign and validate request form values.
if err = c.BindJSON(&f); err != nil { if err = c.BindJSON(&f); err != nil {
Abort(c, http.StatusBadRequest, i18n.ErrBadRequest) Abort(c, http.StatusBadRequest, i18n.ErrBadRequest)
return return

View File

@@ -50,6 +50,7 @@ func StartImport(router *gin.RouterGroup) {
var f form.ImportOptions var f form.ImportOptions
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -42,6 +42,7 @@ func StartIndexing(router *gin.RouterGroup) {
var f form.IndexOptions var f form.IndexOptions
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -28,6 +28,7 @@ func UpdateLabel(router *gin.RouterGroup) {
var f form.Label var f form.Label
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -27,6 +27,7 @@ func UpdateLink(c *gin.Context) {
var f form.Link var f form.Link
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
log.Debugf("share: %s", err) log.Debugf("share: %s", err)
AbortBadRequest(c) AbortBadRequest(c)

View File

@@ -1,6 +1,8 @@
package api package api
import ( import (
"errors"
"fmt"
"net/http" "net/http"
"github.com/dustin/go-humanize/english" "github.com/dustin/go-humanize/english"
@@ -14,7 +16,6 @@ import (
"github.com/photoprism/photoprism/pkg/authn" "github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/header" "github.com/photoprism/photoprism/pkg/header"
"github.com/photoprism/photoprism/pkg/i18n"
) )
// CreateOAuthToken creates a new access token for clients that // CreateOAuthToken creates a new access token for clients that
@@ -31,10 +32,12 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Get client IP address for logs and rate limiting checks. // Get client IP address for logs and rate limiting checks.
clientIp := ClientIP(c) clientIp := ClientIP(c)
actor := "unknown client"
action := "create token"
// Abort if running in public mode. // Abort if running in public mode.
if get.Config().Public() { if get.Config().Public() {
event.AuditErr([]string{clientIp, "client", "create session", "oauth2", authn.ErrDisabledInPublicMode.Error()}) event.AuditErr([]string{clientIp, "oauth2", actor, action, authn.ErrDisabledInPublicMode.Error()})
AbortForbidden(c) AbortForbidden(c)
return return
} }
@@ -54,15 +57,15 @@ func CreateOAuthToken(router *gin.RouterGroup) {
f.ClientID = clientId f.ClientID = clientId
f.ClientSecret = clientSecret f.ClientSecret = clientSecret
} else if err = c.ShouldBind(&f); err != nil { } else if err = c.ShouldBind(&f); err != nil {
event.AuditWarn([]string{clientIp, "client", "create session", "oauth2", "%s"}, err) event.AuditWarn([]string{clientIp, "oauth2", actor, action, "%s"}, err)
AbortBadRequest(c) AbortBadRequest(c)
return return
} }
// Check the credentials for completeness and the correct format. // Check the credentials for completeness and the correct format.
if err = f.Validate(); err != nil { if err = f.Validate(); err != nil {
event.AuditWarn([]string{clientIp, "client", "create session", "oauth2", "%s"}, err) event.AuditWarn([]string{clientIp, "oauth2", actor, action, "%s"}, err)
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} }
@@ -75,6 +78,14 @@ func CreateOAuthToken(router *gin.RouterGroup) {
return return
} }
if f.ClientID != "" {
actor = fmt.Sprintf("client %s", clean.Log(f.ClientID))
} else if f.Username != "" {
actor = fmt.Sprintf("user %s", clean.Log(f.Username))
} else if f.GrantType == authn.GrantPassword {
actor = "unknown user"
}
// Create a new session (access token) based on the grant type specified in the request. // Create a new session (access token) based on the grant type specified in the request.
switch f.GrantType { switch f.GrantType {
case authn.GrantClientCredentials, authn.GrantUndefined: case authn.GrantClientCredentials, authn.GrantUndefined:
@@ -83,20 +94,20 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Check if a client has been found, it is enabled, and the credentials are valid. // Check if a client has been found, it is enabled, and the credentials are valid.
if client == nil { if client == nil {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientID.Error()}, f.ClientID) event.AuditWarn([]string{clientIp, "oauth2", actor, action, authn.ErrInvalidClientID.Error()})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} else if !client.AuthEnabled { } else if !client.AuthEnabled {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrAuthenticationDisabled.Error()}, f.ClientID) event.AuditWarn([]string{clientIp, "oauth2", actor, action, authn.ErrAuthenticationDisabled.Error()})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 { } else if method := client.Method(); !method.IsDefault() && method != authn.MethodOAuth2 {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "method %s not supported"}, f.ClientID, clean.LogQuote(method.String())) event.AuditWarn([]string{clientIp, "oauth2", actor, action, "method %s not supported"}, clean.LogQuote(method.String()))
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} else if client.InvalidSecret(f.ClientSecret) { } else if client.InvalidSecret(f.ClientSecret) {
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidClientSecret.Error()}, f.ClientID) event.AuditWarn([]string{clientIp, "oauth2", actor, action, authn.ErrInvalidClientSecret.Error()})
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} }
@@ -105,42 +116,88 @@ func CreateOAuthToken(router *gin.RouterGroup) {
// Create new client session. // Create new client session.
sess = client.NewSession(c, authn.GrantClientCredentials) sess = client.NewSession(c, authn.GrantClientCredentials)
case authn.GrantPassword: case authn.GrantPassword, authn.GrantSession:
// Generate an app password for a user account and accept the password for confirmation. // Generate an app password for a user account and check the password for confirmation.
event.AuditWarn([]string{clientIp, "client %s", "create session", "oauth2", "password grant type is not implemented yet"}, f.ClientID) s := Session(clientIp, AuthToken(c))
c.AbortWithStatusJSON(http.StatusUnauthorized, gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)})
return if s == nil {
AbortInvalidCredentials(c)
return
} else if s.Username() == "" || s.IsClient() || s.IsRegistered() {
AbortInvalidCredentials(c)
return
}
actor = fmt.Sprintf("user %s", clean.Log(s.Username()))
if s.User().Provider().SupportsPasswordAuthentication() {
loginForm := form.Login{
Username: s.Username(),
Password: f.Password,
}
authUser, authProvider, authMethod, authErr := entity.Auth(loginForm, nil, c)
if authProvider.IsClient() {
event.AuditErr([]string{clientIp, "oauth2", actor, action, authn.Denied})
AbortInvalidCredentials(c)
return
} else if !authUser.Equal(s.User()) {
event.AuditErr([]string{clientIp, "oauth2", actor, action, authn.ErrInvalidUsername.Error()})
AbortInvalidCredentials(c)
return
} else if authMethod.Is(authn.Method2FA) && errors.Is(authErr, authn.ErrPasscodeRequired) {
// Ignore.
} else if authErr != nil {
event.AuditErr([]string{clientIp, "oauth2", actor, action, "%s"}, clean.Error(authErr))
AbortInvalidCredentials(c)
return
}
f.GrantType = authn.GrantPassword
} else {
f.GrantType = authn.GrantSession
}
sess = entity.NewClientAuthentication(f.ClientName, f.Lifetime, f.Scope, f.GrantType, s.User())
// Return the reserved request rate limit tokens after successful authentication.
r.Success()
default: default:
event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", authn.ErrInvalidGrantType.Error()}, f.ClientID) event.AuditErr([]string{clientIp, "oauth2", actor, action, authn.ErrInvalidGrantType.Error()}, clean.Log(f.ClientID))
c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} }
// Save new session. // Save new session.
if sess, err = get.Session().Save(sess); err != nil { if sess, err = get.Session().Save(sess); err != nil {
event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", err.Error()}, f.ClientID) event.AuditErr([]string{clientIp, "oauth2", actor, action, err.Error()}, f.ClientID)
c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrInvalidCredentials)}) AbortInvalidCredentials(c)
return return
} else if sess == nil { } else if sess == nil {
event.AuditErr([]string{clientIp, "client %s", "create session", "oauth2", StatusFailed.String()}, f.ClientID) event.AuditErr([]string{clientIp, "oauth2", actor, action, StatusFailed.String()}, f.ClientID)
c.AbortWithStatusJSON(sess.HttpStatus(), gin.H{"error": i18n.Msg(i18n.ErrUnexpected)}) AbortUnexpectedError(c)
return return
} else { } else {
event.AuditInfo([]string{clientIp, "client %s", "session %s", "oauth2", "created"}, f.ClientID, sess.RefID) event.AuditInfo([]string{clientIp, "oauth2", actor, action, authn.Created}, f.ClientID, sess.RefID)
} }
// Delete any existing client sessions above the configured limit. // Delete any existing client sessions above the configured limit.
if client == nil { if client == nil {
// Skip deletion if not created by a client. // Skip deletion if not created by a client.
} else if deleted := client.EnforceAuthTokenLimit(); deleted > 0 { } else if deleted := client.EnforceAuthTokenLimit(); deleted > 0 {
event.AuditInfo([]string{clientIp, "client %s", "session %s", "oauth2", "deleted %s"}, f.ClientID, sess.RefID, english.Plural(deleted, "previously created client session", "previously created client sessions")) event.AuditInfo([]string{clientIp, "oauth2", actor, action, "deleted %s to enforce token limit"}, f.ClientID, sess.RefID, english.Plural(deleted, "session", "sessions"))
} }
// Send response with access token, token type, and token lifetime. // Send response with access token, token type, and token lifetime.
response := gin.H{ response := gin.H{
"status": StatusSuccess,
"session_id": sess.ID,
"access_token": sess.AuthToken(), "access_token": sess.AuthToken(),
"token_type": sess.AuthTokenType(), "token_type": sess.AuthTokenType(),
"expires_in": sess.ExpiresIn(), "expires_in": sess.ExpiresIn(),
"client_name": sess.ClientName,
"scope": sess.Scope(),
} }
c.JSON(http.StatusOK, response) c.JSON(http.StatusOK, response)

View File

@@ -1,10 +1,12 @@
package api package api
import ( import (
"fmt"
"net/http" "net/http"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/photoprism/photoprism/internal/acl"
"github.com/photoprism/photoprism/internal/entity" "github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/event" "github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/form" "github.com/photoprism/photoprism/internal/form"
@@ -27,70 +29,128 @@ func RevokeOAuthToken(router *gin.RouterGroup) {
return return
} }
// Disable caching of responses.
c.Header(header.CacheControl, header.CacheControlNoStore)
// Get client IP address for logs and rate limiting checks. // Get client IP address for logs and rate limiting checks.
clientIp := ClientIP(c) clientIp := ClientIP(c)
actor := "unknown client"
action := "revoke token"
// Abort if running in public mode. // Abort if running in public mode.
if get.Config().Public() { if get.Config().Public() {
event.AuditErr([]string{clientIp, "client", "delete session", "oauth2", authn.ErrDisabledInPublicMode.Error()}) event.AuditErr([]string{clientIp, "oauth2", actor, action, authn.ErrDisabledInPublicMode.Error()})
Abort(c, http.StatusForbidden, i18n.ErrForbidden) Abort(c, http.StatusForbidden, i18n.ErrForbidden)
return return
} }
// Session and user information.
var s, sess *entity.Session
var authToken, sUserUID string
var role acl.Role
var err error var err error
// Token revokation request form. // Token revokation request form.
var f form.OAuthRevokeToken var f form.OAuthRevokeToken
// Get token from request header. // Get token and session from request header.
authToken := AuthToken(c) if authToken = AuthToken(c); authToken == "" {
role = acl.RoleNone
} else if s = Session(clientIp, authToken); s != nil {
// Set log role and actor based on the session referenced in request header.
sUserUID = s.UserUID
if s.IsClient() {
role = s.ClientRole()
actor = fmt.Sprintf("client %s", clean.Log(s.ClientInfo()))
} else if username := s.Username(); username != "" {
role = s.UserRole()
actor = fmt.Sprintf("user %s", clean.Log(username))
} else {
role = s.UserRole()
actor = fmt.Sprintf("unknown %s", s.UserRole().String())
}
}
// Get the auth token to be revoked from the submitted form values or the request header. // Get the auth token to be revoked from the submitted form values or the request header.
if err = c.ShouldBind(&f); err != nil && authToken == "" { if err = c.ShouldBind(&f); err != nil && authToken == "" {
event.AuditWarn([]string{clientIp, "client", "delete session", "oauth2", "%s"}, err) event.AuditWarn([]string{clientIp, "oauth2", actor, action, "%s"}, err)
AbortBadRequest(c) AbortBadRequest(c)
return return
} else if f.Empty() { } else if f.Empty() {
f.AuthToken = authToken f.Token = authToken
f.TypeHint = form.ClientAccessToken f.TokenTypeHint = form.AccessToken
} }
// Check the token form values. // Validate revokation form values.
if err = f.Validate(); err != nil { if err = f.Validate(); err != nil {
event.AuditWarn([]string{clientIp, "client", "delete session", "oauth2", "%s"}, err) event.AuditWarn([]string{clientIp, "oauth2", actor, action, "%s"}, err)
AbortBadRequest(c) AbortInvalidCredentials(c)
return return
} }
// Disable caching of responses. // Find session to be revoked.
c.Header(header.CacheControl, header.CacheControlNoStore) switch f.TokenTypeHint {
case form.RefID:
if s == nil || sUserUID == "" || role == acl.RoleNone {
c.AbortWithStatusJSON(http.StatusForbidden, i18n.NewResponse(http.StatusForbidden, i18n.ErrForbidden))
return
} else if sess = entity.FindSessionByRefID(f.Token); sess == nil {
AbortInvalidCredentials(c)
return
}
case form.SessionID:
if s == nil || sUserUID == "" || role == acl.RoleNone {
c.AbortWithStatusJSON(http.StatusForbidden, i18n.NewResponse(http.StatusForbidden, i18n.ErrForbidden))
return
}
// Find session based on auth token. sess, err = entity.FindSession(f.Token)
sess, err := entity.FindSession(rnd.SessionID(f.AuthToken)) case form.AccessToken:
sess, err = entity.FindSession(rnd.SessionID(f.Token))
}
// If not already set, get the log role and actor from the session to be revoked.
if sess != nil && role == acl.RoleNone {
if sess.IsClient() {
role = sess.ClientRole()
actor = fmt.Sprintf("client %s", clean.Log(sess.ClientInfo()))
} else if username := sess.Username(); username != "" {
role = s.UserRole()
actor = fmt.Sprintf("user %s", clean.Log(username))
} else {
role = sess.UserRole()
actor = fmt.Sprintf("unknown %s", sess.UserRole().String())
}
}
// Check revokation request and abort if invalid.
if err != nil { if err != nil {
event.AuditErr([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", "%s"}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String(), err.Error()) event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", "%s"}, clean.Log(sess.RefID), role.String(), err.Error())
c.AbortWithStatusJSON(http.StatusUnauthorized, i18n.NewResponse(http.StatusUnauthorized, i18n.ErrUnauthorized)) AbortInvalidCredentials(c)
return return
} else if sess == nil { } else if sess == nil {
event.AuditErr([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", authn.Denied}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String()) event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", authn.Denied}, clean.Log(sess.RefID), role.String())
c.AbortWithStatusJSON(http.StatusUnauthorized, i18n.NewResponse(http.StatusUnauthorized, i18n.ErrUnauthorized)) AbortInvalidCredentials(c)
return return
} else if sess.Abort(c) { } else if sess.Abort(c) {
event.AuditErr([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", authn.Denied}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String()) event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", authn.Denied}, clean.Log(sess.RefID), role.String())
return return
} else if !sess.IsClient() { } else if !sess.IsClient() {
event.AuditErr([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", authn.Denied}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String()) event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", authn.Denied}, clean.Log(sess.RefID), role.String())
c.AbortWithStatusJSON(http.StatusForbidden, i18n.NewResponse(http.StatusForbidden, i18n.ErrForbidden)) c.AbortWithStatusJSON(http.StatusForbidden, i18n.NewResponse(http.StatusForbidden, i18n.ErrForbidden))
return return
} else if sUserUID != "" && sess.UserUID != sUserUID {
event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", authn.ErrUnauthorized.Error()}, clean.Log(sess.RefID), role.String())
AbortInvalidCredentials(c)
return
} else { } else {
event.AuditInfo([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", authn.Granted}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String()) event.AuditInfo([]string{clientIp, "oauth2", actor, action, "delete %s as %s", authn.Granted}, clean.Log(sess.RefID), role.String())
} }
// Delete session cache and database record. // Delete session cache and database record.
if err = sess.Delete(); err != nil { if err = sess.Delete(); err != nil {
// Log error. // Log error.
event.AuditErr([]string{clientIp, "client %s", "session %s", "delete session as %s", "oauth2", "%s"}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID), sess.ClientRole().String(), err) event.AuditErr([]string{clientIp, "oauth2", actor, action, "delete %s as %s", "%s"}, clean.Log(sess.RefID), role.String(), err)
// Return JSON error. // Return JSON error.
c.AbortWithStatusJSON(http.StatusNotFound, i18n.NewResponse(http.StatusNotFound, i18n.ErrNotFound)) c.AbortWithStatusJSON(http.StatusNotFound, i18n.NewResponse(http.StatusNotFound, i18n.ErrNotFound))
@@ -98,7 +158,7 @@ func RevokeOAuthToken(router *gin.RouterGroup) {
} }
// Log event. // Log event.
event.AuditInfo([]string{clientIp, "client %s", "session %s", "oauth2", "deleted"}, clean.Log(sess.ClientInfo()), clean.Log(sess.RefID)) event.AuditInfo([]string{clientIp, "oauth2", actor, action, "delete %s as %s", "deleted"}, clean.Log(sess.RefID))
// Send response. // Send response.
c.JSON(http.StatusOK, DeleteSessionResponse(sess.ID)) c.JSON(http.StatusOK, DeleteSessionResponse(sess.ID))

View File

@@ -84,7 +84,7 @@ func TestRevokeOAuthToken(t *testing.T) {
revokeData := url.Values{ revokeData := url.Values{
"token": {authToken}, "token": {authToken},
"token_type_hint": {form.ClientAccessToken}, "token_type_hint": {form.AccessToken},
} }
revokeToken, _ := http.NewRequest("POST", revokePath, strings.NewReader(revokeData.Encode())) revokeToken, _ := http.NewRequest("POST", revokePath, strings.NewReader(revokeData.Encode()))

View File

@@ -40,6 +40,7 @@ func AddPhotoLabel(router *gin.RouterGroup) {
var f form.Label var f form.Label
// Assign and validate request form values.
if err = c.BindJSON(&f); err != nil { if err = c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -89,7 +89,7 @@ func UpdatePhoto(router *gin.RouterGroup) {
return return
} }
// 2) Update form with values from request // 2) Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
Abort(c, http.StatusBadRequest, i18n.ErrBadRequest) Abort(c, http.StatusBadRequest, i18n.ErrBadRequest)
return return

View File

@@ -125,6 +125,7 @@ func AddService(router *gin.RouterGroup) {
var f form.Service var f form.Service
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -39,7 +39,8 @@ func UploadToService(router *gin.RouterGroup) {
var f form.SyncUpload var f form.SyncUpload
if err := c.BindJSON(&f); err != nil { // Assign and validate request form values.
if err = c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return
} }

View File

@@ -32,7 +32,7 @@ func CreateSession(router *gin.RouterGroup) {
clientIp := ClientIP(c) clientIp := ClientIP(c)
// Validate request data. // Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
event.AuditWarn([]string{clientIp, "create session", "invalid request", "%s"}, err) event.AuditWarn([]string{clientIp, "create session", "invalid request", "%s"}, err)
AbortBadRequest(c) AbortBadRequest(c)

View File

@@ -61,14 +61,15 @@ func UpdateSubject(router *gin.RouterGroup) {
return return
} }
// Initialize form. // Create request value form.
f, err := form.NewSubject(*m) f, err := form.NewSubject(*m)
// Assign and validate request form values.
if err != nil { if err != nil {
log.Errorf("subject: %s (new form)", err) log.Errorf("subject: %s (new form)", err)
AbortSaveFailed(c) AbortSaveFailed(c)
return return
} else if err := c.BindJSON(&f); err != nil { } else if err = c.BindJSON(&f); err != nil {
log.Errorf("subject: %s (update form)", err) log.Errorf("subject: %s (update form)", err)
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -1,7 +1,6 @@
package api package api
import ( import (
"errors"
"net/http" "net/http"
"github.com/dustin/go-humanize/english" "github.com/dustin/go-humanize/english"
@@ -40,10 +39,13 @@ func CreateUserPasscode(router *gin.RouterGroup) {
return return
} }
// Check password and abort if invalid. // Check password if user authenticates with a local account.
if user.InvalidPassword(frm.Password) { switch user.Provider() {
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) case authn.ProviderDefault, authn.ProviderLocal:
return if user.InvalidPassword(frm.Password) {
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword)
return
}
} }
// Return the reserved request rate limit tokens after successful authentication. // Return the reserved request rate limit tokens after successful authentication.
@@ -55,20 +57,20 @@ func CreateUserPasscode(router *gin.RouterGroup) {
// Generate and save new passcode key. // Generate and save new passcode key.
var passcode *entity.Passcode var passcode *entity.Passcode
if key, err := rnd.AuthKey(conf.AppName(), user.UserName); err != nil { if key, err := rnd.AuthKey(conf.AppName(), user.UserName); err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to generate passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeGenerateFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusInternalServerError, i18n.ErrUnexpected) Abort(c, http.StatusInternalServerError, i18n.ErrUnexpected)
return return
} else if passcode, err = entity.NewPasscode(user.UID(), key.String(), rnd.RecoveryCode()); err != nil { } else if passcode, err = entity.NewPasscode(user.UID(), key.String(), rnd.RecoveryCode()); err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to create passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeCreateFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusInternalServerError, i18n.ErrUnexpected) Abort(c, http.StatusInternalServerError, i18n.ErrUnexpected)
return return
} else if err = passcode.Save(); err != nil { } else if err = passcode.Save(); err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to save passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeSaveFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusConflict, i18n.ErrSaveFailed) Abort(c, http.StatusConflict, i18n.ErrSaveFailed)
return return
} }
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "created"}, s.RefID) event.AuditInfo([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.Passcode, authn.Created}, s.RefID)
c.JSON(http.StatusOK, passcode) c.JSON(http.StatusOK, passcode)
}) })
@@ -98,11 +100,11 @@ func ConfirmUserPasscode(router *gin.RouterGroup) {
valid, passcode, err := user.VerifyPasscode(frm.Passcode()) valid, passcode, err := user.VerifyPasscode(frm.Passcode())
if err != nil { if err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to verify passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeVerificationFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode)
return return
} else if !valid { } else if !valid {
event.AuditWarn([]string{ClientIP(c), "session %s", "users", user.UserName, authn.ErrInvalidPasscode.Error()}, s.RefID) event.AuditWarn([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrInvalidPasscode.Error()}, s.RefID)
Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode) Abort(c, http.StatusForbidden, i18n.ErrInvalidPasscode)
return return
} }
@@ -110,7 +112,7 @@ func ConfirmUserPasscode(router *gin.RouterGroup) {
// Return the reserved request rate limit tokens after successful authentication. // Return the reserved request rate limit tokens after successful authentication.
r.Success() r.Success()
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "verified"}, s.RefID) event.AuditInfo([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.Passcode, authn.Verified}, s.RefID)
// Clear session cache. // Clear session cache.
s.ClearCache() s.ClearCache()
@@ -135,18 +137,18 @@ func ActivateUserPasscode(router *gin.RouterGroup) {
passcode, err := user.ActivatePasscode() passcode, err := user.ActivatePasscode()
if err != nil { if err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to activate passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeActivationFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusForbidden, i18n.ErrSaveFailed) Abort(c, http.StatusForbidden, i18n.ErrSaveFailed)
return return
} }
// Log event. // Log event.
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "activated"}, s.RefID) event.AuditInfo([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.Passcode, authn.Activated}, s.RefID)
// Invalidate any other user sessions to protect the account: // Invalidate any other user sessions to protect the account:
// https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html // https://cheatsheetseries.owasp.org/cheatsheets/Session_Management_Cheat_Sheet.html
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "invalidated %s"}, s.RefID, event.AuditInfo([]string{ClientIP(c), "session %s", authn.Users, user.UserName, "invalidated %s"}, s.RefID,
english.Plural(user.DeleteSessions([]string{s.ID}), "session", "sessions")) english.Plural(user.DeleteSessions([]string{s.ID}), authn.Session, authn.Sessions))
// Clear session cache. // Clear session cache.
s.ClearCache() s.ClearCache()
@@ -175,10 +177,14 @@ func DeactivateUserPasscode(router *gin.RouterGroup) {
return return
} }
// Check password and abort if invalid. // Check password if user authenticates with a local account.
if user.InvalidPassword(frm.Password) { switch user.Provider() {
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword) case authn.ProviderDefault, authn.ProviderLocal:
return // Check password and abort if invalid.
if user.InvalidPassword(frm.Password) {
Abort(c, http.StatusForbidden, i18n.ErrInvalidPassword)
return
}
} }
// Return the reserved request rate limit tokens after successful authentication. // Return the reserved request rate limit tokens after successful authentication.
@@ -186,12 +192,12 @@ func DeactivateUserPasscode(router *gin.RouterGroup) {
// Delete passcode. // Delete passcode.
if _, err := user.DeactivatePasscode(); err != nil { if _, err := user.DeactivatePasscode(); err != nil {
event.AuditErr([]string{ClientIP(c), "session %s", "users", user.UserName, "failed to deactivate passcode", clean.Error(err)}, s.RefID) event.AuditErr([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.ErrPasscodeDeactivationFailed.Error(), clean.Error(err)}, s.RefID)
Abort(c, http.StatusNotFound, i18n.ErrNotFound) Abort(c, http.StatusNotFound, i18n.ErrNotFound)
return return
} }
event.AuditInfo([]string{ClientIP(c), "session %s", "users", user.UserName, "passcode", "deactivated"}, s.RefID) event.AuditInfo([]string{ClientIP(c), "session %s", authn.Users, user.UserName, authn.Passcode, authn.Deactivated}, s.RefID)
// Clear session cache. // Clear session cache.
s.ClearCache() s.ClearCache()
@@ -210,20 +216,20 @@ func checkUserPasscodeAuth(c *gin.Context, action acl.Permission) (*entity.Sessi
// You cannot change any passwords without authentication and settings enabled. // You cannot change any passwords without authentication and settings enabled.
if conf.Public() || conf.DisableSettings() { if conf.Public() || conf.DisableSettings() {
Abort(c, http.StatusForbidden, i18n.ErrPublic) Abort(c, http.StatusForbidden, i18n.ErrPublic)
return nil, nil, nil, errors.New("unsupported") return nil, nil, nil, authn.ErrPasscodeNotSupported
} }
// Check limit for failed auth requests (max. 10 per minute). // Check limit for failed auth requests (max. 10 per minute).
if limiter.Login.Reject(ClientIP(c)) { if limiter.Login.Reject(ClientIP(c)) {
limiter.AbortJSON(c) limiter.AbortJSON(c)
return nil, nil, nil, errors.New("rate limit exceeded") return nil, nil, nil, authn.ErrRateLimitExceeded
} }
// Get session. // Get session.
s := Auth(c, acl.ResourcePasscode, action) s := Auth(c, acl.ResourcePasscode, action)
if s.Abort(c) { if s.Abort(c) {
return s, nil, nil, errors.New("unauthorized") return s, nil, nil, authn.ErrUnauthorized
} }
// Check if the current user has management privileges. // Check if the current user has management privileges.
@@ -235,24 +241,24 @@ func checkUserPasscodeAuth(c *gin.Context, action acl.Permission) (*entity.Sessi
// Regular users can only set up a passcode for their own account. // Regular users can only set up a passcode for their own account.
if user.UserUID != uid { if user.UserUID != uid {
AbortForbidden(c) AbortForbidden(c)
return s, nil, nil, errors.New("unauthorized") return s, nil, nil, authn.ErrUnauthorized
} }
// Check if the auth provider supports passcodes. // Check if the user's authentication provider supports 2FA passcodes.
if !user.Provider().Supports2FA() { if !user.Provider().SupportsPasscodeAuthentication() {
Abort(c, http.StatusForbidden, i18n.ErrUnsupported) Abort(c, http.StatusForbidden, i18n.ErrUnsupported)
return s, nil, nil, errors.New("unsupported") return s, nil, nil, authn.ErrPasscodeNotSupported
} }
frm := &form.Passcode{} frm := &form.Passcode{}
// Validate request parameters. // Validate request form values.
if err := c.BindJSON(frm); err != nil { if err := c.BindJSON(frm); err != nil {
Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword) Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword)
return s, nil, nil, errors.New("invalid request") return s, nil, nil, authn.ErrInvalidRequest
} else if authn.KeyTOTP.NotEqual(frm.Type) { } else if authn.KeyTOTP.NotEqual(frm.Type) {
Abort(c, http.StatusBadRequest, i18n.ErrUnsupportedType) Abort(c, http.StatusBadRequest, i18n.ErrUnsupportedType)
return s, nil, nil, errors.New("unsupported") return s, nil, nil, authn.ErrInvalidPasscodeType
} }
return s, user, frm, nil return s, user, frm, nil

View File

@@ -69,6 +69,7 @@ func UpdateUserPassword(router *gin.RouterGroup) {
f := form.ChangePassword{} f := form.ChangePassword{}
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword) Error(c, http.StatusBadRequest, err, i18n.ErrInvalidPassword)
return return

View File

@@ -53,7 +53,7 @@ func UpdateUser(router *gin.RouterGroup) {
return return
} }
// Update form with values from request. // Assign and validate request form values.
if err = c.BindJSON(&f); err != nil { if err = c.BindJSON(&f); err != nil {
log.Error(err) log.Error(err)
AbortBadRequest(c) AbortBadRequest(c)

View File

@@ -169,6 +169,7 @@ func ProcessUserUpload(router *gin.RouterGroup) {
var f form.UploadOptions var f form.UploadOptions
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -45,6 +45,7 @@ func ZipCreate(router *gin.RouterGroup) {
var f form.Selection var f form.Selection
start := time.Now() start := time.Now()
// Assign and validate request form values.
if err := c.BindJSON(&f); err != nil { if err := c.BindJSON(&f); err != nil {
AbortBadRequest(c) AbortBadRequest(c)
return return

View File

@@ -145,7 +145,7 @@ func DeleteClientSessions(client *Client, authMethod authn.MethodType, limit int
q = q.Where("user_uid = ?", client.UserUID) q = q.Where("user_uid = ?", client.UserUID)
} }
if !authMethod.IsDefault() { if !authMethod.IsUndefined() {
q = q.Where("auth_method = ?", authMethod.String()) q = q.Where("auth_method = ?", authMethod.String())
} }

View File

@@ -19,7 +19,7 @@ import (
) )
// Auth checks if the credentials are valid and returns the user and authentication provider. // Auth checks if the credentials are valid and returns the user and authentication provider.
var Auth = func(f form.Login, m *Session, c *gin.Context) (user *User, provider authn.ProviderType, method authn.MethodType, err error) { var Auth = func(f form.Login, s *Session, c *gin.Context) (user *User, provider authn.ProviderType, method authn.MethodType, err error) {
// Get sanitized username from login form. // Get sanitized username from login form.
nameName := f.CleanUsername() nameName := f.CleanUsername()
@@ -27,7 +27,7 @@ var Auth = func(f form.Login, m *Session, c *gin.Context) (user *User, provider
user = FindUserByName(nameName) user = FindUserByName(nameName)
// Try local authentication. // Try local authentication.
provider, method, err = AuthLocal(user, f, m, c) provider, method, err = AuthLocal(user, f, s, c)
if err != nil { if err != nil {
return user, provider, method, err return user, provider, method, err
@@ -69,7 +69,7 @@ func AuthSession(f form.Login, c *gin.Context) (sess *Session, user *User, err e
} }
// AuthLocal authenticates against the local user database with the specified username and password. // AuthLocal authenticates against the local user database with the specified username and password.
func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider authn.ProviderType, method authn.MethodType, err error) { func AuthLocal(user *User, f form.Login, s *Session, c *gin.Context) (provider authn.ProviderType, method authn.MethodType, err error) {
// Set defaults. // Set defaults.
provider = authn.ProviderNone provider = authn.ProviderNone
method = authn.MethodUndefined method = authn.MethodUndefined
@@ -84,10 +84,10 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
if user == nil { if user == nil {
message := authn.ErrAccountNotFound.Error() message := authn.ErrAccountNotFound.Error()
if m != nil { if s != nil {
event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(username)) event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) event.LoginError(clientIp, "api", username, s.UserAgent, message)
m.Status = http.StatusUnauthorized s.Status = http.StatusUnauthorized
} }
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
@@ -97,20 +97,20 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
if !user.Provider().IsDefault() && !user.Provider().IsLocal() { if !user.Provider().IsDefault() && !user.Provider().IsLocal() {
message := fmt.Sprintf("%s authentication disabled", authn.ProviderLocal.String()) message := fmt.Sprintf("%s authentication disabled", authn.ProviderLocal.String())
if m != nil { if s != nil {
event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(username)) event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) event.LoginError(clientIp, "api", username, s.UserAgent, message)
m.Status = http.StatusUnauthorized s.Status = http.StatusUnauthorized
} }
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
} else if !user.CanLogIn() { } else if !user.CanLogIn() {
message := authn.ErrAccountDisabled.Error() message := authn.ErrAccountDisabled.Error()
if m != nil { if s != nil {
event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(username)) event.AuditWarn([]string{clientIp, "session %s", "login as %s", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) event.LoginError(clientIp, "api", username, s.UserAgent, message)
m.Status = http.StatusUnauthorized s.Status = http.StatusUnauthorized
} }
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
@@ -120,9 +120,13 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
if authSess, authUser, authErr := AuthSession(f, c); authSess != nil && authUser != nil && authErr == nil { if authSess, authUser, authErr := AuthSession(f, c); authSess != nil && authUser != nil && authErr == nil {
if !authUser.IsRegistered() || authUser.UserUID != user.UserUID { if !authUser.IsRegistered() || authUser.UserUID != user.UserUID {
message := authn.ErrInvalidUsername.Error() message := authn.ErrInvalidUsername.Error()
event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) if s != nil {
m.Status = http.StatusUnauthorized event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, s.UserAgent, message)
s.Status = http.StatusUnauthorized
}
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
} else if insufficientScope := authSess.InsufficientScope(acl.ResourceSessions, acl.Permissions{acl.ActionCreate}); insufficientScope || !authSess.IsClient() { } else if insufficientScope := authSess.InsufficientScope(acl.ResourceSessions, acl.Permissions{acl.ActionCreate}); insufficientScope || !authSess.IsClient() {
var message string var message string
@@ -131,19 +135,27 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
} else { } else {
message = authn.ErrUnauthorized.Error() message = authn.ErrUnauthorized.Error()
} }
event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, m.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) if s != nil {
m.Status = http.StatusUnauthorized event.AuditErr([]string{clientIp, "session %s", "login as %s with app password", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, s.UserAgent, message)
s.Status = http.StatusUnauthorized
}
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
} else { } else {
provider = authn.ProviderApplication provider = authn.ProviderApplication
method = authn.MethodSession method = authn.MethodSession
m.ClientUID = authSess.ClientUID
m.ClientName = authSess.ClientName if s != nil {
m.SetScope(authSess.Scope()) s.ClientUID = authSess.ClientUID
m.SetMethod(authn.MethodSession) s.ClientName = authSess.ClientName
event.AuditInfo([]string{clientIp, "session %s", "login as %s with app password", authn.Succeeded}, m.RefID, clean.LogQuote(username)) s.SetScope(authSess.Scope())
event.LoginInfo(clientIp, "api", username, m.UserAgent) s.SetMethod(authn.MethodSession)
event.AuditInfo([]string{clientIp, "session %s", "login as %s with app password", authn.Succeeded}, s.RefID, clean.LogQuote(username))
event.LoginInfo(clientIp, "api", username, s.UserAgent)
}
return provider, method, authErr return provider, method, authErr
} }
} }
@@ -152,10 +164,10 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
if user.InvalidPassword(f.Password) { if user.InvalidPassword(f.Password) {
message := authn.ErrInvalidPassword.Error() message := authn.ErrInvalidPassword.Error()
if m != nil { if s != nil {
event.AuditErr([]string{clientIp, "session %s", "login as %s", message}, m.RefID, clean.LogQuote(username)) event.AuditErr([]string{clientIp, "session %s", "login as %s", message}, s.RefID, clean.LogQuote(username))
event.LoginError(clientIp, "api", username, m.UserAgent, message) event.LoginError(clientIp, "api", username, s.UserAgent, message)
m.Status = http.StatusUnauthorized s.Status = http.StatusUnauthorized
} }
return provider, method, i18n.Error(i18n.ErrInvalidCredentials) return provider, method, i18n.Error(i18n.ErrInvalidCredentials)
@@ -174,9 +186,9 @@ func AuthLocal(user *User, f form.Login, m *Session, c *gin.Context) (provider a
method = authn.MethodDefault method = authn.MethodDefault
} }
if m != nil { if s != nil {
event.AuditInfo([]string{clientIp, "session %s", "login as %s", authn.Succeeded}, m.RefID, clean.LogQuote(username)) event.AuditInfo([]string{clientIp, "session %s", "login as %s", authn.Succeeded}, s.RefID, clean.LogQuote(username))
event.LoginInfo(clientIp, "api", username, m.UserAgent) event.LoginInfo(clientIp, "api", username, s.UserAgent)
} }
return provider, method, nil return provider, method, nil

View File

@@ -110,7 +110,7 @@ func TestDeleteClientSessions(t *testing.T) {
client.ClientUID = clientUID client.ClientUID = clientUID
// Make sure no sessions exist yet and test missing arguments. // Make sure no sessions exist yet and test missing arguments.
assert.Equal(t, 0, DeleteClientSessions(&Client{}, "", -1)) assert.Equal(t, 0, DeleteClientSessions(&Client{}, authn.MethodUndefined, -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(&Client{}, authn.MethodDefault, 0)) assert.Equal(t, 0, DeleteClientSessions(&Client{}, authn.MethodDefault, 0))
@@ -127,10 +127,11 @@ func TestDeleteClientSessions(t *testing.T) {
// Check if the expected number of sessions is deleted until none are left. // Check if the expected number of sessions is deleted until none are left.
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, -1))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOIDC, 1)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodDefault, 1))
assert.Equal(t, 9, DeleteClientSessions(client, authn.MethodOAuth2, 1)) assert.Equal(t, 9, DeleteClientSessions(client, authn.MethodOAuth2, 1))
assert.Equal(t, 1, DeleteClientSessions(client, authn.MethodOAuth2, 0)) assert.Equal(t, 1, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0)) assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodOAuth2, 0))
assert.Equal(t, 0, DeleteClientSessions(client, authn.MethodUndefined, 0))
} }
func TestSessionStatusUnauthorized(t *testing.T) { func TestSessionStatusUnauthorized(t *testing.T) {

View File

@@ -543,7 +543,7 @@ func (m *User) SetProvider(t authn.ProviderType) *User {
m.AuthProvider = t.String() m.AuthProvider = t.String()
if !m.Provider().Supports2FA() && m.Method().Is(authn.Method2FA) { if !m.Provider().SupportsPasscodeAuthentication() && m.Method().Is(authn.Method2FA) {
m.AuthMethod = "" m.AuthMethod = ""
} }
@@ -561,7 +561,8 @@ func (m *User) SetMethod(method authn.MethodType) *User {
return m return m
} }
if !m.Provider().Supports2FA() && method.Is(authn.Method2FA) { // It must not be possible to activate 2FA if the authentication provider does not support passcodes.
if !m.Provider().SupportsPasscodeAuthentication() && method.Is(authn.Method2FA) {
return m return m
} }
@@ -936,7 +937,7 @@ func (m *User) VerifyPasscode(code string) (valid bool, passcode *Passcode, err
func (m *User) ActivatePasscode() (passcode *Passcode, err error) { func (m *User) ActivatePasscode() (passcode *Passcode, err error) {
if m == nil { if m == nil {
err = errors.New("user is nil") err = errors.New("user is nil")
} else if !m.Provider().Supports2FA() { } else if !m.Provider().SupportsPasscodeAuthentication() {
err = authn.ErrPasscodeNotSupported err = authn.ErrPasscodeNotSupported
} else if passcode = m.Passcode(authn.KeyTOTP); passcode == nil { } else if passcode = m.Passcode(authn.KeyTOTP); passcode == nil {
// Cannot enable 2FA if user has no passcode. // Cannot enable 2FA if user has no passcode.

View File

@@ -7,11 +7,11 @@ import (
// Login represents a login form. // Login represents a login form.
type Login struct { type Login struct {
Username string `json:"username,omitempty"` Username string `json:"username,omitempty"` // The local Username or LDAP user principal name (UPN).
Password string `json:"password,omitempty"` Password string `json:"password,omitempty"` // The user's Password.
Code string `json:"code,omitempty"` Code string `json:"code,omitempty"` // 2FA Verification Code (Passcodes).
Token string `json:"token,omitempty"` Token string `json:"token,omitempty"` // Share Token.
Email string `json:"email,omitempty"` Email string `json:"email,omitempty"` // Reserved.
} }
// CleanUsername returns the sanitized and normalized username. // CleanUsername returns the sanitized and normalized username.

View File

@@ -11,6 +11,7 @@ import (
type OAuthCreateToken struct { type OAuthCreateToken struct {
GrantType authn.GrantType `form:"grant_type" json:"grant_type,omitempty"` GrantType authn.GrantType `form:"grant_type" json:"grant_type,omitempty"`
ClientID string `form:"client_id" json:"client_id,omitempty"` ClientID string `form:"client_id" json:"client_id,omitempty"`
ClientName string `form:"client_name" json:"client_name,omitempty"`
ClientSecret string `form:"client_secret" json:" client_secret,omitempty"` ClientSecret string `form:"client_secret" json:" client_secret,omitempty"`
Username string `form:"username" json:"username,omitempty"` Username string `form:"username" json:"username,omitempty"`
Password string `form:"password" json:"password,omitempty"` Password string `form:"password" json:"password,omitempty"`
@@ -19,9 +20,8 @@ type OAuthCreateToken struct {
CodeVerifier string `form:"code_verifier" json:"code_verifier,omitempty"` CodeVerifier string `form:"code_verifier" json:"code_verifier,omitempty"`
RedirectURI string `form:"redirect_uri" json:"redirect_uri,omitempty"` RedirectURI string `form:"redirect_uri" json:"redirect_uri,omitempty"`
Assertion string `form:"assertion" json:"assertion,omitempty"` Assertion string `form:"assertion" json:"assertion,omitempty"`
Name string `form:"name" json:"name,omitempty"`
Scope string `form:"scope" json:"scope,omitempty"` Scope string `form:"scope" json:"scope,omitempty"`
Expires int `form:"expires" json:"expires,omitempty"` Lifetime int64 `form:"lifetime" json:"lifetime,omitempty"`
} }
// Validate verifies the request parameters depending on the grant type. // Validate verifies the request parameters depending on the grant type.
@@ -51,7 +51,7 @@ func (f OAuthCreateToken) Validate() error {
return authn.ErrPasswordRequired return authn.ErrPasswordRequired
} else if len(f.Password) > txt.ClipPassword { } else if len(f.Password) > txt.ClipPassword {
return authn.ErrInvalidCredentials return authn.ErrInvalidCredentials
} else if f.Name == "" { } else if f.ClientName == "" {
return authn.ErrNameRequired return authn.ErrNameRequired
} else if f.Scope == "" { } else if f.Scope == "" {
return authn.ErrScopeRequired return authn.ErrScopeRequired

View File

@@ -68,22 +68,22 @@ func TestOAuthCreateToken_Validate(t *testing.T) {
}) })
t.Run("Password", func(t *testing.T) { t.Run("Password", func(t *testing.T) {
m := OAuthCreateToken{ m := OAuthCreateToken{
GrantType: authn.GrantPassword, GrantType: authn.GrantPassword,
Username: "admin", Username: "admin",
Password: "cs5gfen1bgxz7s9i", Password: "cs5gfen1bgxz7s9i",
Name: "test", ClientName: "test",
Scope: "*", Scope: "*",
} }
assert.NoError(t, m.Validate()) assert.NoError(t, m.Validate())
}) })
t.Run("PasswordRequired", func(t *testing.T) { t.Run("PasswordRequired", func(t *testing.T) {
m := OAuthCreateToken{ m := OAuthCreateToken{
GrantType: authn.GrantPassword, GrantType: authn.GrantPassword,
Username: "admin", Username: "admin",
Password: "", Password: "",
Name: "test", ClientName: "test",
Scope: "*", Scope: "*",
} }
assert.Error(t, m.Validate()) assert.Error(t, m.Validate())

View File

@@ -1,45 +1,71 @@
package form package form
import ( import (
"fmt" "github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/rnd" "github.com/photoprism/photoprism/pkg/rnd"
) )
const ( const (
ClientAccessToken = "access_token" RefID = "ref_id"
SessionID = "session_id"
AccessToken = "access_token"
) )
// OAuthRevokeToken represents a token revokation form. // OAuthRevokeToken represents a token revokation form.
type OAuthRevokeToken struct { type OAuthRevokeToken struct {
AuthToken string `form:"token" binding:"required" json:"token,omitempty"` Token string `form:"token" binding:"required" json:"token,omitempty"`
TypeHint string `form:"token_type_hint" json:" token_type_hint,omitempty"` TokenTypeHint string `form:"token_type_hint" json:" token_type_hint,omitempty"`
} }
// Empty checks if all form values are unset. // Empty checks if all form values are unset.
func (f OAuthRevokeToken) Empty() bool { func (f *OAuthRevokeToken) Empty() bool {
switch { switch {
case f.AuthToken != "": case f.Token != "":
return false return false
case f.TypeHint != "": case f.TokenTypeHint != "":
return false return false
} }
return true return true
} }
// Validate checks the token and token type. // Validate checks the revoke token form values and returns an error if invalid.
func (f OAuthRevokeToken) Validate() error { func (f *OAuthRevokeToken) Validate() error {
// Check auth token. // Require a token.
if f.AuthToken == "" { if f.Token == "" {
return fmt.Errorf("missing token") return authn.ErrTokenRequired
} else if !rnd.IsAlnum(f.AuthToken) {
return fmt.Errorf("invalid token")
} }
// Check token type. // Validate token type.
if f.TypeHint != "" && f.TypeHint != ClientAccessToken { isRefID := rnd.IsRefID(f.Token)
return fmt.Errorf("unsupported token type") isSessionID := rnd.IsSessionID(f.Token)
isAuthAny := rnd.IsAuthAny(f.Token)
switch f.TokenTypeHint {
case "":
if !isRefID && !isSessionID && !isAuthAny {
return authn.ErrInvalidToken
} else if isRefID {
f.TokenTypeHint = RefID
} else if isSessionID {
f.TokenTypeHint = SessionID
} else {
f.TokenTypeHint = AccessToken
}
case RefID:
if !isRefID {
return authn.ErrInvalidToken
}
case SessionID:
if !isSessionID {
return authn.ErrInvalidToken
}
case AccessToken:
if !isAuthAny {
return authn.ErrInvalidToken
}
default:
return authn.ErrInvalidTokenType
} }
return nil return nil

View File

@@ -4,27 +4,29 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/rnd"
) )
func TestOAuthRevokeToken_Empty(t *testing.T) { func TestOAuthRevokeToken_Empty(t *testing.T) {
t.Run("AuthTokenAndTypeHintEmpty", func(t *testing.T) { t.Run("AuthTokenAndTypeHintEmpty", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "", Token: "",
TypeHint: "", TokenTypeHint: "",
} }
assert.True(t, m.Empty()) assert.True(t, m.Empty())
}) })
t.Run("AuthTokenNotEmpty", func(t *testing.T) { t.Run("AuthTokenNotEmpty", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "abc", Token: "abc",
TypeHint: "", TokenTypeHint: "",
} }
assert.False(t, m.Empty()) assert.False(t, m.Empty())
}) })
t.Run("TypeHintNotEmpty", func(t *testing.T) { t.Run("TypeHintNotEmpty", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "", Token: "",
TypeHint: "test", TokenTypeHint: "test",
} }
assert.False(t, m.Empty()) assert.False(t, m.Empty())
}) })
@@ -33,30 +35,47 @@ func TestOAuthRevokeToken_Empty(t *testing.T) {
func TestOAuthRevokeToken_Validate(t *testing.T) { func TestOAuthRevokeToken_Validate(t *testing.T) {
t.Run("AuthTokenEmpty", func(t *testing.T) { t.Run("AuthTokenEmpty", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "", Token: "",
TypeHint: "test", TokenTypeHint: "test",
} }
assert.Error(t, m.Validate()) assert.Error(t, m.Validate())
}) })
t.Run("AuthTokenInvalid", func(t *testing.T) { t.Run("AuthTokenInvalid", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "abc 234", Token: "abc 234",
TypeHint: "test", TokenTypeHint: "test",
} }
assert.Error(t, m.Validate()) assert.Error(t, m.Validate())
}) })
t.Run("UnsupportedToken", func(t *testing.T) { t.Run("UnsupportedToken", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "abc234", Token: "abc234",
TypeHint: "test", TokenTypeHint: "test",
} }
assert.Error(t, m.Validate()) assert.Error(t, m.Validate())
}) })
t.Run("Valid", func(t *testing.T) { t.Run("AccessToken", func(t *testing.T) {
m := OAuthRevokeToken{ m := OAuthRevokeToken{
AuthToken: "abc234", Token: rnd.AuthToken(),
TypeHint: "access_token", TokenTypeHint: "access_token",
} }
assert.NoError(t, m.Validate()) assert.NoError(t, m.Validate())
assert.Equal(t, AccessToken, m.TokenTypeHint)
})
t.Run("SessionID", func(t *testing.T) {
m := OAuthRevokeToken{
Token: rnd.SessionID(rnd.AuthToken()),
TokenTypeHint: "session_id",
}
assert.NoError(t, m.Validate())
assert.Equal(t, SessionID, m.TokenTypeHint)
})
t.Run("NoTokenTypeHint", func(t *testing.T) {
m := OAuthRevokeToken{
Token: rnd.AuthToken(),
TokenTypeHint: "",
}
assert.NoError(t, m.Validate())
assert.Equal(t, AccessToken, m.TokenTypeHint)
}) })
} }

View File

@@ -2,7 +2,15 @@ package authn
// Generic status messages for authentication and authorization: // Generic status messages for authentication and authorization:
const ( const (
Denied = "denied" Denied = "denied"
Granted = "granted" Granted = "granted"
Succeeded = "succeeded" Created = "created"
Succeeded = "succeeded"
Verified = "verified"
Activated = "activated"
Deactivated = "deactivated"
Passcode = "passcode"
Session = "session"
Sessions = "sessions"
Users = "users"
) )

View File

@@ -13,13 +13,18 @@ var (
ErrAccountAlreadyExists = errors.New("account already exists") ErrAccountAlreadyExists = errors.New("account already exists")
ErrAccountNotFound = errors.New("account not found") ErrAccountNotFound = errors.New("account not found")
ErrAccountDisabled = errors.New("account disabled") ErrAccountDisabled = errors.New("account disabled")
ErrInvalidRequest = errors.New("invalid request")
ErrInvalidCredentials = errors.New("invalid credentials") ErrInvalidCredentials = errors.New("invalid credentials")
ErrInvalidShareToken = errors.New("invalid share token") ErrInvalidShareToken = errors.New("invalid share token")
ErrTokenRequired = errors.New("token required")
ErrInvalidToken = errors.New("invalid token")
ErrInvalidTokenType = errors.New("invalid token type")
ErrInsufficientScope = errors.New("insufficient scope") ErrInsufficientScope = errors.New("insufficient scope")
ErrNameRequired = errors.New("name required") ErrNameRequired = errors.New("name required")
ErrScopeRequired = errors.New("scope required") ErrScopeRequired = errors.New("scope required")
ErrDisabledInPublicMode = errors.New("disabled in public mode") ErrDisabledInPublicMode = errors.New("disabled in public mode")
ErrAuthenticationDisabled = errors.New("authentication disabled") ErrAuthenticationDisabled = errors.New("authentication disabled")
ErrRateLimitExceeded = errors.New("rate limit exceeded")
) )
// OAuth2-related error messages: // OAuth2-related error messages:
@@ -31,7 +36,7 @@ var (
ErrClientSecretRequired = errors.New("client secret required") ErrClientSecretRequired = errors.New("client secret required")
) )
// Username-related error messages: // User-related error messages:
var ( var (
ErrUsernameRequired = errors.New("username required") ErrUsernameRequired = errors.New("username required")
ErrInvalidUsername = errors.New("invalid username") ErrInvalidUsername = errors.New("invalid username")
@@ -39,15 +44,21 @@ var (
// Passcode-related error messages: // Passcode-related error messages:
var ( var (
ErrPasscodeRequired = errors.New("passcode required") ErrPasscodeRequired = errors.New("passcode required")
ErrPasscodeNotSetUp = errors.New("passcode required, but not configured") ErrPasscodeNotSetUp = errors.New("passcode required, but not configured")
ErrPasscodeNotVerified = errors.New("passcode not verified") ErrPasscodeNotVerified = errors.New("passcode not verified")
ErrPasscodeAlreadyActivated = errors.New("passcode already activated") ErrPasscodeAlreadyActivated = errors.New("passcode already activated")
ErrPasscodeNotSupported = errors.New("passcode not supported") ErrPasscodeGenerateFailed = errors.New("failed to generate passcode")
ErrInvalidPasscode = errors.New("invalid passcode") ErrPasscodeCreateFailed = errors.New("failed to create passcode")
ErrInvalidPasscodeFormat = errors.New("invalid passcode format") ErrPasscodeSaveFailed = errors.New("failed to save passcode")
ErrInvalidPasscodeKey = errors.New("invalid passcode key") ErrPasscodeVerificationFailed = errors.New("failed to verify passcode")
ErrInvalidPasscodeType = errors.New("invalid passcode type") ErrPasscodeActivationFailed = errors.New("failed to activate passcode")
ErrPasscodeDeactivationFailed = errors.New("failed to deactivate passcode")
ErrPasscodeNotSupported = errors.New("passcode not supported")
ErrInvalidPasscode = errors.New("invalid passcode")
ErrInvalidPasscodeFormat = errors.New("invalid passcode format")
ErrInvalidPasscodeKey = errors.New("invalid passcode key")
ErrInvalidPasscodeType = errors.New("invalid passcode type")
) )
// Password-related error messages: // Password-related error messages:

View File

@@ -13,6 +13,7 @@ const (
GrantUndefined GrantType = "" GrantUndefined GrantType = ""
GrantCLI GrantType = "cli" GrantCLI GrantType = "cli"
GrantImplicit GrantType = "implicit" GrantImplicit GrantType = "implicit"
GrantSession GrantType = "session"
GrantPassword GrantType = "password" GrantPassword GrantType = "password"
GrantClientCredentials GrantType = "client_credentials" GrantClientCredentials GrantType = "client_credentials"
GrantShareToken GrantType = "share_token" GrantShareToken GrantType = "share_token"
@@ -33,7 +34,9 @@ func Grant(s string) GrantType {
return GrantCLI return GrantCLI
case "implicit": case "implicit":
return GrantImplicit return GrantImplicit
case "password", "passwd", "pass", "user", "username": case "session":
return GrantSession
case "password", "passwd", "pass":
return GrantPassword return GrantPassword
case "client_credentials", "client": case "client_credentials", "client":
return GrantClientCredentials return GrantClientCredentials
@@ -61,6 +64,8 @@ func (t GrantType) Pretty() string {
return "CLI" return "CLI"
case GrantImplicit: case GrantImplicit:
return "Implicit" return "Implicit"
case GrantSession:
return "Session"
case GrantPassword: case GrantPassword:
return "Password" return "Password"
case GrantClientCredentials: case GrantClientCredentials:

View File

@@ -9,6 +9,7 @@ import (
func TestGrantType_String(t *testing.T) { func TestGrantType_String(t *testing.T) {
assert.Equal(t, "", GrantUndefined.String()) assert.Equal(t, "", GrantUndefined.String())
assert.Equal(t, "client_credentials", GrantClientCredentials.String()) assert.Equal(t, "client_credentials", GrantClientCredentials.String())
assert.Equal(t, "session", GrantSession.String())
assert.Equal(t, "password", GrantPassword.String()) assert.Equal(t, "password", GrantPassword.String())
assert.Equal(t, "refresh_token", GrantRefreshToken.String()) assert.Equal(t, "refresh_token", GrantRefreshToken.String())
assert.Equal(t, "authorization_code", GrantAuthorizationCode.String()) assert.Equal(t, "authorization_code", GrantAuthorizationCode.String())
@@ -20,6 +21,7 @@ func TestGrantType_String(t *testing.T) {
func TestGrantType_Is(t *testing.T) { func TestGrantType_Is(t *testing.T) {
assert.Equal(t, true, GrantUndefined.Is(GrantUndefined)) assert.Equal(t, true, GrantUndefined.Is(GrantUndefined))
assert.Equal(t, true, GrantClientCredentials.Is(GrantClientCredentials)) assert.Equal(t, true, GrantClientCredentials.Is(GrantClientCredentials))
assert.Equal(t, true, GrantSession.Is(GrantSession))
assert.Equal(t, true, GrantPassword.Is(GrantPassword)) assert.Equal(t, true, GrantPassword.Is(GrantPassword))
assert.Equal(t, false, GrantClientCredentials.Is(GrantPassword)) assert.Equal(t, false, GrantClientCredentials.Is(GrantPassword))
assert.Equal(t, false, GrantClientCredentials.Is(GrantRefreshToken)) assert.Equal(t, false, GrantClientCredentials.Is(GrantRefreshToken))
@@ -46,6 +48,7 @@ func TestGrantType_IsNot(t *testing.T) {
func TestGrantType_IsUndefined(t *testing.T) { func TestGrantType_IsUndefined(t *testing.T) {
assert.Equal(t, true, GrantUndefined.IsUndefined()) assert.Equal(t, true, GrantUndefined.IsUndefined())
assert.Equal(t, false, GrantClientCredentials.IsUndefined()) assert.Equal(t, false, GrantClientCredentials.IsUndefined())
assert.Equal(t, false, GrantSession.IsUndefined())
assert.Equal(t, false, GrantPassword.IsUndefined()) assert.Equal(t, false, GrantPassword.IsUndefined())
} }
@@ -53,6 +56,7 @@ func TestGrantType_Pretty(t *testing.T) {
assert.Equal(t, "", GrantUndefined.Pretty()) assert.Equal(t, "", GrantUndefined.Pretty())
assert.Equal(t, "CLI", GrantCLI.Pretty()) assert.Equal(t, "CLI", GrantCLI.Pretty())
assert.Equal(t, "Client Credentials", GrantClientCredentials.Pretty()) assert.Equal(t, "Client Credentials", GrantClientCredentials.Pretty())
assert.Equal(t, "Session", GrantSession.Pretty())
assert.Equal(t, "Password", GrantPassword.Pretty()) assert.Equal(t, "Password", GrantPassword.Pretty())
assert.Equal(t, "Refresh Token", GrantRefreshToken.Pretty()) assert.Equal(t, "Refresh Token", GrantRefreshToken.Pretty())
assert.Equal(t, "Authorization Code", GrantAuthorizationCode.Pretty()) assert.Equal(t, "Authorization Code", GrantAuthorizationCode.Pretty())
@@ -66,6 +70,7 @@ func TestGrantType_Equal(t *testing.T) {
assert.True(t, GrantClientCredentials.Equal("client_credentials")) assert.True(t, GrantClientCredentials.Equal("client_credentials"))
assert.True(t, GrantClientCredentials.Equal("client")) assert.True(t, GrantClientCredentials.Equal("client"))
assert.True(t, GrantUndefined.Equal("")) assert.True(t, GrantUndefined.Equal(""))
assert.True(t, GrantSession.Equal("session"))
assert.True(t, GrantPassword.Equal("Password")) assert.True(t, GrantPassword.Equal("Password"))
assert.True(t, GrantPassword.Equal("password")) assert.True(t, GrantPassword.Equal("password"))
assert.True(t, GrantPassword.Equal("pass")) assert.True(t, GrantPassword.Equal("pass"))
@@ -89,6 +94,7 @@ func TestGrant(t *testing.T) {
assert.Equal(t, GrantUndefined, Grant("")) assert.Equal(t, GrantUndefined, Grant(""))
assert.Equal(t, GrantCLI, Grant("cli")) assert.Equal(t, GrantCLI, Grant("cli"))
assert.Equal(t, GrantImplicit, Grant("implicit")) assert.Equal(t, GrantImplicit, Grant("implicit"))
assert.Equal(t, GrantSession, Grant("session"))
assert.Equal(t, GrantPassword, Grant("pass")) assert.Equal(t, GrantPassword, Grant("pass"))
assert.Equal(t, GrantPassword, Grant("password")) assert.Equal(t, GrantPassword, Grant("password"))
assert.Equal(t, GrantClientCredentials, Grant("client credentials")) assert.Equal(t, GrantClientCredentials, Grant("client credentials"))

View File

@@ -16,7 +16,6 @@ const (
MethodDefault MethodType = "default" MethodDefault MethodType = "default"
MethodSession MethodType = "session" MethodSession MethodType = "session"
MethodOAuth2 MethodType = "oauth2" MethodOAuth2 MethodType = "oauth2"
MethodOIDC MethodType = "oidc"
Method2FA MethodType = "2fa" Method2FA MethodType = "2fa"
) )
@@ -30,8 +29,6 @@ func Method(s string) MethodType {
return MethodDefault return MethodDefault
case "oauth2", "oauth": case "oauth2", "oauth":
return MethodOAuth2 return MethodOAuth2
case "sso":
return MethodOIDC
case "2fa", "mfa", "otp", "totp": case "2fa", "mfa", "otp", "totp":
return Method2FA return Method2FA
case "access_token": case "access_token":
@@ -46,8 +43,6 @@ func (t MethodType) Pretty() string {
switch t { switch t {
case MethodOAuth2: case MethodOAuth2:
return "OAuth2" return "OAuth2"
case MethodOIDC:
return "OIDC"
case Method2FA: case Method2FA:
return "2FA" return "2FA"
default: default:
@@ -62,8 +57,6 @@ func (t MethodType) String() string {
return string(MethodDefault) return string(MethodDefault)
case "oauth": case "oauth":
return string(MethodOAuth2) return string(MethodOAuth2)
case "openid":
return string(MethodOIDC)
case "2fa", "otp", "totp": case "2fa", "otp", "totp":
return string(Method2FA) return string(Method2FA)
default: default:

View File

@@ -9,17 +9,13 @@ import (
func TestMethodType_String(t *testing.T) { func TestMethodType_String(t *testing.T) {
assert.Equal(t, "default", MethodDefault.String()) assert.Equal(t, "default", MethodDefault.String())
assert.Equal(t, "oauth2", MethodOAuth2.String()) assert.Equal(t, "oauth2", MethodOAuth2.String())
assert.Equal(t, "oidc", MethodOIDC.String())
assert.Equal(t, "2fa", Method2FA.String()) assert.Equal(t, "2fa", Method2FA.String())
assert.Equal(t, "default", MethodUndefined.String()) assert.Equal(t, "default", MethodUndefined.String())
} }
func TestMethodType_Is(t *testing.T) { func TestMethodType_Is(t *testing.T) {
assert.Equal(t, true, MethodDefault.Is(MethodDefault)) assert.Equal(t, true, MethodDefault.Is(MethodDefault))
assert.Equal(t, false, MethodOIDC.Is(MethodOAuth2))
assert.Equal(t, false, Method2FA.Is(MethodOIDC))
assert.Equal(t, true, MethodOAuth2.Is(MethodOAuth2)) assert.Equal(t, true, MethodOAuth2.Is(MethodOAuth2))
assert.Equal(t, true, MethodOIDC.Is(MethodOIDC))
assert.Equal(t, true, Method2FA.Is(Method2FA)) assert.Equal(t, true, Method2FA.Is(Method2FA))
assert.Equal(t, true, MethodUndefined.Is(MethodUndefined)) assert.Equal(t, true, MethodUndefined.Is(MethodUndefined))
} }
@@ -28,10 +24,7 @@ func TestMethodType_IsNot(t *testing.T) {
assert.Equal(t, true, MethodDefault.IsNot(MethodUndefined)) assert.Equal(t, true, MethodDefault.IsNot(MethodUndefined))
assert.Equal(t, false, MethodDefault.IsNot(MethodDefault)) assert.Equal(t, false, MethodDefault.IsNot(MethodDefault))
assert.Equal(t, false, MethodOAuth2.IsNot(MethodOAuth2)) assert.Equal(t, false, MethodOAuth2.IsNot(MethodOAuth2))
assert.Equal(t, false, MethodOIDC.IsNot(MethodOIDC))
assert.Equal(t, false, Method2FA.IsNot(Method2FA)) assert.Equal(t, false, Method2FA.IsNot(Method2FA))
assert.Equal(t, true, MethodOAuth2.IsNot(MethodOIDC))
assert.Equal(t, true, MethodOIDC.IsNot(MethodOAuth2))
assert.Equal(t, true, Method2FA.IsNot(MethodOAuth2)) assert.Equal(t, true, Method2FA.IsNot(MethodOAuth2))
assert.Equal(t, true, MethodUndefined.IsNot(MethodDefault)) assert.Equal(t, true, MethodUndefined.IsNot(MethodDefault))
} }
@@ -44,7 +37,6 @@ func TestMethodType_IsUndefined(t *testing.T) {
func TestMethodType_IsDefault(t *testing.T) { func TestMethodType_IsDefault(t *testing.T) {
assert.Equal(t, true, MethodDefault.IsDefault()) assert.Equal(t, true, MethodDefault.IsDefault())
assert.Equal(t, false, MethodOAuth2.IsDefault()) assert.Equal(t, false, MethodOAuth2.IsDefault())
assert.Equal(t, false, MethodOIDC.IsDefault())
assert.Equal(t, false, Method2FA.IsDefault()) assert.Equal(t, false, Method2FA.IsDefault())
assert.Equal(t, true, MethodUndefined.IsDefault()) assert.Equal(t, true, MethodUndefined.IsDefault())
} }
@@ -52,7 +44,6 @@ func TestMethodType_IsDefault(t *testing.T) {
func TestMethodType_Pretty(t *testing.T) { func TestMethodType_Pretty(t *testing.T) {
assert.Equal(t, "Default", MethodDefault.Pretty()) assert.Equal(t, "Default", MethodDefault.Pretty())
assert.Equal(t, "OAuth2", MethodOAuth2.Pretty()) assert.Equal(t, "OAuth2", MethodOAuth2.Pretty())
assert.Equal(t, "OIDC", MethodOIDC.Pretty())
assert.Equal(t, "2FA", Method2FA.Pretty()) assert.Equal(t, "2FA", Method2FA.Pretty())
assert.Equal(t, "Default", MethodUndefined.Pretty()) assert.Equal(t, "Default", MethodUndefined.Pretty())
} }
@@ -73,8 +64,6 @@ func TestMethod(t *testing.T) {
assert.Equal(t, MethodDefault, Method("access_token")) assert.Equal(t, MethodDefault, Method("access_token"))
assert.Equal(t, MethodDefault, Method("false")) assert.Equal(t, MethodDefault, Method("false"))
assert.Equal(t, MethodOAuth2, Method("oauth2")) assert.Equal(t, MethodOAuth2, Method("oauth2"))
assert.Equal(t, MethodOIDC, Method("oidc"))
assert.Equal(t, MethodOIDC, Method("sso"))
assert.Equal(t, Method2FA, Method("2fa")) assert.Equal(t, Method2FA, Method("2fa"))
assert.Equal(t, Method2FA, Method("totp")) assert.Equal(t, Method2FA, Method("totp"))
assert.Equal(t, Method2FA, Method("2FA")) assert.Equal(t, Method2FA, Method("2FA"))

View File

@@ -17,6 +17,7 @@ const (
ProviderApplication ProviderType = "application" ProviderApplication ProviderType = "application"
ProviderAccessToken ProviderType = "access_token" ProviderAccessToken ProviderType = "access_token"
ProviderLocal ProviderType = "local" ProviderLocal ProviderType = "local"
ProviderOIDC ProviderType = "oidc"
ProviderLDAP ProviderType = "ldap" ProviderLDAP ProviderType = "ldap"
ProviderLink ProviderType = "link" ProviderLink ProviderType = "link"
ProviderNone ProviderType = "none" ProviderNone ProviderType = "none"
@@ -24,6 +25,7 @@ const (
// RemoteProviders contains remote auth providers. // RemoteProviders contains remote auth providers.
var RemoteProviders = list.List{ var RemoteProviders = list.List{
string(ProviderOIDC),
string(ProviderLDAP), string(ProviderLDAP),
} }
@@ -32,18 +34,25 @@ var LocalProviders = list.List{
string(ProviderLocal), string(ProviderLocal),
} }
// Method2FAProviders contains auth providers that support Method2FA. // ClientProviders contains all client authentication providers.
var Method2FAProviders = list.List{ var ClientProviders = list.List{
string(ProviderClient),
string(ProviderApplication),
string(ProviderAccessToken),
}
// PasswordProviders contains authentication providers that allow a password to be checked for authentication.
var PasswordProviders = list.List{
string(ProviderDefault), string(ProviderDefault),
string(ProviderLocal), string(ProviderLocal),
string(ProviderLDAP), string(ProviderLDAP),
} }
// ClientProviders contains all client auth providers. // PasscodeProviders contains authentication providers that support 2-Factor Authentication (2FA) with a TOTP passcode.
var ClientProviders = list.List{ var PasscodeProviders = list.List{
string(ProviderClient), string(ProviderDefault),
string(ProviderApplication), string(ProviderLocal),
string(ProviderAccessToken), string(ProviderLDAP),
} }
// Provider casts a string to a normalized provider type. // Provider casts a string to a normalized provider type.
@@ -58,6 +67,8 @@ func Provider(s string) ProviderType {
return ProviderLocal return ProviderLocal
case "app", "application": case "app", "application":
return ProviderApplication return ProviderApplication
case "oidc", "openid":
return ProviderOIDC
case "ldap", "ad", "ldap/ad", "ldap\\ad": case "ldap", "ad", "ldap/ad", "ldap\\ad":
return ProviderLDAP return ProviderLDAP
case "client", "client_credentials", "oauth2": case "client", "client_credentials", "oauth2":
@@ -70,6 +81,8 @@ func Provider(s string) ProviderType {
// Pretty returns the provider identifier in an easy-to-read format. // Pretty returns the provider identifier in an easy-to-read format.
func (t ProviderType) Pretty() string { func (t ProviderType) Pretty() string {
switch t { switch t {
case ProviderOIDC:
return "OIDC"
case ProviderLDAP: case ProviderLDAP:
return "LDAP/AD" return "LDAP/AD"
case ProviderClient: case ProviderClient:
@@ -132,11 +145,6 @@ func (t ProviderType) IsLocal() bool {
return list.Contains(LocalProviders, string(t)) return list.Contains(LocalProviders, string(t))
} }
// Supports2FA checks if the provider supports two-factor authentication with a passcode.
func (t ProviderType) Supports2FA() bool {
return list.Contains(Method2FAProviders, string(t))
}
// IsClient checks if the authentication is provided for a client. // IsClient checks if the authentication is provided for a client.
func (t ProviderType) IsClient() bool { func (t ProviderType) IsClient() bool {
return list.Contains(ClientProviders, string(t)) return list.Contains(ClientProviders, string(t))
@@ -151,3 +159,13 @@ func (t ProviderType) IsApplication() bool {
func (t ProviderType) IsDefault() bool { func (t ProviderType) IsDefault() bool {
return t.String() == ProviderDefault.String() return t.String() == ProviderDefault.String()
} }
// SupportsPasswordAuthentication checks if the provider allows a password to be checked for authentication.
func (t ProviderType) SupportsPasswordAuthentication() bool {
return list.Contains(PasswordProviders, string(t))
}
// SupportsPasscodeAuthentication checks if the provider supports two-factor authentication with a passcode.
func (t ProviderType) SupportsPasscodeAuthentication() bool {
return list.Contains(PasscodeProviders, string(t))
}

View File

@@ -11,6 +11,7 @@ func TestProviderType_String(t *testing.T) {
assert.Equal(t, "default", ProviderDefault.String()) assert.Equal(t, "default", ProviderDefault.String())
assert.Equal(t, "none", ProviderNone.String()) assert.Equal(t, "none", ProviderNone.String())
assert.Equal(t, "local", ProviderLocal.String()) assert.Equal(t, "local", ProviderLocal.String())
assert.Equal(t, "oidc", ProviderOIDC.String())
assert.Equal(t, "ldap", ProviderLDAP.String()) assert.Equal(t, "ldap", ProviderLDAP.String())
assert.Equal(t, "link", ProviderLink.String()) assert.Equal(t, "link", ProviderLink.String())
assert.Equal(t, "access_token", ProviderAccessToken.String()) assert.Equal(t, "access_token", ProviderAccessToken.String())
@@ -19,6 +20,8 @@ func TestProviderType_String(t *testing.T) {
func TestProviderType_Is(t *testing.T) { func TestProviderType_Is(t *testing.T) {
assert.False(t, ProviderLocal.Is(ProviderLDAP)) assert.False(t, ProviderLocal.Is(ProviderLDAP))
assert.True(t, ProviderOIDC.Is(ProviderOIDC))
assert.False(t, ProviderOIDC.Is(ProviderLDAP))
assert.True(t, ProviderLDAP.Is(ProviderLDAP)) assert.True(t, ProviderLDAP.Is(ProviderLDAP))
assert.False(t, ProviderClient.Is(ProviderLDAP)) assert.False(t, ProviderClient.Is(ProviderLDAP))
assert.False(t, ProviderApplication.Is(ProviderLDAP)) assert.False(t, ProviderApplication.Is(ProviderLDAP))
@@ -30,6 +33,8 @@ func TestProviderType_Is(t *testing.T) {
func TestProviderType_IsNot(t *testing.T) { func TestProviderType_IsNot(t *testing.T) {
assert.False(t, ProviderLocal.IsNot(ProviderLocal)) assert.False(t, ProviderLocal.IsNot(ProviderLocal))
assert.False(t, ProviderOIDC.IsNot(ProviderOIDC))
assert.True(t, ProviderOIDC.IsNot(ProviderLDAP))
assert.True(t, ProviderLDAP.IsNot(ProviderLocal)) assert.True(t, ProviderLDAP.IsNot(ProviderLocal))
assert.False(t, ProviderClient.IsNot(ProviderClient)) assert.False(t, ProviderClient.IsNot(ProviderClient))
assert.False(t, ProviderApplication.IsNot(ProviderApplication)) assert.False(t, ProviderApplication.IsNot(ProviderApplication))
@@ -41,11 +46,14 @@ func TestProviderType_IsNot(t *testing.T) {
func TestProviderType_IsUndefined(t *testing.T) { func TestProviderType_IsUndefined(t *testing.T) {
assert.True(t, ProviderUndefined.IsUndefined()) assert.True(t, ProviderUndefined.IsUndefined())
assert.True(t, ProviderUndefined.IsDefault())
assert.False(t, ProviderLocal.IsUndefined()) assert.False(t, ProviderLocal.IsUndefined())
assert.False(t, ProviderOIDC.IsUndefined())
} }
func TestProviderType_IsRemote(t *testing.T) { func TestProviderType_IsRemote(t *testing.T) {
assert.False(t, ProviderLocal.IsRemote()) assert.False(t, ProviderLocal.IsRemote())
assert.True(t, ProviderOIDC.IsRemote())
assert.True(t, ProviderLDAP.IsRemote()) assert.True(t, ProviderLDAP.IsRemote())
assert.False(t, ProviderClient.IsRemote()) assert.False(t, ProviderClient.IsRemote())
assert.False(t, ProviderApplication.IsRemote()) assert.False(t, ProviderApplication.IsRemote())
@@ -57,6 +65,7 @@ func TestProviderType_IsRemote(t *testing.T) {
func TestProviderType_IsLocal(t *testing.T) { func TestProviderType_IsLocal(t *testing.T) {
assert.True(t, ProviderLocal.IsLocal()) assert.True(t, ProviderLocal.IsLocal())
assert.False(t, ProviderOIDC.IsLocal())
assert.False(t, ProviderLDAP.IsLocal()) assert.False(t, ProviderLDAP.IsLocal())
assert.False(t, ProviderClient.IsLocal()) assert.False(t, ProviderClient.IsLocal())
assert.False(t, ProviderApplication.IsLocal()) assert.False(t, ProviderApplication.IsLocal())
@@ -67,18 +76,20 @@ func TestProviderType_IsLocal(t *testing.T) {
} }
func TestProviderType_SupportsPasscode(t *testing.T) { func TestProviderType_SupportsPasscode(t *testing.T) {
assert.True(t, ProviderLocal.Supports2FA()) assert.True(t, ProviderLocal.SupportsPasscodeAuthentication())
assert.True(t, ProviderLDAP.Supports2FA()) assert.False(t, ProviderOIDC.SupportsPasscodeAuthentication())
assert.False(t, ProviderClient.Supports2FA()) assert.True(t, ProviderLDAP.SupportsPasscodeAuthentication())
assert.False(t, ProviderApplication.Supports2FA()) assert.False(t, ProviderClient.SupportsPasscodeAuthentication())
assert.False(t, ProviderAccessToken.Supports2FA()) assert.False(t, ProviderApplication.SupportsPasscodeAuthentication())
assert.False(t, ProviderNone.Supports2FA()) assert.False(t, ProviderAccessToken.SupportsPasscodeAuthentication())
assert.True(t, ProviderDefault.Supports2FA()) assert.False(t, ProviderNone.SupportsPasscodeAuthentication())
assert.False(t, ProviderUndefined.Supports2FA()) assert.True(t, ProviderDefault.SupportsPasscodeAuthentication())
assert.False(t, ProviderUndefined.SupportsPasscodeAuthentication())
} }
func TestProviderType_IsDefault(t *testing.T) { func TestProviderType_IsDefault(t *testing.T) {
assert.False(t, ProviderLocal.IsDefault()) assert.False(t, ProviderLocal.IsDefault())
assert.False(t, ProviderOIDC.IsDefault())
assert.False(t, ProviderLDAP.IsDefault()) assert.False(t, ProviderLDAP.IsDefault())
assert.False(t, ProviderNone.IsDefault()) assert.False(t, ProviderNone.IsDefault())
assert.True(t, ProviderDefault.IsDefault()) assert.True(t, ProviderDefault.IsDefault())
@@ -87,6 +98,7 @@ func TestProviderType_IsDefault(t *testing.T) {
func TestProviderType_IsClient(t *testing.T) { func TestProviderType_IsClient(t *testing.T) {
assert.False(t, ProviderLocal.IsClient()) assert.False(t, ProviderLocal.IsClient())
assert.False(t, ProviderOIDC.IsClient())
assert.False(t, ProviderLDAP.IsClient()) assert.False(t, ProviderLDAP.IsClient())
assert.False(t, ProviderNone.IsClient()) assert.False(t, ProviderNone.IsClient())
assert.False(t, ProviderDefault.IsClient()) assert.False(t, ProviderDefault.IsClient())
@@ -94,12 +106,16 @@ func TestProviderType_IsClient(t *testing.T) {
} }
func TestProviderType_Equal(t *testing.T) { func TestProviderType_Equal(t *testing.T) {
assert.True(t, ProviderOIDC.Equal("OIDC"))
assert.True(t, ProviderLDAP.Equal("LDAP"))
assert.True(t, ProviderClient.Equal("Client")) assert.True(t, ProviderClient.Equal("Client"))
assert.True(t, ProviderClient.Equal("Client Credentials")) assert.True(t, ProviderClient.Equal("Client Credentials"))
assert.False(t, ProviderLocal.Equal("Client")) assert.False(t, ProviderLocal.Equal("Client"))
} }
func TestProviderType_NotEqual(t *testing.T) { func TestProviderType_NotEqual(t *testing.T) {
assert.False(t, ProviderOIDC.NotEqual("OIDC"))
assert.False(t, ProviderLDAP.NotEqual("LDAP"))
assert.False(t, ProviderClient.NotEqual("Client")) assert.False(t, ProviderClient.NotEqual("Client"))
assert.False(t, ProviderClient.NotEqual("Client Credentials")) assert.False(t, ProviderClient.NotEqual("Client Credentials"))
assert.True(t, ProviderLocal.NotEqual("Client")) assert.True(t, ProviderLocal.NotEqual("Client"))
@@ -107,6 +123,7 @@ func TestProviderType_NotEqual(t *testing.T) {
func TestProviderType_Pretty(t *testing.T) { func TestProviderType_Pretty(t *testing.T) {
assert.Equal(t, "Local", ProviderLocal.Pretty()) assert.Equal(t, "Local", ProviderLocal.Pretty())
assert.Equal(t, "OIDC", ProviderOIDC.Pretty())
assert.Equal(t, "LDAP/AD", ProviderLDAP.Pretty()) assert.Equal(t, "LDAP/AD", ProviderLDAP.Pretty())
assert.Equal(t, "None", ProviderNone.Pretty()) assert.Equal(t, "None", ProviderNone.Pretty())
assert.Equal(t, "Default", ProviderDefault.Pretty()) assert.Equal(t, "Default", ProviderDefault.Pretty())
@@ -117,6 +134,7 @@ func TestProviderType_Pretty(t *testing.T) {
func TestProvider(t *testing.T) { func TestProvider(t *testing.T) {
assert.Equal(t, ProviderLocal, Provider("pass")) assert.Equal(t, ProviderLocal, Provider("pass"))
assert.Equal(t, ProviderOIDC, Provider("oidc"))
assert.Equal(t, ProviderLDAP, Provider("ad")) assert.Equal(t, ProviderLDAP, Provider("ad"))
assert.Equal(t, ProviderDefault, Provider("")) assert.Equal(t, ProviderDefault, Provider(""))
assert.Equal(t, ProviderLink, Provider("url")) assert.Equal(t, ProviderLink, Provider("url"))