User Documentation
 All Files Functions Groups
cross_validation.sql_in
Go to the documentation of this file.
1 /* ----------------------------------------------------------------------- *//**
2  *
3  * @file cross_validation.sql_in
4  *
5  * @brief SQL functions for cross validation
6  * @date January 2011
7  *
8  * @sa For a brief introduction to the usage of cross validation, see the
9  * module description \ref grp_validation.
10  *
11  *//* ----------------------------------------------------------------------- */
12 
13 
14 m4_include(`SQLCommon.m4') --'
15 
16 /**
17 @addtogroup grp_validation
18 
19 @about
20 
21 Cross-validation, sometimes called rotation estimation, is a technique for assessing how the results of a statistical
22 analysis will generalize to an independent data set. It is mainly used in settings where the goal is prediction, and
23 one wants to estimate how accurately a predictive model will perform in practice. One round of cross-validation
24 involves partitioning a sample of data into complementary subsets, performing the analysis on one subset (called
25 the training set), and validating the analysis on the other subset (called the validation set or testing set). To
26 reduce variability, multiple rounds of cross-validation are performed using different partitions, and the validation
27 results are averaged over the rounds.
28 
29 In k-fold cross-validation, the original sample is randomly partitioned into k equal size subsamples. Of the k subsamples,
30 a single subsample is retained as the validation data for testing the model, and the remaining k − 1 subsamples are used
31 as training data. The cross-validation process is then repeated k times (the folds), with each of the k subsamples used
32 exactly once as the validation data. The k results from the folds then can be averaged (or otherwise combined) to produce
33 a single estimation. The advantage of this method over repeated random sub-sampling is that all observations are used for
34 both training and validation, and each observation is used for validation exactly once. 10-fold cross-validation is
35 commonly used, but in general k remains an unfixed parameter.
36 
37 @input
38 
39 <b>The flexible interface.</b>
40 
41 The input includes the data set, a training function, a prediction function and an error metric function.
42 
43 The training function takes in a given data set with independent and dependent variables in it and produces
44 a model, which is stored in an output table.
45 
46 The prediction function takes in the model generated by the training function and a different data set with
47 independent variables in it, and it produces a prediction of the dependent variables bease on the model.
48 The prediction is stored in an output table. The prediction function should take a unique ID column name of
49 the data table as one of the inputs, otherwise the prediction result cannot be compared with the validation
50 values.
51 
52 The error metric function takes in the prediction made by the prediction function, and compare with the known
53 values of the dependent variables of the data set that was fed into the prediction function. It computes the
54 error metric defined by the function. The results are stored in a table
55 
56 Other inputs include the output table name, k value for the k-fold cross-validation, and how many folds the user
57 wants to try (for example, the user can choose to run a simple validation instead of a full cross-validation.)
58 
59 @usage
60 
61 <b>The flexible interface.</b>
62 
63 In order to choose the optimum value for a parameter of the model, the user needs to provied the training function,
64 prediction function, error metric function, the parameter and its values to be studied and the data set.
65 
66 It would be better if the data set has a unique ID for each row, so that it is easier to cut the data set into the
67 training part and the validation part. The user also needs to inform the cross validation (CV) function about whether this
68 ID value is randomly assigned to each row. If it is not randomly assigned, the CV function will automatically generate
69 a random ID for each row.
70 
71 If the data set has no unique ID for each row, the CV function will copy the data set and create a randomly assigned ID
72 column for the newly created temp table. The new table will be dropped after the computation is finished. To minimize
73 the copying work load, the user needs to provide the data column names (for independent variables and dependent
74 variables) that are going to be used in the calculation, and only these columns will be copied.
75 
76 <pre>SELECT cross_validation_general(
77  <em>modelling_func</em>, -- Name of function that trains the model
78  <em>modelling_params</em>, -- Array of parameters for modelling function
79  <em>modelling_params_type</em>, -- Types of each parameters for modelling function
80  --
81  <em>param_explored</em>, -- Name of parameter that will be checked to find the optimum value, the
82  ---- same name must also appear in the array of modelling_params
83  <em>explore_values</em>, -- Values of this parameter that will be studied
84  --
85  <em>predict_func</em>, -- Name of function for prediction
86  <em>predict_params</em>, -- Array of parameters for prediction function
87  <em>predict_params_type</em>, -- Types of each parameters for prediction function
88  --
89  <em>metric_func</em>, -- Name of function for measuring errors
90  <em>metric_params</em>, -- Array of parameters for error metric function
91  <em>metric_params_type</em>, -- Types of each parameters for metric function
92  --
93  <em>data_tbl</em>, -- Data table which will be split into training and validation parts
94  <em>data_id</em>, -- Name of the unique ID associated with each row. Provide <em>NULL</em>
95  ---- if there is no such column in the data table
96  <em>id_is_random</em>, -- Whether the provided ID is randomly assigned to each row
97  --
98  <em>validation_result</em>, -- Table name to store the output of CV function, see the Output for
99  ---- format. It will be automatically created by CV function
100  --
101  <em>data_cols</em>, -- Names of data columns that are going to be used. It is only useful when
102  ---- <em>data_id</em> is NULL, otherwise it is ignored.
103  <em>fold_num</em> -- Value of k. How many folds validation? Each validation uses 1/fold_num
104  ---- fraction of the data for validation. Deafult value: 10.
105 );</pre>
106 
107 Special keywords in parameter arrays of modelling, prediction and metric functions:
108 
109 <em>\%data%</em> : The argument position for training/validation data
110 
111 <em>\%model%</em> : The argument position for the output/input of modelling/prediction function
112 
113 <em>\%id%</em> : The argument position of unique ID column (provided by user or generated by CV function as is mentioned above)
114 
115 <em>\%prediction%</em> : The argument position for the output/input of prediction/metric function
116 
117 <em>\%error%</em> : The argument position for the output of metric function
118 
119 <b>Note</b>: If the parameter <em>explore_values</em> is NULL or has zero length, then the cross validation function will only run a data folding.
120 
121 Output:
122 <pre> param_explored | average error | standard deviation of error
123 -------------------------|------------------|--------------------------------
124  .......
125 </pre>
126 
127 <b>Note:</b>
128 
129 <em>max_locks_per_transaction</em>, which usually has the default value of 64, limits the number of tables that can be
130 dropped inside a single transaction (the CV function). Thus the number of different values of <em>param_explored</em>
131 (or the length of array <em>explored_values</em>) cannot be too large. For 10-fold cross validation, the limit of
132 length(<em>explored_values</em>) is around 40. If this number is too large, the use might see "out of shared memory"
133 error because <em>max_locks_per_transaction</em> is used up.
134 
135 One way to overcome this limitation is to run CV function multiple times, and each run covers a different region of
136 values of the parameter.
137 
138 In the future, MADlib will implement cross-validation functions for each individual applicable module, where we can optimize the calculation to avoid table droppings and this max_locks_per_transaction limitation. However, such cross-validation functions need to know the implementation details of the modules to do the optimization and thus cannot be as flexible as the cross-validation function provided here.
139 
140 The cross-validation function provided here is very flexible, and can actually work with any algorithms that the user want to cross-validate including the algorithms written by the user. The price for this flexiblity is that the algorithms' details cannot be utilized to optimize the calculation and thus <em>max_locks_per_transaction</em> limitation cannot be avoided.
141 
142 @examp
143 
144 Cross validation is used on elastic net regression to find the best value of the regularization parameter.
145 
146 (1) Populate the table 'cvtest' with 101 dimensional independent variable 'val', and dependent
147 variable 'dep'.
148 
149 (2) Run the general CV function
150 <pre>
151 select madlib.cross_validation_general (
152  'madlib.elastic_net_train',
153  '{\%data%, \%model%, dep, val, gaussian, 1, lambda, True, Null, fista, "{eta = 2, max_stepsize = 2, use_active_set = t}", Null, 2000, 1e-6}'::varchar[],
154  '{varchar, varchar, varchar, varchar, varchar, double precision, double precision, boolean, varchar, varchar, varchar[], varchar, integer, double precision}'::varchar[],
155  --
156  'lambda',
157  '{0.02, 0.04, 0.06, 0.08, 0.10, 0.12, 0.14, 0.16, 0.18, 0.20, 0.22, 0.24, 0.26, 0.28, 0.30, 0.32, 0.34, 0.36}'::varchar[],
158  --
159  'madlib.elastic_net_predict',
160  '{\%model%, \%data%, \%id%, \%prediction%}'::varchar[],
161  '{text, text, text, text}'::varchar[],
162  --
163  'madlib.mse_error',
164  '{\%prediction%, \%data%, \%id%, dep, \%error%}'::varchar[],
165  '{varchar, varchar, varchar, varchar, varchar}'::varchar[],
166  --
167  'cvtest',
168  NULL::varchar,
169  False,
170  --
171  'valid_rst_tbl',
172  '{val, dep}'::varchar[],
173  10
174 );
175 
176 </pre>
177 
178 @sa File cross_validation.sql_in documenting the SQL functions.
179 
180 */
181 
182 ------------------------------------------------------------------------
183 /*
184  * @brief Perform cross validation for modules that conforms with a fixed SQL API
185  * Note: There is a lock number limitation of this function. It is flexible to use, so that the user can
186  * try CV method on their own functions. On the other hand, cross_validation function does not have the
187  * lock number limitation.
188  *
189  * @param modelling_func Name of function that trains the model
190  * @param modelling_params Array of parameters for modelling function
191  * @param modelling_params_type Types of each parameters for modelling function
192  * @param param_explored Name of parameter that will be checked to find the optimum value, the same name must also appear in the array of modelling_params
193  * @param explore_values Values of this parameter that will be studied
194  * @param predict_func Name of function for prediction
195  * @param predict_params Array of parameters for prediction function
196  * @param predict_params_type Types of each parameters for prediction function
197  * @param metric_func Name of function for measuring errors
198  * @param metric_params Array of parameters for error metric function
199  * @param metric_params_type Types of each parameters for metric function
200  * @param data_tbl Data table which will be split into training and validation parts
201  * @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table
202  * @param id_is_random Whether the provided ID is randomly assigned to each row
203  * @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function
204  * @param fold_num Value of k. How many folds validation? Each validation uses 1/fold_num fraction of the data for validation. Deafult value: 10.
205  * @param data_cols Names of data columns that are going to be used. It is only useful when <em>data_id</em> is NULL, otherwise it is ignored.
206  */
207 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general(
208  modelling_func VARCHAR, -- function for setting up the model
209  modelling_params VARCHAR[], -- parameters for modelling
210  modelling_params_type VARCHAR[], -- parameter types for modelling
211  --
212  param_explored VARCHAR, -- which parameter will be studied using validation
213  explore_values VARCHAR[], -- values that will be explored for this parameter
214  --
215  predict_func VARCHAR, -- function for predicting using the model
216  predict_params VARCHAR[], -- parameters for prediction
217  predict_params_type VARCHAR[], -- parameter types for prediction
218  --
219  metric_func VARCHAR, -- function that computes the error metric
220  metric_params VARCHAR[], -- parameters for metric
221  metric_params_type VARCHAR[], -- parameter types for metric
222  --
223  data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts
224  data_id VARCHAR, -- user provide a unique ID for each row
225  id_is_random BOOLEAN, -- the ID provided by user is random
226  --
227  validation_result VARCHAR, -- store the result: param values, error, +/-
228  --
229  data_cols VARCHAR[], -- names of data columns that are going to be used
230  fold_num INTEGER -- how many fold validation, default: 10
231 ) RETURNS VOID AS $$
232 PythonFunction(validation, cross_validation, cross_validation_general)
233 $$ LANGUAGE plpythonu;
234 
235 ------------------------------------------------------------------------
236 
237 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation_general(
238  modelling_func VARCHAR, -- function for setting up the model
239  modelling_params VARCHAR[], -- parameters for modelling
240  modelling_params_type VARCHAR[], -- parameter types for modelling
241  --
242  param_explored VARCHAR, -- which parameter will be studied using validation
243  explore_values VARCHAR[], -- values that will be explored for this parameter
244  --
245  predict_func VARCHAR, -- function for predicting using the model
246  predict_params VARCHAR[], -- parameters for prediction
247  predict_params_type VARCHAR[], -- parameter types for prediction
248  --
249  metric_func VARCHAR, -- function that computes the error metric
250  metric_params VARCHAR[], -- parameters for prediction
251  metric_params_type VARCHAR[], -- parameter types for prediction
252  --
253  data_tbl VARCHAR, -- table containing the data, which will be split into training and validation parts
254  data_id VARCHAR, -- user provide a unique ID for each row
255  id_is_random BOOLEAN, -- the ID provided by user is random
256  --
257  validation_result VARCHAR, -- store the result: param values, error, +/-
258  --
259  data_cols VARCHAR[] -- names of data columns that are going to be used
260 ) RETURNS VOID AS $$
261 BEGIN
262  PERFORM MADLIB_SCHEMA.cross_validation_general($1,$2,$3,$4,$5,$6,$7,$8,$9,$10,$11,$12,$13,$14,$15,$16,10);
263 END;
264 $$ LANGUAGE plpgsql VOLATILE;
265 
266 ------------------------------------------------------------------------
267 ------------------------------------------------------------------------
268 ------------------------------------------------------------------------
269 
270 /**
271  * @brief Simple interface of cross-validation, which has no limitation on lock number
272  *
273  * @param module_name Module to be cross validated
274  * @param func_args Arguments of modelling function of the module, including the table name of data
275  * @param param_to_try The name of the paramter that CV runs through
276  * @param param_values The values of the parameter that CV will try
277  * @param data_id Name of the unique ID associated with each row. Provide <em>NULL</em> if there is no such column in the data table
278  * @param id_is_random Whether the provided ID is randomly assigned to each row
279  * @param validation_result Table name to store the output of CV function, see the Output for format. It will be automatically created by CV function
280  * @param fold_num How many fold cross-validation
281  */
282 /*
283 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation(
284  module_name VARCHAR, -- module to be cross validated
285  func_args VARCHAR[],
286  param_to_try VARCHAR,
287  param_values DOUBLE PRECISION[],
288  data_id VARCHAR,
289  id_is_random BOOLEAN,
290  validation_result VARCHAR,
291  fold_num INTEGER
292 ) RETURNS VOID AS $$
293 PythonFunction(validation, cross_validation, cross_validation)
294 $$ LANGUAGE plpythonu;
295 */
296 -- ------------------------------------------------------------------------
297 /*
298 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation(
299  module_name VARCHAR,
300  func_args VARCHAR[],
301  param_to_try VARCHAR,
302  param_values DOUBLE PRECISION[],
303  data_id VARCHAR,
304  id_is_random BOOLEAN,
305  validation_result VARCHAR
306 ) RETURNS VOID AS $$
307 BEGIN
308  PERFORM MADLIB_SCHEMA.cross_validation($1, $2, $3, $4, $5, $6, $7, 10);
309 END;
310 $$ LANGUAGE plpgsql VOLATILE;
311 */
312 -- ------------------------------------------------------------------------
313 
314 /**
315  * @brief Print the help message for a given module's cross-validation.
316  */
317 /*
318 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation(module_name VARCHAR)
319 RETURNS VARCHAR AS $$
320 PythonFunction(validation, cross_validation, cross_validation_help)
321 $$ LANGUAGE plpythonu;
322 */
323 -- ------------------------------------------------------------------------
324 
325 /**
326  * @brief Print the supported module names for cross_validation
327  */
328  /*
329 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cross_validation()
330 RETURNS VARCHAR AS $$
331 DECLARE
332  msg VARCHAR;
333 BEGIN
334  msg := 'cross_validation function now supports Ridge linear regression';
335  return msg;
336 END;
337 $$ LANGUAGE plpgsql STRICT;
338 */
339 ------------------------------------------------------------------------
340 
341 /**
342  * @brief A wrapper for linear regression
343  */
344 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_train(
345  tbl_source VARCHAR,
346  col_ind_var VARCHAR,
347  col_dep_var VARCHAR,
348  tbl_result VARCHAR
349 ) RETURNS VOID AS $$
350 PythonFunction(validation, cross_validation, cv_linregr_train)
351 $$ LANGUAGE plpythonu;
352 
353 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.linregr_predict(
354  coef DOUBLE PRECISION[],
355  col_ind DOUBLE PRECISION[]
356 ) RETURNS DOUBLE PRECISION AS $$
357 PythonFunction(validation, cross_validation, linregr_predict)
358 $$ LANGUAGE plpythonu;
359 
360 /**
361  * @brief A wrapper for linear regression prediction
362  */
363 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_linregr_predict(
364  tbl_model VARCHAR,
365  tbl_newdata VARCHAR,
366  col_ind_var VARCHAR,
367  col_id VARCHAR, -- ID column
368  tbl_predict VARCHAR
369 ) RETURNS VOID AS $$
370 PythonFunction(validation, cross_validation, cv_linregr_predict)
371 $$ LANGUAGE plpythonu;
372 
373 -- compare the prediction and actual values
374 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.mse_error(
375  tbl_prediction VARCHAR, -- predicted values
376  tbl_actual VARCHAR,
377  id_actual VARCHAR,
378  values_actual VARCHAR,
379  tbl_error VARCHAR
380 ) RETURNS VOID AS $$
381 DECLARE
382  error DOUBLE PRECISION;
383  old_messages VARCHAR;
384 BEGIN
385  old_messages := (SELECT setting FROM pg_settings WHERE name = 'client_min_messages');
386  EXECUTE 'SET client_min_messages TO warning';
387 
388  EXECUTE '
389  CREATE TABLE '|| tbl_error ||' AS
390  SELECT
391  avg(('|| tbl_prediction ||'.prediction - '|| tbl_actual ||'.'|| values_actual ||')^2) as mean_squared_error
392  FROM
393  '|| tbl_prediction ||',
394  '|| tbl_actual ||'
395  WHERE
396  '|| tbl_prediction ||'.id = '|| tbl_actual ||'.'|| id_actual;
397 
398  EXECUTE 'SET client_min_messages TO ' || old_messages;
399 END;
400 $$ LANGUAGE plpgsql VOLATILE;
401 
402 ------------------------------------------------------------------------
403 
404 /**
405  * @brief A prediction function for logistic regression
406  *
407  * @param coef Coefficients. Note: MADlib logregr_train function does not produce a seperate intercept term
408  * as elastic_net_train function.
409  * @param col_ind Independent variable, which must be an array
410  *
411  */
412 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_predict(
413  coef DOUBLE PRECISION[],
414  col_ind DOUBLE PRECISION[]
415 ) RETURNS BOOLEAN AS $$
416 PythonFunction(validation, cross_validation, logregr_predict)
417 $$ LANGUAGE plpythonu;
418 
419 /**
420  * @brief A prediction function for logistic regression
421  * The result is stored in the table of tbl_predict
422  *
423  * This function can be used together with cross-validation
424  */
425 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_predict(
426  tbl_model VARCHAR,
427  tbl_newdata VARCHAR,
428  col_ind_var VARCHAR,
429  col_id VARCHAR,
430  tbl_predict VARCHAR
431 ) RETURNS VOID AS $$
432 PythonFunction(validation, cross_validation, cv_logregr_predict)
433 $$ LANGUAGE plpythonu;
435 /**
436  * @brief Metric function for logistic regression
437  *
438  * @param coef Logistic fitting coefficients. Note: MADlib logregr_train function does not produce a seperate intercept term
439  * as elastic_net_train function.
440  * @param col_ind Independent variable, an array
441  * @param col_dep Dependent variable
442  *
443  * returns 1 if the prediction is the same as col_dep, otherwise 0
444  */
445 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.logregr_accuracy(
446  coef DOUBLE PRECISION[],
447  col_ind DOUBLE PRECISION[],
448  col_dep BOOLEAN
449 ) RETURNS INTEGER AS $$
450 PythonFunction(validation, cross_validation, logregr_accuracy)
451 $$ LANGUAGE plpythonu;
452 
453 /**
454  * @brief Metric function for logistic regression
455  *
456  * It computes the percentage of correct predictions.
457  * The result is stored in the table of tbl_accuracy
458  */
459 CREATE OR REPLACE FUNCTION MADLIB_SCHEMA.cv_logregr_accuracy(
460  tbl_predict VARCHAR,
461  tbl_source VARCHAR,
462  col_id VARCHAR,
463  col_dep_var VARCHAR,
464  tbl_accuracy VARCHAR
465 ) RETURNS VOID AS $$
466 PythonFunction(validation, cross_validation, cv_logregr_accuracy)
467 $$ LANGUAGE plpythonu;