calc_lambda_regtools

PURPOSE ^

% CALC_LAMBDA_REGTOOLS: Find optimal hyperparameter by the L-curve (LCC)

SYNOPSIS ^

function lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot)

DESCRIPTION ^

% CALC_LAMBDA_REGTOOLS: Find optimal hyperparameter by the L-curve (LCC) 
 criterion or the generalized cross-validation (GCV).
   lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot);

 Output:
   lambdas   - "optimal" hyperparameter(s) determined using LCC or GCV

 Input:
   imdl      - inverse model (EIDORS struct)
   vh        - homogenous voltage matrix (of size nVtg x 1)
   vi        - inhomogenous voltage matrix (of size nVtg x nFrames) including noise(!)
   type      - type of approach used, either:
               'LCC' (default), the L-curve criterion
               'GCV', generalized cross-validation
   doPlot    - will enable plotting if set to true (default = false)

 Example:
   calc_lambda_regtools('unit_test');  

 NOTE
   if vi contains multiple frames the returned values will contain an
   "optimal" hyperparameter for each frame. An appropriate lambda can then 
   be determined from the average (e.g. median) of these values.

 See also: RTv4manual.pdf (please note that all page numbers listed
 correspond to the ones written in the upper right corner, the effective
 PDF page number will be += 2).

 Nomenclature: Jacobian J is A; Prior R (not RtR) is L; Voltage v is b

 Fabian Braun, December 2016

 CITATION_REQUEST:
 AUTHOR: P C Hansen
 TITLE: Regularization tools version 4.0 for Matlab 7.3.
 JOURNAL: Numerical algorithms
 YEAR: 2007
 VOL: 46
 NUM: 2
 PAGE: S189-194
 DOI: 10.1007/s11075-007-9136-9

CROSS-REFERENCE INFORMATION ^

This function calls: This function is called by:

SUBFUNCTIONS ^

SOURCE CODE ^

0001 function lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot)
0002 %% CALC_LAMBDA_REGTOOLS: Find optimal hyperparameter by the L-curve (LCC)
0003 % criterion or the generalized cross-validation (GCV).
0004 %   lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot);
0005 %
0006 % Output:
0007 %   lambdas   - "optimal" hyperparameter(s) determined using LCC or GCV
0008 %
0009 % Input:
0010 %   imdl      - inverse model (EIDORS struct)
0011 %   vh        - homogenous voltage matrix (of size nVtg x 1)
0012 %   vi        - inhomogenous voltage matrix (of size nVtg x nFrames) including noise(!)
0013 %   type      - type of approach used, either:
0014 %               'LCC' (default), the L-curve criterion
0015 %               'GCV', generalized cross-validation
0016 %   doPlot    - will enable plotting if set to true (default = false)
0017 %
0018 % Example:
0019 %   calc_lambda_regtools('unit_test');
0020 %
0021 % NOTE
0022 %   if vi contains multiple frames the returned values will contain an
0023 %   "optimal" hyperparameter for each frame. An appropriate lambda can then
0024 %   be determined from the average (e.g. median) of these values.
0025 %
0026 % See also: RTv4manual.pdf (please note that all page numbers listed
0027 % correspond to the ones written in the upper right corner, the effective
0028 % PDF page number will be += 2).
0029 %
0030 % Nomenclature: Jacobian J is A; Prior R (not RtR) is L; Voltage v is b
0031 %
0032 % Fabian Braun, December 2016
0033 %
0034 % CITATION_REQUEST:
0035 % AUTHOR: P C Hansen
0036 % TITLE: Regularization tools version 4.0 for Matlab 7.3.
0037 % JOURNAL: Numerical algorithms
0038 % YEAR: 2007
0039 % VOL: 46
0040 % NUM: 2
0041 % PAGE: S189-194
0042 % DOI: 10.1007/s11075-007-9136-9
0043 %
0044 
0045 % (C) 2016 Fabian Braun. License: GPL version 2 or version 3
0046 % $Id: calc_lambda_regtools.m 5930 2019-04-13 04:55:41Z alistair_boyle $
0047 
0048 citeme(mfilename);
0049 
0050 %% unit testing?
0051 if ischar(imdl) && strcmpi(imdl, 'unit_test')
0052    doUnitTest();
0053    return;
0054 end
0055 
0056 
0057 %% set default inputs
0058 if ~exist('type', 'var') || isempty(type)
0059     type = 'LCC';
0060 end
0061 if ~exist('doPlot', 'var') || isempty(doPlot)
0062     doPlot = false;
0063 end
0064 
0065 %% check for existence of the regtools package
0066 if exist('regudemo.m')==2  % file is already on path
0067 % Do nothing. We're OK
0068 elseif exist('./regtools', 'dir') %check if in current folder
0069    addpath('./regtools');
0070 %%% What should this do?
0071 elseif exist([fileparts(mfilename('fullpath')), filesep, 'regtools'])
0072    addpath([fileparts(mfilename('fullpath')), filesep, 'regtools'])
0073 else
0074    error('Regtools are required but are not available, please download them from <a href="matlab: web http://www.mathworks.com/matlabcentral/fileexchange/52-regtools -browser">File Exchange</a> or <a href="matlab: web http://www2.compute.dtu.dk/~pcha/Regutools/ -browser">P.C. Hansen''s website</a> and store them in the subfolder called ''regtools''. In order to allow for a fast execution it is recommended to disable (uncomment) all plotting functions in l_cuve.m and gcv.m.');
0075 end
0076 
0077 % AA: 3feb2017: Please make changes so
0078 % 1. we don't call get_RM
0079 % 2. we call calc_R_prior
0080 %     fix calc_R_prior so it does what you want
0081 % 3. rename to calc_lambda_regtools
0082 % 4. change tutorial to call new name
0083 % 5. Make changes to mk_GREIT_model
0084 % 6. Merge these changes into mainline (if it works)
0085 %    OR: delete mainline and svn mv
0086 
0087 %% prepare imdl
0088 imdlTmp = imdl;
0089 imdlTmp.prior_use_fwd_not_rec = 0;  
0090 % if isfield(imdl.fwd_model,'coarse2fine')
0091 %     imdlTmp.fwd_model = rmfield(imdlTmp.fwd_model,'coarse2fine');
0092 % end
0093 % if isfield(imdl, 'rec_model') && isfield(imdl.rec_model,'coarse2fine')
0094 %     imdlTmp.rec_model = rmfield(imdlTmp.rec_model,'coarse2fine');
0095 % end
0096 img_bkgnd = calc_jacobian_bkgnd(imdlTmp);
0097 A = calc_jacobian(img_bkgnd);
0098 W = calc_meas_icov(imdlTmp);
0099 L = calc_R_prior(imdlTmp);   
0100 
0101 LtL = calc_RtR_prior(imdlTmp);
0102 LtL_ = L'*L;    
0103 % assert(all(LtL_(:) - LtL(:) < 100*eps), 'Prior differs too much!');
0104 
0105 % check that measurement covariance matrix W is identity
0106 assert(isequal(W, speye(size(W))));
0107 
0108 
0109 %% (IMPORTANT!) bring generalized to standart form (section 2.6 p.21 of RTv4manual.pdf)
0110 % L-curves and of generalized and standard form are equal this is
0111 % because they have identical norms see (section 2.6.3 p.24 of RTv4manual.pdf)
0112 [A_s, ~, ~] = std_form(A, L, nan(size(vh,1),1));  % as L is square b won't be affected, only A
0113 % [A_s,b_s,L_p,K,M] = std_form(A,L,b);
0114 % NOTE: We need it in standard form as l_curve and gcv routines only accept this
0115 [U_s, s_s] = csvd(A_s);
0116 
0117 %% Iterate through all frames to get a range of lambdas
0118 nFrames = size(vi,2);
0119 lambdas = nan(nFrames,1);
0120 
0121 progress_msg('Calculating lambda for each frame:', 0, nFrames);
0122 
0123 if doPlot
0124     figure(); 
0125 end
0126 
0127 data_width= max(num_frames(vi), num_frames(vh));
0128 vi = filt_data( imdl, vi, data_width );
0129 vh = filt_data( imdl, vh, data_width );
0130 B = vh - vi;
0131 
0132 for iFrame = 1:nFrames
0133     
0134     progress_msg(iFrame, nFrames);
0135     
0136     %% prepare differential data of current frame
0137     b = B(:,iFrame);
0138     
0139     switch(lower(type))
0140         case 'lcc'
0141             %% L-curve (see section 2.5 p.20 of RTv4manual.pdf)
0142             % calculate and plot continuous l-curve (documentation on p.83 of RTv4manual.pdf)
0143             lambdas(iFrame) = l_curve(U_s,s_s,b);
0144 
0145             % add my own l-curve to plot for validation purposes
0146             if doPlot && (iFrame == nFrames)
0147                 lInit = imdl.hyperparameter.value;
0148                 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10));
0149                 lams = [lams lInit];
0150 
0151                 clear myrho myeta
0152                 for i=1:length(lams);
0153                     imdl.hyperparameter.value = lams(i);
0154                     RM = get_RM( imdl );
0155                     myrho(i) = (norm(A*(RM*b) - b));
0156                     myeta(i) = (norm(L*(RM*b)));
0157                 end
0158 
0159                 % plot it
0160                 hold on;
0161                 loglog(myrho(1:end-1), myeta(1:end-1), 'ob');
0162                 hold on;
0163                 loglog(myrho(end), myeta(end), 'og');
0164             end
0165         case 'gcv'
0166             %% gcv (see p.37 of RTv4manual.pdf)
0167             % documentation on p.65 of RTv4manual.pdf
0168             lambdas(iFrame) = gcv(U_s,s_s,b);
0169             
0170             % plot my own GCV for validation purposes
0171             if doPlot && (iFrame == nFrames)
0172                 lInit = imdl.hyperparameter.value;
0173                 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10));
0174                 lams = [lams lInit];
0175 
0176                 clear myG
0177                 for i=1:length(lams);
0178                     imdl.hyperparameter.value = lams(i);
0179                     RM = get_RM( imdl );
0180                     rho = (norm(A*(RM*b) - b))^2;
0181                     myG(i) = rho / (trace(eye(size(RM,2)) - A*RM)^2);
0182                 end
0183 
0184                 % plot it
0185                 hold on;
0186                 loglog(lams(1:end-1), myG(1:end-1), 'ob');
0187                 hold on;
0188                 loglog(lams(end), myG(end), 'og');
0189             end
0190         otherwise
0191             error('type not supported!');
0192     end
0193 end
0194 
0195 progress_msg('Calculating lambda for each frame:', inf);
0196 
0197 end
0198 
0199 % TODO: this code really needs to be cleaned, but not before eidors 3.4
0200 function nf= num_frames(d0)
0201    if isnumeric( d0 )
0202       nf= size(d0,2);
0203    elseif d0(1).type == 'data';
0204       nf= size( horzcat( d0(:).meas ), 2);
0205    else
0206       error('Problem calculating number of frames. Expecting numeric or data object');
0207    end
0208 end
0209 
0210 % test for existance of meas_select and filter data
0211 function d2= filt_data(inv_model, d0, data_width )
0212    if ~isnumeric( d0 )
0213        % we probably have a 'data' object
0214 
0215        d1 = [];
0216        for i=1:length(d0)
0217           if strcmp( d0(i).type, 'data' )
0218               d1 = [d1, d0(i).meas];
0219           else
0220               error('expecting an object of type data');
0221           end
0222        end
0223 
0224    else
0225       % we have a matrix of data. Hope for the best
0226       d1 = d0;
0227    end
0228 
0229    d1= double(d1); % ensure we can do math on our object
0230 
0231    if isfield(inv_model.fwd_model,'meas_select') && ...
0232      ~isempty(inv_model.fwd_model.meas_select)
0233       % we have a meas_select parameter that isn []
0234 
0235       meas_select= inv_model.fwd_model.meas_select;
0236       if     size(d1,1) == length(meas_select)
0237          d2= d1(meas_select,:);
0238       elseif size(d1,1) == sum(meas_select==1)
0239          d2= d1;
0240       else
0241          error('inconsistent difference data: (%d ~= %d). Maybe check fwd_model.meas_select',  ...
0242                size(d1,1), length(meas_select));
0243       end
0244    else
0245       d2= d1;
0246    end
0247 
0248    if nargin==3 % expand to data width
0249       d2_width= size(d2,2);
0250       if d2_width == data_width
0251          % ok
0252       elseif d2_width == 1
0253          d2= d2(:,ones(1,data_width));
0254       else
0255          error('inconsistent difference data: (%d ~= %d)',  ...
0256                d2_width, data_width);
0257       end
0258    end
0259 end
0260 
0261 function doUnitTest()
0262 % inspired by the tutorial mentioned below:
0263 % http://eidors3d.sourceforge.net/tutorial/EIDORS_basics/tutorial110.shtml
0264 %
0265 
0266 % Load some data
0267 load iirc_data_2006
0268 
0269 stim = mk_stim_patterns(16,1,[0,1],[0,1],{'meas_current'},1);
0270 
0271 for iRun = 1
0272 % for iRun = [0 1]
0273     % Get a 2D image reconstruction model
0274     if iRun 
0275         % more advanced 3D model which includes coarse2fine mapping which
0276         % makes all crash
0277         fmdl = mk_library_model('adult_male_16el');
0278         fmdl.stimulation = stim;
0279         opts = [];
0280         opts.noise_figure = 0.5;
0281         imdl = mk_GN_model(fmdl, opts, []);
0282     else
0283         % simple one
0284         imdl= mk_common_model('c2c');
0285         imdl.fwd_model.stimulation = stim;
0286         imdl.fwd_model = rmfield( imdl.fwd_model, 'meas_select');
0287         imdl.RtR_prior = @prior_tikhonov;
0288     end
0289 
0290     % load the real data
0291     vi = real(v_rotate)/1e4; vh = real(v_reference)/1e4;
0292     % allow double precision, else we run into (unexplainable) problems
0293     vi = double(vi); vh = double(vh);   
0294 
0295     % get the hyperparameter value via L-curve
0296     figure
0297     lambdas_lcc = calc_lambda_regtools(imdl,vh,vi,'LCC',true);
0298 
0299     % get the hyperparameter value via GCV
0300     lambdas_gcv = calc_lambda_regtools(imdl,vh,vi,'GCV',true);
0301 
0302     % visualize
0303     FramesOfInterest = [10 35 60 85];
0304     fig = figure(1 + iRun);
0305     subplot(121);
0306     imdl.hyperparameter.value = median(lambdas_lcc);
0307     imgs_lcc = inv_solve(imdl, vh, vi(:,FramesOfInterest));
0308     imgs_lcc.show_slices.img_cols = 1;
0309     show_slices(imgs_lcc);
0310     title('L-curve');
0311 
0312     subplot(122);
0313     imdl.hyperparameter.value = median(lambdas_gcv);
0314     imgs_gcv = inv_solve(imdl, vh, vi(:,FramesOfInterest));
0315     imgs_gcv.show_slices.img_cols = 1;
0316     show_slices(imgs_gcv);
0317     title('GCV');
0318 
0319 end
0320 
0321 end

Generated on Fri 30-Dec-2022 19:44:54 by m2html © 2005