MADlib
0.7 A newer version is available
User Documentation
|
00001 /* ----------------------------------------------------------------------- *//** 00002 * 00003 * @file cross_validation.sql_in 00004 * 00005 * @brief SQL functions for cross validation 00006 * @date January 2011 00007 * 00008 * @sa For a brief introduction to the usage of cross validation, see the 00009 * module description \ref grp_validation. 00010 * 00011 *//* ----------------------------------------------------------------------- */ 00012 00013 00014 m4_include(`SQLCommon.m4') --' 00015 00016 /** 00017 @addtogroup grp_validation 00018 00019 @about 00020 00021 Cross-validation, sometimes called rotation estimation, is a technique for assessing how the results of a statistical 00022 analysis will generalize to an independent data set. It is mainly used in settings where the goal is prediction, and 00023 one wants to estimate how accurately a predictive model will perform in practice. One round of cross-validation 00024 involves partitioning a sample of data into complementary subsets, performing the analysis on one subset (called 00025 the training set), and validating the analysis on the other subset (called the validation set or testing set). To 00026 reduce variability, multiple rounds of cross-validation are performed using different partitions, and the validation 00027 results are averaged over the rounds. 00028 00029 In k-fold cross-validation, the original sample is randomly partitioned into k equal size subsamples. Of the k subsamples, 00030 a single subsample is retained as the validation data for testing the model, and the remaining k − 1 subsamples are used 00031 as training data. The cross-validation process is then repeated k times (the folds), with each of the k subsamples used 00032 exactly once as the validation data. The k results from the folds then can be averaged (or otherwise combined) to produce 00033 a single estimation. The advantage of this method over repeated random sub-sampling is that all observations are used for 00034 both training and validation, and each observation is used for validation exactly once. 10-fold cross-validation is 00035 commonly used, but in general k remains an unfixed parameter. 00036 00037 @input 00038 00039 <b>The flexible interface.</b> 00040 00041 The input includes the data set, a training function, a prediction function and an error metric function. 00042 00043 The training function takes in a given data set with independent and dependent variables in it and produces 00044 a model, which is stored in an output table. 00045 00046 The prediction function takes in the model generated by the training function and a different data set with 00047 independent variables in it, and it produces a prediction of the dependent variables bease on the model. 00048 The prediction is stored in an output table. The prediction function should take a unique ID column name of 00049 the data table as one of the inputs, otherwise the prediction result cannot be compared with the validation 00050 values. 00051 00052 The error metric function takes in the prediction made by the prediction function, and compare with the known 00053 values of the dependent variables of the data set that was fed into the prediction function. It computes the 00054 error metric defined by the function. The results are stored in a table 00055 00056 Other inputs include the output table name, k value for the k-fold cross-validation, and how many folds the user 00057 wants to try (for example, the user can choose to run a simple validation instead of a full cross-validation.) 00058 00059 @usage 00060 00061 <b>The flexible interface.</b> 00062 00063 In order to choose the optimum value for a parameter of the model, the user needs to provied the training function, 00064 prediction function, error metric function, the parameter and its values to be studied and the data set. 00065 00066 It would be better if the data set has a unique ID for each row, so that it is easier to cut the data set into the 00067 training part and the validation part. The user also needs to inform the cross validation (CV) function about whether this 00068 ID value is randomly assigned to each row. If it is not randomly assigned, the CV function will automatically generate 00069 a random ID for each row. 00070 00071 If the data set has no unique ID for each row, the CV function will copy the data set and create a randomly assigned ID 00072 column for the newly created temp table. The new table will be dropped after the computation is finished. To minimize 00073 the copying work load, the user needs to provide the data column names (for independent variables and dependent 00074 variables) that are going to be used in the calculation, and only these columns will be copied. 00075 00076 <pre>SELECT cross_validation_general( 00077 <em>modelling_func</em>, -- Name of function that trains the model 00078 <em>modelling_params</em>, -- Array of parameters for modelling function 00079 <em>modelling_params_type</em>, -- Types of each parameters for modelling function 00080 -- 00081 <em>param_explored</em>, -- Name of parameter that will be checked to find the optimum value, the 00082 ---- same name must also appear in the array of modelling_params 00083 <em>explore_values</em>, -- Values of this parameter that will be studied 00084 -- 00085 <em>predict_func</em>, -- Name of function for prediction 00086 <em>predict_params</em>, -- Array of parameters for prediction function 00087 <em>predict_params_type</em>, -- Types of each parameters for prediction function 00088 -- 00089 <em>metric_func</em>, -- Name of function for measuring errors 00090 <em>metric_params</em>, -- Array of parameters for error metric function 00091 <em>metric_params_type</em>, -- Types of each parameters for metric function 00092 -- 00093 <em>data_tbl</em>, -- Data table which will be split into training and validation parts 00094 <em>data_id</em>, -- Name of the unique ID associated with each row. Provide <em>NULL</em> 00095 ---- if there is no such column in the data table 00096 <em>id_is_random</em>, -- Whether the provided ID is randomly assigned to each row 00097 -- 00098 <em>validation_result</em>, -- Table name to store the output of CV function, see the Output for 00099 ---- format. It will be automatically created by CV function 00100 -- 00101 <em>data_cols</em>, -- Names of data columns that are going to be used. It is only useful when 00102 ---- <em>data_id</em> is NULL, otherwise it is ignored. 00103 <em>fold_num</em> -- Value of k. How many folds validation? Each validation uses 1/fold_num 00104 ---- fraction of the data for validation. Deafult value: 10. 00105 );</pre> 00106 00107 Special keywords in parameter arrays of modelling, prediction and metric functions: 00108 00109 <em>\%data%</em> : The argument position for training/validation data 00110 00111 <em>\%model%</em> : The argument position for the output/input of modelling/prediction function 00112 00113 <em>\%id%</em> : The argument position of unique ID column (provided by user or generated by CV function as is mentioned above) 00114 00115 <em>\%prediction%</em> : The argument position for the output/input of prediction/metric function 00116 00117 <em>\%error%</em> : The argument position for the output of metric function 00118 00119 <b>Note</b>: If the parameter <em>explore_values</em> is NULL or has zero length, then the cross validation function will only run a data folding. 00120 00121 Output: 00122 <pre> param_explored | average error | standard deviation of error 00123 -------------------------|------------------|-------------------------------- 00124 ....... 00125 </pre> 00126 00127 <b>Note:</b> 00128 00129 <em>max_locks_per_transaction</em>, which usually has the default value of 64, limits the number of tables that can be 00130 dropped inside a single transaction (the CV function). Thus the number of different values of <em>param_explored</em> 00131 (or the length of array <em>explored_values</em>) cannot be too large. For 10-fold cross validation, the limit of 00132 length(<em>explored_values</em>) is around 40. If this number is too large, the use might see "out of shared memory" 00133 error because <em>max_locks_per_transaction</em> is used up. 00134 00135 One way to overcome this limitation is to run CV function multiple times, and each run covers a different region of 00136 values of the parameter. 00137 00138 In the future, MADlib will implement cross-validation functions for each individual applicable module, where we can optimize the calculation to avoid table droppings and this max_locks_per_transaction limitation. However, such cross-validation functions need to know the implementation details of the modules to do the optimization and thus cannot be as flexible as the cross-validation function provided here. 00139 00140 The cross-validation function provided here is very flexible, and can actually work with any algorithms that the user want to cross-validate including the algorithms written by the user. The price for this flexiblity is that the algorithms' details cannot be utilized to optimize the calculation and thus <em>max_locks_per_transaction</em> limitation cannot be avoided. 00141 00142 @examp 00143 00144 Cross validation is used on elastic net regression to find the best value of the regularization parameter. 00145 00146 (1) Populate the table 'cvtest' with 101 dimensional independent variable 'val', and dependent 00147 variable 'dep'. 00148 00149 (2) Run the general CV function 00150 <pre> 00151 select madlib.cross_validation_general ( 00152 'madlib.elastic_net_train', 00153 '{\%data%, \%model%, dep, val, gaussian, 1, lambda, True, Null, fista, "{eta = 2, max_stepsize = 2, use_active_set = t}", Null, 2000, 1e-6}'::varchar[], 00154 '{varchar, varchar, varchar, varchar, varchar, double precision, double precision, boolean, varchar, varchar, varchar[], varchar, integer, double precision}'::varchar[], 00155 -- 00156 'lambda', 00157 '{0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30, 0.32, 0.34, 0.36}'::varchar[], 00158 -- 00159 'madlib.elastic_net_predict', 00160 '{\%model%, \%data%, \%id%, \%prediction%}'::varchar[], 00161 '{text, text, text, text}'::varchar[], 00162 -- 00163 'madlib.mse_error', 00164 '{\%prediction%, \%data%, \%id%, dep, \%error%}'::varchar[], 00165 '{varchar, varchar, varchar, varchar, varchar}'::varchar[], 00166 -- 00167 'cvtest', 00168 NULL::varchar, 00169 False, 00170 -- 00171 'valid_rst_tbl', 00172 '{val, dep}'::varchar[], 00173 10 00174 ); 00175 00176 </pre> 00177 00178 @sa File cross_validation.sql_in documenting the SQL functions. 00179 00180 */ 00181 00182 ------------------------------------------------------------------------ 00183 /* 00184 * @brief Perform cross validation for modules that conforms with a fixed SQL API 00185 * Note: There is a lock number limitation of this function. It is flexible to use, so that the user can 00186 * try CV method on their own functions. On the other hand, cross_validation function does not have the 00187 * lock number limitation. 00188 * 00189 * @param modelling_func Name of function that trains the model 00190 * @param modelling_params Array of parameters for modelling function 00191 * @param modelling_params_type Types of each parameters for modelling function 00192 * @param param_explored Name of parameter that will be checked to find the optimum value, the same name must also appear in the array of modelling_params 00193 * @param explore_values Values of this parameter that will be studied 00194 * @param predict_func Name of function for prediction 00195 * @param predict_params Array of parameters for prediction function 00196 * @param predict_params_type Types of each parameters for prediction function 00197 * @param metric_func Name of function for measuring errors 00198 * @param metric_params Array of parameters for error metric function 00199 * @param metric_params_type Types of each parameters for metric function 00200 * @param data_tbl Data table which will be split into training and validation parts 00201 * @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table 00202 * @param id_is_random Whether the provided ID is randomly assigned to each row 00203 * @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function 00204 * @param fold_num Value of k. How many folds validation? Each validation uses 1/fold_num fraction of the data for validation. Deafult value: 10. 00205 * @param data_cols Names of data columns that are going to be used. It is only useful when <em>data_id</em> is NULL, otherwise it is ignored. 00206 */ 00207 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general( 00208 modelling_func VARCHAR, -- function for setting up the model 00209 modelling_params VARCHAR[], -- parameters for modelling 00210 modelling_params_type VARCHAR[], -- parameter types for modelling 00211 -- 00212 param_explored VARCHAR, -- which parameter will be studied using validation 00213 explore_values VARCHAR[], -- values that will be explored for this parameter 00214 -- 00215 predict_func VARCHAR, -- function for predicting using the model 00216 predict_params VARCHAR[], -- parameters for prediction 00217 predict_params_type VARCHAR[], -- parameter types for prediction 00218 -- 00219 metric_func VARCHAR, -- function that computes the error metric 00220 metric_params VARCHAR[], -- parameters for metric 00221 metric_params_type VARCHAR[], -- parameter types for metric 00222 -- 00223 data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts 00224 data_id VARCHAR, -- user provide a unique ID for each row 00225 id_is_random BOOLEAN, -- the ID provided by user is random 00226 -- 00227 validation_result VARCHAR, -- store the result: param values, error, +/- 00228 -- 00229 data_cols VARCHAR[], -- names of data columns that are going to be used 00230 fold_num INTEGER -- how many fold validation, default: 10 00231 ) RETURNS VOID AS $$ 00232 PythonFunction(validation, cross_validation, cross_validation_general) 00233 $$ LANGUAGE plpythonu; 00234 00235 ------------------------------------------------------------------------ 00236 00237 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general( 00238 modelling_func VARCHAR, -- function for setting up the model 00239 modelling_params VARCHAR[], -- parameters for modelling 00240 modelling_params_type VARCHAR[], -- parameter types for modelling 00241 -- 00242 param_explored VARCHAR, -- which parameter will be studied using validation 00243 explore_values VARCHAR[], -- values that will be explored for this parameter 00244 -- 00245 predict_func VARCHAR, -- function for predicting using the model 00246 predict_params VARCHAR[], -- parameters for prediction 00247 predict_params_type VARCHAR[], -- parameter types for prediction 00248 -- 00249 metric_func VARCHAR, -- function that computes the error metric 00250 metric_params VARCHAR[], -- parameters for prediction 00251 metric_params_type VARCHAR[], -- parameter types for prediction 00252 -- 00253 data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts 00254 data_id VARCHAR, -- user provide a unique ID for each row 00255 id_is_random BOOLEAN, -- the ID provided by user is random 00256 -- 00257 validation_result VARCHAR, -- store the result: param values, error, +/- 00258 -- 00259 data_cols VARCHAR[] -- names of data columns that are going to be used 00260 ) RETURNS VOID AS $$ 00261 BEGIN 00262 PERFORM MADLIB_SCHEMA.cross_validation_general($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,10); 00263 END; 00264 $$ LANGUAGE plpgsql VOLATILE; 00265 00266 ------------------------------------------------------------------------ 00267 ------------------------------------------------------------------------ 00268 ------------------------------------------------------------------------ 00269 00270 /** 00271 * @brief Simple interface of cross-validation, which has no limitation on lock number 00272 * 00273 * @param module_name Module to be cross validated 00274 * @param func_args Arguments of modelling function of the module, including the table name of data 00275 * @param param_to_try The name of the paramter that CV runs through 00276 * @param param_values The values of the parameter that CV will try 00277 * @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table 00278 * @param id_is_random Whether the provided ID is randomly assigned to each row 00279 * @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function 00280 * @param fold_num How many fold cross-validation 00281 */ 00282 /* 00283 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation( 00284 module_name VARCHAR, -- module to be cross validated 00285 func_args VARCHAR[], 00286 param_to_try VARCHAR, 00287 param_values DOUBLE PRECISION[], 00288 data_id VARCHAR, 00289 id_is_random BOOLEAN, 00290 validation_result VARCHAR, 00291 fold_num INTEGER 00292 ) RETURNS VOID AS $$ 00293 PythonFunction(validation, cross_validation, cross_validation) 00294 $$ LANGUAGE plpythonu; 00295 */ 00296 -- ------------------------------------------------------------------------ 00297 /* 00298 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation( 00299 module_name VARCHAR, 00300 func_args VARCHAR[], 00301 param_to_try VARCHAR, 00302 param_values DOUBLE PRECISION[], 00303 data_id VARCHAR, 00304 id_is_random BOOLEAN, 00305 validation_result VARCHAR 00306 ) RETURNS VOID AS $$ 00307 BEGIN 00308 PERFORM MADLIB_SCHEMA.cross_validation($1, $2, $3, $4, $5, $6, $7, 10); 00309 END; 00310 $$ LANGUAGE plpgsql VOLATILE; 00311 */ 00312 -- ------------------------------------------------------------------------ 00313 00314 /** 00315 * @brief Print the help message for a given module's cross-validation. 00316 */ 00317 /* 00318 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation(module_name VARCHAR) 00319 RETURNS VARCHAR AS $$ 00320 PythonFunction(validation, cross_validation, cross_validation_help) 00321 $$ LANGUAGE plpythonu; 00322 */ 00323 -- ------------------------------------------------------------------------ 00324 00325 /** 00326 * @brief Print the supported module names for cross_validation 00327 */ 00328 /* 00329 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation() 00330 RETURNS VARCHAR AS $$ 00331 DECLARE 00332 msg VARCHAR; 00333 BEGIN 00334 msg := 'cross_validation function now supports Ridge linear regression'; 00335 return msg; 00336 END; 00337 $$ LANGUAGE plpgsql STRICT; 00338 */ 00339 ------------------------------------------------------------------------ 00340 00341 /** 00342 * @brief A wrapper for linear regression 00343 */ 00344 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_train( 00345 tbl_source VARCHAR, 00346 col_ind_var VARCHAR, 00347 col_dep_var VARCHAR, 00348 tbl_result VARCHAR 00349 ) RETURNS VOID AS $$ 00350 PythonFunction(validation, cross_validation, cv_linregr_train) 00351 $$ LANGUAGE plpythonu; 00352 00353 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.linregr_predict( 00354 coef DOUBLE PRECISION[], 00355 col_ind DOUBLE PRECISION[] 00356 ) RETURNS DOUBLE PRECISION AS $$ 00357 PythonFunction(validation, cross_validation, linregr_predict) 00358 $$ LANGUAGE plpythonu; 00359 00360 /** 00361 * @brief A wrapper for linear regression prediction 00362 */ 00363 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_predict( 00364 tbl_model VARCHAR, 00365 tbl_newdata VARCHAR, 00366 col_ind_var VARCHAR, 00367 col_id VARCHAR, -- ID column 00368 tbl_predict VARCHAR 00369 ) RETURNS VOID AS $$ 00370 PythonFunction(validation, cross_validation, cv_linregr_predict) 00371 $$ LANGUAGE plpythonu; 00372 00373 -- compare the prediction and actual values 00374 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.mse_error( 00375 tbl_prediction VARCHAR, -- predicted values 00376 tbl_actual VARCHAR, 00377 id_actual VARCHAR, 00378 values_actual VARCHAR, 00379 tbl_error VARCHAR 00380 ) RETURNS VOID AS $$ 00381 DECLARE 00382 error DOUBLE PRECISION; 00383 old_messages VARCHAR; 00384 BEGIN 00385 old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages'); 00386 EXECUTE 'SET client_min_messages TO warning'; 00387 00388 EXECUTE ' 00389 CREATE TABLE '|| tbl_error ||' AS 00390 SELECT 00391 avg(('|| tbl_prediction ||'.prediction - '|| tbl_actual ||'.'|| values_actual ||')^2) as mean_squared_error 00392 FROM 00393 '|| tbl_prediction ||', 00394 '|| tbl_actual ||' 00395 WHERE 00396 '|| tbl_prediction ||'.id = '|| tbl_actual ||'.'|| id_actual; 00397 00398 EXECUTE 'SET client_min_messages TO ' || old_messages; 00399 END; 00400 $$ LANGUAGE plpgsql VOLATILE; 00401 00402 ------------------------------------------------------------------------ 00403 00404 /** 00405 * @brief A prediction function for logistic regression 00406 * 00407 * @param coef Coefficients. Note: MADlib logregr_train function does not produce a seperate intercept term 00408 * as elastic_net_train function. 00409 * @param col_ind Independent variable, which must be an array 00410 * 00411 */ 00412 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_predict( 00413 coef DOUBLE PRECISION[], 00414 col_ind DOUBLE PRECISION[] 00415 ) RETURNS BOOLEAN AS $$ 00416 PythonFunction(validation, cross_validation, logregr_predict) 00417 $$ LANGUAGE plpythonu; 00418 00419 /** 00420 * @brief A prediction function for logistic regression 00421 * The result is stored in the table of tbl_predict 00422 * 00423 * This function can be used together with cross-validation 00424 */ 00425 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_predict( 00426 tbl_model VARCHAR, 00427 tbl_newdata VARCHAR, 00428 col_ind_var VARCHAR, 00429 col_id VARCHAR, 00430 tbl_predict VARCHAR 00431 ) RETURNS VOID AS $$ 00432 PythonFunction(validation, cross_validation, cv_logregr_predict) 00433 $$ LANGUAGE plpythonu; 00434 00435 /** 00436 * @brief Metric function for logistic regression 00437 * 00438 * @param coef Logistic fitting coefficients. Note: MADlib logregr_train function does not produce a seperate intercept term 00439 * as elastic_net_train function. 00440 * @param col_ind Independent variable, an array 00441 * @param col_dep Dependent variable 00442 * 00443 * returns 1 if the prediction is the same as col_dep, otherwise 0 00444 */ 00445 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_accuracy( 00446 coef DOUBLE PRECISION[], 00447 col_ind DOUBLE PRECISION[], 00448 col_dep BOOLEAN 00449 ) RETURNS INTEGER AS $$ 00450 PythonFunction(validation, cross_validation, logregr_accuracy) 00451 $$ LANGUAGE plpythonu; 00452 00453 /** 00454 * @brief Metric function for logistic regression 00455 * 00456 * It computes the percentage of correct predictions. 00457 * The result is stored in the table of tbl_accuracy 00458 */ 00459 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_accuracy( 00460 tbl_predict VARCHAR, 00461 tbl_source VARCHAR, 00462 col_id VARCHAR, 00463 col_dep_var VARCHAR, 00464 tbl_accuracy VARCHAR 00465 ) RETURNS VOID AS $$ 00466 PythonFunction(validation, cross_validation, cv_logregr_accuracy) 00467 $$ LANGUAGE plpythonu;