function [Factors,Diag] = swatld(varargin)
% Fit a PARAFAC model to a three-way array using the SWATLD algorithm
%
% SYNTAX
% [Factors,Diag] = SWATLD (X,F,Options,A,B,C);
% 
% INPUTS
% X      : data array
% F      : model's rank
% Options: see ParOptions
% A,B,C  : initial estimates for the loading matrices
%
% OUTPUTS
% Factors: cell vectors with the loading matrices (NB Sign permutation is
%          not fixed according to any convention)
% Diag   : structure with some diagnostics
%          Diag.fit(1)        : value of the loss function at convergence
%          Diag.fit(2)        : not used
%          Diag.fit(3)        : not used
%          Diag.it(1)         : total number of iterations
%          Diag.it(2)         : not used
%          Diag.it(3)         : not used
%          Diag.it(4)         : not used
%          Diag.convergence(1): relative fit decrease
%          Diag.convergence(2): not used
%          Diag.convergence(3): relative loss function value
%          Diag.convergence(4): not used
%          Diag.convergence(5): not used
%          Diag.convergence(6): max number of iterations reached
%          Diag.convergence(7): not used
%
% Author: Giorgio Tomasi 
%         Royal Agricultural and Veterinary University 
%         Rolighedsvej 30 
%         DK-1958 Frederiksberg C 
%         Denmark 
% 
% Last modified: 09-Jan-2005
% 
% Contact: Giorgio Tomasi, gt@kvl.dk; Rasmus Bro, rb@kvl.dk 
% 
% Reference: "A novel trilinear decomposition algorithm for second-order linear calibration"
%            Chen et al. Chemometrics and Intelligent Laboratory Systems, 52 (2000) 75-86.
%            "A comparison of algorithms for fitting the PARAFAC model"
%            G. Tomasi, R. Bro, Computational Statistics and Data Analysis, in press
%

if ~nargin
    help swatld
    return
end

%Check input values
[X,F,Options,Factors] = Check_GenParafac_Input(varargin{:});
if isempty(F)
   [Factors,Diag] = deal([]);
   return
end

% Compute initial estimates if not given
if isempty(Factors)                                         
    [Factors{1:ndims(X)}] = InitPar(X,F,Options);
end

% Initialization:
dimX     = size(X);                                                             % Array's dimensions vector
[A,B,C]  = deal(Factors{:});                                                    % Change format from cell to double
Diag     = struct('fit',zeros(3,1),'it',zeros(4,1),'convergence',false(7,1));   % Initialise diagnostics structure  
X        = reshape(X,dimX(1),dimX(2) * dimX(3));                                % Initialise permutation scheme for X
SSX      = tss(X,false);                                                        % Total Sum of Squares
FitNew   = tss(X - A * kr(C,B)',false);                                         % Initial fit
it       = 0;
C        = normit(C);                                                           % Necessary for the first iteration
pinvC    = pinv(C);

if ~isequal(Options.display,'none')                                             % Display titles for iteration's information     
    fprintf('\n\n            Fit           It      EV%%\n')
end
% Iterative refinement:
while ~any(Diag.convergence)
    
    FitOld = FitNew; 

    % Compute B
    X      = reshape(X',dimX(2),dimX(3) * dimX(1));             % Permute array to have mode 2 as first
    A      = normit(A);                                         % Column-wise normalise A so that diag(A'*A) = diag(C' * C)
    pinvA  = pinv(A);                                           % Compute pseudoinverse of A
    B      = .5 * (X * (kr(A,pinvC') + kr(pinvA',C)));          % New estimate of B
    
    % Compute C
    X      = reshape(X',dimX(3),dimX(1) * dimX(2));             % Permute array to have mode 3 as first
    B      = normit(B);                                         % Column-wise normalise B so that diag(A'*A) = diag(B' * B)
    pinvB  = pinv(B);                                           % Compute pseudoinverse of C
    C      = .5 * (X * (kr(B,pinvA') + kr(pinvB',A))); 
    
    % Compute A
    X      = reshape(X',dimX(1),dimX(2) * dimX(3));             % Permute array to have mode 1 as first
    C      = normit(C);                                         % Column-wise normalise C so that diag(B'*B) = diag(C' * C)
    pinvC  = pinv(C);                                           % Compute pseudoinverse of C
    A      = .5 * (X * (kr(pinvC',B) + kr(C,pinvB')));          % New estimate of A
    
    it     = it + 1;                                            % Update n. of iterations
    FitNew = tss(X - A * kr(C,B)',false);                       % Compute least squares loss function
    
    % Check convergence
    Diag.convergence(1) = (abs(FitNew - FitOld) / FitOld) <= Options.convcrit.relfit;   % Relative fit decrease
    Diag.convergence(3) = (FitNew / SSX) <= Options.convcrit.fit;                       % Value of Loss function compared to SSX
    Diag.convergence(6) = it > Options.convcrit.maxiter;                                % N of iterations
    if ~isequal(Options.display,'none') && ~rem(it(1),Options.display) && ~any(Diag.convergence)
        DisplayIt_SWATLD(it(1),FitNew,SSX,Options)
    end
    
end
% Save diagnostics
Diag.fit   = FitNew;
Diag.it(1) = it;

% Scale the factors according to the common convention
[Factors{3:-1:1}] = scale_factors(1,C,B,A);

% Sort the factors according to their norm (in decresing order)
[nil,Seq]                    = sort(-sum(Factors{1}.^2));
[Factors{1:length(Factors)}] = FacPerm(Seq,Factors{:});

if strcmpi(Options.diagnostics,'on')  
    
    % Display some details about convergence
    ConvMsg = {'Relative fit decrease'
        ''
        'Loss function value of less than machine precision'
        ''
        ''
        'Max number of iterations reached'};
    ConvMsg = char(ConvMsg(Diag.convergence));
    fprintf('\n The algorithm has converged after %i iterations',it(1))
    fprintf('\n Met convergence criteria: %s',ConvMsg(1,:))
    if size(ConvMsg,1) > 1
        fprintf('\n                           %s',ConvMsg(1,:))
    end   
    fprintf('\n')
    drawnow
    
end

%----------------------------------------------------------------------------------------------------------------------

function DisplayIt_SWATLD(It,Fit,SSX,Options)
FitStr = sprintf('%12.10f',Fit);
FitStr = [char(32*ones(1,22-length(FitStr))),FitStr];
ItStr  = num2str(It);
ItStr  = [char(32*ones(1,length(num2str(Options.convcrit.maxiter))-length(ItStr))),ItStr];
VarStr = sprintf('%2.4f',100*(1-Fit/SSX));
VarStr = [char(32*ones(1,7-length(VarStr))),VarStr];
fprintf([' ',FitStr,'  ',ItStr,'  ',VarStr]);
fprintf('\n');
