from gpaw import GPAW, PW, FermiDirac, MixerSum, Mixer
from ase.io import read, Trajectory
from ase.neb import NEB
from ase.constraints import FixAtoms
from ase.optimize.bfgs import BFGS
from ase.visualize import view
from ase.parallel import rank, size
from gpaw.utilities import h2gpts
from ase.optimize import QuasiNewton

nimages = int(5)
#run='5'

n = size // nimages      # number of cpu's per image
j = 1 + rank // n  # my image number
assert nimages * n == size

initial = read("../initial.traj")
#path = read("path.traj@:")

final = read("../final.traj")

images = [initial]
for i in range(nimages):
    ranks = range(i * n, (i + 1) * n)
#    image = path[i]
    image = initial.copy()
    if rank in ranks:
        calc=GPAW(h=0.18,
                  kpts=(3,3,1),
                  xc='RPBE',
                  communicator=ranks,
                  txt='neb%d.txt' % j)
        
#calc = GPAW(gpts=h2gpts(0.18, image.get_cell(), idiv=16),
#                    kpts = (3,3,1),
#                    xc = "RPBE",
#                    occupations=FermiDirac(0.1, fixmagmom=False),
#                    mixer=Mixer(0.02, 5, weight=100.0),
##                   symmetry='off',
#                    maxiter = 400,
#                    spinpol = False,
#                    communicator=ranks,
#                    txt='neb%d.txt' % j)
        image.set_calculator(calc)
    images.append(image)
images.append(final)

a1 = read('../run5/neb1.traj')
a2 = read('../run5/neb2.traj')
a3 = read('../run5/neb3.traj')
a4 = read('../run5/neb4.traj')
a5 = read('../run5/neb5.traj')
#a6 = read('neb6.traj')
#a7 = read('neb7.traj')

images[1].set_positions(a1.get_positions())
images[2].set_positions(a2.get_positions())
images[3].set_positions(a3.get_positions())
#images[4].set_positions((a3.get_positions()+a4.get_positions())/2.)
images[4].set_positions(a4.get_positions())
images[5].set_positions(a5.get_positions())
#images[7].set_positions(a6.get_positions())
#images[8].set_positions(a7.get_positions())



neb = NEB(images, parallel=True,climb=True)
#neb.interpolate()

optimizer = BFGS(neb, logfile='neb.log')

traj = Trajectory('neb%s.traj' %j, 'w', images[j],
                  master=(rank %n == 0))

optimizer.attach(traj)
optimizer.run(fmax=0.05)
