/**
 * An example solution to Assignment 4 in C Programming course.
 *
 * Partial sums are not used, since their benefit in the overall
 * complexity is negligible, and also to make the solution easier
 * to understand.
 *
 * [Scratch that... it's to make _my_ life easier...]
 *
 * Michael Orlov, Dec 1, 2003
 */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>

#define TRUE  1
#define FALSE 0
#define TRIPLE_WIDTH 14

/**
 * Function prototypes.
 */

static int  search(int level, int k);

static int  check_sum(int row, int col);
static int  sum_row(int row);
static int  sum_col(int col);
static int  sum_slash();
static int  sum_backslash();

static void show();
static void print_triple(int a, int b, int c, int width);
static int  print_triple_raw(int a, int b, int c);
static void print_array(const char *header, const int *arr, int n);

static void error(const char *message);

/**
 * Variables accessible to the whole translation unit.
 */

static int n, nn;               // square size and size^2
static int sum;                 // target sum (-,|,/,\)
static int *h, *v, *d;          // arrays of h, v and d values
static int *m;                  // matrix (contains indexes into h,v,d)
static int *used;               // used indexes (boolean per index)

/** 
 * Program entry point.
 *
 * Input/Output as per Assignment 4.
 * 
 * @return program status
 */
int
main()
{
  int i, k, k_found, sum_check;

  /**
   * Get matrix size.
   */

  printf("Matrix size: ");
  scanf("%d", &n);

  if (n <= 0)
    error("bad n");

  nn = n * n;

  printf("\n");

  /**
   * Get requested solutions number.
   */

  printf("Solutions number: ");
  scanf("%d", &k);

  if (k < 0)
    error("bad k");

  printf("\n");

  /**
   * Allocate space for the arrays and the matrix.
   */

  if ((h = malloc(nn * sizeof(int))) == NULL
      || (v = malloc(nn * sizeof(int))) == NULL
      || (d = malloc(nn * sizeof(int))) == NULL
      || (m = malloc(nn * sizeof(int))) == NULL
      || (used = malloc(nn * sizeof(int))) == NULL)
    error("insufficient memory");

  /** 
   * We can't use memset(), since this makes assumptions
   * about the architecture.
   */

  for (i = 0; i < nn; ++i)
    used[i] = FALSE;

  /**
   * Get h, v and d values.
   */

  for (i = 0; i < nn; ++i) {
    printf("Triple %d: ", i+1);
    scanf("%d%d%d", h+i, v+i, d+i);
  }

  printf("\n");

  /**
   * Print h,v,d arrays
   */

  print_array("h values:", h, nn);
  print_array("v values:", v, nn);
  print_array("d values:", d, nn);

  printf("\n");

  /**
   * Calculate sum and check that it is possible
   * for h and v simultaneously.
   */

  // do preliminary checks only if k > 0
  if (k > 0) {
    for (sum = 0, i = 0; i < nn; ++i)
      sum += h[i];

    if (sum % n != 0)
      printf("Failed to find magic square - no integer sum possible.\n");

    else {
      for (sum_check = 0, i = 0; i < nn; ++i)
        sum_check += v[i];

      if (sum != sum_check)
        printf("Failed to find magic square - "
               "horizontal and vertical sums don't match.\n");

      else {
        sum /= n;

        /**
         * Search, and print the first result is successful,
         * failure message if not.
         */
        k_found = search(0, k);

        printf("Found %d solutions.\n", k - k_found);
        
        if (k_found != 0)
          printf("Failed to find %d magic squares.\n", k_found);
      }
    }
  }
  else
    printf("Found 0 solutions.\n");

  /**
   * Free memory.
   */
  free(m);
  free(d);
  free(v);
  free(h);
}

/** 
 * Recursively search for a magic square.
 *
 * Sum checks are made as soon as possible, so that if level n^2
 * is reached, it means that the square is magic.
 * 
 * @param level current level (0 at the beginning).
 * 
 * @return number of remained solutions
 */
int
search(int level, int k)
{
  int index;                            // new index
  int row = level / n, col = level % n; // 0-based row and column

  // if we reached last level, print the solution
  if (level == nn) {
    show();
    --k;
  }
  // search if didn't reach nn
  else
    for (index = 0; k > 0 && index < nn; ++index)
      // try only unused indexes
      if (! used[index]) {
        m[level] = index;

        // check at most one row, at most one column,
        // and at most one diagonal
        if (! check_sum(row, col))
          continue;

        used[index] = TRUE;
        k           = search(level + 1, k);
        used[index] = FALSE;
      }

  // return the level to which we got
  return k;
}

/** 
 * Checks sums of at most one row, at most one column, and at most
 * one diagonal, if current (row,col) position is final in them.
 *
 * If this is the first row, sum is set and not checked against.
 *
 * Note that the whole function could be written as one long
 * boolean expression...
 * 
 * @param row current entry row
 * @param col current entry column
 * 
 * @return TRUE - sum is ok, FALSE - sum doesn't match
 */
int
check_sum(int row, int col)
{
  // check at most one row
  if (col == n-1 && sum != sum_row(row))
    return FALSE;

  // check at most one column
  if (row == n-1 && sum != sum_col(col))
    return FALSE;

  // check at most one diagonal
  if (row == n-1) {
    if (col == 0 && sum != sum_slash())
      return FALSE;
    else if (col == n-1 && sum != sum_backslash())
      return FALSE;
  }

  // if all checks ok, return success
  return TRUE;
}

/** 
 * Sum given row.
 * 
 * @param row which row (0-based)
 * 
 * @return the sum
 */
int
sum_row(int row)
{
  int sum = 0;                  // sum accumulator
  int i;                        // temporary

  for (i = 0; i < n; ++i)
    sum += h[m[row*n + i]];

  return sum;
}

/** 
 * Sum given column.
 * 
 * @param col which column (0-based)
 * 
 * @return the sum
 */
int
sum_col(int col)
{
  int sum = 0;                  // sum accumulator
  int i;                        // temporary

  for (i = 0; i < n; ++i)
    sum += v[m[i*n + col]];

  return sum;
}

/** 
 * Sum [/] diagonal
 * 
 * @return the sum
 */
int
sum_slash()
{
  int sum = 0;                  // sum accumulator
  int index;                    // index temporary
  int i;                        // temporary

  for (i = 0; i < n; ++i) {
    index = i*n + (n-1 - i);
      
    sum += d[m[index]];
  }

  return sum;
}

/** 
 * Sum [\] given diagonal
 * 
 * @return the sum
 */
int
sum_backslash()
{
  int sum = 0;                  // sum accumulator
  int index;                    // index temporary
  int i;                        // temporary

  for (i = 0; i < n; ++i) {
    index = i*n + i;
      
    sum += d[m[index]];
  }

  return sum;
}

/** 
 * Prints the matrix in tabular form.
 */
void
show()
{
  int i, j;                     // temporaries
  int index = 0;                // flat index into m

  for (i = 0; i < n; ++i) {
    for (j = 0; j < n; ++j, ++index)
      print_triple(h[m[index]], v[m[index]], d[m[index]],
                   (j+1 < n) ? TRIPLE_WIDTH : 0);

    putchar('\n');
  }

  putchar('\n');
}

/** 
 * Prints a triple as (a,b,c), adding necessary number of spaces
 * 
 * @param a 
 * @param b 
 * @param c 
 * @param width exact width (0 - do not add spaces)
 */
void
print_triple(int a, int b, int c, int width)
{
  int actual;                   // actual triple width
  int i;                        // temporary
  
  if (width == 0)
    print_triple_raw(a, b, c);
  else {
    actual = print_triple_raw(a, b, c);

    for (i = 0; i < width - actual; ++i)
      putchar(' ');
  }
}

/** 
 * Prints either (a,b,c) or a [if a=b=c].
 * 
 * @param a 
 * @param b 
 * @param c 
 * 
 * @return number of characters printed
 */
int
print_triple_raw(int a, int b, int c)
{
  if (a == b && b == c)
    return printf("%d", a);
  else
    return printf("(%d,%d,%d)", a, b, c);
}

/** 
 * Prints array in the form
 * header [a1, a2, ..., a_n]
 * 
 * @param header 
 * @param arr array to print
 * @param n number of elements to print
 */
void
print_array(const char *header, const int *arr, int n)
{
  int i;
  
  printf("%s [", header);

  for (i = 0; i < n-1; ++i)
    printf("%d, ", arr[i]);

  if (n > 0)
    printf("%d", arr[i]);

  printf("]\n");
}


void
error(const char *message)
{
  fprintf(stderr, "Error: ");
  fprintf(stderr, message);
  fprintf(stderr, ".\n");

  exit(1);
}

