#include <math.h>
#include <maya/MIOStream.h>
#include <maya/MSimple.h>
#include <maya/MTimer.h>
#include <maya/MGlobal.h>
#include <maya/MThreadPool.h>
DeclareSimpleCommand( threadTestCmd, PLUGIN_COMPANY, "2008");
typedef struct _threadDataTag
{
    int threadNo;
    long primesFound;
    long start, end;
} threadData;
typedef struct _taskDataTag
{
    long start, end, totalPrimes;
} taskData;
#define NUM_TASKS   16
static bool TestForPrime(int val)
{
    int limit, factor = 3;
    limit = (long)(sqrtf((float)val)+0.5f);
    while( (factor <= limit) && (val % factor))
        factor ++;
    return (factor > limit);
}
MThreadRetVal Primes(void *data)
{
    threadData *myData = (threadData *)data;
    for( int i = myData->start + myData->threadNo*2; i <= myData->end; i += 2*NUM_TASKS )
    {
        if( TestForPrime(i) )
            myData->primesFound++;
    }
    return (MThreadRetVal)0;
}
void DecomposePrimes(void *data, MThreadRootTask *root)
{
    taskData *taskD = (taskData *)data;
    
    threadData tdata[NUM_TASKS];
    for( int i = 0; i < NUM_TASKS; ++i )
    {
        tdata[i].threadNo    = i;
        tdata[i].primesFound = 0;
        tdata[i].start       = taskD->start;
        tdata[i].end         = taskD->end;
        MThreadPool::createTask(Primes, (void *)&tdata[i], root);
    }
    MThreadPool::executeAndJoin(root);
    for( int i = 0; i < NUM_TASKS; ++i )
    {
        taskD->totalPrimes += tdata[i].primesFound;
    }
}
int SerialPrimes(int start, int end)
{
    int primesFound = 0;
    for( int i = start; i <= end; i+=2)
    {
        if( TestForPrime(i) )
            primesFound++;
    }
    return primesFound;
}
int ParallelPrimes(int start, int end)
{
    MStatus stat = MThreadPool::init();
    if( MStatus::kSuccess != stat ) {
        MString str = MString("Error creating threadpool");
        MGlobal::displayError(str);
        return 0;
    }
    taskData tdata;
    tdata.totalPrimes = 0;
    tdata.start       = start;
    tdata.end         = end;
    MThreadPool::newParallelRegion(DecomposePrimes, (void *)&tdata);
    
    MThreadPool::release();
    
    MThreadPool::release();
    return tdata.totalPrimes;
}
MStatus threadTestCmd::doIt( const MArgList& args )
{
    MString introStr = MString("Computation of primes using the Maya API");
    MGlobal::displayInfo(introStr);
    if(args.length() != 2) {
        MString str = MString("Invalid number of arguments, usage: threadTestCmd 1 10000");
        MGlobal::displayError(str);
        return MStatus::kFailure;
    }
    MStatus stat;
    int start = args.asInt( 0, &stat );
    if ( MS::kSuccess != stat ) {
        MString str = MString("Invalid argument 1, usage: threadTestCmd 1 10000");
        MGlobal::displayError(str);
        return MStatus::kFailure;
    }
    int end = args.asInt( 1, &stat );
    if ( MS::kSuccess != stat ) {
        MString str = MString("Invalid argument 2, usage: threadTestCmd 1 10000");
        MGlobal::displayError(str);
        return MStatus::kFailure;
    }
    
    if((start % 2) == 0 ) start++;
    
    MTimer timer;
    timer.beginTimer();
    int serialPrimes = SerialPrimes(start, end);
    timer.endTimer();
    double serialTime = timer.elapsedTime();
    
    timer.beginTimer();
    int parallelPrimes = ParallelPrimes(start, end);
    timer.endTimer();
    double parallelTime = timer.elapsedTime();
    
    if ( serialPrimes != parallelPrimes ) {
        MString str("Error: Computations inconsistent");
        MGlobal::displayError(str);
        return MStatus::kFailure;
    }
    
    if(parallelTime>0.0) {
      double ratio = serialTime/parallelTime;
      MString str = MString("\nElapsed time for serial computation: ") + serialTime + MString("s\n");
      str += MString("Elapsed time for parallel computation: ") + parallelTime + MString("s\n");
      str += MString("Speedup: ") + ratio + MString("x\n");
      MGlobal::displayInfo(str);
    } else {
      MString str = MString("\nParallel time zero, no scaling measurement possible\n");
      MGlobal::displayInfo(str);
    }
    return MStatus::kSuccess;
}