-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sourcery refactored master branch #32
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ def signal_ndarray(df: pd.DataFrame) -> Tuple[np.ndarray, np.ndarray]: | |
X = np.zeros((n_trials, n_channels, n_samples)) | ||
y = np.empty((n_trials)) | ||
|
||
catmap = dict(((cls, i) for i, cls in enumerate(df["class"].cat.categories))) | ||
catmap = {cls: i for i, cls in enumerate(df["class"].cat.categories)} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
# pprint(catmap) | ||
|
||
i_t = 0 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,7 +18,7 @@ def test_unison_shuffled_copies(): | |
b = np.array(range(1, 11)) | ||
c = np.array(range(2, 12)) | ||
sa, sb, sc = unison_shuffled_copies(a, b, c) | ||
assert all([v1 == v2 - 1 == v3 - 2 for v1, v2, v3 in zip(sa, sb, sc)]) | ||
assert all(v1 == v2 - 1 == v3 - 2 for v1, v2, v3 in zip(sa, sb, sc)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
assert all(a == np.array(range(10))), "input array was mutated" | ||
assert all(c == np.array(range(2, 12))), "input array was mutated" | ||
|
||
|
@@ -114,10 +114,8 @@ def aggregate_windows_to_epochs( | |
map(lambda v: map_cls[v], predicted[test][i_start : i_stop + 1]) | ||
) | ||
vote = sum(y_preds) / len(y_preds) | ||
votes.append(vote > 0.5) | ||
else: | ||
# Use the mean probability | ||
vote = np.mean(predicted_proba[test][i_start : i_stop + 1, 1]) | ||
votes.append(vote > 0.5) | ||
|
||
votes.append(vote > 0.5) | ||
Comment on lines
-117
to
+120
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
return np.array(ys_epoch), np.array(votes) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -56,7 +56,7 @@ def start(self, filename: str = None, duration=None, extras: dict = None) -> Non | |
def record(): | ||
# NOTE: This runs in a seperate process/thread | ||
self.board.start_stream() | ||
for i in range(duration): | ||
for _ in range(duration): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
sleep(1) | ||
self._save() | ||
self._stop_brainflow() | ||
|
@@ -87,8 +87,7 @@ def check(self, max_uv_abs=200) -> List[str]: | |
channel_names = BoardShim.get_eeg_names(self.brainflow_id) | ||
# FIXME: _check_samples expects different (Muse) inputs | ||
checked = _check_samples(data.T, channel_names, max_uv_abs=max_uv_abs) # type: ignore | ||
bads = [ch for ch, ok in checked.items() if not ok] | ||
return bads | ||
return [ch for ch, ok in checked.items() if not ok] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def _init_brainflow(self) -> None: | ||
""" | ||
|
@@ -201,13 +200,10 @@ def get_data(self, clear_buffer=True) -> pd.DataFrame: | |
total_data, stim_array, 1 | ||
) # Append the stim array to data. | ||
|
||
# Subtract five seconds of settling time from beginning | ||
# total_data = total_data[5 * self.sfreq :] | ||
df = pd.DataFrame( | ||
return pd.DataFrame( | ||
total_data, | ||
columns=["timestamps"] + ch_names + (["stim"] if self.markers else []), | ||
) | ||
return df | ||
Comment on lines
-204
to
-210
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
def _save(self) -> None: | ||
"""Saves the data to a CSV file.""" | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,9 +45,7 @@ def __init__(self, device_name: str): | |
|
||
@property | ||
def started(self) -> bool: | ||
if self.stream_process: | ||
return self.stream_process.exitcode is None | ||
return False | ||
return self.stream_process.exitcode is None if self.stream_process else False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
def start(self, filename: str = None, duration=None, extras: dict = None): | ||
""" | ||
|
@@ -123,15 +121,14 @@ def _read_buffer(self) -> np.ndarray: | |
|
||
inlets = _get_inlets(verbose=False) | ||
|
||
for i in range(5): | ||
for _ in range(5): | ||
for inlet in inlets: | ||
inlet.pull(timeout=0.5) # type: ignore | ||
inlets = [inlet for inlet in inlets if inlet.buffer.any()] # type: ignore | ||
if inlets: | ||
break | ||
else: | ||
logger.info("No inlets with data, trying again in a second...") | ||
sleep(1) | ||
logger.info("No inlets with data, trying again in a second...") | ||
sleep(1) | ||
Comment on lines
-126
to
+131
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
if not inlets: | ||
raise Exception("No inlets found") | ||
|
@@ -149,5 +146,4 @@ def check(self, max_uv_abs: float) -> List[str]: | |
channels=["TP9", "AF7", "AF8", "TP10"], | ||
max_uv_abs=max_uv_abs, | ||
) | ||
bads = [ch for ch, ok in checked.items() if not ok] | ||
return bads | ||
return [ch for ch, ok in checked.items() if not ok] | ||
Comment on lines
-152
to
+149
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,8 +33,7 @@ def load_demo(duration=1) -> mne.io.RawArray: | |
ch_names = ["T7", "CP5", "FC5", "C3", "C4", "FC6", "CP6", "T8"] | ||
sfreq = BoardShim.get_sampling_rate(BoardIds.SYNTHETIC_BOARD.value) | ||
info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) | ||
raw = mne.io.RawArray(eeg_data, info) | ||
return raw | ||
return mne.io.RawArray(eeg_data, info) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
||
|
||
def test_load_demo(): | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,9 +38,9 @@ def _get_inlets(plt=None, verbose=True) -> List["Inlet"]: | |
info.nominal_srate() != pylsl.IRREGULAR_RATE | ||
or info.channel_format() != pylsl.cf_string | ||
): | ||
logger.warning("Invalid marker stream " + info.name()) | ||
logger.warning(f"Invalid marker stream {info.name()}") | ||
if verbose: | ||
logger.info("Adding marker inlet: " + info.name()) | ||
logger.info(f"Adding marker inlet: {info.name()}") | ||
Comment on lines
-41
to
+43
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
inlets.append(MarkerInlet(info)) | ||
elif ( | ||
info.nominal_srate() != pylsl.IRREGULAR_RATE | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -101,15 +101,14 @@ def connect( | |
sleep(5) | ||
continue | ||
except Exception as e: | ||
if "No Muses found" in str(e): | ||
msg = "No Muses found, trying again in 5s..." | ||
logger.warning(msg) | ||
notify("Couldn't connect", msg) | ||
sleep(5) | ||
continue | ||
else: | ||
if "No Muses found" not in str(e): | ||
raise | ||
|
||
msg = "No Muses found, trying again in 5s..." | ||
logger.warning(msg) | ||
notify("Couldn't connect", msg) | ||
sleep(5) | ||
continue | ||
Comment on lines
-104
to
+111
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
loud = True | ||
started = time() | ||
stop = started + duration | ||
|
@@ -169,11 +168,10 @@ def check(device_name: str): | |
if all_good: | ||
if not last_good: | ||
logger.info("All channels good!") | ||
else: | ||
if bads != last_bads: | ||
logger.warning( | ||
"Warning, bad signal for channels: " + ", ".join(bads) | ||
) | ||
elif bads != last_bads: | ||
logger.warning( | ||
"Warning, bad signal for channels: " + ", ".join(bads) | ||
) | ||
Comment on lines
-172
to
+174
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
last_good = all_good | ||
last_check = time() | ||
last_bads = bads | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,13 +32,11 @@ def generate_save_fn( | |
# create the directory if it doesn't exist | ||
recording_dir.mkdir(parents=True, exist_ok=True) | ||
|
||
# generate filename based on recording date-and-timestamp and then append to recording_dir | ||
save_fp = recording_dir / ( | ||
"recording_%s" % time.strftime("%Y-%m-%d-%H.%M.%S", time.gmtime()) + ".csv" | ||
return recording_dir / ( | ||
f'recording_{time.strftime("%Y-%m-%d-%H.%M.%S", time.gmtime())}' | ||
+ ".csv" | ||
) | ||
|
||
return save_fp | ||
Comment on lines
-35
to
-40
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ):
|
||
|
||
|
||
def print_statusline(msg: str): | ||
"""From: https://stackoverflow.com/a/43952192/965332""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Function
_clean_signal_quality
refactored with the following changes:list-comprehension
)This removes the following comments ( why? ):