function model = ahimbu(X, options)
% AHIMBU Train a hierarchical PLS-DA classifier
%
% Usage:
%  model = AHIMBU(X, options)
%  options = AHIMBU('options')
%
% Arguments:
%  X: a dataset object
%  options: a structure containing the following fields:
%   - preprocessing: a cell array containing two preprocessing objects
%     for predictors and the class indicator matrix, respectively.
%     Defaults to mean centering for both.
%   - maxlv: an integer; the maximal number of PLS latent variables to
%     consider for each model.
%     Defaults to 6.
%   - classset: chooses the row label set from X to predict.
%     Defaults to 1.
%   - preserve_class_labels: makes the resulting model return cell
%     arrays containing string labels of classes (or zeroes) instead of
%     integer class numbers.
%     Defaults to false.
%   - use_choosecomp: determines whether to use choosecomp() to select
%     the optimal number of components instead of choosing the one with
%     the least sum of false-positive rate and false-negative rate. Note
%     that choosecomp may decline to return a suggestion (e.g. when
%     maxlv < 7), in which case the latter option is still taken.
%     Defaults to false.
%   - cvi: the cross-validation method to use, see the "cvi" argument of
%     crossval().
%     Defaults to random subset selection for 3 subsets and 20
%     replicates.
%   - verbose: makes AHIMBU print information while it's preparing the
%     model, which may be helpful for want of a progress bar.
%     Defaults to true.
%
% Returns:
%  model: an evrimodel 'modelselector' object
%
% See also: PLSDA, PREPROCESS, CHOOSECOMP, CROSSVAL

% Special case: return default options when called requesting them
if nargin == 1 && ischar(X) && strcmp(X, 'options')
    model = defaults();
    return
end

% Special case: use default options
if nargin < 2
    options = defaults();
else
    options = reconopts(options, defaults());
end

labels = X.class{1,options.classset};

verbose(options, ...
    'size(X) = (%d,%d); %d unique labels', ...
    size(X,1), size(X,2), length(unique(labels)) ...
)

% establish invariant for the first while loop iteration
newlabels = labels;
RES = 1;
Xnew = X;
% Continue merging until perfect separation reached or cannot merge
while length(unique(newlabels)) > 2 && max(RES(:)) > 0
    [RES,Xnew,newlabels] = mergethem(Xnew, newlabels, options);
    verbose(options, ...
        '%dx%d confusion matrix, max = %g', ...
        size(RES, 1), size(RES, 2), max(RES(:)) ...
    );
end

% modelselector docs say that the targets should be ordered by the value
% of their class number, make it so
class_ids = sort(unique(newlabels));

% target models for the modelselector
submodels = cell(length(class_ids)+1, 1);

% For every new class...
for i=1:length(class_ids)
    % Samples of this new class
    subset = newlabels == class_ids(i);
    % And their classes (how many?)
    cl_sub = unique(labels(subset));
    sublen = length(cl_sub);

    verbose(options, ...
        'new class %d contains %d class(es),', class_ids(i), sublen ...
    );
    if sublen == 1
        % New class is a primitive class, return the original label
        verbose(options, 'so we are done with it')
        if options.preserve_class_labels
            % This returns a cell array containing the string
            cl = X.classlookup{1,options.classset}.find(cl_sub);
            % So assign the string only
            submodels{i} = cl{1};
        else
            % Return just the class number
            submodels{i} = cl_sub;
        end
    else
        % Otherwise split it further
        verbose(options, 'calling myself to split/merge it again')
        submodels{i} = ahimbu(X(subset, :), options);
    end
end

% modelselector likes to have an "otherwise" option for its PLS-DA
% models
submodels{length(class_ids) + 1} = 0;

% decision rule for the next layer
[~, ~, triggermodel] = optiplsda(X, newlabels, options);
verbose(options, 'chose %d components for the decision rule', triggermodel.ncomp);

model = modelselector(triggermodel, submodels{:});
end

% DEFAULTS Produce an AHIMBU options structure with default parameter values
function opts = defaults()
    opts.preprocessing = {preprocess('meancenter'), preprocess('meancenter')};
    opts.maxlv = 6;
    opts.classset = 1;
    opts.preserve_class_labels = false;
    opts.verbose = true;
    opts.use_choosecomp = false;
    opts.cvi = {'rnd', 3, 20};
end

function [RES,A,labels] = mergethem(A, labels, opts);
% MERGETHEM Merge two classes giving largest pairwise error on CV
%
% Usage: [RES, A, labels] = mergethem(A, labels, opts)
%
% A: a dataset object
% labels: class labels to predict
% opts: the AHIMBU options structure
%
% RES: pairwise misclassification matrix
% A: new dataset with two of the classes merged
% labels: updated labels with two of the classes merged
verbose(opts, '%d classes', length(unique(labels)));
assert(length(unique(labels)) > 2, 'Need at least three different classes');
RES = testall(A,labels,opts);
verbose(opts, '%dx%d confusion matrix', size(RES, 1), size(RES, 2))
[row,col] = find(RES==max(max(RES))); %take position of max misclass error
r = row(1);
c = col(1);
% alternatively: [~, IND] = max(RES(:)); [r,c] = ind2sub(size(RES),IND)

numcl = unique(labels);
% Remove class zero. That one is not used by us (it is for samples with no
% class and is automatically added even if there are none of such samples
numcl(numcl==0)=[];
if length(numcl)~=size(RES,2)
    error('Something wrong')
end

idx = find(labels==numcl(r)|labels==numcl(c));
labels(idx)=numcl(r); % Set all of them to one of the classes
end

function [misclassed, ncomp, model] = optiplsda(X, y, opts)
% OPTIPLSDA Cross-validate a PLS-DA model
%
% Usage: [misclassed, ncomp, model] = OPTIPLSDA(X, y, opts)
%
% X: a dataset
% y: a column vector of classes to predict
% opts: the AHIMBU options structure
%
% misclassed: a cell array containing false positive and false negative
%             rates per the number of latent variables for each class
% ncomp: the number of components deemed optimal
% model: the PLS-DA model with the optimal number of components
%        (not fitted if not requested)
%
% Depending on the options, the number of components is chosen using the
% choosecomp function, or according to minimal sum of false
% positive and false negative rate in cross-validation.
%
% See also: plsda, choosecomp

plsopts = struct( ...
    'display', 'off', 'plots', 'none', ...
    'preprocessing', {opts.preprocessing} ...
);
ncomp = [];
if opts.use_choosecomp
    % NOTE: choosecomp() only works on evrimodel objects, which requires
    % us to (1) fit it first and then (2) cross-validate it. Later,
    % when we may want to return the optimal model, we have to re-fit it
    % again.
    model = plsda(X, y, 1, plsopts);
    % NOTE: we can't use model.crossvalidate() here because we need to
    % pass a "y" argument separate from the class labels in X. There is
    % no documented way to do that with a method call, and the
    % undocumented way ignores the "y" argument.
    model = crossval(X, y, model, opts.cvi, opts.maxlv, plsopts);

    % remember the cross-validation results because we need to change the
    % number of components by re-creating the model
    misclassed = model.detail.misclassedcv;

    ncomp = choosecomp(model);
else
    cvopts = struct( ...
        'preprocessing', { opts.preprocessing }, ...
        'plots', 'none', 'discrim', 'yes', 'display', 'no' ...
    );
    res = crossval(X, y(:), 'sim', opts.cvi, opts.maxlv, cvopts);
    misclassed = res.misclassed;
end

if (isempty(ncomp))
    % sum the per-class values...
    [~, ncomp] = min(sum(cell2mat(cellfun( ...
        ... of FPR+FNR
        @sum, misclassed(:), 'UniformOutput', false ...
    ))));
    % This leaves a vector containing one value per number of LVs.
    % Choose the smallest one.
end

% recreate the model from scratch because we can't just change model.ncomp
if nargout >= 3; model = plsda(X, y, ncomp, plsopts); end
end

function RES = testall(A,labels,opts)
% TESTALL Calculate pairwise PLS-DA misclassification matrix in CV
%
% Usage: RES = TESTALL(A, labels, opts)
%
% A: a dataset
% labels: a vector of integer class labels
% opts: the AHIMBU options structure
%
% RES: a symmetric matrix of pairwise misclassification errors
%
% See also: plsda, optiplsda

%% Now do the thing
numcl = unique(labels);

RES=zeros(length(numcl));  %RB RES matrix
for i=1:length(numcl)-1
    for j=i+1:length(numcl)
        idx = find(labels==numcl(i)|labels==numcl(j));
        a = A(idx,:);
        [misclassed, ncomp] = optiplsda(a, labels(idx), opts);
        % misclassed{1} is for class i
        % These are two-class models, so FPR{1} == FNR{2}.
        % misclassed{1}(2,ncomp) is the false negative rate
        % And by doing min, I select the best one (the number of components
        % that gives the lowest
        RES(i,j)=sum(misclassed{1}(:,ncomp));
        RES(j,i)=RES(i,j);
    end
end
end

% VERBOSE Call FPRINTF if options.verbose is true
function verbose(o, format, varargin)
    if ~o.verbose; return; end
    stk = dbstack;
    fprintf(['%s(): ', format, '\n'], stk(2).name, varargin{:});
end
