Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mob][fix] Enable face to person assignment for faces with low score #4626

Merged
merged 5 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions mobile/lib/db/ml/db.dart
Original file line number Diff line number Diff line change
Expand Up @@ -604,8 +604,6 @@ class MLDataDB {
}

Future<List<FaceDbInfoForClustering>> getFaceInfoForClustering({
double minScore = kMinimumQualityFaceScore,
int minClarity = kLaplacianHardThreshold,
int maxFaces = 20000,
int offset = 0,
int batchSize = 10000,
Expand All @@ -622,7 +620,7 @@ class MLDataDB {
// Query a batch of rows
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT $faceIDColumn, $embeddingColumn, $faceScore, $faceBlur, $isSideways FROM $facesTable'
' WHERE $faceScore > $minScore AND $faceBlur > $minClarity'
' WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold'
' ORDER BY $faceIDColumn'
' DESC LIMIT $batchSize OFFSET $offset',
);
Expand Down Expand Up @@ -698,12 +696,10 @@ class MLDataDB {
return result;
}

Future<int> getTotalFaceCount({
double minFaceScore = kMinimumQualityFaceScore,
}) async {
Future<int> getTotalFaceCount() async {
final db = await instance.asyncDB;
final List<Map<String, dynamic>> maps = await db.getAll(
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $minFaceScore AND $faceBlur > $kLaplacianHardThreshold',
'SELECT COUNT(*) as count FROM $facesTable WHERE $faceScore > $kMinimumQualityFaceScore AND $faceBlur > $kLaplacianHardThreshold',
);
return maps.first['count'] as int;
}
Expand Down
9 changes: 2 additions & 7 deletions mobile/lib/services/machine_learning/ml_service.dart
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import "package:photos/services/filedata/filedata_service.dart";
import "package:photos/services/filedata/model/file_data.dart";
import 'package:photos/services/machine_learning/face_ml/face_clustering/face_clustering_service.dart';
import "package:photos/services/machine_learning/face_ml/face_clustering/face_db_info_for_clustering.dart";
import 'package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart';
import "package:photos/services/machine_learning/face_ml/person/person_service.dart";
import "package:photos/services/machine_learning/ml_indexing_isolate.dart";
import 'package:photos/services/machine_learning/ml_result.dart';
Expand Down Expand Up @@ -238,10 +237,7 @@ class MLService {
}
}

Future<void> clusterAllImages({
double minFaceScore = kMinimumQualityFaceScore,
bool clusterInBuckets = true,
}) async {
Future<void> clusterAllImages({bool clusterInBuckets = true}) async {
if (_cannotRunMLFunction()) return;

_logger.info("`clusterAllImages()` called");
Expand Down Expand Up @@ -269,13 +265,12 @@ class MLService {

// Get a sense of the total number of faces in the database
final int totalFaces =
await MLDataDB.instance.getTotalFaceCount(minFaceScore: minFaceScore);
await MLDataDB.instance.getTotalFaceCount();
final fileIDToCreationTime =
await FilesDB.instance.getFileIDToCreationTime();
final startEmbeddingFetch = DateTime.now();
// read all embeddings
final result = await MLDataDB.instance.getFaceInfoForClustering(
minScore: minFaceScore,
maxFaces: totalFaces,
);
final Set<int> missingFileIDs = {};
Expand Down
28 changes: 22 additions & 6 deletions mobile/lib/ui/viewer/file_details/face_widget.dart
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import "package:flutter/cupertino.dart";
import "package:flutter/foundation.dart" show kDebugMode;
import "package:flutter/material.dart";
import "package:photos/db/ml/db.dart";
import "package:photos/extensions/stop_watch.dart";
import "package:photos/generated/l10n.dart";
import "package:photos/models/base/id.dart";
import 'package:photos/models/file/file.dart';
import "package:photos/models/ml/face/face.dart";
import "package:photos/models/ml/face/person.dart";
import "package:photos/services/machine_learning/face_ml/face_detection/detection.dart";
import "package:photos/services/machine_learning/face_ml/face_filtering/face_filtering_constants.dart";
import "package:photos/services/machine_learning/face_ml/feedback/cluster_feedback.dart";
import "package:photos/services/search_service.dart";
import "package:photos/theme/ente_theme.dart";
Expand Down Expand Up @@ -68,16 +69,13 @@ class _FaceWidgetState extends State<FaceWidget> {
if (widget.editMode) return;

log(
"FaceWidget is tapped, with person ${widget.person} and clusterID ${widget.clusterID}",
"FaceWidget is tapped, with person ${widget.person?.data.name} and clusterID ${widget.clusterID}",
name: "FaceWidget",
);
if (widget.person == null && widget.clusterID == null) {
// Get faceID and double check that it doesn't belong to an existing clusterID. If it does, push that cluster page
final w = (kDebugMode ? EnteWatch('FaceWidget') : null)
?..start();
// Double check that it doesn't belong to an existing clusterID.
final existingClusterID = await MLDataDB.instance
.getClusterIDForFaceID(widget.face.faceID);
w?.log('getting existing clusterID for faceID');
if (existingClusterID != null) {
final fileIdsToClusterIds =
await MLDataDB.instance.getFileIdToClusterIds();
Expand All @@ -99,6 +97,24 @@ class _FaceWidgetState extends State<FaceWidget> {
),
),
);
return;
}
if (widget.face.score <= kMinimumQualityFaceScore) {
// The face score is too low for automatic clustering,
// assigning a manual new clusterID so that the user can cluster it manually
final String clusterID = newClusterID();
await MLDataDB.instance.updateFaceIdToClusterId(
{widget.face.faceID: clusterID},
);
await Navigator.of(context).push(
MaterialPageRoute(
builder: (context) => ClusterPage(
[widget.file],
clusterID: clusterID,
),
),
);
return;
}

showShortToast(
Expand Down
Loading