diff --git a/.github/workflows/ruff_lint.yml b/.github/workflows/ruff_lint.yml index 745d697..665f0bc 100644 --- a/.github/workflows/ruff_lint.yml +++ b/.github/workflows/ruff_lint.yml @@ -19,4 +19,4 @@ jobs: run: pip install ruff - name: Run Ruff Linter with auto-fix - run: ruff check --fix . \ No newline at end of file + run: ruff check --fix imputegap/ diff --git a/.idea/workspace.xml b/.idea/workspace.xml index 93aa710..085af45 100644 --- a/.idea/workspace.xml +++ b/.idea/workspace.xml @@ -2,8 +2,11 @@ - - + + + + + diff --git a/imputegap/wrapper/AlgoPython/priSTI/models/layers.py b/imputegap/wrapper/AlgoPython/priSTI/models/layers.py index c050093..e744d14 100644 --- a/imputegap/wrapper/AlgoPython/priSTI/models/layers.py +++ b/imputegap/wrapper/AlgoPython/priSTI/models/layers.py @@ -1,7 +1,8 @@ import copy +import math +import torch import torch.nn as nn import torch.nn.functional as F -import math from imputegap.wrapper.AlgoPython.priSTI.models.generate_adj import * @@ -345,7 +346,9 @@ def forward(self, x, itp_x=None, **kwargs): # merge head into batch for queries and key / values queries = queries.reshape(b, n, h, -1).transpose(1, 2) - merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1) + def merge_key_values(t): + return t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1) + keys, values = map(merge_key_values, (keys, values)) # attention diff --git a/imputegap/wrapper/AlgoPython/priSTI/models/pristi.py b/imputegap/wrapper/AlgoPython/priSTI/models/pristi.py index a3a7098..5043565 100644 --- a/imputegap/wrapper/AlgoPython/priSTI/models/pristi.py +++ b/imputegap/wrapper/AlgoPython/priSTI/models/pristi.py @@ -106,7 +106,7 @@ def calc_loss( return loss def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask): - if self.is_unconditional == True: + if self.is_unconditional: total_input = noisy_data.unsqueeze(1) else: if not self.use_guide: @@ -124,7 +124,7 @@ def impute(self, observed_data, cond_mask, side_info, n_samples, itp_info): imputed_samples = torch.zeros(B, n_samples, K).to(self.device) for i in range(n_samples): # generate noisy observation for unconditional model - if self.is_unconditional == True: + if self.is_unconditional: noisy_obs = observed_data noisy_cond_history = [] for t in range(self.num_steps): @@ -135,7 +135,7 @@ def impute(self, observed_data, cond_mask, side_info, n_samples, itp_info): current_sample = torch.randn_like(observed_data) for t in range(self.num_steps - 1, -1, -1): - if self.is_unconditional == True: + if self.is_unconditional: diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample diff_input = diff_input.unsqueeze(1) # (B,1,K,L) else: diff --git a/requirements.txt b/requirements.txt index ab46f18..9ac4026 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# Core Libraries numpy==1.26.4 matplotlib==3.7.5 toml==0.10.2 @@ -5,6 +6,8 @@ scikit-learn==1.3.2 scipy==1.14.1 setuptools==75.1.0 tensorflow==2.17.0 + +# Additional Libraries shap==0.44.1 pycatch22==0.4.5 scikit-optimize==0.10.2 @@ -15,15 +18,19 @@ types-toml types-setuptools wheel -# POST ALGO +# PyTorch and Related Libraries torch==2.5.1 # brits -ujson==5.10.0 # brits -torchaudio==2.5.1 # brits torchvision==0.20.1 # brits -ipdb==0.13.13 # brits -pandas==2.0.3 # brits +torchaudio==2.5.1 # brits +# PyTorch Geometric and Extensions (Install via prebuilt binaries) torch-geometric==2.6.1 # MPIN -torch-cluster==1.6.3 # MPIN +torch-cluster==1.6.3 # MPIN; install from prebuilt binaries +-f https://data.pyg.org/whl/torch-2.5.1.html + +# Other Dependencies +ujson==5.10.0 # brits +ipdb==0.13.13 # brits +pandas==2.0.3 # brits torchcde==0.2.5 # PRISTI -tables==3.10.2 # SSA \ No newline at end of file +tables==3.10.2 # SSA