package classify import ( "bytes" "fmt" "image/color" "math" "os" "path" "runtime/debug" "sort" "strings" "sync" "github.com/disintegration/imaging" tf "github.com/wamuir/graft/tensorflow" "github.com/photoprism/photoprism/internal/ai/tensorflow" "github.com/photoprism/photoprism/pkg/clean" "github.com/photoprism/photoprism/pkg/http/scheme" "github.com/photoprism/photoprism/pkg/media" ) // Model represents a TensorFlow classification model. type Model struct { model *tf.SavedModel name string modelsPath string defaultLabelsPath string labels []string disabled bool meta *tensorflow.ModelInfo builder *tensorflow.ImageTensorBuilder mutex sync.Mutex } // NewModel returns new TensorFlow classification model instance. func NewModel(modelsPath, name, defaultLabelsPath string, meta *tensorflow.ModelInfo, disabled bool) *Model { if meta == nil { meta = new(tensorflow.ModelInfo) } return &Model{ name: name, modelsPath: modelsPath, defaultLabelsPath: defaultLabelsPath, meta: meta, disabled: disabled, } } // NewNasnet returns new Nasnet TensorFlow classification model instance. func NewNasnet(modelsPath string, disabled bool) *Model { return NewModel(modelsPath, "nasnet", "", &tensorflow.ModelInfo{ TFVersion: "1.12.0", Tags: []string{"photoprism"}, Input: &tensorflow.PhotoInput{ Name: "input_1", Height: 224, Width: 224, ResizeOperation: tensorflow.CenterCrop, ColorChannelOrder: tensorflow.RGB, Shape: tensorflow.DefaultPhotoInputShape(), Intervals: []tensorflow.Interval{ { Start: -1, End: 1, }, }, OutputIndex: 0, }, Output: &tensorflow.ModelOutput{ Name: "predictions/Softmax", NumOutputs: 1000, OutputIndex: 0, OutputsLogits: false, }, }, disabled) } // Init initializes tensorflow models if not disabled. func (m *Model) Init() (err error) { if m.disabled { return nil } return m.loadModel() } // File returns matching labels for a local jpeg file. func (m *Model) File(fileName string, confidenceThreshold int) (result Labels, err error) { if m.disabled { return nil, nil } var data []byte if data, err = os.ReadFile(fileName); err != nil { //nolint:gosec // fileName is provided by trusted callers; reading arbitrary local files is expected behavior return nil, err } return m.Run(data, confidenceThreshold) } // Url returns matching labels for a remote jpeg file. func (m *Model) Url(imgUrl string, confidenceThreshold int) (result Labels, err error) { if m.disabled { return nil, nil } var data []byte if data, err = media.ReadUrl(imgUrl, scheme.HttpsData); err != nil { return nil, err } return m.Run(data, confidenceThreshold) } // Run returns matching labels for the specified JPEG image. func (m *Model) Run(img []byte, confidenceThreshold int) (result Labels, err error) { defer func() { if r := recover(); r != nil { err = fmt.Errorf("classify: %s (inference panic)\nstack: %s", r, debug.Stack()) } }() if m.disabled { return result, nil } if loadErr := m.loadModel(); loadErr != nil { return nil, loadErr } // Create input tensor from image. tensor, err := m.createTensor(img) if err != nil { return nil, err } // Run inference. output, err := m.model.Session.Run( map[tf.Output]*tf.Tensor{ m.model.Graph.Operation(m.meta.Input.Name).Output(m.meta.Input.OutputIndex): tensor, }, []tf.Output{ m.model.Graph.Operation(m.meta.Output.Name).Output(m.meta.Output.OutputIndex), }, nil) if err != nil { return result, fmt.Errorf("classify: %s (run inference)", clean.Error(err)) } if len(output) < 1 { return result, fmt.Errorf("classify: inference failed, no output") } // Return best labels result = m.bestLabels(output[0].Value().([][]float32)[0], confidenceThreshold) if len(result) > 0 { log.Tracef("classify: image classified as %+v", result) } else { result = Labels{} } return result, nil } func (m *Model) loadLabels(modelPath string) (err error) { numLabels := int(m.meta.Output.NumOutputs) m.labels, err = tensorflow.LoadLabels(modelPath, numLabels) if os.IsNotExist(err) { log.Infof("vision: model does not seem to have tags at %s, trying %s", clean.Log(modelPath), clean.Log(m.defaultLabelsPath)) m.labels, err = tensorflow.LoadLabels(m.defaultLabelsPath, numLabels) } if err != nil { return fmt.Errorf("classify: could not load tags: %v", err) } return nil } // ModelLoaded tests if the TensorFlow model is loaded. func (m *Model) ModelLoaded() bool { return m.model != nil } func (m *Model) loadModel() (err error) { // Use mutex to prevent the model from being loaded and // initialized twice by different indexing workers. m.mutex.Lock() defer m.mutex.Unlock() if m.ModelLoaded() { return nil } modelPath := path.Join(m.modelsPath, m.name) if len(m.meta.Tags) == 0 { infos, modelErr := tensorflow.GetModelTagsInfo(modelPath) switch { case modelErr != nil: log.Errorf("classify: could not get info from model in %s (%s)", clean.Log(modelPath), clean.Error(modelErr)) case len(infos) == 1: log.Debugf("classify: model info: %+v", infos[0]) m.meta.Merge(&infos[0]) case len(infos) > 1: log.Warnf("classify: found %d metagraphs, which is too many", len(infos)) default: log.Warnf("classify: no metagraphs found in %s", clean.Log(modelPath)) } } m.model, err = tensorflow.SavedModel(modelPath, m.meta.Tags) if err != nil { return fmt.Errorf("classify: %s. Path: %s", clean.Error(err), modelPath) } if !m.meta.IsComplete() { input, output, modelErr := tensorflow.GetInputAndOutputFromSavedModel(m.model) if modelErr != nil { log.Errorf("classify: could not get info from signatures (%s)", clean.Error(modelErr)) input, output, modelErr = tensorflow.GuessInputAndOutput(m.model) if modelErr != nil { return fmt.Errorf("classify: %s", clean.Error(modelErr)) } } m.meta.Merge(&tensorflow.ModelInfo{ Input: input, Output: output, }) } if m.meta.Output.OutputsLogits { _, err = tensorflow.AddSoftmax(m.model.Graph, m.meta) if err != nil { return fmt.Errorf("classify: could not add softmax (%s)", clean.Error(err)) } } m.builder, err = tensorflow.NewImageTensorBuilder(m.meta.Input) if err != nil { return fmt.Errorf("classify: could not create the tensor builder (%s)", clean.Error(err)) } return m.loadLabels(modelPath) } // bestLabels returns the best 5 labels (if enough high probability labels) from the prediction of the model func (m *Model) bestLabels(probabilities []float32, confidenceThreshold int) Labels { var result Labels for i, p := range probabilities { if i >= len(m.labels) { // break if probabilities and labels does not match break } confidence := int(math.Round(float64(p * 100))) // discard labels with low probabilities if confidence < confidenceThreshold { continue } labelText := strings.ToLower(m.labels[i]) rule, _ := Rules.Find(labelText) // discard labels that don't met the threshold if p < rule.Threshold { continue } // Get rule label name instead of t.labels name if it exists if rule.Label != "" { labelText = rule.Label } labelText = strings.TrimSpace(labelText) result = append(result, Label{Name: labelText, Source: SrcImage, Uncertainty: 100 - confidence, Priority: rule.Priority, Categories: rule.Categories}) } // Sort by probability sort.Sort(result) // Return the best labels only. if l := len(result); l < 5 { return result[:l] } else { return result[:5] } } // createTensor converts bytes jpeg image in a tensor object required as tensorflow model input func (m *Model) createTensor(image []byte) (*tf.Tensor, error) { img, err := imaging.Decode(bytes.NewReader(image), imaging.AutoOrientation(true)) if err != nil { return nil, err } // Resize the image only if its resolution does not match the model. if img.Bounds().Dx() != m.meta.Input.Resolution() || img.Bounds().Dy() != m.meta.Input.Resolution() { switch m.meta.Input.ResizeOperation { case tensorflow.ResizeBreakAspectRatio: img = imaging.Resize(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos) case tensorflow.CenterCrop: img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos) case tensorflow.Padding: resized := imaging.Fit(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Lanczos) dst := imaging.New(m.meta.Input.Resolution(), m.meta.Input.Resolution(), color.NRGBA{0, 0, 0, 255}) img = imaging.PasteCenter(dst, resized) default: img = imaging.Fill(img, m.meta.Input.Resolution(), m.meta.Input.Resolution(), imaging.Center, imaging.Lanczos) } } return tensorflow.Image(img, m.meta.Input, m.builder) }