Skip to content

Commit

Permalink
Use Literal type and lower case for encodings
Browse files Browse the repository at this point in the history
  • Loading branch information
trossi committed Jun 17, 2024
1 parent 442fd3a commit 42ac9b8
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 16 deletions.
5 changes: 3 additions & 2 deletions rdata/_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import os
from typing import Any

from .conversion.to_r import Encoding
from .unparser import Compression, FileFormat


Expand All @@ -20,7 +21,7 @@ def write_rds(
*,
file_format: FileFormat = "xdr",
compression: Compression = "gzip",
encoding: str = "UTF-8",
encoding: Encoding = "utf-8",
format_version: int = DEFAULT_FORMAT_VERSION,
) -> None:
"""
Expand Down Expand Up @@ -75,7 +76,7 @@ def write_rda(
*,
file_format: FileFormat = "xdr",
compression: Compression = "gzip",
encoding: str = "UTF-8",
encoding: Encoding = "utf-8",
format_version: int = DEFAULT_FORMAT_VERSION,
) -> None:
"""
Expand Down
24 changes: 14 additions & 10 deletions rdata/conversion/to_r.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,15 @@

if TYPE_CHECKING:
from collections.abc import Mapping
from typing import Any, Final, Protocol
from typing import Any, Final, Literal, Protocol

Encoding = Literal["utf-8", "cp1252"]


class Converter(Protocol):
"""Protocol for Py-to-R conversion."""

def __call__(self, data: Any, *, encoding: str) -> RObject: # noqa: ANN401
def __call__(self, data: Any, *, encoding: Encoding) -> RObject: # noqa: ANN401
"""Convert Python object to R object."""


Expand Down Expand Up @@ -91,7 +94,7 @@ def build_r_object(
def build_r_list(
data: Mapping[str, Any] | list[Any],
*,
encoding: str,
encoding: Encoding,
convert_value: Converter | None = None,
) -> RObject:
"""
Expand Down Expand Up @@ -138,7 +141,7 @@ def build_r_list(
def build_r_sym(
data: str,
*,
encoding: str,
encoding: Encoding,
) -> RObject:
"""
Build R object representing symbol.
Expand All @@ -158,7 +161,7 @@ def build_r_sym(
def build_r_data(
r_object: RObject,
*,
encoding: str = "UTF-8",
encoding: Encoding = "utf-8",
format_version: int = DEFAULT_FORMAT_VERSION,
r_version_serialized: int = DEFAULT_R_VERSION_SERIALIZED,
) -> RData:
Expand All @@ -184,7 +187,8 @@ def build_r_data(
)

minimum_version_with_encoding = 3
extra = (RExtraInfo(encoding) if versions.format >= minimum_version_with_encoding
extra = (RExtraInfo(encoding.upper())
if versions.format >= minimum_version_with_encoding
else RExtraInfo(None))

return RData(versions, extra, r_object)
Expand All @@ -193,7 +197,7 @@ def build_r_data(
def convert_to_r_object_for_rda(
data: Mapping[str, Any],
*,
encoding: str = "UTF-8",
encoding: Encoding = "utf-8",
) -> RObject:
"""
Convert Python dictionary to R object for RDA file.
Expand All @@ -217,7 +221,7 @@ def convert_to_r_object_for_rda(
def convert_to_r_object( # noqa: C901, PLR0912, PLR0915
data: Any, # noqa: ANN401
*,
encoding: str = "UTF-8",
encoding: Encoding = "utf-8",
) -> RObject:
"""
Convert Python data to R object.
Expand Down Expand Up @@ -318,9 +322,9 @@ def convert_to_r_object( # noqa: C901, PLR0912, PLR0915
r_type = RObjectType.CHAR
if all(chr(byte) in string.printable for byte in data):
gp = CharFlags.ASCII
elif encoding == "UTF-8":
elif encoding == "utf-8":
gp = CharFlags.UTF8
elif encoding == "CP1252":
elif encoding == "cp1252":
# Note!
# CP1252 and Latin1 are not the same.
# Does CharFlags.LATIN1 mean actually CP1252
Expand Down
12 changes: 8 additions & 4 deletions rdata/tests/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from rdata.unparser import unparse_data

if TYPE_CHECKING:
from rdata.conversion.to_r import Encoding
from rdata.unparser import Compression, FileFormat, FileType


Expand Down Expand Up @@ -118,9 +119,12 @@ def test_convert_to_r(fname: str) -> None:
except NotImplementedError as e:
pytest.skip(str(e))

encoding = r_data.extra.encoding
encoding: Encoding
encoding = r_data.extra.encoding # type: ignore [assignment]
if encoding is None:
encoding = "CP1252" if "win" in fname else "UTF-8"
encoding = "cp1252" if "win" in fname else "utf-8"
else:
encoding = encoding.lower() # type: ignore [assignment]

try:
if file_type == "rds":
Expand Down Expand Up @@ -168,13 +172,13 @@ def test_unparse_bad_rda() -> None:
def test_convert_to_r_bad_encoding() -> None:
"""Test checking encoding."""
with pytest.raises(LookupError, match="(?i)unknown encoding"):
rdata.conversion.convert_to_r_object("ä", encoding="non-existent")
rdata.conversion.convert_to_r_object("ä", encoding="non-existent") # type: ignore [arg-type]


def test_convert_to_r_unsupported_encoding() -> None:
"""Test checking encoding."""
with pytest.raises(ValueError, match="(?i)unsupported encoding"):
rdata.conversion.convert_to_r_object("ä", encoding="CP1250")
rdata.conversion.convert_to_r_object("ä", encoding="cp1250") # type: ignore [arg-type]


def test_unparse_big_int() -> None:
Expand Down

0 comments on commit 42ac9b8

Please sign in to comment.