function model=scream_pred(varargin)
%Applies a SCREAM model to an unknown set of samples and calculates
%predictions of the dependent block based on the model, as described in F. Marini
%& R. Bro, "SCREAM: a novel method method for multi-way regression problems 
%with shifts and shape changes in one mode", Chemometr. Intell. Lab. Syst."
%
%I/O:   pred=scream_pred(Xnew,ynew, screammodel, plots)
% where Xnew is the independent array X measured on new samples
%       ynew (optional) is the actual value of Y for new samples
%       screammodel is the SCREAM model calculated using scream_mod
%       plots governs the final plotting (default is 'on').
% and   pred is a structure array containing the model predictions (if ynew
%       is input, also the values of bias and RMSEP are calculated).
%
%Examples:
%pred=scream_pred(Xnew,ynew, screammodel, 'off')
%pred=scream_pred(Xnew,screammodel, 'off')
%pred=scream_pred(Xnew,screammodel)
%

X=varargin{1};                          %Independent block X for new samples

if nargin==4                            %if loop over the possible syntaxes
    y=varargin{2};                      %Dependent block X for new samples (optional)
    pcmodel=varargin{3};                %SCREAM model
    plots=varargin{4};                  %plots on/off
elseif nargin==3
    if isstruct(varargin{2})
        y=[];
        pcmodel=varargin{2};
        plots=varargin{3};
    else
        y=varargin{2};
        pcmodel=varargin{3};
        plots='on';
    end
elseif nargin==2
    y=[];
    pcmodel=varargin{2};
    plots='on';
end

        
if ~iscell(X)                           %Transforms X into cell array if needed
  x=cell(1,size(X,3));
  for k = 1:size(X,3)
    x{k} = X(:,:,k);
  end
  X = x;
  clear x
end

I = size(X{1},1);
K = length(X);

%retrieves the needed matrices and parameters from the calculated SCREAM
%model
H=pcmodel.Xloads{2}.H;                  %H matrix 
A=pcmodel.Xloads{3};                    %Loadings along the third mode
W=pcmodel.Xweights;                     %X weights W for the calculation of the C scores
Py=pcmodel.Yloads;                      %Y loadings for prediction of the Y
prntlag=pcmodel.opts.prntlag;
crit=pcmodel.opts.crit;
maxit=pcmodel.opts.maxit;

F=size(A,2);                            %Number of SCREAM components
C = rand(K,F);                          %Initialization of C


%Preallocation of variables
P=cell(1,K);
Xpcov=zeros(I,F,K);
f=nan(maxit);

% Compute P and create 3w array Xpcov
for k = 1:K
    Qk=X{k}'*(A*diag(C(k,:))*H');
    P{k}= Qk*psqrt(Qk'*Qk);
    Xpcov(:,:,k) = X{k}*P{k};
end

%reestimate C
X3w=permute(Xpcov,[3,2,1]);
Xha=reshape(X3w,K,I*F);
C=Xha*W;


%compute the loss function
Z=kron(A,H);
Z=Z(:,1:F+1:F*F);
fit=sum(sum((Xha-C*Z').^2));
fitold=2*fit;

iter=0;



while abs((fit-fitold)/fitold)>crit&&iter<maxit&&fit>10*eps
    iter=iter+1;
    fitold=fit;
   
    for k = 1:K
        Qk=X{k}'*(A*diag(C(k,:))*H');
        P{k}= Qk*psqrt(Qk'*Qk);
        Xpcov(:,:,k) = X{k}*P{k};
    end
    
    %reestimate C
    X3w=permute(Xpcov,[3,2,1]);
    Xha=reshape(X3w,K,I*F);
    C=Xha*W;
    fit=sum(sum((Xha-C*Z').^2));
    
    if rem(iter,prntlag)==0
        fprintf('f = %12.8f after %g iters; \n',fit,iter)
    end
    f(iter)=fit;
end

model.loads{1}=C;
model.loads{2}.P=P;
model.loads{2}.H=H;
model.loads{3}=A;
model.Yloads=Py;
model.Xweights=W;
model.Ypred=C*Py;

if ~isempty(y)                          %if the actual values of Y are provided, then bias and RMSEP can be calculated.
    model.bias=mean(y-model.Ypred);
    model.rmsep=sqrt(sum((y-model.Ypred).^2)./K);
end

if strcmp(plots,'on')
    %First figure: decomposition of the X block
    
    figure
    subplot(2,2,1)
    plot(C, 'LineWidth', 1.5)
    axis tight
    xlabel('Sample Index')
    ylabel('SCREAM components')
    title('Loadings along the sample mode')
    
    subplot(2,2,2)
    if F>1
        plot(C(:,1), C(:,2), '.r', 'MarkerSize', 8)
        xlabel('SCREAM comp.1')
        ylabel('SCREAM comp.2')
        title('Scores plot')
    else
        bar(C(:,1))
        xlabel('Sample Index')
        ylabel('SCREAM comp.1')
        title('Scores plot')
    end
    axis tight
    
    subplot(2,2,3)
    for i=1:K
        plot(P{i}*H, 'LineWidth', 1.5), hold on
    end
    hold off
    axis tight
    xlabel('Variable Index')
    ylabel('SCREAM components')
    title('Loadings along the shifted mode')
    
    
    subplot(2,2,4)
    plot(A, 'LineWidth', 1.5)
    axis tight
    xlabel('Variable Index')
    ylabel('SCREAM components')
    title('Loadings along the third mode')
    
    if ~isempty(y)   %If y is provided: observed vs predicted plots
        
        %Second figure: predictions
        ny1=ceil(sqrt(size(y,2)));    %1st dimension for subplot
        ny2=ceil(size(y,2)./ny1);     %2nd dimension for subplot
        figure
        for j=1:size(y,2)
            subplot(ny2,ny1,j)
            plot(y(:,j), model.Ypred(:,j), '.r', 'MarkerSize', 8)
            axis tight
            xlabel('Y measured')
            ylabel('Y predicted')
            title(['Y variable ' num2str(j)])
        end
    end
    
    
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
function X = psqrt(A,tol)
% Allows to calculate A^(-.5) even in the presence of rank-problems

[U,S,V] = svd(A,0);
if min(size(S)) == 1
   S = S(1);
else
   S = diag(S);
end
if (nargin == 1)
   tol = max(size(A)) * S(1) * eps;
end
r = sum(S > tol);
if (r == 0)
   X = zeros(size(A'));
else
   S = diag(ones(r,1)./sqrt(S(1:r)));
   X = V(:,1:r)*S*U(:,1:r)';
end
