diff --git a/examples/svm/svm.py b/examples/svm/svm.py index 22a90cd..90e28be 100644 --- a/examples/svm/svm.py +++ b/examples/svm/svm.py @@ -12,6 +12,7 @@ # imported: as long as the module we want exists on the worker, it can be # imported and used. import numpy as np +import sklearn from sklearn.svm import SVC from sklearn.ensemble import BaggingClassifier from sklearn.multiclass import OneVsRestClassifier @@ -58,6 +59,7 @@ def main(params): out : dict Dictionary of performance metrics to send to SHADHO. """ + # Extract the kernel name and parameters. This is just a short expression # to get the only dictionary entry, which should have our hyperparameters. kernel_params = list(params.values())[0] @@ -68,14 +70,8 @@ def main(params): X_train = X_train.astype(np.float32) / 255.0 X_test = X_test.astype(np.float32) / 255.0 - # Set up the SVM with its parameterized kernel. The long form of instantiation - # is done here to show what `kernel_params` looks like internally. - # This can be shortened to `svc = SVC(**kernel_params)` - svc = SVC(kernel=kernel_params['kernel'], - C=kernel_params['C'], - gamma=kernel_params['gamma'] if 'gamma' in kernel_params else None, - coef0=kernel_params['coef0'] if 'coef0' in kernel_params else None, - degree=kernel_params['degree'] if 'degree' in kernel_params else None) + # Set up the SVM with its parameterized kernel. + svc = SVC(**kernel_params) # Set up parallel training across as many cores as are available on the # worker. @@ -94,7 +90,7 @@ def main(params): # Generate and encode testing set predictions, along with prediction time. start = time.time() predictions = s.predict(X_test) - test_time = time.time() + test_time = time.time() - start encoder = LabelBinarizer() loss_labels = encoder.fit_transform(predictions) @@ -113,7 +109,7 @@ def main(params): 'accuracy': acc, 'precision': p, 'recall': r, - 'params': svm, + 'params': kernel_params, 'train_time': train_time, 'test_time': test_time } diff --git a/shadho/spaces.py b/shadho/spaces.py index d429dce..4963d6b 100644 --- a/shadho/spaces.py +++ b/shadho/spaces.py @@ -1,4 +1,4 @@ -from shadho.scaling import linear, ln, log_10, log_2 +from .scaling import linear, ln, log_10, log_2 from pyrameter import Scope, ContinuousDomain, DiscreteDomain import scipy.stats