0001 function lambdas = calc_lambda_regtools(imdl, vh, vi, type, doPlot)
0002
0003
0004
0005
0006
0007
0008
0009
0010
0011
0012
0013
0014
0015
0016
0017
0018
0019
0020
0021
0022
0023
0024
0025
0026
0027
0028
0029
0030
0031
0032
0033
0034
0035
0036
0037
0038
0039
0040
0041
0042
0043
0044
0045
0046
0047
0048 citeme(mfilename);
0049
0050
0051 if ischar(imdl) && strcmpi(imdl, 'unit_test')
0052 doUnitTest();
0053 return;
0054 end
0055
0056
0057
0058 if ~exist('type', 'var') || isempty(type)
0059 type = 'LCC';
0060 end
0061 if ~exist('doPlot', 'var') || isempty(doPlot)
0062 doPlot = false;
0063 end
0064
0065
0066 if exist('regudemo.m')==2
0067
0068 elseif exist('./regtools', 'dir')
0069 addpath('./regtools');
0070
0071 elseif exist([fileparts(mfilename('fullpath')), filesep, 'regtools'])
0072 addpath([fileparts(mfilename('fullpath')), filesep, 'regtools'])
0073 else
0074 error('Regtools are required but are not available, please download them from <a href="matlab: web http://www.mathworks.com/matlabcentral/fileexchange/52-regtools -browser">File Exchange</a> or <a href="matlab: web http://www2.compute.dtu.dk/~pcha/Regutools/ -browser">P.C. Hansen''s website</a> and store them in the subfolder called ''regtools''. In order to allow for a fast execution it is recommended to disable (uncomment) all plotting functions in l_cuve.m and gcv.m.');
0075 end
0076
0077
0078
0079
0080
0081
0082
0083
0084
0085
0086
0087
0088 imdlTmp = imdl;
0089 imdlTmp.prior_use_fwd_not_rec = 0;
0090
0091
0092
0093
0094
0095
0096 img_bkgnd = calc_jacobian_bkgnd(imdlTmp);
0097 A = calc_jacobian(img_bkgnd);
0098 W = calc_meas_icov(imdlTmp);
0099 L = calc_R_prior(imdlTmp);
0100
0101 LtL = calc_RtR_prior(imdlTmp);
0102 LtL_ = L'*L;
0103
0104
0105
0106 assert(isequal(W, speye(size(W))));
0107
0108
0109
0110
0111
0112 [A_s, ~, ~] = std_form(A, L, nan(size(vh,1),1));
0113
0114
0115 [U_s, s_s] = csvd(A_s);
0116
0117
0118 nFrames = size(vi,2);
0119 lambdas = nan(nFrames,1);
0120
0121 progress_msg('Calculating lambda for each frame:', 0, nFrames);
0122
0123 if doPlot
0124 figure();
0125 end
0126
0127 data_width= max(num_frames(vi), num_frames(vh));
0128 vi = filt_data( imdl, vi, data_width );
0129 vh = filt_data( imdl, vh, data_width );
0130 B = vh - vi;
0131
0132 for iFrame = 1:nFrames
0133
0134 progress_msg(iFrame, nFrames);
0135
0136
0137 b = B(:,iFrame);
0138
0139 switch(lower(type))
0140 case 'lcc'
0141
0142
0143 lambdas(iFrame) = l_curve(U_s,s_s,b);
0144
0145
0146 if doPlot && (iFrame == nFrames)
0147 lInit = imdl.hyperparameter.value;
0148 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10));
0149 lams = [lams lInit];
0150
0151 clear myrho myeta
0152 for i=1:length(lams);
0153 imdl.hyperparameter.value = lams(i);
0154 RM = get_RM( imdl );
0155 myrho(i) = (norm(A*(RM*b) - b));
0156 myeta(i) = (norm(L*(RM*b)));
0157 end
0158
0159
0160 hold on;
0161 loglog(myrho(1:end-1), myeta(1:end-1), 'ob');
0162 hold on;
0163 loglog(myrho(end), myeta(end), 'og');
0164 end
0165 case 'gcv'
0166
0167
0168 lambdas(iFrame) = gcv(U_s,s_s,b);
0169
0170
0171 if doPlot && (iFrame == nFrames)
0172 lInit = imdl.hyperparameter.value;
0173 lams = flip(logspace(log10(lInit*1E-3), log10(lInit*1E3), 10));
0174 lams = [lams lInit];
0175
0176 clear myG
0177 for i=1:length(lams);
0178 imdl.hyperparameter.value = lams(i);
0179 RM = get_RM( imdl );
0180 rho = (norm(A*(RM*b) - b))^2;
0181 myG(i) = rho / (trace(eye(size(RM,2)) - A*RM)^2);
0182 end
0183
0184
0185 hold on;
0186 loglog(lams(1:end-1), myG(1:end-1), 'ob');
0187 hold on;
0188 loglog(lams(end), myG(end), 'og');
0189 end
0190 otherwise
0191 error('type not supported!');
0192 end
0193 end
0194
0195 progress_msg('Calculating lambda for each frame:', inf);
0196
0197 end
0198
0199
0200 function nf= num_frames(d0)
0201 if isnumeric( d0 )
0202 nf= size(d0,2);
0203 elseif d0(1).type == 'data';
0204 nf= size( horzcat( d0(:).meas ), 2);
0205 else
0206 error('Problem calculating number of frames. Expecting numeric or data object');
0207 end
0208 end
0209
0210
0211 function d2= filt_data(inv_model, d0, data_width )
0212 if ~isnumeric( d0 )
0213
0214
0215 d1 = [];
0216 for i=1:length(d0)
0217 if strcmp( d0(i).type, 'data' )
0218 d1 = [d1, d0(i).meas];
0219 else
0220 error('expecting an object of type data');
0221 end
0222 end
0223
0224 else
0225
0226 d1 = d0;
0227 end
0228
0229 d1= double(d1);
0230
0231 if isfield(inv_model.fwd_model,'meas_select') && ...
0232 ~isempty(inv_model.fwd_model.meas_select)
0233
0234
0235 meas_select= inv_model.fwd_model.meas_select;
0236 if size(d1,1) == length(meas_select)
0237 d2= d1(meas_select,:);
0238 elseif size(d1,1) == sum(meas_select==1)
0239 d2= d1;
0240 else
0241 error('inconsistent difference data: (%d ~= %d). Maybe check fwd_model.meas_select', ...
0242 size(d1,1), length(meas_select));
0243 end
0244 else
0245 d2= d1;
0246 end
0247
0248 if nargin==3
0249 d2_width= size(d2,2);
0250 if d2_width == data_width
0251
0252 elseif d2_width == 1
0253 d2= d2(:,ones(1,data_width));
0254 else
0255 error('inconsistent difference data: (%d ~= %d)', ...
0256 d2_width, data_width);
0257 end
0258 end
0259 end
0260
0261 function doUnitTest()
0262
0263
0264
0265
0266
0267 load iirc_data_2006
0268
0269 stim = mk_stim_patterns(16,1,[0,1],[0,1],{'meas_current'},1);
0270
0271 for iRun = 1
0272
0273
0274 if iRun
0275
0276
0277 fmdl = mk_library_model('adult_male_16el');
0278 fmdl.stimulation = stim;
0279 opts = [];
0280 opts.noise_figure = 0.5;
0281 imdl = mk_GN_model(fmdl, opts, []);
0282 else
0283
0284 imdl= mk_common_model('c2c');
0285 imdl.fwd_model.stimulation = stim;
0286 imdl.fwd_model = rmfield( imdl.fwd_model, 'meas_select');
0287 imdl.RtR_prior = @prior_tikhonov;
0288 end
0289
0290
0291 vi = real(v_rotate)/1e4; vh = real(v_reference)/1e4;
0292
0293 vi = double(vi); vh = double(vh);
0294
0295
0296 figure
0297 lambdas_lcc = calc_lambda_regtools(imdl,vh,vi,'LCC',true);
0298
0299
0300 lambdas_gcv = calc_lambda_regtools(imdl,vh,vi,'GCV',true);
0301
0302
0303 FramesOfInterest = [10 35 60 85];
0304 fig = figure(1 + iRun);
0305 subplot(121);
0306 imdl.hyperparameter.value = median(lambdas_lcc);
0307 imgs_lcc = inv_solve(imdl, vh, vi(:,FramesOfInterest));
0308 imgs_lcc.show_slices.img_cols = 1;
0309 show_slices(imgs_lcc);
0310 title('L-curve');
0311
0312 subplot(122);
0313 imdl.hyperparameter.value = median(lambdas_gcv);
0314 imgs_gcv = inv_solve(imdl, vh, vi(:,FramesOfInterest));
0315 imgs_gcv.show_slices.img_cols = 1;
0316 show_slices(imgs_gcv);
0317 title('GCV');
0318
0319 end
0320
0321 end