diff --git a/src/banks/env.py b/src/banks/env.py index e74c460..7ab5b38 100644 --- a/src/banks/env.py +++ b/src/banks/env.py @@ -3,14 +3,14 @@ # SPDX-License-Identifier: MIT from jinja2 import Environment, select_autoescape -from banks.extensions import GenerateExtension +from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension from banks.filters import lemmatize from banks.loader import MultiLoader # Init the Jinja env env = Environment( loader=MultiLoader(), - extensions=[GenerateExtension], + extensions=[GenerateExtension, HFInferenceEndpointsExtension], autoescape=select_autoescape( enabled_extensions=("html", "xml"), default_for_string=True, diff --git a/src/banks/extensions/__init__.py b/src/banks/extensions/__init__.py index 2f5bf74..6046ecd 100644 --- a/src/banks/extensions/__init__.py +++ b/src/banks/extensions/__init__.py @@ -2,5 +2,6 @@ # # SPDX-License-Identifier: MIT from banks.extensions.generate import GenerateExtension +from banks.extensions.inference_endpoint import HFInferenceEndpointsExtension -__all__ = ("GenerateExtension",) +__all__ = ("GenerateExtension", "HFInferenceEndpointsExtension") diff --git a/src/banks/extensions/inference_endpoint.py b/src/banks/extensions/inference_endpoint.py new file mode 100644 index 0000000..4e5b597 --- /dev/null +++ b/src/banks/extensions/inference_endpoint.py @@ -0,0 +1,51 @@ +# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi +# +# SPDX-License-Identifier: MIT +import os +import html + +import requests +from jinja2 import nodes +from jinja2.ext import Extension + + +class HFInferenceEndpointsExtension(Extension): + """ + `inference_endpoint` can be used to call the Hugging Face Inference Endpoint API + passing a prompt to get back some content. + + Example: + ``` + {% inference_endpoint "write a tweet with positive sentiment", "https://foo.aws.endpoints.huggingface.cloud" %} + Life is beautiful, full of opportunities & positivity + ``` + """ + + # a set of names that trigger the extension. + tags = {"inference_endpoint"} + + def parse(self, parser): + # We get the line number of the first token so that we can give + # that line number to the nodes we create by hand. + lineno = next(parser.stream).lineno + + # The args passed to the extension: + # - the prompt text used to generate new text + args = [parser.parse_expression()] + # - second param after the comma, the inference endpoint URL + parser.stream.skip_if("comma") + args.append(parser.parse_expression()) + + return nodes.Output([self.call_method("_call_endpoint", args)]).set_lineno(lineno) + + def _call_endpoint(self, text, endpoint): + """ + Helper callback. + """ + access_token = os.environ.get("HF_ACCESS_TOKEN") + response = requests.post(endpoint, json={"inputs": text}, headers={'Authorization': f'Bearer {access_token}'}) + response_body = response.json() + + if response_body: + return html.unescape(response_body[0].get("generated_text", "")) + return ""