#!/bin/env python

from mpi4py import MPI
from numpy import char
import sys

MAX_MSG_SIZE = (1<<22)

skip = 1000
loop = 10000
skip_large = 10
loop_large = 100
large_message_size = 8192

comm = MPI.COMM_WORLD
myid = comm.Get_rank()

if myid == 0:
  print "MPI Latency Test\n"
  sys.stdout.flush()
  print "Size\t\tLatency (us) \n"
  sys.stdout.flush()

size = 0 
while size <= MAX_MSG_SIZE:

  sbuf = char.chararray(size)
  rbuf = char.chararray(size)

  for i in range(0, size):
    sbuf[i] = 'a'
    rbuf[i] = 'b'

  s_buf = [sbuf, size, MPI.CHAR]
  r_buf = [rbuf, size, MPI.CHAR]

  if size > large_message_size:
    loop = loop_large
    skip = skip_large

  comm.Barrier()

  if myid == 0:
    for i in range(0, loop+skip):
      if i == skip:
        t_start = MPI.Wtime()
      comm.Send(s_buf, 1, 0)
      comm.Recv(r_buf, 1, 1)
    t_end = MPI.Wtime()
  elif myid == 1:
    for i in range(0, loop+skip):
      comm.Recv(r_buf, 0, 0)
      comm.Send(s_buf, 0, 1)

  if myid == 0:
    latency = (t_end - t_start) * 1.0e6 / (2.0 * loop)
    print "%d\t\t%0.2f" % (size, latency)
    sys.stdout.flush()

  del s_buf, sbuf
  del r_buf, rbuf

  if size == 0:
    size += 1
  else :
    size *= 2
