Files
photoprism/internal/config/config_db.go
2025-11-22 20:00:53 +01:00

666 lines
17 KiB
Go

package config
import (
"errors"
"fmt"
"os"
"path/filepath"
"runtime"
"strconv"
"strings"
"time"
"github.com/jinzhu/gorm"
_ "github.com/jinzhu/gorm/dialects/mysql" // register mysql dialect
_ "github.com/jinzhu/gorm/dialects/sqlite"
"golang.org/x/mod/semver"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/entity/migrate"
"github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/internal/mutex"
"github.com/photoprism/photoprism/internal/service/cluster"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/dsn"
"github.com/photoprism/photoprism/pkg/txt"
)
// SQL Databases.
// TODO: PostgreSQL support requires upgrading GORM, so generic column data types can be used.
const (
Auto = "auto"
MySQL = dsn.DriverMySQL
MariaDB = dsn.DriverMariaDB
Postgres = dsn.DriverPostgres
SQLite3 = dsn.DriverSQLite3
)
// DatabaseDriver returns the database driver name.
func (c *Config) DatabaseDriver() string {
c.normalizeDatabaseDSN()
switch strings.ToLower(c.options.DatabaseDriver) {
case MySQL, MariaDB:
c.options.DatabaseDriver = MySQL
case SQLite3, "sqlite", "test", "file", "":
c.options.DatabaseDriver = SQLite3
case "tidb":
log.Warnf("config: database driver 'tidb' is deprecated, using sqlite")
c.options.DatabaseDriver = SQLite3
c.options.DatabaseDSN = ""
default:
log.Warnf("config: unsupported database driver %s, using sqlite", c.options.DatabaseDriver)
c.options.DatabaseDriver = SQLite3
c.options.DatabaseDSN = ""
}
return c.options.DatabaseDriver
}
// DatabaseDriverName returns the formatted database driver name.
func (c *Config) DatabaseDriverName() string {
switch c.DatabaseDriver() {
case MySQL, MariaDB:
return "MariaDB"
case SQLite3, "sqlite", "test", "file", "":
return "SQLite"
case "tidb":
return "TiDB"
default:
return "unsupported database"
}
}
// DatabaseVersion returns the database version string, if known.
func (c *Config) DatabaseVersion() string {
return c.dbVersion
}
// IsDatabaseVersion checks if the database version is at least the specified version in semver format.
func (c *Config) IsDatabaseVersion(semverVersion string) bool {
if semverVersion == "" {
return true
}
return semver.Compare(c.DatabaseVersion(), semverVersion) >= 0
}
// DatabaseSsl checks if the database supports SSL connections for backup and restore.
func (c *Config) DatabaseSsl() bool {
if c.dbVersion == "" {
return false
}
switch c.DatabaseDriver() {
case MySQL:
// see https://mariadb.org/mission-impossible-zero-configuration-ssl/
return c.IsDatabaseVersion("v11.4")
default:
return false
}
}
// normalizeDatabaseDSN maps the deprecated DatabaseDsn database configuration
// value to its current counterpart, DatabaseDSN, before consumption.
func (c *Config) normalizeDatabaseDSN() {
if c.options.DatabaseDSN == "" && c.options.Deprecated.DatabaseDsn != "" {
c.options.DatabaseDSN = c.options.Deprecated.DatabaseDsn
c.options.Deprecated.DatabaseDsn = ""
event.SystemWarn([]string{"config", "options", "DatabaseDsn has been deprecated in favor of DatabaseDSN"})
}
}
// DatabaseDSN returns the database data source name (DSN).
func (c *Config) DatabaseDSN() string {
// Generate matching database DSN based on the configured database driver.
if c.NoDatabaseDSN() {
switch c.DatabaseDriver() {
case MySQL:
databaseServer := c.DatabaseServer()
// Connect via Unix Domain Socket?
if strings.HasPrefix(databaseServer, "/") {
log.Debugf("mariadb: connecting via Unix domain socket")
databaseServer = fmt.Sprintf("unix(%s)", databaseServer)
} else {
databaseServer = fmt.Sprintf("tcp(%s)", databaseServer)
}
return fmt.Sprintf(
"%s:%s@%s/%s?%s&timeout=%ds",
c.DatabaseUser(),
c.DatabasePassword(),
databaseServer,
c.DatabaseName(),
dsn.Params[dsn.DriverMySQL],
c.DatabaseTimeout(),
)
case Postgres:
return fmt.Sprintf(
"user=%s password=%s dbname=%s host=%s port=%d connect_timeout=%d %s",
c.DatabaseUser(),
c.DatabasePassword(),
c.DatabaseName(),
c.DatabaseHost(),
c.DatabasePort(),
c.DatabaseTimeout(),
dsn.Params[dsn.DriverPostgres],
)
case SQLite3:
return filepath.Join(c.StoragePath(), fmt.Sprintf("index.db?%s", dsn.Params[dsn.DriverSQLite3]))
default:
log.Errorf("config: empty database dsn")
return ""
}
}
// If missing, add the required parameters to the configured MySQL/MariaDB DSN.
if c.DatabaseDriver() == MySQL && !strings.Contains(c.options.DatabaseDSN, "?") {
c.options.DatabaseDSN = fmt.Sprintf(
"%s?%s&timeout=%ds",
c.options.DatabaseDSN,
dsn.Params[dsn.DriverMySQL],
c.DatabaseTimeout())
}
return c.options.DatabaseDSN
}
// NoDatabaseDSN checks if no manual database data source name (DSN) configuration is set.
func (c *Config) NoDatabaseDSN() bool {
c.normalizeDatabaseDSN()
return c.options.DatabaseDSN == ""
}
// HasDatabaseDSN checks if a manual database data source name (DSN) configuration is set.
func (c *Config) HasDatabaseDSN() bool {
return !c.NoDatabaseDSN()
}
// ReportDatabaseDSN checks if the database data source name (DSN) should be reported
// instead of database name, server, user, and password.
func (c *Config) ReportDatabaseDSN() bool {
if c.DatabaseDriver() == SQLite3 {
return true
}
return c.HasDatabaseDSN()
}
// ParseDatabaseDSN parses the database dsn and extracts user, password, database server, and name.
func (c *Config) ParseDatabaseDSN() {
if c.NoDatabaseDSN() {
return
} else if c.options.DatabaseServer != "" && c.DatabaseDriver() == SQLite3 {
return
}
d := dsn.Parse(c.options.DatabaseDSN)
c.options.DatabaseName = d.Name
c.options.DatabaseServer = d.Server
c.options.DatabaseUser = d.User
c.options.DatabasePassword = d.Password
}
// DatabaseFile returns the filename part of a sqlite database DSN.
func (c *Config) DatabaseFile() string {
fileName, _, _ := strings.Cut(strings.TrimPrefix(c.DatabaseDSN(), "file:"), "?")
return fileName
}
// DatabaseServer the database server.
func (c *Config) DatabaseServer() string {
c.ParseDatabaseDSN()
if c.DatabaseDriver() == SQLite3 {
return ""
} else if c.options.DatabaseServer == "" {
return localhost
}
return c.options.DatabaseServer
}
// DatabaseHost the database server host.
func (c *Config) DatabaseHost() string {
c.ParseDatabaseDSN()
if c.DatabaseDriver() == SQLite3 {
return ""
}
d := dsn.Parse(c.DatabaseDSN())
return d.Host()
}
// DatabasePort the database server port.
func (c *Config) DatabasePort() int {
c.ParseDatabaseDSN()
if c.DatabaseDriver() == SQLite3 {
return 0
}
d := dsn.Parse(c.DatabaseDSN())
return d.Port()
}
// DatabasePortString the database server port as string.
func (c *Config) DatabasePortString() string {
if c.DatabaseDriver() == SQLite3 {
return ""
}
return strconv.Itoa(c.DatabasePort())
}
// DatabaseName the database schema name.
func (c *Config) DatabaseName() string {
c.ParseDatabaseDSN()
if c.DatabaseDriver() == SQLite3 {
return c.DatabaseDSN()
} else if c.options.DatabaseName == "" {
return "photoprism"
}
return c.options.DatabaseName
}
// DatabaseUser returns the database user name.
func (c *Config) DatabaseUser() string {
if c.DatabaseDriver() == SQLite3 {
return ""
}
c.ParseDatabaseDSN()
if c.options.DatabaseUser == "" {
return "photoprism"
}
return c.options.DatabaseUser
}
// DatabasePassword returns the database user password.
func (c *Config) DatabasePassword() string {
if c.DatabaseDriver() == SQLite3 {
return ""
}
c.ParseDatabaseDSN()
// Try to read password from file if c.options.DatabasePassword is not set.
if c.options.DatabasePassword != "" {
return clean.Password(c.options.DatabasePassword)
} else if fileName := FlagFilePath("DATABASE_PASSWORD"); fileName == "" {
// No password set, this is not an error.
return ""
} else if b, err := os.ReadFile(fileName); err != nil || len(b) == 0 { //nolint:gosec // path derived from environment variable for DB password
log.Warnf("config: failed to read database password from %s (%s)", fileName, err)
return ""
} else {
return clean.Password(string(b))
}
}
// DatabaseProvisionPrefix returns the sanitized prefix for provisioned database names and users.
func (c *Config) DatabaseProvisionPrefix() string {
prefix := strings.TrimSpace(c.options.DatabaseProvisionPrefix)
if prefix == "" {
return cluster.DefaultDatabaseProvisionPrefix
}
prefix = strings.ToLower(prefix)
cleaned := make([]rune, 0, len(prefix))
prevUnderscore := false
for _, r := range prefix {
switch {
case r >= 'a' && r <= 'z':
cleaned = append(cleaned, r)
prevUnderscore = false
case r >= '0' && r <= '9':
if len(cleaned) == 0 {
continue
}
cleaned = append(cleaned, r)
prevUnderscore = false
case r == '_' || r == '-' || r == ' ':
if len(cleaned) == 0 || prevUnderscore {
continue
}
cleaned = append(cleaned, '_')
prevUnderscore = true
default:
continue
}
if len(cleaned) >= cluster.DatabaseProvisionPrefixMaxLen {
break
}
}
if len(cleaned) == 0 {
return cluster.DefaultDatabaseProvisionPrefix
}
result := string(cleaned)
c.options.DatabaseProvisionPrefix = result
return result
}
// ShouldAutoRotateDatabase decides whether callers should request DB rotation automatically.
// It is used by both the CLI and node bootstrap to avoid unnecessary provisioning calls.
func (c *Config) ShouldAutoRotateDatabase() bool {
if c.Portal() || c.DatabaseDriver() != MySQL {
return false
}
if c.DatabaseName() == "" || c.DatabaseUser() == "" || c.DatabasePassword() == "" {
return true
}
return false
}
// DatabaseTimeout returns the TCP timeout in seconds for establishing a database connection:
// - https://github.com/photoprism/photoprism/issues/4059#issuecomment-1989119004
// - https://github.com/go-sql-driver/mysql/blob/master/README.md#timeout
func (c *Config) DatabaseTimeout() int {
// Ensure that the timeout is between 1 and a maximum
// of 60 seconds, with a default of 15 seconds.
if c.options.DatabaseTimeout <= 0 {
return 15
} else if c.options.DatabaseTimeout > 60 {
return 60
}
return c.options.DatabaseTimeout
}
// DatabaseConns returns the maximum number of open connections to the database.
func (c *Config) DatabaseConns() int {
limit := c.options.DatabaseConns
if limit <= 0 {
limit = (runtime.NumCPU() * 2) + 16
}
if limit > 1024 {
limit = 1024
}
return limit
}
// DatabaseConnsIdle returns the maximum number of idle connections to the database (equal or less than open).
func (c *Config) DatabaseConnsIdle() int {
limit := c.options.DatabaseConnsIdle
if limit <= 0 {
limit = runtime.NumCPU() + 8
}
if limit > c.DatabaseConns() {
limit = c.DatabaseConns()
}
return limit
}
// Db returns the db connection.
func (c *Config) Db() *gorm.DB {
if c.db == nil {
log.Fatal("config: database not connected")
}
return c.db
}
// CloseDb closes the db connection (if any).
func (c *Config) CloseDb() error {
if c.db != nil {
if err := c.db.Close(); err == nil {
c.db = nil
entity.SetDbProvider(nil)
} else {
return err
}
}
return nil
}
// SetDbOptions sets the database collation to unicode if supported.
func (c *Config) SetDbOptions() {
switch c.DatabaseDriver() {
case MySQL, MariaDB:
c.Db().Set("gorm:table_options", "ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci")
case Postgres:
// Ignore for now.
case SQLite3:
// Not required as Unicode is default.
}
}
// RegisterDb opens a database connection if needed,
// sets the database options and connection provider.
func (c *Config) RegisterDb() {
if err := c.connectDb(); err != nil {
log.Errorf("config: %s (register db)")
return
}
c.SetDbOptions()
entity.SetDbProvider(c)
}
// InitDb initializes the database without running previously failed migrations.
func (c *Config) InitDb() {
c.RegisterDb()
c.MigrateDb(false, nil)
}
// MigrateDb will initialize the database and migrate the schema if necessary.
func (c *Config) MigrateDb(runFailed bool, ids []string) {
entity.Admin.UserName = c.AdminUser()
// Automatically migrate database schema only once per release to reduce startup time.
version := migrate.FirstOrCreateVersion(c.Db(), migrate.NewVersion(c.Version(), c.Edition()))
entity.InitDb(migrate.Opt(version.NeedsMigration(), runFailed, ids))
if err := version.Migrated(c.Db()); err != nil {
log.Warnf("config: %s (migrate)", err)
}
// Set the password for the initial Super Admin account, if specified.
if c.AdminPassword() == "" {
log.Warnf("config: %s account cannot be initialized due to missing or invalid password", clean.LogQuote(c.AdminUser()))
} else {
entity.Admin.InitAccount(c.AdminUser(), c.AdminPassword(), c.AdminScope())
}
// Start recording warnings and errors after the required database table has been created.
entity.LogWarningsAndErrors()
}
// InitTestDb drops all tables in the currently configured database and re-creates them.
func (c *Config) InitTestDb() {
entity.ResetTestFixtures()
if c.AdminPassword() == "" {
// Do nothing.
} else {
entity.Admin.InitAccount(c.AdminUser(), c.AdminPassword(), c.AdminScope())
}
// Start recording warnings and errors after the required table has have been created.
entity.LogWarningsAndErrors()
}
// checkDb checks the database server version.
func (c *Config) checkDb(db *gorm.DB) error {
if txt.Bool(os.Getenv(EnvVar("DATABASE_SKIP_VERSION_CHECK"))) {
log.Debugf("config: skipping database version check")
return nil
}
if db == nil {
return fmt.Errorf("config: missing database connection")
}
switch c.DatabaseDriver() {
case MySQL:
type Res struct {
Value string `gorm:"column:Value;"`
}
var res Res
err := db.Raw("SELECT VERSION() AS Value").Scan(&res).Error
if err != nil {
err = db.Raw("SHOW VARIABLES LIKE 'innodb_version'").Scan(&res).Error
}
// Version query not supported.
if err != nil {
log.Tracef("config: failed to detect database version (%s)", err)
return nil
}
c.dbVersion = clean.Version(res.Value)
switch {
case c.dbVersion == "":
log.Warnf("config: unknown database server version")
case !c.IsDatabaseVersion("v10.0.0"):
return fmt.Errorf("config: MySQL %s is not supported, see https://docs.photoprism.app/getting-started/#databases", c.dbVersion)
case !c.IsDatabaseVersion("v10.5.12"):
return fmt.Errorf("config: MariaDB %s is not supported, see https://docs.photoprism.app/getting-started/#databases", c.dbVersion)
}
case SQLite3:
type Res struct {
Value string `gorm:"column:Value;"`
}
var res Res
err := db.Raw("SELECT sqlite_version() AS Value").Scan(&res).Error
// Version query not supported.
if err != nil {
log.Warnf("config: failed to detect database version (%s)", err)
return nil
}
c.dbVersion = clean.Version(res.Value)
if c.dbVersion == "" {
log.Warnf("config: unknown database server version")
}
}
return nil
}
// connectDb establishes a database connection.
func (c *Config) connectDb() error {
// Make sure this is not running twice.
mutex.Db.Lock()
defer mutex.Db.Unlock()
// Database connection already exists.
if c.db != nil {
return nil
}
// Get database driver and data source name.
dbDriver := c.DatabaseDriver()
dbDsn := c.DatabaseDSN()
if dbDriver == "" {
return errors.New("config: database driver not specified")
}
if dbDsn == "" {
return errors.New("config: database DSN not specified")
}
// Open database connection.
db, err := gorm.Open(dbDriver, dbDsn)
if err != nil || db == nil {
log.Infof("config: waiting for the database to become available")
for i := 1; i <= 12; i++ {
db, err = gorm.Open(dbDriver, dbDsn)
if db != nil && err == nil {
break
}
time.Sleep(5 * time.Second)
}
if err != nil || db == nil {
return err
}
}
// Configure database logging.
db.LogMode(false)
db.SetLogger(log)
// Set database connection parameters.
db.DB().SetMaxOpenConns(c.DatabaseConns())
db.DB().SetMaxIdleConns(c.DatabaseConnsIdle())
db.DB().SetConnMaxLifetime(time.Hour)
// Check database server version.
if err = c.checkDb(db); err != nil {
if c.Unsafe() {
log.Error(err)
} else {
return err
}
}
if dbVersion := c.DatabaseVersion(); dbVersion != "" {
log.Debugf("database: opened connection to %s %s", c.DatabaseDriverName(), dbVersion)
}
// Ok.
c.db = db
return nil
}
// ImportSQL imports a file to the currently configured database.
func (c *Config) ImportSQL(filename string) {
contents, err := os.ReadFile(filename) //nolint:gosec // import path is provided by trusted caller
if err != nil {
log.Error(err)
return
}
statements := strings.Split(string(contents), ";\n")
q := c.Db().Unscoped()
for _, stmt := range statements {
// Skip empty lines and comments
if len(stmt) < 3 || stmt[0] == '#' || stmt[0] == ';' {
continue
}
var result struct{}
q.Raw(stmt).Scan(&result)
}
}