#!/bin/env python from mpi4py import MPI from numpy import char import sys MAX_MSG_SIZE = (1<<22) skip = 1000 loop = 10000 skip_large = 10 loop_large = 100 large_message_size = 8192 comm = MPI.COMM_WORLD myid = comm.Get_rank() if myid == 0: print "MPI Latency Test\n" sys.stdout.flush() print "Size\t\tLatency (us) \n" sys.stdout.flush() size = 0 while size <= MAX_MSG_SIZE: sbuf = char.chararray(size) rbuf = char.chararray(size) for i in range(0, size): sbuf[i] = 'a' rbuf[i] = 'b' s_buf = [sbuf, size, MPI.CHAR] r_buf = [rbuf, size, MPI.CHAR] if size > large_message_size: loop = loop_large skip = skip_large comm.Barrier() if myid == 0: for i in range(0, loop+skip): if i == skip: t_start = MPI.Wtime() comm.Send(s_buf, 1, 0) comm.Recv(r_buf, 1, 1) t_end = MPI.Wtime() elif myid == 1: for i in range(0, loop+skip): comm.Recv(r_buf, 0, 0) comm.Send(s_buf, 0, 1) if myid == 0: latency = (t_end - t_start) * 1.0e6 / (2.0 * loop) print "%d\t\t%0.2f" % (size, latency) sys.stdout.flush() del s_buf, sbuf del r_buf, rbuf if size == 0: size += 1 else : size *= 2