#!/bin/env python

from mpi4py import MPI 
from numpy import char
import sys

MAX_MSG_SIZE = (1<<22)

loop = 100
window_size = 64
skip = 10

loop_large = 20
window_size_large = 64
skip_large = 2

large_message_size = 8192

request = MPI.Request()
req_buf = []
for i in range(0, window_size):
  inst = MPI.Request()
  req_buf.append(inst)

comm = MPI.COMM_WORLD
myid = comm.Get_rank()

sack_buf = char.chararray(4)
rack_buf = char.chararray(4)

for i in range(0, 4):
  sack_buf[i] = 'a'
  rack_buf[i] = 'b'

s_ack_buf = [sack_buf, 4, MPI.CHAR]
r_ack_buf = [rack_buf, 4, MPI.CHAR]

if myid == 0:
  print "MPI Bandwidth Test\n"
  sys.stdout.flush()
  print "Size\t\tBandwidth (MB/s) \n"
  sys.stdout.flush()

size = 1
while size <= MAX_MSG_SIZE:

  sbuf = char.chararray(size)
  rbuf = char.chararray(size)

  for i in range(0, size):
    sbuf[i] = 'a'
    rbuf[i] = 'b'

  s_buf = [sbuf, size, MPI.CHAR]
  r_buf = [rbuf, size, MPI.CHAR]

  if size > large_message_size:
    loop = loop_large
    skip = skip_large
    window_size = window_size_large

  if myid == 0:
    for i in range(0, loop+skip):
      if i == skip:
        t_start = MPI.Wtime()
      for j in range(0, window_size):
        req_buf[j] = comm.Isend(s_buf, 1, 0)
      request.Waitall(req_buf)
      comm.Recv(r_ack_buf, 1, 1)
    t_end = MPI.Wtime()
    t = t_end - t_start
  elif myid == 1:
    for i in range(0, loop+skip):
      for j in range(0, window_size):
        req_buf[j] = comm.Irecv(r_buf, 0, 0)
      request.Waitall(req_buf)
      comm.Send(s_ack_buf, 0, 1)

  if myid == 0:
    tmp = ((size*1.0)/(1000*1000))*loop*window_size
    print "%d\t\t%0.2f" % (size, tmp/t)
    sys.stdout.flush() 

  del s_buf, sbuf
  del r_buf, rbuf

  size *= 2
