Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] - Handle filestream json data in multipart/form-data requests #65

Merged
merged 1 commit into from
Jan 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions flask_pydantic_spec/flask_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,19 +148,28 @@ def request_validation(
else:
parsed_body = request.get_json(silent=True) or {}
elif request.content_type and "multipart/form-data" in request.content_type:
parsed_body = parse_multi_dict(request.form) if request.form else {}
# It's possible there is a binary json object in the files - iterate through and find it
parsed_body = {}
for key, value in request.files.items():
if value.mimetype == "application/json":
parsed_body[key] = json.loads(value.stream.read().decode(encoding="utf-8"))
# Finally, find any JSON objects in the form and add them to the body
parsed_body.update(parse_multi_dict(request.form) or {})
else:
parsed_body = request.get_data() or {}

req_headers: Optional[Headers] = request.headers or None
req_cookies: Optional[Mapping[str, str]] = request.cookies or None
setattr(
request,
"context",
Context(
query=query.parse_obj(req_query) if query else None,
body=getattr(body, "model").parse_obj(parsed_body)
if body and getattr(body, "model")
else None,
body=(
getattr(body, "model").parse_obj(parsed_body)
if body and getattr(body, "model")
else None
),
headers=headers.parse_obj(req_headers or {}) if headers else None,
cookies=cookies.parse_obj(req_cookies or {}) if cookies else None,
),
Expand Down
31 changes: 21 additions & 10 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
from io import BytesIO
from random import randint
import gzip
from typing import Union

import pytest
import json
from flask import Flask, jsonify, request
from werkzeug.datastructures import FileStorage
from werkzeug.test import Client

from flask_pydantic_spec.types import Response, MultipartFormRequest
from flask_pydantic_spec import FlaskPydanticSpec
Expand Down Expand Up @@ -111,7 +114,7 @@ def client(request):


@pytest.mark.parametrize("client", [422], indirect=True)
def test_flask_validate(client):
def test_flask_validate(client: Client):
resp = client.get("/ping")
assert resp.status_code == 422
assert resp.headers.get("X-Error") == "Validation Error"
Expand Down Expand Up @@ -158,23 +161,31 @@ def test_flask_validate(client):


@pytest.mark.parametrize("client", [422], indirect=True)
def test_sending_file(client):
@pytest.mark.parametrize(
"data",
[
FileStorage(
BytesIO(json.dumps({"type": "foo", "created_at": str(datetime.now().date())}).encode()),
),
json.dumps({"type": "foo", "created_at": str(datetime.now().date())}),
],
)
def test_sending_file(client: Client, data: Union[FileStorage, str]):
file = FileStorage(BytesIO(b"abcde"), filename="test.jpg", name="test.jpg")
resp = client.post(
"/api/file",
data={
"file": file,
"file_name": "another_test.jpg",
"data": json.dumps({"type": "foo", "created_at": str(datetime.now().date())}),
"data": data,
},
content_type="multipart/form-data",
)
assert resp.status_code == 200
assert resp.json["name"] == "another_test.jpg"


@pytest.mark.parametrize("client", [422], indirect=True)
def test_query_params(client):
def test_query_params(client: Client):
resp = client.get("api/user?name=james&name=bethany&name=claire")
assert resp.status_code == 200
assert len(resp.json["data"]) == 2
Expand All @@ -189,15 +200,15 @@ def test_query_params(client):


@pytest.mark.parametrize("client", [200], indirect=True)
def test_flask_skip_validation(client):
def test_flask_skip_validation(client: Client):
resp = client.get("api/group/test")
assert resp.status_code == 200
assert resp.json["name"] == "test"
assert resp.json["score"] == ["a", "b", "c", "d", "e"]


@pytest.mark.parametrize("client", [422], indirect=True)
def test_flask_doc(client):
def test_flask_doc(client: Client):
resp = client.get("/apidoc/openapi.json")
assert resp.json == api.spec

Expand All @@ -211,7 +222,7 @@ def test_flask_doc(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_validate_with_alternative_code(client):
def test_flask_validate_with_alternative_code(client: Client):
resp = client.get("/ping")
assert resp.status_code == 400
assert resp.headers.get("X-Error") == "Validation Error"
Expand All @@ -222,7 +233,7 @@ def test_flask_validate_with_alternative_code(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_post_gzip(client):
def test_flask_post_gzip(client: Client):
body = dict(name="flask", limit=10)
compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8"))

Expand All @@ -240,7 +251,7 @@ def test_flask_post_gzip(client):


@pytest.mark.parametrize("client", [400], indirect=True)
def test_flask_post_gzip_failure(client):
def test_flask_post_gzip_failure(client: Client):
body = dict(name="flask")
compressed = gzip.compress(bytes(json.dumps(body), encoding="utf-8"))

Expand Down
Loading