diff --git a/prodigy_pdf/__init__.py b/prodigy_pdf/__init__.py index 71ba3de..ab66c7d 100644 --- a/prodigy_pdf/__init__.py +++ b/prodigy_pdf/__init__.py @@ -8,8 +8,8 @@ import pypdfium2 as pdfium from prodigy import recipe, set_hashes, ControllerComponentsDict -from prodigy.components.stream import Stream -from prodigy.util import msg +from prodigy.components.stream import Stream, get_stream +from prodigy.util import msg, split_string def page_to_image(page: pdfium.PdfPage) -> str: @@ -132,7 +132,7 @@ def _validate_ocr_example(ex: Dict): # fmt: off dataset=("Dataset to save answers to", "positional", None, str), source=("Source with PDF Annotations", "positional", None, str), - labels=("Labels to consider", "option", "l", str), + labels=("Labels to consider", "option", "l", split_string), scale=("Zoom scale. Increase above 3 to upscale the image for OCR.", "option", "s", int), remove_base64=("Remove base64-encoded image data", "flag", "R", bool), fold_dashes=("Removes dashes at the end of a textline and folds them with the next term.", "flag", "f", bool), @@ -150,11 +150,10 @@ def pdf_ocr_correct( ) -> ControllerComponentsDict: """Applies OCR to annotated segments and gives a textbox for corrections.""" stream = get_stream(source) - labels = labels.split(",") def new_stream(stream): for ex in stream: - useful_spans = [span for span in ex['spans'] if span['label'] in labels] + useful_spans = [span for span in ex.get('spans', []) if span['label'] in labels] if useful_spans: _validate_ocr_example(ex) pdf = pdfium.PdfDocument(ex['meta']['path'])