Skip to content

Commit

Permalink
1.b upload wrappers integration alpha alg
Browse files Browse the repository at this point in the history
  • Loading branch information
qnater committed Jan 20, 2025
1 parent f055085 commit 3c9a1f9
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ruff_lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ jobs:
run: pip install ruff

- name: Run Ruff Linter with auto-fix
run: ruff check --fix .
run: ruff check --fix imputegap/
9 changes: 6 additions & 3 deletions .idea/workspace.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 5 additions & 2 deletions imputegap/wrapper/AlgoPython/priSTI/models/layers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import copy
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from imputegap.wrapper.AlgoPython.priSTI.models.generate_adj import *


Expand Down Expand Up @@ -345,7 +346,9 @@ def forward(self, x, itp_x=None, **kwargs):
# merge head into batch for queries and key / values
queries = queries.reshape(b, n, h, -1).transpose(1, 2)

merge_key_values = lambda t: t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)
def merge_key_values(t):
return t.reshape(b, k, -1, d_h).transpose(1, 2).expand(-1, h, -1, -1)

keys, values = map(merge_key_values, (keys, values))

# attention
Expand Down
6 changes: 3 additions & 3 deletions imputegap/wrapper/AlgoPython/priSTI/models/pristi.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def calc_loss(
return loss

def set_input_to_diffmodel(self, noisy_data, observed_data, cond_mask):
if self.is_unconditional == True:
if self.is_unconditional:
total_input = noisy_data.unsqueeze(1)
else:
if not self.use_guide:
Expand All @@ -124,7 +124,7 @@ def impute(self, observed_data, cond_mask, side_info, n_samples, itp_info):
imputed_samples = torch.zeros(B, n_samples, K).to(self.device)
for i in range(n_samples):
# generate noisy observation for unconditional model
if self.is_unconditional == True:
if self.is_unconditional:
noisy_obs = observed_data
noisy_cond_history = []
for t in range(self.num_steps):
Expand All @@ -135,7 +135,7 @@ def impute(self, observed_data, cond_mask, side_info, n_samples, itp_info):
current_sample = torch.randn_like(observed_data)

for t in range(self.num_steps - 1, -1, -1):
if self.is_unconditional == True:
if self.is_unconditional:
diff_input = cond_mask * noisy_cond_history[t] + (1.0 - cond_mask) * current_sample
diff_input = diff_input.unsqueeze(1) # (B,1,K,L)
else:
Expand Down
21 changes: 14 additions & 7 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Core Libraries
numpy==1.26.4
matplotlib==3.7.5
toml==0.10.2
scikit-learn==1.3.2
scipy==1.14.1
setuptools==75.1.0
tensorflow==2.17.0

# Additional Libraries
shap==0.44.1
pycatch22==0.4.5
scikit-optimize==0.10.2
Expand All @@ -15,15 +18,19 @@ types-toml
types-setuptools
wheel

# POST ALGO
# PyTorch and Related Libraries
torch==2.5.1 # brits
ujson==5.10.0 # brits
torchaudio==2.5.1 # brits
torchvision==0.20.1 # brits
ipdb==0.13.13 # brits
pandas==2.0.3 # brits
torchaudio==2.5.1 # brits

# PyTorch Geometric and Extensions (Install via prebuilt binaries)
torch-geometric==2.6.1 # MPIN
torch-cluster==1.6.3 # MPIN
torch-cluster==1.6.3 # MPIN; install from prebuilt binaries
-f https://data.pyg.org/whl/torch-2.5.1.html

# Other Dependencies
ujson==5.10.0 # brits
ipdb==0.13.13 # brits
pandas==2.0.3 # brits
torchcde==0.2.5 # PRISTI
tables==3.10.2 # SSA
tables==3.10.2 # SSA

0 comments on commit 3c9a1f9

Please sign in to comment.