Auth: Refactor grant, method, and provider types in pkg/authn #808 #4114

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer
2024-04-06 15:13:15 +02:00
parent cdd435f97c
commit b11491c9d6
18 changed files with 295 additions and 268 deletions

View File

@@ -8,6 +8,7 @@ import (
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/report"
"github.com/photoprism/photoprism/pkg/rnd"
@@ -92,7 +93,7 @@ func authAddAction(ctx *cli.Context) error {
}
// Create session and show the authentication secret.
sess, err := entity.AddClientAuthentication(clientName, ctx.Int64("expires"), authScope, user)
sess, err := entity.AddClientAuthentication(clientName, ctx.Int64("expires"), authScope, authn.GrantCLI, user)
if err != nil {
return fmt.Errorf("failed to create authentication secret: %s", err)

View File

@@ -63,7 +63,7 @@ var ClientAddFlags = []cli.Flag{
cli.StringFlag{
Name: "provider, p",
Usage: ClientAuthProvider,
Value: authn.ProviderClientCredentials.String(),
Value: authn.ProviderClient.String(),
Hidden: true,
},
cli.StringFlag{
@@ -107,7 +107,7 @@ var ClientModFlags = []cli.Flag{
cli.StringFlag{
Name: "provider, p",
Usage: ClientAuthProvider,
Value: authn.ProviderClientCredentials.String(),
Value: authn.ProviderClient.String(),
Hidden: true,
},
cli.StringFlag{

View File

@@ -20,7 +20,7 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0)
assert.NoError(t, err)
assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "client_credentials")
assert.Contains(t, output0, "client")
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "cs7pvt5h8rw9aaqj"})
@@ -44,7 +44,7 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output2)
assert.NoError(t, err)
assert.NotContains(t, output2, "not found")
assert.Contains(t, output2, "client_credentials")
assert.Contains(t, output2, "client")
})
t.Run("RemoveClient", func(t *testing.T) {
var err error
@@ -58,7 +58,7 @@ func TestCientsRemoveCommand(t *testing.T) {
//t.Logf(output0)
assert.NoError(t, err)
assert.NotContains(t, output0, "not found")
assert.Contains(t, output0, "client_credentials")
assert.Contains(t, output0, "client")
// Create test context with flags and arguments.
ctx := NewTestContext([]string{"rm", "--force", "cs7pvt5h8rw9aaqj"})

View File

@@ -63,7 +63,7 @@ func NewClient() *Client {
ClientType: authn.ClientConfidential,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "",
AuthExpires: unix.Hour,
@@ -547,7 +547,7 @@ func (m *Client) SetFormValues(frm form.Client) *Client {
// Replace empty values with defaults.
if m.AuthProvider == "" {
m.AuthProvider = authn.ProviderClientCredentials.String()
m.AuthProvider = authn.ProviderClient.String()
}
if m.AuthMethod == "" {

View File

@@ -35,7 +35,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientConfidential,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "*",
AuthExpires: unix.Day,
@@ -53,7 +53,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientPublic,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "*",
AuthExpires: 0,
@@ -71,7 +71,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientConfidential,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "metrics",
AuthExpires: unix.Hour,
@@ -89,7 +89,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientUnknown,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodUndefined.String(),
AuthScope: "*",
AuthExpires: unix.Hour,
@@ -107,7 +107,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientConfidential,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "metrics",
AuthExpires: unix.Hour,
@@ -125,7 +125,7 @@ var ClientFixtures = ClientMap{
ClientType: authn.ClientConfidential,
ClientURL: "",
CallbackURL: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "statistics",
AuthExpires: unix.Hour,

View File

@@ -303,15 +303,15 @@ func TestClient_SetSecret(t *testing.T) {
func TestClient_Provider(t *testing.T) {
t.Run("New", func(t *testing.T) {
client := NewClient()
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
})
t.Run("Alice", func(t *testing.T) {
client := ClientFixtures.Get("alice")
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
})
t.Run("Bob", func(t *testing.T) {
client := ClientFixtures.Get("bob")
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
})
}
@@ -497,13 +497,13 @@ func TestClient_UserInfo(t *testing.T) {
func TestClient_AuthInfo(t *testing.T) {
t.Run("New", func(t *testing.T) {
assert.Equal(t, "Client Credentials (OAuth2)", NewClient().AuthInfo())
assert.Equal(t, "Client (OAuth2)", NewClient().AuthInfo())
})
t.Run("Alice", func(t *testing.T) {
assert.Equal(t, "Client Credentials (OAuth2)", ClientFixtures.Pointer("alice").AuthInfo())
assert.Equal(t, "Client (OAuth2)", ClientFixtures.Pointer("alice").AuthInfo())
})
t.Run("Metrics", func(t *testing.T) {
assert.Equal(t, "Client Credentials (OAuth2)", ClientFixtures.Pointer("metrics").AuthInfo())
assert.Equal(t, "Client (OAuth2)", ClientFixtures.Pointer("metrics").AuthInfo())
})
}
@@ -526,7 +526,7 @@ func TestClient_SetFormValues(t *testing.T) {
var values = form.Client{
ClientName: "New Name",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "test",
AuthExpires: 4000,
@@ -550,7 +550,7 @@ func TestClient_SetFormValues(t *testing.T) {
var values = form.Client{
ClientName: "Annika",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "metrics",
AuthExpires: -4000,
@@ -574,7 +574,7 @@ func TestClient_SetFormValues(t *testing.T) {
var values = form.Client{
ClientName: "Friend",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "test",
AuthExpires: 4000000,
@@ -614,7 +614,7 @@ func TestClient_SetFormValues(t *testing.T) {
assert.Equal(t, int64(3600), c.AuthExpires)
assert.Equal(t, "*", c.AuthScope)
assert.Equal(t, "oauth2", c.AuthMethod)
assert.Equal(t, "client_credentials", c.AuthProvider)
assert.Equal(t, "client", c.AuthProvider)
})
}
@@ -624,7 +624,7 @@ func TestClient_Validate(t *testing.T) {
m := Client{
ClientName: "test",
ClientType: "test",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: "basic",
AuthScope: "all",
}
@@ -639,7 +639,7 @@ func TestClient_Validate(t *testing.T) {
m := Client{
ClientName: "",
ClientType: "test",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: "basic",
AuthScope: "all",
}
@@ -654,7 +654,7 @@ func TestClient_Validate(t *testing.T) {
m := Client{
ClientName: "test",
ClientType: "",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: "basic",
AuthScope: "all",
}
@@ -669,7 +669,7 @@ func TestClient_Validate(t *testing.T) {
m := Client{
ClientName: "test",
ClientType: "test",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: "",
AuthScope: "all",
}
@@ -684,7 +684,7 @@ func TestClient_Validate(t *testing.T) {
m := Client{
ClientName: "test",
ClientType: "test",
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: "basic",
AuthScope: "",
}

View File

@@ -6,7 +6,7 @@ import (
)
// NewClientAuthentication returns a new session that authenticates a client application.
func NewClientAuthentication(clientName string, lifetime int64, scope string, user *User) *Session {
func NewClientAuthentication(clientName string, lifetime int64, scope string, grantType authn.GrantType, user *User) *Session {
sess := NewSession(lifetime, 0)
if clientName == "" {
@@ -15,6 +15,7 @@ func NewClientAuthentication(clientName string, lifetime int64, scope string, us
sess.SetClientName(clientName)
sess.SetScope(scope)
sess.SetGrantType(grantType)
if user != nil {
sess.SetUser(user)
@@ -30,8 +31,8 @@ func NewClientAuthentication(clientName string, lifetime int64, scope string, us
}
// AddClientAuthentication creates a new session for authenticating a client application.
func AddClientAuthentication(clientName string, lifetime int64, scope string, user *User) (*Session, error) {
sess := NewClientAuthentication(clientName, lifetime, scope, user)
func AddClientAuthentication(clientName string, lifetime int64, scope string, grantType authn.GrantType, user *User) (*Session, error) {
sess := NewClientAuthentication(clientName, lifetime, scope, grantType, user)
if err := sess.Create(); err != nil {
return nil, err

View File

@@ -5,12 +5,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/authn"
"github.com/photoprism/photoprism/pkg/unix"
)
func TestNewClientAuthentication(t *testing.T) {
t.Run("Anonymous", func(t *testing.T) {
sess := NewClientAuthentication("Anonymous", unix.Day, "metrics", nil)
sess := NewClientAuthentication("Anonymous", unix.Day, "metrics", authn.GrantClientCredentials, nil)
if sess == nil {
t.Fatal("session must not be nil")
@@ -25,7 +26,7 @@ func TestNewClientAuthentication(t *testing.T) {
t.Fatal("user must not be nil")
}
sess := NewClientAuthentication("alice", unix.Day, "metrics", user)
sess := NewClientAuthentication("alice", unix.Day, "metrics", authn.GrantPassword, user)
if sess == nil {
t.Fatal("session must not be nil")
@@ -40,7 +41,7 @@ func TestNewClientAuthentication(t *testing.T) {
t.Fatal("user must not be nil")
}
sess := NewClientAuthentication("alice", unix.Day, "", user)
sess := NewClientAuthentication("alice", unix.Day, "", authn.GrantCLI, user)
if sess == nil {
t.Fatal("session must not be nil")
@@ -55,7 +56,7 @@ func TestNewClientAuthentication(t *testing.T) {
t.Fatal("user must not be nil")
}
sess := NewClientAuthentication("", 0, "metrics", user)
sess := NewClientAuthentication("", 0, "metrics", authn.GrantCLI, user)
if sess == nil {
t.Fatal("session must not be nil")
@@ -67,7 +68,7 @@ func TestNewClientAuthentication(t *testing.T) {
func TestAddClientAuthentication(t *testing.T) {
t.Run("Anonymous", func(t *testing.T) {
sess, err := AddClientAuthentication("", unix.Day, "metrics", nil)
sess, err := AddClientAuthentication("", unix.Day, "metrics", authn.GrantClientCredentials, nil)
assert.NoError(t, err)
@@ -84,7 +85,7 @@ func TestAddClientAuthentication(t *testing.T) {
t.Fatal("user must not be nil")
}
sess, err := AddClientAuthentication("My Client App Token", unix.Day, "metrics", user)
sess, err := AddClientAuthentication("My Client App Token", unix.Day, "metrics", authn.GrantCLI, user)
assert.NoError(t, err)

View File

@@ -32,6 +32,7 @@ var SessionFixtures = SessionMap{
RefID: "sessxkkcabcd",
SessTimeout: unix.Day * 3,
SessExpires: unix.Time() + unix.Week,
GrantType: authn.GrantPassword.String(),
user: UserFixtures.Pointer("alice"),
UserUID: UserFixtures.Pointer("alice").UserUID,
UserName: UserFixtures.Pointer("alice").UserName,
@@ -45,6 +46,7 @@ var SessionFixtures = SessionMap{
AuthScope: clean.Scope("*"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantCLI.String(),
ClientName: "alice_token",
LastActive: -1,
user: UserFixtures.Pointer("alice"),
@@ -59,7 +61,8 @@ var SessionFixtures = SessionMap{
SessExpires: unix.Time() + unix.Day,
AuthScope: clean.Scope("*"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodPersonal.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantPassword.String(),
ClientName: "alice_token_personal",
LastActive: -1,
user: UserFixtures.Pointer("alice"),
@@ -74,7 +77,8 @@ var SessionFixtures = SessionMap{
SessExpires: unix.Time() + unix.Day,
AuthScope: clean.Scope("webdav"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodPersonal.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantPassword.String(),
ClientName: "alice_token_webdav",
LastActive: -1,
user: UserFixtures.Pointer("alice"),
@@ -90,6 +94,7 @@ var SessionFixtures = SessionMap{
AuthScope: clean.Scope("metrics photos albums videos"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantPassword.String(),
ClientName: "alice_token_scope",
user: UserFixtures.Pointer("alice"),
UserUID: UserFixtures.Pointer("alice").UserUID,
@@ -103,6 +108,7 @@ var SessionFixtures = SessionMap{
RefID: "sessxkkcabce",
SessTimeout: unix.Day * 3,
SessExpires: unix.Time() + unix.Week,
GrantType: authn.GrantPassword.String(),
user: UserFixtures.Pointer("bob"),
UserUID: UserFixtures.Pointer("bob").UserUID,
UserName: UserFixtures.Pointer("bob").UserName,
@@ -113,6 +119,7 @@ var SessionFixtures = SessionMap{
RefID: "sessxkkcabcf",
SessTimeout: unix.Day * 3,
SessExpires: unix.Time() + unix.Week,
GrantType: authn.GrantImplicit.String(),
user: UserFixtures.Pointer("unauthorized"),
UserUID: UserFixtures.Pointer("unauthorized").UserUID,
UserName: UserFixtures.Pointer("unauthorized").UserName,
@@ -126,6 +133,7 @@ var SessionFixtures = SessionMap{
user: &Visitor,
UserUID: Visitor.UserUID,
UserName: Visitor.UserName,
GrantType: authn.GrantShareToken.String(),
DataJSON: []byte(`{"tokens":["1jxf3jfn2k"],"shares":["as6sg6bxpogaaba8"]}`),
data: &SessionData{
Tokens: []string{"1jxf3jfn2k"},
@@ -141,6 +149,7 @@ var SessionFixtures = SessionMap{
AuthScope: clean.Scope("metrics"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantShareToken.String(),
ClientName: "visitor_token_metrics",
user: &Visitor,
UserUID: Visitor.UserUID,
@@ -163,8 +172,9 @@ var SessionFixtures = SessionMap{
SessTimeout: 0,
SessExpires: unix.Time() + unix.Week,
AuthScope: clean.Scope("metrics"),
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
GrantType: authn.GrantClientCredentials.String(),
ClientUID: ClientFixtures.Get("metrics").ClientUID,
ClientName: ClientFixtures.Get("metrics").ClientName,
user: nil,
@@ -182,6 +192,7 @@ var SessionFixtures = SessionMap{
AuthScope: clean.Scope("metrics"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantCLI.String(),
ClientName: "token_metrics",
user: nil,
UserUID: "",
@@ -198,6 +209,7 @@ var SessionFixtures = SessionMap{
AuthScope: clean.Scope("settings"),
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodDefault.String(),
GrantType: authn.GrantCLI.String(),
ClientName: "token_settings",
user: nil,
UserUID: "",
@@ -212,8 +224,9 @@ var SessionFixtures = SessionMap{
SessTimeout: 0,
SessExpires: unix.Time() + unix.Week,
AuthScope: clean.Scope("statistics"),
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
GrantType: authn.GrantCLI.String(),
ClientUID: ClientFixtures.Get("analytics").ClientUID,
ClientName: ClientFixtures.Get("analytics").ClientName,
user: nil,

View File

@@ -1,11 +1,11 @@
package entity
import (
"github.com/gin-gonic/gin"
"net/http"
"testing"
"time"
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/internal/acl"
@@ -515,7 +515,7 @@ func TestSession_AuthInfo(t *testing.T) {
i := m.AuthInfo()
assert.Equal(t, "Access Token (Personal)", i)
assert.Equal(t, "Access Token", i)
})
}
@@ -550,20 +550,20 @@ func TestSession_SetMethod(t *testing.T) {
UserName: "test",
RefID: "sessxkkcxxxz",
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodPersonal.String(),
AuthMethod: authn.MethodDefault.String(),
}
m := s.SetMethod("")
assert.Equal(t, authn.ProviderAccessToken, m.Provider())
assert.Equal(t, authn.MethodPersonal, m.Method())
assert.Equal(t, authn.MethodDefault, m.Method())
})
t.Run("Test", func(t *testing.T) {
s := &Session{
UserName: "test",
RefID: "sessxkkcxxxz",
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodPersonal.String(),
AuthMethod: authn.MethodDefault.String(),
}
m := s.SetMethod("Test")
@@ -576,7 +576,7 @@ func TestSession_SetMethod(t *testing.T) {
UserName: "test",
RefID: "sessxkkcxxxz",
AuthProvider: authn.ProviderAccessToken.String(),
AuthMethod: authn.MethodPersonal.String(),
AuthMethod: authn.MethodDefault.String(),
}
m := s.SetMethod(authn.MethodSession)

View File

@@ -35,7 +35,7 @@ func NewClient() Client {
ClientSecret: "",
ClientName: "",
ClientRole: acl.RoleClient.String(),
AuthProvider: authn.ProviderClientCredentials.String(),
AuthProvider: authn.ProviderClient.String(),
AuthMethod: authn.MethodOAuth2.String(),
AuthScope: "",
AuthExpires: 3600,
@@ -79,7 +79,7 @@ func AddClientFromCli(ctx *cli.Context) Client {
f.AuthProvider = authn.Provider(ctx.String("provider")).String()
if f.AuthProvider == "" {
f.AuthProvider = authn.ProviderClientCredentials.String()
f.AuthProvider = authn.ProviderClient.String()
}
f.AuthMethod = authn.Method(ctx.String("method")).String()

View File

@@ -13,7 +13,7 @@ import (
func TestNewClient(t *testing.T) {
t.Run("Defaults", func(t *testing.T) {
client := NewClient()
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
assert.Equal(t, authn.MethodOAuth2, client.Method())
assert.Equal(t, "", client.Scope())
assert.Equal(t, "", client.Name())
@@ -52,7 +52,7 @@ func TestAddClientFromCli(t *testing.T) {
client := AddClientFromCli(ctx)
// Check form values.
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
assert.Equal(t, authn.MethodOAuth2, client.Method())
assert.Equal(t, "*", client.Scope())
assert.Equal(t, "Test", client.Name())
@@ -125,7 +125,7 @@ func TestModClientFromCli(t *testing.T) {
client := ModClientFromCli(ctx)
// Check form values.
assert.Equal(t, authn.ProviderClientCredentials, client.Provider())
assert.Equal(t, authn.ProviderClient, client.Provider())
assert.Equal(t, authn.MethodOAuth2, client.Method())
assert.Equal(t, "*", client.Scope())
assert.Equal(t, "Test", client.Name())

View File

@@ -1,8 +1,6 @@
package authn
import (
"strings"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/txt"
)
@@ -13,8 +11,10 @@ type GrantType string
// Standard authentication grant types.
const (
GrantUndefined GrantType = ""
GrantClientCredentials GrantType = "client_credentials"
GrantCLI GrantType = "cli"
GrantImplicit GrantType = "implicit"
GrantPassword GrantType = "password"
GrantClientCredentials GrantType = "client_credentials"
GrantShareToken GrantType = "share_token"
GrantRefreshToken GrantType = "refresh_token"
GrantAuthorizationCode GrantType = "authorization_code"
@@ -23,70 +23,22 @@ const (
GrantTokenExchange GrantType = "urn:ietf:params:oauth:grant-type:token-exchange"
)
// String returns the provider identifier as a string.
func (t GrantType) String() string {
return clean.TypeLowerUnderscore(string(t))
}
// Is compares the method with another type.
func (t GrantType) Is(method GrantType) bool {
return t == method
}
// IsNot checks if the method is not the specified type.
func (t GrantType) IsNot(method GrantType) bool {
return t != method
}
// IsUndefined checks if the method is undefined.
func (t GrantType) IsUndefined() bool {
return t == ""
}
// Equal checks if the type matches.
func (t GrantType) Equal(s string) bool {
return strings.EqualFold(s, t.String())
}
// NotEqual checks if the type is different.
func (t GrantType) NotEqual(s string) bool {
return !t.Equal(s)
}
// Pretty returns the provider identifier in an easy-to-read format.
func (t GrantType) Pretty() string {
switch t {
case GrantShareToken:
return "Share Token"
case GrantRefreshToken:
return "Refresh Token"
case GrantClientCredentials:
return "Client Credentials"
case GrantAuthorizationCode:
return "Authorization Code"
case GrantJwtBearer:
return "JWT Bearer Assertion"
case GrantSamlBearer:
return "SAML2 Bearer Assertion"
case GrantTokenExchange:
return "Token Exchange"
default:
return txt.UpperFirst(t.String())
}
}
// Grant casts a string to a normalized grant type.
func Grant(s string) GrantType {
s = clean.TypeLowerUnderscore(s)
switch s {
case "", "-", "null", "nil", "0", "false":
case "", "_", "-", "null", "nil", "0", "false":
return GrantUndefined
case "client_credentials", "client":
return GrantClientCredentials
case "cli", "terminal", "command":
return GrantCLI
case "implicit":
return GrantImplicit
case "password", "passwd", "pass", "user", "username":
return GrantPassword
case "share_token", "share":
case "client_credentials", "client":
return GrantClientCredentials
case "share_token", "share":
return GrantShareToken
case "refresh_token", "refresh":
return GrantRefreshToken
case "authorization_code", "auth_code":
@@ -101,3 +53,61 @@ func Grant(s string) GrantType {
return GrantType(s)
}
}
// Pretty returns the grant type in a human-readable format.
func (t GrantType) Pretty() string {
switch t {
case GrantCLI:
return "CLI"
case GrantImplicit:
return "Implicit"
case GrantPassword:
return "Password"
case GrantClientCredentials:
return "Client Credentials"
case GrantShareToken:
return "Share Token"
case GrantRefreshToken:
return "Refresh Token"
case GrantAuthorizationCode:
return "Authorization Code"
case GrantJwtBearer:
return "JWT Bearer Assertion"
case GrantSamlBearer:
return "SAML2 Bearer Assertion"
case GrantTokenExchange:
return "Token Exchange"
default:
return txt.UpperFirst(t.String())
}
}
// String returns the grant type as a string.
func (t GrantType) String() string {
return clean.TypeLowerUnderscore(string(t))
}
// Equal checks if the type matches the specified string.
func (t GrantType) Equal(s string) bool {
return t == Grant(s)
}
// NotEqual checks if the type does mot match the specified string.
func (t GrantType) NotEqual(s string) bool {
return !t.Equal(s)
}
// Is compares the grant with another type.
func (t GrantType) Is(grantType GrantType) bool {
return t == grantType
}
// IsNot checks if the grant is not the specified type.
func (t GrantType) IsNot(grantType GrantType) bool {
return t != grantType
}
// IsUndefined checks if the grant is undefined.
func (t GrantType) IsUndefined() bool {
return t == ""
}

View File

@@ -51,6 +51,7 @@ func TestGrantType_IsUndefined(t *testing.T) {
func TestGrantType_Pretty(t *testing.T) {
assert.Equal(t, "", GrantUndefined.Pretty())
assert.Equal(t, "CLI", GrantCLI.Pretty())
assert.Equal(t, "Client Credentials", GrantClientCredentials.Pretty())
assert.Equal(t, "Password", GrantPassword.Pretty())
assert.Equal(t, "Refresh Token", GrantRefreshToken.Pretty())
@@ -61,33 +62,49 @@ func TestGrantType_Pretty(t *testing.T) {
func TestGrantType_Equal(t *testing.T) {
assert.True(t, GrantClientCredentials.Equal("Client_Credentials"))
assert.False(t, GrantClientCredentials.Equal("Client Credentials"))
assert.True(t, GrantClientCredentials.Equal("Client Credentials"))
assert.True(t, GrantClientCredentials.Equal("client_credentials"))
assert.False(t, GrantClientCredentials.Equal("client"))
assert.True(t, GrantClientCredentials.Equal("client"))
assert.True(t, GrantUndefined.Equal(""))
assert.True(t, GrantPassword.Equal("Password"))
assert.True(t, GrantPassword.Equal("password"))
assert.False(t, GrantPassword.Equal("pass"))
assert.True(t, GrantPassword.Equal("pass"))
}
func TestGrantType_NotEqual(t *testing.T) {
assert.False(t, GrantClientCredentials.NotEqual("Client_Credentials"))
assert.True(t, GrantClientCredentials.NotEqual("Client Credentials"))
assert.False(t, GrantClientCredentials.NotEqual("Client Credentials"))
assert.False(t, GrantClientCredentials.NotEqual("client_credentials"))
assert.True(t, GrantClientCredentials.NotEqual("client"))
assert.False(t, GrantClientCredentials.NotEqual("client"))
assert.True(t, GrantClientCredentials.NotEqual("access_token"))
assert.True(t, GrantClientCredentials.NotEqual(""))
assert.False(t, GrantUndefined.NotEqual(""))
assert.False(t, GrantPassword.NotEqual("Password"))
assert.False(t, GrantPassword.NotEqual("password"))
assert.True(t, GrantPassword.NotEqual("pass"))
assert.False(t, GrantPassword.NotEqual("pass"))
assert.True(t, GrantPassword.NotEqual("passw"))
}
func TestGrant(t *testing.T) {
assert.Equal(t, GrantUndefined, Grant(""))
assert.Equal(t, GrantClientCredentials, Grant("client credentials"))
assert.Equal(t, GrantCLI, Grant("cli"))
assert.Equal(t, GrantImplicit, Grant("implicit"))
assert.Equal(t, GrantPassword, Grant("pass"))
assert.Equal(t, GrantPassword, Grant("password"))
assert.Equal(t, GrantClientCredentials, Grant("client credentials"))
assert.Equal(t, GrantClientCredentials, Grant("client_credentials"))
assert.Equal(t, GrantShareToken, Grant("share_token"))
assert.Equal(t, GrantRefreshToken, Grant("refresh_token"))
assert.Equal(t, GrantAuthorizationCode, Grant("auth_code"))
assert.Equal(t, GrantAuthorizationCode, Grant("authorization_code"))
assert.Equal(t, GrantAuthorizationCode, Grant("authorization code"))
assert.Equal(t, GrantJwtBearer, Grant("jwt-bearer"))
assert.Equal(t, GrantJwtBearer, Grant("jwt_bearer"))
assert.Equal(t, GrantJwtBearer, Grant("jwt bearer"))
assert.Equal(t, GrantSamlBearer, Grant("saml"))
assert.Equal(t, GrantSamlBearer, Grant("saml2"))
assert.Equal(t, GrantSamlBearer, Grant("saml2-bearer"))
assert.Equal(t, GrantTokenExchange, Grant("token-exchange"))
assert.Equal(t, GrantTokenExchange, Grant("token_exchange"))
assert.Equal(t, GrantTokenExchange, Grant("token exchange"))
}

View File

@@ -15,35 +15,44 @@ const (
MethodUndefined MethodType = ""
MethodDefault MethodType = "default"
MethodSession MethodType = "session"
MethodPersonal MethodType = "personal"
MethodOAuth2 MethodType = "oauth2"
MethodOIDC MethodType = "oidc"
Method2FA MethodType = "2fa"
)
// Is compares the method with another type.
func (t MethodType) Is(method MethodType) bool {
return t == method
// Method casts a string to a normalized method type.
func Method(s string) MethodType {
s = clean.TypeLowerUnderscore(s)
switch s {
case "":
return MethodUndefined
case "_", "-", "null", "nil", "0", "false":
return MethodDefault
case "oauth2", "oauth":
return MethodOAuth2
case "sso":
return MethodOIDC
case "2fa", "mfa", "otp", "totp":
return Method2FA
case "access_token":
return MethodDefault
default:
return MethodType(s)
}
}
// IsNot checks if the method is not the specified type.
func (t MethodType) IsNot(method MethodType) bool {
return t != method
}
// IsUndefined checks if the method is undefined.
func (t MethodType) IsUndefined() bool {
return t == ""
}
// IsDefault checks if this is the default method.
func (t MethodType) IsDefault() bool {
return t.String() == MethodDefault.String()
}
// IsSession checks if this is the session method.
func (t MethodType) IsSession() bool {
return t.String() == MethodSession.String()
// Pretty returns the provider identifier in an easy-to-read format.
func (t MethodType) Pretty() string {
switch t {
case MethodOAuth2:
return "OAuth2"
case MethodOIDC:
return "OIDC"
case Method2FA:
return "2FA"
default:
return txt.UpperFirst(t.String())
}
}
// String returns the provider identifier as a string.
@@ -62,47 +71,37 @@ func (t MethodType) String() string {
}
}
// Equal checks if the type matches.
// Equal checks if the type matches the specified string.
func (t MethodType) Equal(s string) bool {
return strings.EqualFold(s, t.String())
}
// NotEqual checks if the type is different.
// NotEqual checks if the type does not match the specified string.
func (t MethodType) NotEqual(s string) bool {
return !t.Equal(s)
}
// Pretty returns the provider identifier in an easy-to-read format.
func (t MethodType) Pretty() string {
switch t {
case MethodOAuth2:
return "OAuth2"
case MethodOIDC:
return "OIDC"
case Method2FA:
return "2FA"
default:
return txt.UpperFirst(t.String())
}
// Is compares the method with another type.
func (t MethodType) Is(methodType MethodType) bool {
return t == methodType
}
// Method casts a string to a normalized method type.
func Method(s string) MethodType {
s = clean.TypeLower(s)
switch s {
case "":
return MethodUndefined
case "-", "null", "nil", "0", "false":
return MethodDefault
case "oauth2", "oauth":
return MethodOAuth2
case "sso":
return MethodOIDC
case "2fa", "mfa", "otp", "totp":
return Method2FA
case "access_token":
return MethodDefault
default:
return MethodType(s)
}
// IsNot checks if the method is not the specified type.
func (t MethodType) IsNot(methodType MethodType) bool {
return t != methodType
}
// IsUndefined checks if the method is undefined.
func (t MethodType) IsUndefined() bool {
return t == ""
}
// IsDefault checks if this is the default method.
func (t MethodType) IsDefault() bool {
return t.String() == MethodDefault.String()
}
// IsSession checks if this is the session method.
func (t MethodType) IsSession() bool {
return t.String() == MethodSession.String()
}

View File

@@ -8,7 +8,6 @@ import (
func TestMethodType_String(t *testing.T) {
assert.Equal(t, "default", MethodDefault.String())
assert.Equal(t, "personal", MethodPersonal.String())
assert.Equal(t, "oauth2", MethodOAuth2.String())
assert.Equal(t, "oidc", MethodOIDC.String())
assert.Equal(t, "2fa", Method2FA.String())
@@ -17,8 +16,6 @@ func TestMethodType_String(t *testing.T) {
func TestMethodType_Is(t *testing.T) {
assert.Equal(t, true, MethodDefault.Is(MethodDefault))
assert.Equal(t, false, MethodPersonal.Is(MethodDefault))
assert.Equal(t, false, MethodOAuth2.Is(MethodPersonal))
assert.Equal(t, false, MethodOIDC.Is(MethodOAuth2))
assert.Equal(t, false, Method2FA.Is(MethodOIDC))
assert.Equal(t, true, MethodOAuth2.Is(MethodOAuth2))
@@ -30,7 +27,6 @@ func TestMethodType_Is(t *testing.T) {
func TestMethodType_IsNot(t *testing.T) {
assert.Equal(t, true, MethodDefault.IsNot(MethodUndefined))
assert.Equal(t, false, MethodDefault.IsNot(MethodDefault))
assert.Equal(t, false, MethodPersonal.IsNot(MethodPersonal))
assert.Equal(t, false, MethodOAuth2.IsNot(MethodOAuth2))
assert.Equal(t, false, MethodOIDC.IsNot(MethodOIDC))
assert.Equal(t, false, Method2FA.IsNot(Method2FA))
@@ -47,7 +43,6 @@ func TestMethodType_IsUndefined(t *testing.T) {
func TestMethodType_IsDefault(t *testing.T) {
assert.Equal(t, true, MethodDefault.IsDefault())
assert.Equal(t, false, MethodPersonal.IsDefault())
assert.Equal(t, false, MethodOAuth2.IsDefault())
assert.Equal(t, false, MethodOIDC.IsDefault())
assert.Equal(t, false, Method2FA.IsDefault())
@@ -56,7 +51,6 @@ func TestMethodType_IsDefault(t *testing.T) {
func TestMethodType_Pretty(t *testing.T) {
assert.Equal(t, "Default", MethodDefault.Pretty())
assert.Equal(t, "Personal", MethodPersonal.Pretty())
assert.Equal(t, "OAuth2", MethodOAuth2.Pretty())
assert.Equal(t, "OIDC", MethodOIDC.Pretty())
assert.Equal(t, "2FA", Method2FA.Pretty())

View File

@@ -1,8 +1,6 @@
package authn
import (
"strings"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/list"
"github.com/photoprism/photoprism/pkg/txt"
@@ -16,7 +14,6 @@ const (
ProviderUndefined ProviderType = ""
ProviderDefault ProviderType = "default"
ProviderClient ProviderType = "client"
ProviderClientCredentials ProviderType = "client_credentials"
ProviderApplication ProviderType = "application"
ProviderAccessToken ProviderType = "access_token"
ProviderLocal ProviderType = "local"
@@ -45,19 +42,79 @@ var Method2FAProviders = list.List{
// ClientProviders contains all client auth providers.
var ClientProviders = list.List{
string(ProviderClient),
string(ProviderClientCredentials),
string(ProviderApplication),
string(ProviderAccessToken),
}
// Provider casts a string to a normalized provider type.
func Provider(s string) ProviderType {
s = clean.TypeLowerUnderscore(s)
switch s {
case "", "_", "-", "null", "nil", "0", "false":
return ProviderDefault
case "token", "url":
return ProviderLink
case "pass", "passwd", "password":
return ProviderLocal
case "app", "application":
return ProviderApplication
case "ldap", "ad", "ldap/ad", "ldap\\ad":
return ProviderLDAP
case "client", "client_credentials", "oauth2":
return ProviderClient
default:
return ProviderType(s)
}
}
// Pretty returns the provider identifier in an easy-to-read format.
func (t ProviderType) Pretty() string {
switch t {
case ProviderLDAP:
return "LDAP/AD"
case ProviderClient:
return "Client"
case ProviderAccessToken:
return "Access Token"
default:
return txt.UpperFirst(t.String())
}
}
// String returns the provider identifier as a string.
func (t ProviderType) String() string {
switch t {
case "":
return string(ProviderDefault)
case "token":
return string(ProviderLink)
case "password":
return string(ProviderLocal)
case "client", "client credentials", "client_credentials", "oauth2":
return string(ProviderClient)
default:
return string(t)
}
}
// Equal checks if the type matches the specified string.
func (t ProviderType) Equal(s string) bool {
return t == Provider(s)
}
// NotEqual checks if the type does not match the specified string.
func (t ProviderType) NotEqual(s string) bool {
return !t.Equal(s)
}
// Is compares the provider with another type.
func (t ProviderType) Is(provider ProviderType) bool {
return t == provider
func (t ProviderType) Is(providerType ProviderType) bool {
return t == providerType
}
// IsNot checks if the provider is not the specified type.
func (t ProviderType) IsNot(provider ProviderType) bool {
return t != provider
func (t ProviderType) IsNot(providerType ProviderType) bool {
return t != providerType
}
// IsUndefined checks if the provider is undefined.
@@ -94,64 +151,3 @@ func (t ProviderType) IsApplication() bool {
func (t ProviderType) IsDefault() bool {
return t.String() == ProviderDefault.String()
}
// String returns the provider identifier as a string.
func (t ProviderType) String() string {
switch t {
case "":
return string(ProviderDefault)
case "token":
return string(ProviderLink)
case "password":
return string(ProviderLocal)
case "oauth2", "client credentials":
return string(ProviderClientCredentials)
default:
return string(t)
}
}
// Equal checks if the type matches.
func (t ProviderType) Equal(s string) bool {
return strings.EqualFold(s, t.String())
}
// NotEqual checks if the type is different.
func (t ProviderType) NotEqual(s string) bool {
return !t.Equal(s)
}
// Pretty returns the provider identifier in an easy-to-read format.
func (t ProviderType) Pretty() string {
switch t {
case ProviderLDAP:
return "LDAP/AD"
case ProviderClient:
return "Client"
case ProviderAccessToken:
return "Access Token"
case ProviderClientCredentials:
return "Client Credentials"
default:
return txt.UpperFirst(t.String())
}
}
// Provider casts a string to a normalized provider type.
func Provider(s string) ProviderType {
s = clean.TypeLower(s)
switch s {
case "", "-", "null", "nil", "0", "false":
return ProviderDefault
case "token", "url":
return ProviderLink
case "pass", "passwd", "password":
return ProviderLocal
case "ldap", "ad", "ldap/ad", "ldap\\ad":
return ProviderLDAP
case "oauth2", "client credentials":
return ProviderClientCredentials
default:
return ProviderType(s)
}
}

View File

@@ -14,14 +14,13 @@ func TestProviderType_String(t *testing.T) {
assert.Equal(t, "ldap", ProviderLDAP.String())
assert.Equal(t, "link", ProviderLink.String())
assert.Equal(t, "access_token", ProviderAccessToken.String())
assert.Equal(t, "client_credentials", ProviderClientCredentials.String())
assert.Equal(t, "client", ProviderClient.String())
}
func TestProviderType_Is(t *testing.T) {
assert.False(t, ProviderLocal.Is(ProviderLDAP))
assert.True(t, ProviderLDAP.Is(ProviderLDAP))
assert.False(t, ProviderClient.Is(ProviderLDAP))
assert.False(t, ProviderClientCredentials.Is(ProviderLDAP))
assert.False(t, ProviderApplication.Is(ProviderLDAP))
assert.False(t, ProviderAccessToken.Is(ProviderLDAP))
assert.False(t, ProviderNone.Is(ProviderLDAP))
@@ -33,7 +32,6 @@ func TestProviderType_IsNot(t *testing.T) {
assert.False(t, ProviderLocal.IsNot(ProviderLocal))
assert.True(t, ProviderLDAP.IsNot(ProviderLocal))
assert.False(t, ProviderClient.IsNot(ProviderClient))
assert.False(t, ProviderClientCredentials.IsNot(ProviderClientCredentials))
assert.False(t, ProviderApplication.IsNot(ProviderApplication))
assert.False(t, ProviderAccessToken.IsNot(ProviderAccessToken))
assert.False(t, ProviderNone.IsNot(ProviderNone))
@@ -50,7 +48,6 @@ func TestProviderType_IsRemote(t *testing.T) {
assert.False(t, ProviderLocal.IsRemote())
assert.True(t, ProviderLDAP.IsRemote())
assert.False(t, ProviderClient.IsRemote())
assert.False(t, ProviderClientCredentials.IsRemote())
assert.False(t, ProviderApplication.IsRemote())
assert.False(t, ProviderAccessToken.IsRemote())
assert.False(t, ProviderNone.IsRemote())
@@ -62,7 +59,6 @@ func TestProviderType_IsLocal(t *testing.T) {
assert.True(t, ProviderLocal.IsLocal())
assert.False(t, ProviderLDAP.IsLocal())
assert.False(t, ProviderClient.IsLocal())
assert.False(t, ProviderClientCredentials.IsLocal())
assert.False(t, ProviderApplication.IsLocal())
assert.False(t, ProviderAccessToken.IsLocal())
assert.False(t, ProviderNone.IsLocal())
@@ -74,7 +70,6 @@ func TestProviderType_SupportsPasscode(t *testing.T) {
assert.True(t, ProviderLocal.Supports2FA())
assert.True(t, ProviderLDAP.Supports2FA())
assert.False(t, ProviderClient.Supports2FA())
assert.False(t, ProviderClientCredentials.Supports2FA())
assert.False(t, ProviderApplication.Supports2FA())
assert.False(t, ProviderAccessToken.Supports2FA())
assert.False(t, ProviderNone.Supports2FA())
@@ -96,16 +91,17 @@ func TestProviderType_IsClient(t *testing.T) {
assert.False(t, ProviderNone.IsClient())
assert.False(t, ProviderDefault.IsClient())
assert.True(t, ProviderClient.IsClient())
assert.True(t, ProviderClientCredentials.IsClient())
}
func TestProviderType_Equal(t *testing.T) {
assert.True(t, ProviderClient.Equal("Client"))
assert.True(t, ProviderClient.Equal("Client Credentials"))
assert.False(t, ProviderLocal.Equal("Client"))
}
func TestProviderType_NotEqual(t *testing.T) {
assert.False(t, ProviderClient.NotEqual("Client"))
assert.False(t, ProviderClient.NotEqual("Client Credentials"))
assert.True(t, ProviderLocal.NotEqual("Client"))
}
@@ -115,9 +111,8 @@ func TestProviderType_Pretty(t *testing.T) {
assert.Equal(t, "None", ProviderNone.Pretty())
assert.Equal(t, "Default", ProviderDefault.Pretty())
assert.Equal(t, "Default", ProviderUndefined.Pretty())
assert.Equal(t, "Client", ProviderClient.Pretty())
assert.Equal(t, "Access Token", ProviderAccessToken.Pretty())
assert.Equal(t, "Client Credentials", ProviderClientCredentials.Pretty())
assert.Equal(t, "Client", ProviderClient.Pretty())
}
func TestProvider(t *testing.T) {
@@ -126,7 +121,7 @@ func TestProvider(t *testing.T) {
assert.Equal(t, ProviderDefault, Provider(""))
assert.Equal(t, ProviderLink, Provider("url"))
assert.Equal(t, ProviderDefault, Provider("default"))
assert.Equal(t, ProviderClientCredentials, Provider("oauth2"))
assert.Equal(t, ProviderClient, Provider("oauth2"))
}
func TestProviderType_IsApplication(t *testing.T) {