Skip to content

Commit

Permalink
Merge pull request #133 from unicef/hotfix/send_notification
Browse files Browse the repository at this point in the history
chg ! send_notification
  • Loading branch information
ntrncic authored Jan 15, 2025
2 parents 1396c59 + b7f8b9e commit 549d86b
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 48 deletions.
2 changes: 1 addition & 1 deletion src/hope_dedup_engine/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from hope_dedup_engine.config.celery import app as celery_app

VERSION = __version__ = "0.2.0"
VERSION = __version__ = "0.3.0"

__all__ = ("celery_app",)
2 changes: 1 addition & 1 deletion src/hope_dedup_engine/apps/api/admin/deduplicationset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ class DeduplicationSetAdmin(AdminFiltersMixin, ModelAdmin):
"updated_by",
"deleted",
)
search_fields = ("name",)
search_fields = ("name", "id")
list_filter = (
("state", ChoicesFieldComboFilter),
("created_at", DateRangeFilter),
Expand Down
9 changes: 5 additions & 4 deletions src/hope_dedup_engine/apps/api/deduplication/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
callback_encodings,
encode_chunk,
get_chunks,
handle_error,
)

# def _sort_keys(pair: DuplicateKeyPair) -> DuplicateKeyPair:
Expand Down Expand Up @@ -80,8 +81,8 @@ def update_job_progress(job: DedupJob, progress: int) -> None:
@shared_task(soft_time_limit=0.5 * HOUR, time_limit=1 * HOUR)
def find_duplicates(dedup_job_id: int, version: int) -> None:
dedup_job: DedupJob = DedupJob.objects.get(pk=dedup_job_id, version=version)
deduplication_set = dedup_job.deduplication_set
try:
deduplication_set = dedup_job.deduplication_set

deduplication_set.state = DeduplicationSet.State.DIRTY
deduplication_set.save(update_fields=["state"])
Expand Down Expand Up @@ -129,6 +130,6 @@ def find_duplicates(dedup_job_id: int, version: int) -> None:
"chord_id": str(chord_id),
"chunks": len(chunks),
}

finally:
send_notification(dedup_job.deduplication_set.notification_url)
except Exception:
handle_error(deduplication_set)
raise
118 changes: 76 additions & 42 deletions src/hope_dedup_engine/apps/faces/celery_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from celery.utils.imports import qualname

from hope_dedup_engine.apps.api.models import DedupJob, DeduplicationSet
from hope_dedup_engine.apps.api.utils.notification import send_notification
from hope_dedup_engine.apps.faces.managers import FileSyncManager
from hope_dedup_engine.apps.faces.services.facial import dedupe_images, encode_faces
from hope_dedup_engine.config.celery import DedupeTask, app
Expand Down Expand Up @@ -45,6 +46,12 @@ def shadow_name(task, args, kwargs, options):
return str(e)


def handle_error(ds: DeduplicationSet):
ds.state = DeduplicationSet.State.DIRTY
ds.save(update_fields=["state"])
send_notification(ds.notification_url)


@signals.task_prerun.connect
def handle_task_progress(sender=None, task_id=None, dedup_job_id=None, **kwargs):
if not dedup_job_id:
Expand All @@ -63,9 +70,15 @@ def encode_chunk(
) -> tuple[EncodingType, int, int]:
"""Encode faces in a chunk of files."""
ds = DeduplicationSet.objects.get(pk=config.get("deduplication_set_id"))
callback = partial(notify_status, task=self, dedup_job_id=ds.dedupjob.pk)
pre_encodings = ds.get_encodings()
return encode_faces(files, config.get("encoding"), pre_encodings, progress=callback)
try:
callback = partial(notify_status, task=self, dedup_job_id=ds.dedupjob.pk)
pre_encodings = ds.get_encodings()
return encode_faces(
files, config.get("encoding"), pre_encodings, progress=callback
)
except Exception:
handle_error(ds)
raise


@app.task(bind=True, base=DedupeTask)
Expand All @@ -76,17 +89,21 @@ def dedupe_chunk(
) -> FindingType:
"""Deduplicate faces in a chunk of files."""
ds = DeduplicationSet.objects.get(pk=config.get("deduplication_set_id"))
callback = partial(notify_status, task=self, dedup_job_id=ds.dedupjob.pk)
encoded = ds.get_encodings()
ignored_pairs = set(ds.get_ignored_pairs())
return dedupe_images(
files,
encoded,
ignored_pairs,
dedupe_threshold=config.get("deduplicate", {}).get("threshold"),
options=config.get("deduplicate"),
progress=callback,
)
try:
callback = partial(notify_status, task=self, dedup_job_id=ds.dedupjob.pk)
encoded = ds.get_encodings()
ignored_pairs = set(ds.get_ignored_pairs())
return dedupe_images(
files,
encoded,
ignored_pairs,
dedupe_threshold=config.get("deduplicate", {}).get("threshold"),
options=config.get("deduplicate"),
progress=callback,
)
except Exception:
handle_error(ds)
raise


@app.task(bind=True, base=DedupeTask)
Expand All @@ -97,20 +114,29 @@ def callback_findings(
) -> dict[str, Any]:
"""Aggregate and save findings."""
ds = DeduplicationSet.objects.get(pk=config.get("deduplication_set_id"))
seen_pairs = set()
findings = [
record
for d in results
for record in d
if not (pair := tuple(sorted(record[:2]))) in seen_pairs
and not seen_pairs.add(pair)
]
ds.update_findings(findings)
return {
"Files": len(ds.image_set.all()),
"Config": config.get("deduplicate"),
"Findings": len(findings),
}
try:
seen_pairs = set()
findings = [
record
for d in results
for record in d
if not (pair := tuple(sorted(record[:2]))) in seen_pairs
and not seen_pairs.add(pair)
]
ds.update_findings(findings)

ds.state = DeduplicationSet.State.CLEAN
ds.save(update_fields=["state"])
send_notification(ds.notification_url)

return {
"Files": len(ds.image_set.all()),
"Config": config.get("deduplicate"),
"Findings": len(findings),
}
except Exception:
handle_error(ds)
raise


@app.task(bind=True, base=DedupeTask)
Expand All @@ -121,12 +147,16 @@ def callback_encodings(
) -> dict[str, Any]:
"""Aggregate and save encodings."""
ds = DeduplicationSet.objects.get(pk=config.get("deduplication_set_id"))
encodings = dict(ChainMap(*[result[0] for result in results]))
ds.update_encodings(encodings)
deduplicate_dataset.delay(config)
return {
"Encoded": len(encodings),
}
try:
encodings = dict(ChainMap(*[result[0] for result in results]))
ds.update_encodings(encodings)
deduplicate_dataset.delay(config)
return {
"Encoded": len(encodings),
}
except Exception:
handle_error(ds)
raise


@app.task(bind=True, base=DedupeTask)
Expand All @@ -136,14 +166,18 @@ def deduplicate_dataset(
) -> dict[str, Any]:
"""Deduplicate the dataset."""
ds = DeduplicationSet.objects.get(pk=config.get("deduplication_set_id"))
chunks = get_chunks(list(ds.get_encodings().keys()))
tasks = [dedupe_chunk.s(chunk, config) for chunk in chunks]
chord_id = chord(tasks)(callback_findings.s(config=config))
return {
"deduplication_set": str(ds),
"chord_id": str(chord_id),
"chunks": len(chunks),
}
try:
chunks = get_chunks(list(ds.get_encodings().keys()))
tasks = [dedupe_chunk.s(chunk, config) for chunk in chunks]
chord_id = chord(tasks)(callback_findings.s(config=config))
return {
"deduplication_set": str(ds),
"chord_id": str(chord_id),
"chunks": len(chunks),
}
except Exception:
handle_error(ds)
raise


@shared_task(bind=True)
Expand Down

0 comments on commit 549d86b

Please sign in to comment.