Womens Health Risk Assessment

Modeling Women’s Health Risk Assessment

Posted on Posted in Competition Notes, Machine Learning, scikit-learn

Women’s Health Risk Assessment is a multi-class classification competition for finding an optimized machine learning a solution that allows a young woman (age 15-30 years old) to be accurately categorized for their particular health risk. Based on the category a patient falls within, healthcare providers can offer appropriate education and training programs to help reduce the patient’s reproductive health risks. This blog is modeling Women’s Health Risk Assessment.

In [1]:
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import pandas as pd
import warnings
warnings.filterwarnings('ignore') 
In [2]:
# importing the dataset
df = pd.read_csv("WomenHealth_Training.csv")
df.head(5)
Out[2]:
patientID geo christian muslim hindu other cellphone motorcycle radio cooker ModCon usecondom hivknow lowlit highlit urban rural single segment subgroup
0 4835 9 0.0 0.0 1.0 0.0 0 0 0 0 NaN NaN NaN NaN NaN 0 1 NaN 3 1
1 6719 4 1.0 0.0 0.0 0.0 1 0 1 0 1.0 0.0 1.0 1.0 0.0 0 1 0.0 1 1
2 5450 2 0.0 1.0 0.0 0.0 0 0 0 1 0.0 0.0 1.0 0.0 1.0 0 1 0.0 3 1
3 1207 2 1.0 0.0 0.0 0.0 0 0 0 0 0.0 0.0 1.0 1.0 0.0 0 1 0.0 2 1
4 7290 9 0.0 0.0 1.0 0.0 1 0 0 0 0.0 NaN NaN 0.0 0.0 1 0 0.0 3 1

5 rows × 50 columns

Review input features

In [3]:
print ("\n\n---------------------")
print ("TRAIN SET INFORMATION")
print ("---------------------")
print ("Shape of training set:", df.shape, "\n")
print ("Column Headers:", list(df.columns.values), "\n")
print (df.dtypes)
---------------------
TRAIN SET INFORMATION
---------------------
Shape of training set: (5283, 50) 

Column Headers: ['patientID', 'geo', 'christian', 'muslim', 'hindu', 'other', 'cellphone', 'motorcycle', 'radio', 'cooker', 'fridge', 'furniture', 'computer', 'cart', 'irrigation', 'thrasher', 'car', 'generator', 'INTNR', 'REGION_PROVINCE', 'DISTRICT', 'electricity', 'age', 'tribe', 'foodinsecurity', 'EVER_HAD_SEX', 'EVER_BEEN_PREGNANT', 'CHILDREN', 'india', 'married', 'multpart', 'educ', 'inschool', 'ownincome', 'literacy', 'religion', 'urbanicity', 'LaborDeliv', 'babydoc', 'Debut', 'ModCon', 'usecondom', 'hivknow', 'lowlit', 'highlit', 'urban', 'rural', 'single', 'segment', 'subgroup'] 

patientID               int64
geo                     int64
christian             float64
muslim                float64
hindu                 float64
other                 float64
cellphone               int64
motorcycle              int64
radio                   int64
cooker                  int64
fridge                  int64
furniture               int64
computer                int64
cart                    int64
irrigation              int64
thrasher                int64
car                     int64
generator               int64
INTNR                   int64
REGION_PROVINCE         int64
DISTRICT                int64
electricity             int64
age                     int64
tribe                   int64
foodinsecurity          int64
EVER_HAD_SEX            int64
EVER_BEEN_PREGNANT      int64
CHILDREN                int64
india                   int64
married               float64
multpart              float64
educ                  float64
inschool              float64
ownincome             float64
literacy              float64
religion               object
urbanicity              int64
LaborDeliv            float64
babydoc               float64
Debut                 float64
ModCon                float64
usecondom             float64
hivknow               float64
lowlit                float64
highlit               float64
urban                   int64
rural                   int64
single                float64
segment                 int64
subgroup                int64
dtype: object
In [4]:
import re
missing_values = []
nonumeric_values = []

print ("TRAINING SET INFORMATION")
print ("========================\n")

for column in df:
    # Find all the unique feature values
    uniq = df[column].unique()
    print ("'{}' has {} unique values" .format(column,uniq.size))
    if (uniq.size > 10):
        print("~~Listing up to 10 unique values~~")
    print (uniq[0:10])
    print ("\n-----------------------------------------------------------------------\n")
    
    # Find features with missing values
    if (True in pd.isnull(uniq)):
        s = "{} has {} missing" .format(column, pd.isnull(df[column]).sum())
        missing_values.append(s)
    
    # Find features with non-numeric values
    for i in range (1, np.prod(uniq.shape)):
        if (re.match('nan', str(uniq[i]))):
            break
        if not (re.search('(^\d+\.?\d*$)|(^\d*\.?\d+$)', str(uniq[i]))):
            nonumeric_values.append(column)
            break
  
print ("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
print ("Features with missing values:\n{}\n\n" .format(missing_values))
print ("Features with non-numeric values:\n{}" .format(nonumeric_values))
print ("\n~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n")
TRAINING SET INFORMATION
========================

'patientID' has 5283 unique values
~~Listing up to 10 unique values~~
[4835 6719 5450 1207 7290  620 7760 5437  983 7491]

-----------------------------------------------------------------------

'geo' has 9 unique values
[9 4 2 3 8 6 1 5 7]

-----------------------------------------------------------------------

'christian' has 3 unique values
[  0.   1.  nan]

-----------------------------------------------------------------------

'muslim' has 3 unique values
[  0.   1.  nan]

-----------------------------------------------------------------------

'hindu' has 3 unique values
[  1.   0.  nan]

-----------------------------------------------------------------------

'other' has 3 unique values
[  0.   1.  nan]

-----------------------------------------------------------------------

'cellphone' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'motorcycle' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'radio' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'cooker' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'fridge' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'furniture' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'computer' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'cart' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'irrigation' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'thrasher' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'car' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'generator' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'INTNR' has 3907 unique values
~~Listing up to 10 unique values~~
[11180660     2958      737     1053  4270306     1450      144      313
 14460825     1275]

-----------------------------------------------------------------------

'REGION_PROVINCE' has 67 unique values
~~Listing up to 10 unique values~~
[48 14 43 37 23 44 38 31 50 33]

-----------------------------------------------------------------------

'DISTRICT' has 181 unique values
~~Listing up to 10 unique values~~
[163 126 158 152 194 151 153 142 165 147]

-----------------------------------------------------------------------

'electricity' has 3 unique values
[1 2 3]

-----------------------------------------------------------------------

'age' has 16 unique values
~~Listing up to 10 unique values~~
[30 25 19 18 28 21 29 16 26 22]

-----------------------------------------------------------------------

'tribe' has 59 unique values
~~Listing up to 10 unique values~~
[23 38 52 25 50 12 13 63 22  6]

-----------------------------------------------------------------------

'foodinsecurity' has 4 unique values
[ 2  3  1 99]

-----------------------------------------------------------------------

'EVER_HAD_SEX' has 2 unique values
[1 0]

-----------------------------------------------------------------------

'EVER_BEEN_PREGNANT' has 2 unique values
[1 0]

-----------------------------------------------------------------------

'CHILDREN' has 2 unique values
[1 0]

-----------------------------------------------------------------------

'india' has 2 unique values
[1 0]

-----------------------------------------------------------------------

'married' has 3 unique values
[  0.   1.  nan]

-----------------------------------------------------------------------

'multpart' has 4 unique values
[ nan   1.   0.   2.]

-----------------------------------------------------------------------

'educ' has 7 unique values
[  0.   1.   2.   3.   4.   5.  nan]

-----------------------------------------------------------------------

'inschool' has 3 unique values
[  0.   1.  nan]

-----------------------------------------------------------------------

'ownincome' has 3 unique values
[  1.   0.  nan]

-----------------------------------------------------------------------

'literacy' has 7 unique values
[ nan   0.   4.   3.   2.   5.   1.]

-----------------------------------------------------------------------

'religion' has 11 unique values
~~Listing up to 10 unique values~~
['Hindu' 'Evangelical/Bo' 'Muslim' 'Roman Catholic' 'Other Christia'
 'Buddhist' 'Russian/Easter' 'Traditional/An' 'Other' nan]

-----------------------------------------------------------------------

'urbanicity' has 3 unique values
[0 2 1]

-----------------------------------------------------------------------

'LaborDeliv' has 3 unique values
[  1.   0.  nan]

-----------------------------------------------------------------------

'babydoc' has 11 unique values
~~Listing up to 10 unique values~~
[ 0.  2.  4.  3.  5.  9.  1.  8.  6.  7.]

-----------------------------------------------------------------------

'Debut' has 22 unique values
~~Listing up to 10 unique values~~
[ 17.  14.  18.  nan  15.  16.  21.  20.  25.  29.]

-----------------------------------------------------------------------

'ModCon' has 3 unique values
[ nan   1.   0.]

-----------------------------------------------------------------------

'usecondom' has 3 unique values
[ nan   0.   1.]

-----------------------------------------------------------------------

'hivknow' has 3 unique values
[ nan   1.   0.]

-----------------------------------------------------------------------

'lowlit' has 3 unique values
[ nan   1.   0.]

-----------------------------------------------------------------------

'highlit' has 3 unique values
[ nan   0.   1.]

-----------------------------------------------------------------------

'urban' has 2 unique values
[0 1]

-----------------------------------------------------------------------

'rural' has 2 unique values
[1 0]

-----------------------------------------------------------------------

'single' has 2 unique values
[ nan   0.]

-----------------------------------------------------------------------

'segment' has 4 unique values
[3 1 2 4]

-----------------------------------------------------------------------

'subgroup' has 2 unique values
[1 2]

-----------------------------------------------------------------------


~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Features with missing values:
['christian has 15 missing', 'muslim has 15 missing', 'hindu has 15 missing', 'other has 15 missing', 'married has 22 missing', 'multpart has 1157 missing', 'educ has 90 missing', 'inschool has 17 missing', 'ownincome has 60 missing', 'literacy has 116 missing', 'religion has 15 missing', 'LaborDeliv has 2520 missing', 'babydoc has 67 missing', 'Debut has 572 missing', 'ModCon has 319 missing', 'usecondom has 2367 missing', 'hivknow has 608 missing', 'lowlit has 116 missing', 'highlit has 116 missing', 'single has 1386 missing']


Features with non-numeric values:
['religion']

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Select the rows with nan

In [5]:
df[np.isnan(df['christian'])]
Out[5]:
patientID geo christian muslim hindu other cellphone motorcycle radio cooker ModCon usecondom hivknow lowlit highlit urban rural single segment subgroup
323 7060 7 NaN NaN NaN NaN 0 0 0 0 0.0 0.0 0.0 NaN NaN 0 1 0.0 3 1
453 1319 8 NaN NaN NaN NaN 1 0 1 0 0.0 0.0 1.0 1.0 0.0 0 1 0.0 4 1
685 7159 7 NaN NaN NaN NaN 1 0 0 0 1.0 0.0 0.0 NaN NaN 1 0 0.0 3 1
1191 679 2 NaN NaN NaN NaN 0 0 1 0 0.0 NaN 0.0 0.0 0.0 0 1 0.0 2 1
1887 4085 2 NaN NaN NaN NaN 1 1 1 0 0.0 NaN 0.0 1.0 0.0 1 0 0.0 2 1
1950 3582 3 NaN NaN NaN NaN 1 1 0 0 1.0 0.0 1.0 0.0 0.0 1 0 0.0 2 2
3321 2253 5 NaN NaN NaN NaN 1 0 1 0 0.0 NaN 0.0 0.0 1.0 0 1 0.0 1 2
3777 5106 8 NaN NaN NaN NaN 1 1 0 0 0.0 NaN 0.0 0.0 0.0 0 1 NaN 4 1
4036 5896 9 NaN NaN NaN NaN 1 0 1 0 1.0 NaN 0.0 1.0 0.0 0 1 0.0 1 1
4421 8043 7 NaN NaN NaN NaN 1 0 0 0 0.0 NaN NaN 0.0 0.0 0 0 NaN 2 1
4461 1965 5 NaN NaN NaN NaN 1 0 1 0 1.0 0.0 1.0 0.0 0.0 1 0 0.0 3 1
4722 344 8 NaN NaN NaN NaN 1 1 0 0 0.0 0.0 0.0 0.0 0.0 0 0 0.0 4 1
4860 7198 2 NaN NaN NaN NaN 0 0 0 0 0.0 0.0 0.0 NaN NaN 0 1 0.0 1 2
4960 4485 5 NaN NaN NaN NaN 0 0 1 0 0.0 NaN 0.0 0.0 0.0 1 0 NaN 1 1
5125 8007 5 NaN NaN NaN NaN 1 0 0 0 NaN 0.0 1.0 1.0 0.0 0 1 0.0 3 1

15 rows × 50 columns

Looking at the dataset , the religion has been separated into various religions within the datset so there is no need to change the string variable to numerical values. Also, Religion has only 15 rows missing from the whole dataset , I will drop these rows as well.

In [5]:
df.drop('religion', axis=1, inplace=True)
In [6]:
df = df.dropna(subset = ['christian'])

Some of the assumptions for the missing data depending on the competition about sex and health. ‘married has 22 missing : Yes’, ‘multpart has 1154 missing : 1’, ‘educ has 90 missing : Median ‘, ‘inschool has 17 missing: Yes’, ‘ownincome has 59 missing : No’, ‘literacy has 113 missing : Median’, ‘LaborDeliv has 2513 missing : No’, ‘babydoc has 67 missing : Median’, ‘Debut has 566 missing : Median’, ‘ModCon has 318 missing : No’, ‘usecondom has 2360 missing : No’, ‘hivknow has 607 missing : No’, ‘lowlit has 113 missing : Median’, ‘highlit has 113 missing : Median’, ‘single has 1383 missing : No’

In [7]:
df['married'].fillna(1, inplace=True)
df['multpart'].fillna(1, inplace=True)
df['educ'].fillna(df['educ'].median(), inplace=True)
df['inschool'].fillna(1, inplace=True)
df['ownincome'].fillna(0, inplace=True)
df['literacy'].fillna(df['literacy'].median(), inplace=True)
df['LaborDeliv'].fillna(0, inplace=True)
df['babydoc'].fillna(df['babydoc'].median(), inplace=True)
df['Debut'].fillna(df['Debut'].median(), inplace=True)
df['ModCon'].fillna(0, inplace=True)
df['usecondom'].fillna(0, inplace=True)
df['hivknow'].fillna(0, inplace=True)
df['lowlit'].fillna(df['lowlit'].median(), inplace=True)
df['highlit'].fillna(df['highlit'].median(), inplace=True)
df['single'].fillna(0, inplace=True)
We will work with this dataset in all examples, namely, with the X feature-object matrix and values of the y target variable.
In [8]:
# separate the data from the target attributes
X = df[['patientID', 'christian', 'muslim', 'hindu', 'other',
       'cellphone', 'motorcycle', 'radio', 'cooker', 'fridge', 'furniture',
       'computer', 'cart', 'irrigation', 'thrasher', 'car', 'generator',
       'INTNR', 'REGION_PROVINCE', 'DISTRICT', 'electricity', 'age', 'tribe',
       'foodinsecurity', 'EVER_HAD_SEX', 'EVER_BEEN_PREGNANT', 'CHILDREN',
       'india', 'married', 'multpart', 'educ', 'inschool', 'ownincome',
       'literacy', 'urbanicity', 'LaborDeliv', 'babydoc', 'Debut', 'ModCon',
       'usecondom', 'hivknow', 'lowlit', 'highlit', 'urban', 'rural', 'single']]
y = df[['geo','segment','subgroup']]

Data Normalization

All of us know well that the majority of gradient methods (on which almost all machine learning algorithms are based) are highly sensitive to data scaling. Therefore, before running an algorithm, we should perform either normalization, or the so-called standardization. Normalization involves replacing nominal features, so that each of them would be in the range from 0 to 1. As for standardization, it involves data pre-processing, after which each feature has an average 0 and 1 dispersion.

In [9]:
from sklearn import preprocessing
# normalize the data attributes
normalized_X = preprocessing.normalize(X)
# standardize the data attributes
standardized_X = preprocessing.scale(X)
I will group all the three variables into one total for predicting the outcome
In [10]:
for i, col in enumerate(y.columns.tolist(), 1):
    y.loc[:, col] *= i
y = y.sum(axis=1)

Feature Selection

In [12]:
from sklearn.feature_selection import RFE
from sklearn.linear_model import LogisticRegression
model = LogisticRegression()
# create the RFE model and select 3 attributes
rfe = RFE(model, 3)
rfe = rfe.fit(X, y)
Summarize the selection of the attributes
In [13]:
print(rfe.support_)
[False False False False False False False False False False False False
 False False False False False False  True False False  True False False
 False False False False False False False False False  True False False
 False False False False False False False False False False]
In [14]:
print(rfe.ranking_)
[14 15 33 29 41 22 35 18 32 36 28 40 37 43 38 42 39  5  1  4  8  1  2  6 10
 11 12 27 13  9  7 24 20  1 16 17 21  3 23 30 19 31 34 26 25 44]

Algorithm Development

Most often used for solving tasks of classification (binary), but multiclass classification (the so-called one-vs-all method) is also allowed. The advantage of this algorithm is that there’s the probability of belonging to a class for each object at the output.

In [32]:
from sklearn import metrics
from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import GaussianNB
from sklearn.neighbors import KNeighborsClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.svm import SVC
from sklearn.linear_model.stochastic_gradient import SGDClassifier

Logistic Regression

In [33]:
model = LogisticRegression()
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,
          intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,
          penalty='l2', random_state=None, solver='liblinear', tol=0.0001,
          verbose=0, warm_start=False)
0.675398633257
SGDClassier

Stochastic Gradient Descent (SGD) is a simple yet very efficient approach to discriminative learning of linear classifiers under convex loss functions such as (linear) Support Vector Machines and Logistic Regression. SGD has been successfully applied to large-scale and sparse machine learning problems often encountered in text classification and natural language processing.

In [34]:
model = SGDClassifier(random_state=1)
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
SGDClassifier(alpha=0.0001, average=False, class_weight=None, epsilon=0.1,
       eta0=0.0, fit_intercept=True, l1_ratio=0.15,
       learning_rate='optimal', loss='hinge', n_iter=5, n_jobs=1,
       penalty='l2', power_t=0.5, random_state=1, shuffle=True, verbose=0,
       warm_start=False)
0.526195899772

Naive Bayes

Is also one of the most well-known machine learning algorithms, the main task of which is to restore the density of data distribution of the training sample. This method often provides good quality in multiclass classification problems.

In [35]:
model = GaussianNB()
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
GaussianNB()
0.276006074412

k-Nearest Neighbours

The kNN (k-Nearest Neighbors) method is often used as part of a more complex classification algorithm. For instance, we can use its estimate as an object’s feature. Sometimes, a simple kNN provides great quality on well-chosen features. When parameters (metrics mostly) are set well, the algorithm often gives good quality in regression problems.

In [36]:
# fit a k-nearest neighbor model to the data
model = KNeighborsClassifier()
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=1, n_neighbors=5, p=2,
           weights='uniform')
0.739749430524

Decision Trees

Classification and Regression Trees (CART) are often used in problems, in which objects have category features and used for regression and classification problems. The trees are very well suited for multiclass classification.

In [37]:
# fit a CART model to the data
model = DecisionTreeClassifier()
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,
            max_features=None, max_leaf_nodes=None, min_samples_leaf=1,
            min_samples_split=2, min_weight_fraction_leaf=0.0,
            presort=False, random_state=None, splitter='best')
1.0

Support Vector Machines

SVM (Support Vector Machines) is one of the most popular machine learning algorithms used mainly for the classification problem. As well as logistic regression, SVM allows multi-class classification with the help of the one-vs-all method.

In [46]:
# fit a SVM model to the data
model = SVC()
model.fit(standardized_X, y)
print(model)
# make predictions
expected = y
predicted = model.predict(standardized_X)
# summarize the fit of the model
#print(metrics.classification_report(expected, predicted))
print(metrics.accuracy_score(expected, predicted))
SVC(C=1.0, cache_size=200, class_weight=None, coef0=0.0,
  decision_function_shape=None, degree=3, gamma='auto', kernel='rbf',
  max_iter=-1, probability=False, random_state=None, shrinking=True,
  tol=0.001, verbose=False)
0.836750189825

Optimize Algorithm Parameters

let’s take a look at the selection of the regularization parameter, in which several values are searched in turn:

In [29]:
from sklearn.linear_model import Ridge
from sklearn.grid_search import GridSearchCV
# prepare a range of alpha values to test
alphas = np.array([1,0.1,0.01,0.001,0.0001,0])
# create and fit a ridge regression model, testing each alpha
model = Ridge()
grid = GridSearchCV(estimator=model, param_grid=dict(alpha=alphas))
grid.fit(X, y)
print(grid)
# summarize the results of the grid search
print(grid.best_score_)
print(grid.best_estimator_.alpha)
GridSearchCV(cv=None, error_score='raise',
       estimator=Ridge(alpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,
   normalize=False, random_state=None, solver='auto', tol=0.001),
       fit_params={}, iid=True, n_jobs=1,
       param_grid={'alpha': array([  1.00000e+00,   1.00000e-01,   1.00000e-02,   1.00000e-03,
         1.00000e-04,   0.00000e+00])},
       pre_dispatch='2*n_jobs', refit=True, scoring=None, verbose=0)
0.425103400262
1.0

Sometimes it is more efficient to randomly select a parameter from the given range, estimate the algorithm quality for this parameter and choose the best one.

In [48]:
from scipy.stats import uniform as sp_rand
from sklearn.grid_search import RandomizedSearchCV
# prepare a uniform distribution to sample for the alpha parameter
param_grid = {'alpha': sp_rand()}
# create and fit a ridge regression model, testing random alpha values
model = Ridge()
rsearch = RandomizedSearchCV(estimator=model, param_distributions=param_grid, n_iter=100)
rsearch.fit(standardized_X, y)
print(rsearch)
# summarize the results of the random parameter search
print(rsearch.best_score_)
print(rsearch.best_estimator_.alpha)
RandomizedSearchCV(cv=None, error_score='raise',
          estimator=Ridge(alpha=1.0, copy_X=True, fit_intercept=True, max_iter=None,
   normalize=False, random_state=None, solver='auto', tol=0.001),
          fit_params={}, iid=True, n_iter=100, n_jobs=1,
          param_distributions={'alpha': <scipy.stats._distn_infrastructure.rv_frozen object at 0x0000020B1775FBA8>},
          pre_dispatch='2*n_jobs', random_state=None, refit=True,
          scoring=None, verbose=0)
0.425095237385
0.9958047356368019

Leave a Reply

Your email address will not be published. Required fields are marked *