Skip to content

Commit

Permalink
add HF inference endpoints support
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Jun 25, 2023
1 parent 6c54814 commit 7d6d7ae
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 3 deletions.
4 changes: 2 additions & 2 deletions src/banks/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,14 @@
# SPDX-License-Identifier: MIT
from jinja2 import Environment, select_autoescape

from banks.extensions import GenerateExtension
from banks.extensions import GenerateExtension, HFInferenceEndpointsExtension
from banks.filters import lemmatize
from banks.loader import MultiLoader

# Init the Jinja env
env = Environment(
loader=MultiLoader(),
extensions=[GenerateExtension],
extensions=[GenerateExtension, HFInferenceEndpointsExtension],
autoescape=select_autoescape(
enabled_extensions=("html", "xml"),
default_for_string=True,
Expand Down
3 changes: 2 additions & 1 deletion src/banks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
#
# SPDX-License-Identifier: MIT
from banks.extensions.generate import GenerateExtension
from banks.extensions.inference_endpoint import HFInferenceEndpointsExtension

__all__ = ("GenerateExtension",)
__all__ = ("GenerateExtension", "HFInferenceEndpointsExtension")
51 changes: 51 additions & 0 deletions src/banks/extensions/inference_endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# SPDX-FileCopyrightText: 2023-present Massimiliano Pippi <[email protected]>
#
# SPDX-License-Identifier: MIT
import os
import html

import requests
from jinja2 import nodes
from jinja2.ext import Extension


class HFInferenceEndpointsExtension(Extension):
"""
`inference_endpoint` can be used to call the Hugging Face Inference Endpoint API
passing a prompt to get back some content.
Example:
```
{% inference_endpoint "write a tweet with positive sentiment", "https://foo.aws.endpoints.huggingface.cloud" %}
Life is beautiful, full of opportunities & positivity
```
"""

# a set of names that trigger the extension.
tags = {"inference_endpoint"}

def parse(self, parser):
# We get the line number of the first token so that we can give
# that line number to the nodes we create by hand.
lineno = next(parser.stream).lineno

# The args passed to the extension:
# - the prompt text used to generate new text
args = [parser.parse_expression()]
# - second param after the comma, the inference endpoint URL
parser.stream.skip_if("comma")
args.append(parser.parse_expression())

return nodes.Output([self.call_method("_call_endpoint", args)]).set_lineno(lineno)

def _call_endpoint(self, text, endpoint):
"""
Helper callback.
"""
access_token = os.environ.get("HF_ACCESS_TOKEN")
response = requests.post(endpoint, json={"inputs": text}, headers={'Authorization': f'Bearer {access_token}'})
response_body = response.json()

if response_body:
return html.unescape(response_body[0].get("generated_text", ""))
return ""

0 comments on commit 7d6d7ae

Please sign in to comment.