Skip to content

Commit

Permalink
Merge pull request #47 from bobleesj/fix-joss
Browse files Browse the repository at this point in the history
Support Materials Project, CCCD .cif files
  • Loading branch information
bobleesj authored Oct 27, 2024
2 parents a2a6f99 + 0e30bc7 commit bc69d6e
Show file tree
Hide file tree
Showing 10 changed files with 487 additions and 97 deletions.
24 changes: 24 additions & 0 deletions news/MP.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
**Added:**

* CifEnsemble support for ICSD, COD, MP files
* Support CCDC CIF files

**Changed:**

* <news item>

**Deprecated:**

* <news item>

**Removed:**

* <news item>

**Fixed:**

* Preprocess .cif files in CifEnsemble before initializing into CIF objects

**Security:**

* <news item>
24 changes: 4 additions & 20 deletions src/cifkit/models/cif.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
"""
Import statements placed bottom to avoid cluttering.
"""

import logging

# Polyhedron
import os

# Bond pair
Expand Down Expand Up @@ -42,18 +36,16 @@
# Coordination number
from cifkit.preprocessors.environment_util import flat_site_connections

# Edit .cif file
from cifkit.preprocessors.format import preprocess_label_element_loop_values

# Supercell generation
from cifkit.preprocessors.supercell import get_supercell_points
from cifkit.preprocessors.supercell_util import get_cell_atom_count
from cifkit.utils.bond_pair import get_bond_pairs, get_pairs_sorted_by_mendeleev
from cifkit.utils.cif_editor import add_hashtag_in_first_line, remove_author_loop

# Edit .cif file
from cifkit.utils.cif_editor import edit_cif_file_based_on_db

# Parser .cif file
from cifkit.utils.cif_parser import (
check_unique_atom_site_labels,
get_cif_block,
get_formula_structure_weight_s_group,
get_loop_values,
Expand Down Expand Up @@ -132,15 +124,7 @@ def _log_info(self, message):
def _preprocess(self):
"""Preprocess each .cif file and check any error."""
self._log_info(CifLog.PREPROCESSING.value)

if self.db_source == "ICSD":
add_hashtag_in_first_line(self.file_path)

elif self.db_source == "PCD":
remove_author_loop(self.file_path)

preprocess_label_element_loop_values(self.file_path)
check_unique_atom_site_labels(self.file_path)
edit_cif_file_based_on_db(self.file_path)

def _load_data(self):
"""Load data from the .cif file and process it."""
Expand Down
10 changes: 2 additions & 8 deletions src/cifkit/models/cif_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from cifkit import Cif
from cifkit.figures.histogram import plot_histogram
from cifkit.preprocessors.error import move_files_based_on_errors
from cifkit.preprocessors.format import preprocess_label_element_loop_values
from cifkit.utils.cif_editor import remove_author_loop
from cifkit.utils.cif_editor import edit_cif_file_based_on_db
from cifkit.utils.folder import copy_files, get_file_paths, move_files
from cifkit.utils.log_messages import CifEnsembleLog

Expand All @@ -30,12 +29,7 @@ def __init__(
if preprocess:
self._log_info(CifEnsembleLog.PREPROCESSING.value)
for file_path in file_paths:
try:
remove_author_loop(file_path)
preprocess_label_element_loop_values(file_path)
except Exception as e:
print(f"Error processing {file_path}: {e}")

edit_cif_file_based_on_db(file_path)
# Move ill-formatted files after processing
move_files_based_on_errors(cif_dir_path, file_paths)

Expand Down
21 changes: 21 additions & 0 deletions src/cifkit/utils/cif_editor.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os

from cifkit.preprocessors.format import preprocess_label_element_loop_values
from cifkit.utils import cif_parser

# Parser .cif file
from cifkit.utils.cif_parser import check_unique_atom_site_labels
from cifkit.utils.cif_sourcer import get_cif_db_source


def remove_author_loop(file_path: str) -> None:
"""
Expand Down Expand Up @@ -47,3 +52,19 @@ def add_hashtag_in_first_line(file_path: str):
# Write the modified content back to the file
with open(file_path, "w") as file:
file.writelines(lines)


def edit_cif_file_based_on_db(file_path: str):
"""
Edit a CIF file based on the database it is from.
PCD: Remove author loop and preprocess label element loop values
ICSD: Add a hashtag in the first line
"""
db_source = get_cif_db_source(file_path)
if db_source == "ICSD":
add_hashtag_in_first_line(file_path)
elif db_source == "PCD":
remove_author_loop(file_path)
preprocess_label_element_loop_values(file_path)

check_unique_atom_site_labels(file_path)
2 changes: 2 additions & 0 deletions src/cifkit/utils/cif_sourcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ def get_cif_db_source(file_path):
"ICSD": "_database_code_ICSD",
"MS": "'Materials Studio'",
"PCD": "#_database_code_PCD",
"MP": "# generated using pymatgen",
"CCDC": "# Cambridge Structural Database (CSD)",
}

if os.path.exists(file_path) and file_path.endswith(".cif"):
Expand Down
84 changes: 36 additions & 48 deletions tests/core/models/test_cif.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,56 +651,44 @@ def test_init_without_mendeeleve_number():


"""
1. Test ICSD file
Test CIF various db sources
"""


@pytest.mark.fast
def test_init_ICSD_file(tmpdir):
file_path = "tests/data/cif/sources/ICSD/EntryWithCollCode43054.cif"

copied_file_path = os.path.join(tmpdir, "EntryWithCollCode43054.cif")

shutil.copyfile(file_path, copied_file_path)
cif_ICSD = Cif(copied_file_path)
assert cif_ICSD.db_source == "ICSD"
assert cif_ICSD.unique_elements == {"Fe", "Ge"}
assert cif_ICSD.CN_unique_values_by_best_methods == {7, 13}


"""
2. Test MS file
"""


@pytest.mark.fast
def test_init_MS_file(tmpdir):
file_path = "tests/data/cif/sources/MS/U13Rh4.cif"

copied_file_path = os.path.join(tmpdir, "U13Rh4.cif")

shutil.copyfile(file_path, copied_file_path)
cif_MS = Cif(copied_file_path)

assert cif_MS.db_source == "MS"
assert cif_MS.unique_elements == {"U", "Fe"}
assert cif_MS.supercell_atom_count == 2988


"""
3. Test COD file
"""


@pytest.mark.fast
def test_init_COD_file(tmpdir):
file_path = "tests/data/cif/sources/COD/1010581.cif"

copied_file_path = os.path.join(tmpdir, "1010581.cif")

@pytest.mark.parametrize(
"file_path, expected_db_source, expected_elements, expected_atom_count",
[
(
"tests/data/cif/sources/ICSD/EntryWithCollCode43054.cif",
"ICSD",
{"Fe", "Ge"},
216,
),
("tests/data/cif/sources/MS/U13Rh4.cif", "MS", {"U", "Fe"}, 2988),
("tests/data/cif/sources/MS/U13Rh4.cif", "MS", {"U", "Fe"}, 2988),
("tests/data/cif/sources/COD/1010581.cif", "COD", {"Cu", "Se"}, 1383),
("tests/data/cif/sources/CCDC/2294753.cif", "CCDC", {'Er', 'In', 'Co'}, 3844),
(
"tests/data/cif/sources/MP/LiFeP2O7.cif",
"MP",
{"Fe", "Li", "O", "P"},
594,
),
],
)
@pytest.mark.now
def test_init_cif_file(
tmpdir,
file_path,
expected_db_source,
expected_elements,
expected_atom_count,
):
copied_file_path = os.path.join(tmpdir, os.path.basename(file_path))
shutil.copyfile(file_path, copied_file_path)
cif_COD = Cif(copied_file_path)
cif = Cif(copied_file_path)

assert cif_COD.db_source == "COD"
assert cif_COD.unique_elements == {"Cu", "Se"}
assert cif_COD.supercell_atom_count == 1383
# Perform assertions
assert cif.db_source == expected_db_source
assert cif.unique_elements == expected_elements
assert cif.supercell_atom_count == expected_atom_count
45 changes: 24 additions & 21 deletions tests/core/models/test_cif_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytest

from cifkit import CifEnsemble
from cifkit.utils.folder import get_file_count, get_file_paths
from cifkit.utils.folder import copy_files, get_file_count, get_file_paths


@pytest.mark.fast
Expand Down Expand Up @@ -268,26 +268,6 @@ def test_filter_by_CN_best_methods_exact_matching(
"""


# assert cif_ensemble_test.filter_by_CN_min_dist_method_exact_matching(
# [16]
# ) == {
# "tests/data/cif/ensemble_test/300169.cif",
# "tests/data/cif/ensemble_test/300170.cif",
# "tests/data/cif/ensemble_test/300171.cif",
# }

# assert cif_ensemble_test.filter_by_CN_min_dist_method_exact_matching(
# [9, 12, 16]
# ) == {
# "tests/data/cif/ensemble_test/300171.cif",
# }


# """
# Test filter by rang
# """


@pytest.mark.fast
def test_filter_by_supercell_count(cif_ensemble_test: CifEnsemble):
result = cif_ensemble_test.filter_by_supercell_count(200, 400)
Expand Down Expand Up @@ -531,3 +511,26 @@ def test_init_without_preprocessing(

with caplog.at_level(logging.INFO):
assert "Preprocessing tests/data/cif/folder" not in caplog.text


@pytest.mark.parametrize(
"cif_folder_path, expected_file_count, expected_supercell_stats",
[
("tests/data/cif/sources/ICSD", 4, {216: 2, 307: 1, 320: 1}),
("tests/data/cif/sources/COD", 2, {519: 1, 1383: 1}),
("tests/data/cif/sources/MP", 2, {108: 1, 594: 1}),
("tests/data/cif/sources/PCD", 1, {364: 1}),
("tests/data/cif/sources/MS", 1, {2988: 1}),
("tests/data/cif/sources/CCDC", 1, {3844: 1}),
],
)
@pytest.mark.fast
def test_init_cif_files(
tmpdir, cif_folder_path, expected_file_count, expected_supercell_stats
):
cif_file_paths = get_file_paths(cif_folder_path)
copy_files(tmpdir, cif_file_paths)
ensemble = CifEnsemble(tmpdir)

assert ensemble.file_count == expected_file_count
assert ensemble.supercell_size_stats == expected_supercell_stats
Loading

0 comments on commit bc69d6e

Please sign in to comment.