function [Factors,Diagnostics] = PMF3(varargin);
% Fit a PARAFAC model to a three-way array using a damped Gauss-Newton (Levenberg-Marquadt) algorithm
%
% SYNTAX
% [Factors,Diag] = PMF3(X,F,Options,A,B,C);
% 
% INPUTS
% X      : data array
% F      : model's rank
% Options: see ParOptions
% A,B,C  : initial estimates for the loading matrices
%
% OUTPUTS
% Factors: cell vectors with the loading matrices (NB Sign permutation is
%          not fixed according to any convention)
% Diag   : structure with some diagnostics
%          Diag.fit(1)        : value of the loss function at convergence
%          Diag.fit(2)        : not used
%          Diag.fit(3)        : not used
%          Diag.it(1)         : total number of iterations
%          Diag.it(2)         : number of iterations where the dGN step was accepted
%          Diag.it(3)         : number of iterations where the nonlinear
%                               update is preferred to the Levenberg-Marquadt one
%          Diag.it(4)         : not used
%          Diag.convergence(1): relative fit decrease
%          Diag.convergence(2): relative change in the parameter's norm
%          Diag.convergence(3): relative loss function value lower than machine precision
%          Diag.convergence(4): gradient's infinite norm
%          Diag.convergence(5): not used
%          Diag.convergence(6): max number of iterations reached
%          Diag.convergence(7): not used
%
% Author: Giorgio Tomasi 
%         Royal Agricultural and Veterinary University 
%         Rolighedsvej 30 
%         DK-1958 Frederiksberg C 
%         Denmark 
% 
% Last modified: 10-Jan-2005 15:51
% 
% Contact: Giorgio Tomasi, gt@kvl.dk; Rasmus Bro, rb@kvl.dk 
% 
% References: "A weighted non-negative least squares algorithms for the 3-way PARAFAC factor analysis" 
%             P. Paatero, Chemometrics Intelligent Laboratory Systems 38, 223 (1997).
%             "A comparison of algorithms for fitting the PARAFAC model"
%             G. Tomasi, R. Bro, Computational Statistics and Data Analysis, in press.
%

% Check for minimal input
if ~nargin
    help pmf3
    return
end

%Check input values
[X,F,Options,Factors] = Check_GenParafac_Input(varargin{:});
if isempty(F)
   [Factors,Diagnostics] = deal([]);
   return
end

% Compute initial estimates if not given
if isempty(Factors)                                         
    [Factors{1:ndims(X)}] = InitPar(X,F,Options);
end

%Some initial values
ConvCrit       = false(7,1);                                            % Convergence criteria
it             = zeros(4,1);                                            % Number of iterations [global hessian_computations not_used not_used]
FactorsNew     = Factors;                                               % Model parameters new estimates 
dimX           = size(X);                                               % X dimensions vector
SSX            = tss(X,false);                                          % Total sum of squares
ConvGrad       = Options.convcrit.grad;                                 % Convergence criteria: gradient infinite norm
ConvRelFit     = Options.convcrit.relfit;                               %                       relative fit decrease
ConvPar        = Options.convcrit.par;                                  %                       relative parameter change
ConvFit        = Options.convcrit.fit;                                  %                       loss function/total sum of squares
ConvMaxIt      = Options.convcrit.maxiter;                              %                       max n of iterations
LamHistory     = zeros(ConvMaxIt,1);                                    % History of the damping parameter, used as a diagnostic
Gamma          = Options.regularisation.gamma;                          % Regularisation term
GammaUpdate    = Options.regularisation.n_updates;                      % Number of updates for the regularisation factor
GammaIt        = Options.regularisation.iter;                           % N of consecutive iterations required to decrease Gamma
RegSteps       = 10.^(linspace(-1,log10(ConvRelFit * 5),GammaUpdate));  % Thresholds for updating the regularisation factor
RegItCount     = 0;                                                     % N of consecutive iterations for which the relative fit decrease 
Fit_new        = tss(X - nmodel(Factors),false);                        % Initial least squares loss function value
[Factors{1:3}] = scale_factors(false,Factors{1:3});                     % Scaling affects the penalty value
Penalty_new    = Gamma * tss(cat(1,Factors{:}),false);                  % Initial penalty value
LPMF3_new      = Fit_new + Penalty_new;                                 % Initial Loss function value
                                                                        % is lower than a given threshold

if strcmpi(Options.diagnostics,'on')
    FitHistory  = zeros(ConvMaxIt,1);                                % Initialise history of least squares loss function values
    LossHistory = zeros(ConvMaxIt,1);                                % Initialise history of loss function values
end

if ~isequal(Options.display,'none')
    fprintf('\n\n Alpha                Fit           It      EV%%      Lambda             Gamma              Max Gr\n')
end

%Start fitting
while ~any(ConvCrit)                                        % Start the outer loop
    
    % Outer loop: iterations that require a new computation of H and g

    it(2)                 = it(2) + 1;                      % Update the # iterations for the outer loop 
    [Factors{1:ndims(X)}] = scale_factors(0,Factors{:});    % Scale loadings vectors to equal norm
    FactorsOld            = Factors;                        % Save old factors
    Fit_old               = Fit_new;                        % Save old least squares fit
    LPMF3_old             = LPMF3_new;                      % Save old loss function value
    p                     = vec(cat(1,Factors{:})');        % Vector of parameter estimates
    NormPar               = norm(p);                        % Norm of the vector of parameters
    if ~it(1)
        [g,JtJ,Lambda] = ParafacDer(Factors,X);             % Calculate Hessian and Gradient
        Lambda         = Lambda * Options.lambdainit;       % initialise the damping parameter
    else
        [g,JtJ]        = ParafacDer(Factors,X);             % Calculate Hessian and Gradient
    end
    LamHistory(it(2)) = Lambda;                             % Store the values of the damping parameter to display diagnostic
    if strcmpi(Options.diagnostics,'on')                    % Save some diagnostics
        FitHistory(it(2))   = Fit_old;                      % Store fit values for accepted steps
        LossHistory(it(2))  = LPMF3_old;                    % Store fit values for accepted steps
    end
    Norm_g        = max(abs(g));                            % Gradient infinite norm
    ConvCrit(4,1) = ConvGrad >= Norm_g;                     % Stop if gradient's infinite norm is smaller than a predefined criterion
    Do_It         = true;                                   % Do inner loop
    
    while Do_It && ~any(ConvCrit)                           % Start the inner loop
        
        % Inner loop: if the step is rejected, only the damping parameter is
        %             updated. There is no need to recompute Hessian and gradient
        
        Alpha  = 1;                                         % Initialise step length
        it(1)  = it(1) + 1;                                 % Update # iterations             
        Psi    = JtJ + (Lambda + Gamma) * eye(size(JtJ,1)); % Compute left hand side of the modified normal equations               
        Xsi    = g - Gamma * p;                             % Compute right hand side of the modified normal equations
        
        % Solve the system of modified normal equations using Cholesky factorisation
        [Psi,CFlag] = chol(Psi);                            % Compute the Cholesky factor of Psi   
        if ~CFlag                                           % Psi    is positive definite -> a descent direction can be computed
            
            warning off                                     % Avoid displaying the warning on bad scaling
            lastwarn('')
            deltap = Psi\(Psi'\Xsi);                        % Compute update calculated by back substitution
            if isempty(lastwarn)                            % The matrix is nicely scaled
                
                p_new        = p + deltap;                  % Compute updated parameters

                % Change format of model parameters' estimates to cell
                Count        = 0;
                for m = 1:length(dimX)
                    FactorsNew{m} = reshape(p_new(Count + 1:Count + dimX(m) * F),F,dimX(m))';
                    Count         = Count + dimX(m) * F;
                end
                [FactorsNew{1:3}] = scale_factors(false,FactorsNew{1:3}); % Adjust scaling as it influences the penalty
                p_new             = vec(cat(1,FactorsNew{:})');
                
                % Compute Loss Function value
                Fit_new     = tss(X - nmodel(FactorsNew),false);
                Penalty_new = Gamma * tss(p_new,false);
                LPMF3_new   = Fit_new + Penalty_new; 
                
                % Call line search if the loss function does not decrease
                % Lambda update scheme [cf. P. Paatero, Chemometr. Intell. Lab. Syst. 38, 223 (1997)].
                Do_It  = false;                               % Accept update and exit inner loop (default)
                if LPMF3_old < LPMF3_new; 
                    
                    [Alpha,LPMF3_new,Fit_new] = LineSearch_PMF3(X,p,deltap,Gamma,LPMF3_old,LPMF3_new,Fit_old,Fit_new);
                    Lambda                    = Lambda * Options.lambdaudpar(1);
                    if ~Alpha
                        Do_It = true;                           % If 10 iterations of line search do not decrease the loss
                                                                % function, reject step and update the damping parameter
                    end                        
                    
                else
                    Lambda = Lambda * Options.lambdaudpar(2);   % Decrease the damping parameter
                end
                                
                NonlinearUpdate = false;                        % Flag indicating preference of the nonlinear update
                if ~Do_It                                       % The step was not rejected in the previous
                    
                    %Compute the non-linear update for the parameters
                    Alpha_NLU = 1;
                    p_NLU     = p + 0.5 * deltap;
                    
                    % Change format to cell
                    Count      = 0;
                    for m = 1:length(dimX)
                        Factors_NLU{m} = reshape(p_NLU(Count + 1:Count + dimX(m) * F),F,dimX(m))';
                        Count          = Count + dimX(m) * F;
                    end
                    
                    % Compute nonlinear update
                    Xsi_NLU    = ParafacDer(Factors_NLU,X) - Gamma * p; % Recompute right-hand side
                    deltap_NLU = Psi \ (Psi' \ Xsi_NLU);                % Compute nonlinear update calculated by back substitution
                    p_NLU      = p + deltap_NLU;                        % Update parameter estimates
                    
                    % Change format to cell
                    Count      = 0;
                    for m = 1:length(dimX)
                        Factors_NLU{m} = reshape(p_NLU(Count + 1:Count + dimX(m) * F),F,dimX(m))';
                        Count          = Count + dimX(m) * F;
                    end
                    [Factors_NLU{1:3}] = scale_factors(false,Factors_NLU{1:3});             % Adjust scaling as it influences the penalty
                    p_NLU              = vec(cat(1,Factors_NLU{:})');
                    Fit_NLU            = tss(X - nmodel(Factors_NLU),false);
                    Penalty_NLU        = Gamma * tss(p_NLU,false);
                    LPMF3_NLU          = Fit_NLU + Penalty_NLU; % Compute loss function value
                    
                    % Call line search if the loss function does not decrease
                    if LPMF3_old < LPMF3_NLU; 
                        [Alpha_NLU,LPMF3_NLU,Fit_NLU] = LineSearch_PMF3(X,p,deltap_NLU,Gamma,LPMF3_old,LPMF3_NLU,Fit_old,Fit_NLU);
                    end
                    
                    % Check whether the nonlinear update further improves the loss function  
                    if LPMF3_NLU < LPMF3_new
                        
                        % Nonlinear update is better than standard dGN update 
                        % (note that LPMF3_new is necessarily < LPMF3_old, cf lines 155 and 163)
                        Alpha           = Alpha_NLU;
                        deltap          = deltap_NLU;
                        LPMF3_new       = LPMF3_NLU;
                        Fit_new         = Fit_NLU;
                        it(3)           = it(3) + 1;
                        NonlinearUpdate = true;
                        
                    end
                    %Update regularisation term
                    if abs(LPMF3_old - LPMF3_new)/LPMF3_old < RegSteps(1)
                        RegItCount = RegItCount + 1;
                    end
                    if RegItCount >= GammaIt && length(RegSteps) ~= 1
                        %If the number of consecutive iterations with a decrease of the
                        %loss function exceeds Options.regularisation.iter and the step has
                        %been accepted, reduce Gamma
                        RegItCount = 0;
                        RegSteps   = RegSteps(min(length(RegSteps),2):end);
                        Gamma      = Gamma / Options.regularisation.gammaupdate;
                    end
                    
                end
                
                %Check convergence
                if ~LPMF3_new
                    ConvCrit(1,1) = true ;                                              % In case model fit the data exactly
                else
                    ConvCrit(1,1) = ConvRelFit >= abs(LPMF3_old - LPMF3_new)/LPMF3_old; % Relative fit decrease
                end
                ConvCrit(2,1) = ConvPar >= norm(deltap)/NormPar;                % Relative change in the parameters
                
            else
                Lambda = Lambda * Options.lambdaudpar(1);                       % Increase damping parameter if Psi is nearly singular
            end
            
        else
            Lambda = Lambda * Options.lambdaudpar(1);                           % Increase the Hessian approximation because Psi is singular to 
                                                                                % machine precision 
        end
        ConvCrit(6,1) = it(1) == ConvMaxIt;                                     % Max. number of iterations
        
    end % End of the inner loop

    p     = p + Alpha * deltap;                                                 % Compute updated parameters
    Count = 0;
    for m = 1:length(dimX)
        Factors{m} = reshape(p(Count + 1:Count + dimX(m) * F),F,dimX(m))';
        Count      = Count + dimX(m) * F;
    end
    ConvCrit(3,1) = ConvFit  >= Fit_new / SSX;                                % Check if least squares loss function is not too small compared to Frobenius norm of X
    
    % Display some information on iterations 
    if ~isequal(Options.display,'none') && ~rem(it(1),Options.display) && ~any(ConvCrit)
        DisplayIt_PMF3(it(1),Alpha,LPMF3_new,Fit_new,SSX,LamHistory(it(2)),Gamma,Norm_g,NonlinearUpdate,Options)
    end
    
end % End of the outer loop

% Scale the factors according to the common convention
[Factors{end:-1:1}] = scale_factors(1,Factors{end:-1:1});

% Sort the factors according to their norm (in decresing order)
[nil,Seq]             = sort(-sum(Factors{1}.^2));
[Factors{1:ndims(X)}] = FacPerm(Seq,Factors{:});

if nargout > 1
    Diagnostics             = struct('fit',tss(X - nmodel(Factors),false),'it',it,'convergence',ConvCrit);    
end
if strcmpi(Options.diagnostics,'on')  
    
    % Display some details about convergence
    ConvMsg = {'Relative fit decrease'
        'Parameters'' update'
        'Loss function value of less than machine precision'
        'Gradient equal to zero'
        ''
        'Max number of iterations reached'};
    ConvMsg = char(ConvMsg(ConvCrit));
    fprintf('\n The algorithm has converged after %i iterations (%i Hessian computations)',it(1:2))
    fprintf('\n Met convergence criteria: %s',ConvMsg(1,:))
    if size(ConvMsg,1) > 1
        fprintf('\n                           %s',ConvMsg(1,:))
    end   
    fprintf('\n')
    
    figure('name','Diagnostics','number','off')
    %Show the damping parameter and the fit history
    ax = plotyy(1:it(2),LamHistory(1:it(2)),1:it(2),LossHistory(1:it(2)),@semilogy);
    axes(ax(1))
    ylabel('\lambda')
    xlabel('# it')
    axes(ax(2))
    ylabel('L_{PMF3}(\bfA\rm,\bfB\rm,\bfC\rm)')
    title('Lambda and loss function history')
    drawnow
    
end

%----------------------------------------------------------------------------------------------------------------------

function DisplayIt_PMF3(It,Alpha,Loss,Fit,SSX,LambdaOld,Gamma,Gr,NNUD,Options)
AlStr  = num2str(Alpha,'%1.4g');
AlStr  = [AlStr,char(32*ones(1,6-length(AlStr)))];
if NNUD
   AlStr = ['* ',AlStr];
else
   AlStr = ['  ',AlStr];
end
FitStr = sprintf('%12.10f',Fit);
FitStr = [char(32*ones(1,22-length(FitStr))),FitStr];
ItStr  = num2str(It);
ItStr  = [char(32*ones(1,length(num2str(Options.convcrit.maxiter))-length(ItStr))),ItStr];
VarStr = sprintf('%2.4f',100*(1-Fit/SSX));
VarStr = [char(32*ones(1,7-length(VarStr))),VarStr];
LamStr = sprintf('%12.4f',LambdaOld);
LamStr = [char(32*ones(1,17-length(LamStr))),LamStr];
GamStr = sprintf('%1.3e',Gamma);
GamStr = [char(32*ones(1,5-length(GamStr))),GamStr];
GrStr  = sprintf('%8.4f',Gr);
GrStr  = [char(32*ones(1,15-length(GrStr))),GrStr];
fprintf([' ',AlStr,'  ',FitStr,'  ',ItStr,'  ',VarStr,'  ',LamStr,'   ',GamStr,'   ',GrStr]);
fprintf('\n');

%----------------------------------------------------------------------------------------------------------------------

function [AlphaMin,Q_new,Fit_new] = LineSearch_PMF3(X,p,deltap,Gamma,Q_old,Q_new,Fit_old,Fit_new);
% Line search is called only if Alpha == 1 leads to an increased loss function => regx(3,4) > regx(1,4)
Alpha     = 0.5;                                            % Initialise step length
dimX      = size(X);
FactorsU  = cell(ndims(X),1);                               % Initialise cell vector for updated factors (necessary for scaling)
F         = reshape(p,length(p)/sum(dimX),sum(dimX))';      % Factors
U         = reshape(deltap,length(p)/sum(dimX),sum(dimX))'; % Update
cdim      = cumsum(dimX);                                   % Index of loading matrices in F and U
regx      = [0;0.5;1];
regx      = [[ones(3,1),regx,regx.^2,[Q_old;0;Q_new],[Fit_old;0;Fit_new]];[zeros(12,3),inf * ones(12,2)]];

% Loss function with Alpha = 0.5
FU          = F + Alpha * U;
FactorsU{1} = FU(1:dimX(1),:);
for i_mode = 2:length(dimX)
   FactorsU{i_mode} = FU(cdim(i_mode - 1) + 1:cdim(i_mode),:);
end
[FactorsU{:}] = scale_factors(false,FactorsU{:});                                           % Adjust scaling as it influences the penalty
Penalty       = Gamma * tss(cat(1,FactorsU{:}),false);                                      % Compute regularisation term in the loss function
regx(2,5)     = tss(X - reshape(FactorsU{1} * kr(FactorsU{3},FactorsU{2})',dimX),false);    % Least squares loss function
regx(2,4)     = regx(2,5) + Penalty;

Count = 3;
if regx(2,4) > Q_old
    
    % Half the step is still too long and needs further shortening
    Q = regx(2,4);
    while Q > regx(1,4) && Count <= 13
        
        Count       = Count + 1;
        Alpha       = Alpha * 0.6;
        FU          = F + Alpha * U;
        FactorsU{1} = FU(1:dimX(1),:);
        for i_mode = 2:length(dimX)
            FactorsU{i_mode} = FU(cdim(i_mode - 1) + 1:cdim(i_mode),:);
        end
        [FactorsU{:}]   = scale_factors(false,FactorsU{:});                                         % Adjust scaling as it influences the penalty
        Penalty         = Gamma * tss(cat(1,FactorsU{:}),false);                                    % Compute regularisation term in the loss function
        regx(Count,5)   = tss(X - reshape(FactorsU{1} * kr(FactorsU{3},FactorsU{2})',dimX),false);  % Least squares loss function
        Q               = regx(Count,5) + Penalty;
        regx(Count,4)   = Q;
        regx(Count,1:3) = [1,Alpha,Alpha.^2];
        
    end
    if Count == 13 && Q > regx(1,4)
        AlphaMin = 0;
        return
    end
    
end

% Sort points according to the corresponding loss function values
[nil,t] = sort(regx(:,4));
regx    = regx(t,:);

% Add one point between the two smallest fits
Count           = Count + 1;
Alpha           = (regx(1,2) + regx(2,2))/2;
FU              = F + Alpha * U;
FactorsU{1} = FU(1:dimX(1),:);
for i_mode = 2:length(dimX)
    FactorsU{i_mode} = FU(cdim(i_mode - 1) + 1:cdim(i_mode),:);
end
[FactorsU{:}]   = scale_factors(false,FactorsU{:});                                         % Adjust scaling as it influences the penalty
Penalty         = Gamma * tss(cat(1,FactorsU{:}),false);                                    % Compute regularisation term in the loss function
regx(Count,5)   = tss(X - reshape(FactorsU{1} * kr(FactorsU{3},FactorsU{2})',dimX),false);  % Least squares loss function
regx(Count,4)   = regx(Count,5) + Penalty;
regx(Count,1:3) = [1,Alpha,Alpha.^2];

% Sort points according to the corresponding loss function values
[nil,t] = sort(regx(:,4));
regx    = regx(t,:);

% Fit a quadratic model of loss function
reg             = regx(1:Count,1:3)\regx(1:Count,4);
Count           = Count + 1;

if reg(3) == 0 | any(~isfinite(reg))
    regx(Count,4) = inf;
else
    
    AlphaMin    = -reg(2) / (2 * reg(3));   % a*x2 + bx + c = fit => 2ax + b = 0 => x=-b/2a
    FU          = F + AlphaMin * U;
    FactorsU{1} = FU(1:dimX(1),:);
    for i_mode = 2:length(dimX)
        FactorsU{i_mode} = FU(cdim(i_mode - 1) + 1:cdim(i_mode),:);
    end
    [FactorsU{:}]   = scale_factors(false,FactorsU{:});                                         % Adjust scaling as it influences the penalty
    Penalty         = Gamma * tss(cat(1,FactorsU{:}),false);                                    % Compute regularisation term in the loss function
    regx(Count,5)   = tss(X - reshape(FactorsU{1} * kr(FactorsU{3},FactorsU{2})',dimX),false);  % Least squares loss function
    regx(Count,4)   = regx(Count,5) + Penalty;
    regx(Count,1:3) = [1,AlphaMin,AlphaMin.^2];
    
end

% Sort points according to the corresponding loss function values
[nil,t]  = sort(regx(1:Count,4));
regx     = regx(t,:);

% Return best step length in terms of PMF3 loss function
AlphaMin = regx(1,2);
Q_new    = regx(1,4);
Fit_new  = regx(1,5);
