/*
 * Copyright (C) NHR@FAU, University Erlangen-Nuremberg.
 * All rights reserved. This file is part of nusif-solver.
 * Use of this source code is governed by a MIT style
 * license that can be found in the LICENSE file.
 */
#include "comm.h"
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef _MPI
// subroutines local to this module
static int sum(int* sizes, int init, int offset, int coord)
{
    int sum = 0;

    for (int i = init - offset; coord > 0; i -= offset, --coord) {
        sum += sizes[i];
    }

    return sum;
}

static void assembleResult(Comm* c, double* src, double* dst, int imax, int jmax)
{
    MPI_Request* requests;
    int numRequests = 1;

    if (c->rank == 0) {
        numRequests = c->size + 1;
    } else {
        numRequests = 1;
    }

    requests = (MPI_Request*)malloc(numRequests * sizeof(MPI_Request));

    /* all ranks send their bulk array, including the external boundary layer */
    MPI_Datatype bulkType;
    int oldSizes[NDIMS] = { c->jmaxLocal + 2, c->imaxLocal + 2 };
    int newSizes[NDIMS] = { c->jmaxLocal, c->imaxLocal };
    int starts[NDIMS]   = { 1, 1 };

    if (commIsBoundary(c, L)) {
        newSizes[CIDIM] += 1;
        starts[CIDIM] = 0;
    }
    if (commIsBoundary(c, R)) {
        newSizes[CIDIM] += 1;
    }
    if (commIsBoundary(c, B)) {
        newSizes[CJDIM] += 1;
        starts[CJDIM] = 0;
    }
    if (commIsBoundary(c, T)) {
        newSizes[CJDIM] += 1;
    }

    MPI_Type_create_subarray(NDIMS,
        oldSizes,
        newSizes,
        starts,
        MPI_ORDER_C,
        MPI_DOUBLE,
        &bulkType);
    MPI_Type_commit(&bulkType);
    MPI_Isend(src, 1, bulkType, 0, 0, c->comm, &requests[0]);

    int newSizesI[c->size];
    int newSizesJ[c->size];
    MPI_Gather(&newSizes[CIDIM], 1, MPI_INT, newSizesI, 1, MPI_INT, 0, MPI_COMM_WORLD);
    MPI_Gather(&newSizes[CJDIM], 1, MPI_INT, newSizesJ, 1, MPI_INT, 0, MPI_COMM_WORLD);

    /* rank 0 assembles the subdomains */
    if (c->rank == 0) {
        for (int i = 0; i < c->size; i++) {
            MPI_Datatype domainType;
            int oldSizes[NDIMS] = { jmax + 2, imax + 2 };
            int newSizes[NDIMS] = { newSizesJ[i], newSizesI[i] };
            int coords[NDIMS];
            MPI_Cart_coords(c->comm, i, NDIMS, coords);
            int starts[NDIMS] = { sum(newSizesJ, i, 1, coords[JDIM]),
                sum(newSizesI, i, c->dims[JDIM], coords[IDIM]) };
            printf(
                "Rank: %d, Coords(i,j): %d %d, Size(i,j): %d %d, Target Size(i,j): %d %d "
                "Starts(i,j): %d %d\n",
                i,
                coords[IDIM],
                coords[JDIM],
                oldSizes[CIDIM],
                oldSizes[CJDIM],
                newSizes[CIDIM],
                newSizes[CJDIM],
                starts[CIDIM],
                starts[CJDIM]);

            MPI_Type_create_subarray(NDIMS,
                oldSizes,
                newSizes,
                starts,
                MPI_ORDER_C,
                MPI_DOUBLE,
                &domainType);
            MPI_Type_commit(&domainType);

            MPI_Irecv(dst, 1, domainType, i, 0, c->comm, &requests[i + 1]);
            MPI_Type_free(&domainType);
        }
    }

    MPI_Waitall(numRequests, requests, MPI_STATUSES_IGNORE);
}
#endif // defined _MPI

// exported subroutines
int commIsBoundary(Comm* c, int direction)
{
#ifdef _MPI
    switch (direction) {
    case L:
        return c->coords[IDIM] == 0;
        break;
    case R:
        return c->coords[IDIM] == (c->dims[IDIM] - 1);
        break;
    case B:
        return c->coords[JDIM] == 0;
        break;
    case T:
        return c->coords[JDIM] == (c->dims[JDIM] - 1);
        break;
    }
#endif

    return 1;
}

void commExchange(Comm* c, double* grid)
{
#ifdef _MPI
    MPI_Request requests[8];
    for (int i = 0; i < 8; i++)
        requests[i] = MPI_REQUEST_NULL;

    for (int i = 0; i < NDIRS; i++) {
        double* sbuf = grid + c->sdispls[i];
        double* rbuf = grid + c->rdispls[i];

        int tag = 0;
        if (c->neighbours[i] != MPI_PROC_NULL) {
            // printf("DEBUG: Rank %d - SendRecv with %d\n", c->rank, c->neighbours[i]);
            tag = c->neighbours[i];
        }
        MPI_Irecv(rbuf,
            1,
            c->bufferTypes[i],
            c->neighbours[i],
            tag,
            c->comm,
            &requests[i * 2]);
        MPI_Isend(sbuf,
            1,
            c->bufferTypes[i],
            c->neighbours[i],
            c->rank,
            c->comm,
            &requests[i * 2 + 1]);
    }

    MPI_Waitall(8, requests, MPI_STATUSES_IGNORE);
#endif
}

void commShift(Comm* c, double* f, double* g)
{
#ifdef _MPI
    MPI_Request requests[4] = { MPI_REQUEST_NULL,
        MPI_REQUEST_NULL,
        MPI_REQUEST_NULL,
        MPI_REQUEST_NULL };

    /* shift G */
    /* receive ghost cells from bottom neighbor */
    double* buf = g + 1;
    MPI_Irecv(buf,
        1,
        c->bufferTypes[B],
        c->neighbours[B],
        0,
        c->comm,
        &requests[0]);

    /* send ghost cells to top neighbor */
    buf = g + (c->jmaxLocal) * (c->imaxLocal + 2) + 1;
    MPI_Isend(buf, 1, c->bufferTypes[T], c->neighbours[T], 0, c->comm, &requests[1]);

    /* shift F */
    /* receive ghost cells from left neighbor */
    buf = f + (c->imaxLocal + 2);
    MPI_Irecv(buf,
        1,
        c->bufferTypes[L],
        c->neighbours[L],
        1,
        c->comm,
        &requests[2]);

    /* send ghost cells to right neighbor */
    buf = f + (c->imaxLocal + 2) + (c->imaxLocal);
    MPI_Isend(buf,
        1,
        c->bufferTypes[R],
        c->neighbours[R],
        1,
        c->comm,
        &requests[3]);

    MPI_Waitall(4, requests, MPI_STATUSES_IGNORE);
#endif
}

void commCollectResult(Comm* c,
    double* ug,
    double* vg,
    double* pg,
    double* u,
    double* v,
    double* p,
    int imax,
    int jmax)
{
#ifdef _MPI
    /* collect P */
    assembleResult(c, p, pg, imax, jmax);

    /* collect U */
    assembleResult(c, u, ug, imax, jmax);

    /* collect V */
    assembleResult(c, v, vg, imax, jmax);
#endif
}

void commPartition(Comm* c, int jmax, int imax)
{
#ifdef _MPI
    int dims[NDIMS]    = { 0, 0 };
    int periods[NDIMS] = { 0, 0 };
    MPI_Dims_create(c->size, NDIMS, dims);
    MPI_Cart_create(MPI_COMM_WORLD, NDIMS, dims, periods, 0, &c->comm);
    MPI_Cart_shift(c->comm, IDIM, 1, &c->neighbours[L], &c->neighbours[R]);
    MPI_Cart_shift(c->comm, JDIM, 1, &c->neighbours[B], &c->neighbours[T]);
    MPI_Cart_get(c->comm, NDIMS, c->dims, periods, c->coords);

    int imaxLocal = sizeOfRank(c->coords[IDIM], dims[IDIM], imax);
    int jmaxLocal = sizeOfRank(c->coords[JDIM], dims[JDIM], jmax);

    c->imaxLocal = imaxLocal;
    c->jmaxLocal = jmaxLocal;

    MPI_Datatype jBufferType;
    MPI_Type_contiguous(imaxLocal, MPI_DOUBLE, &jBufferType);
    MPI_Type_commit(&jBufferType);

    MPI_Datatype iBufferType;
    MPI_Type_vector(jmaxLocal, 1, imaxLocal + 2, MPI_DOUBLE, &iBufferType);
    MPI_Type_commit(&iBufferType);

    c->bufferTypes[L]   = iBufferType;
    c->bufferTypes[R]  = iBufferType;
    c->bufferTypes[B] = jBufferType;
    c->bufferTypes[T]    = jBufferType;

    c->sdispls[L]   = (imaxLocal + 2) + 1;
    c->sdispls[R]  = (imaxLocal + 2) + imaxLocal;
    c->sdispls[B] = (imaxLocal + 2) + 1;
    c->sdispls[T]    = jmaxLocal * (imaxLocal + 2) + 1;

    c->rdispls[L]   = (imaxLocal + 2);
    c->rdispls[R]  = (imaxLocal + 2) + (imaxLocal + 1);
    c->rdispls[B] = 1;
    c->rdispls[T]    = (jmaxLocal + 1) * (imaxLocal + 2) + 1;
#else
    c->imaxLocal = imax;
    c->jmaxLocal = jmax;
#endif
}

void commUpdateDatatypes(Comm* oldcomm, Comm* newcomm, int imaxLocal, int jmaxLocal)
{
#if defined _MPI
    newcomm->comm = MPI_COMM_NULL;
    int result = MPI_Comm_dup(oldcomm->comm, &newcomm->comm);

    if (result == MPI_ERR_COMM) {
        printf("\nNull communicator. Duplication failed !!\n");
    }

    newcomm->rank = oldcomm->rank;
    newcomm->size = oldcomm->size;

    memcpy(&newcomm->neighbours, &oldcomm->neighbours, sizeof(oldcomm->neighbours));
    memcpy(&newcomm->coords, &oldcomm->coords, sizeof(oldcomm->coords));
    memcpy(&newcomm->dims, &oldcomm->dims, sizeof(oldcomm->dims));

    newcomm->imaxLocal = imaxLocal/2;
    newcomm->jmaxLocal = jmaxLocal/2;

    MPI_Datatype jBufferType;
    MPI_Type_contiguous(imaxLocal, MPI_DOUBLE, &jBufferType);
    MPI_Type_commit(&jBufferType);

    MPI_Datatype iBufferType;
    MPI_Type_vector(jmaxLocal, 1, imaxLocal + 2, MPI_DOUBLE, &iBufferType);
    MPI_Type_commit(&iBufferType);

    newcomm->bufferTypes[L]   = iBufferType;
    newcomm->bufferTypes[R]  = iBufferType;
    newcomm->bufferTypes[B] = jBufferType;
    newcomm->bufferTypes[T]    = jBufferType;

    newcomm->sdispls[L]   = (imaxLocal + 2) + 1;
    newcomm->sdispls[R]  = (imaxLocal + 2) + imaxLocal;
    newcomm->sdispls[B] = (imaxLocal + 2) + 1;
    newcomm->sdispls[T]    = jmaxLocal * (imaxLocal + 2) + 1;

    newcomm->rdispls[L]   = (imaxLocal + 2);
    newcomm->rdispls[R]  = (imaxLocal + 2) + (imaxLocal + 1);
    newcomm->rdispls[B] = 1;
    newcomm->rdispls[T]    = (jmaxLocal + 1) * (imaxLocal + 2) + 1;
#else
    newcomm->imaxLocal = imaxLocal;
    newcomm->jmaxLocal = jmaxLocal;
#endif
}

void commFreeCommunicator(Comm* comm)
{
    #ifdef _MPI
        MPI_Comm_free(&comm->comm);
    #endif
}