4
\$\begingroup\$

I came across the below problem in a coding challenge: Let's define 2 functions, \$F(i)\$ and \$Val(i,j)\,ドル for an array \$A\$ of size \$N\$ as follows:

$$ \begin{align} F(i) =& \sum_{j=i+1}^{N} Val(i,j) \\ \\ Val(i, j) =& \begin{cases} 1,\qquad & \textrm{if } A[i] < A[j] \\ 0,\qquad & \textrm{otherwise} \end{cases} \end{align} $$ I need to find the number of distinct unordered pairs of elements (a,b) in array \$A\,ドル such that \$F(a)+F(b) \ge K\$.

Here is my solution:

#include<iostream>
#include<cassert>
#define max 1000000000
int main()
{
 unsigned int N, K;
 std::cin >> N;
 assert((1 <= N) && (N <= 1000000));
 std::cin >> K;
 assert((0 <= K) && (K <= max));
 unsigned int a[N], F[N];
 for(unsigned int i = 0; i < N; i++)
 {
 std::cin >> a[i];
 assert((1 <= a[i]) && (a[i] <= max));
 }
 int sum = 0, val = 0;
 for(unsigned int i = 0; i < N; i++)
 {
 for(unsigned int j = i+1; j < N; j++)
 {
 if(a[i] < a[j])
 val = 1;
 else 
 val = 0;
 sum = sum + val;
 }
 F[i] = sum;
 sum = 0;
 }
 int count = 0;
 for(unsigned int i = 0; i < N; i++)
 {
 for(unsigned int j = i+1; j < N; j++)
 {
 if(F[i] + F[j] >= K)
 count++;
 }
 }
 std::cout << count << std::endl;
}

While this works, I would want to know if there is a way to write optimized code which works fast even for larger inputs.

200_success
146k22 gold badges190 silver badges479 bronze badges
asked Jul 2, 2016 at 19:42
\$\endgroup\$
3
  • \$\begingroup\$ I don't have time to write a review, but you should be able to do this in \$O(n \log n)\$ instead of \$O(n^2)\$ time. Hint: if you sort the initial array, you should be able to compute F in linear time. Hint 2: Once you get F, if you sort F, you should be able to compute the final answer in linear time also. \$\endgroup\$ Commented Jul 3, 2016 at 5:56
  • \$\begingroup\$ Sorting the initial array would not work since the order of the elements is important while calculating F. After sorting the order of the elements and hence the value of F would change. \$\endgroup\$ Commented Jul 3, 2016 at 17:48
  • 1
    \$\begingroup\$ I added a review that demonstrates how the problem could be solved in linearithmic time. \$\endgroup\$ Commented Jul 4, 2016 at 8:47

2 Answers 2

1
\$\begingroup\$

The clues in the question

Before I get started with the actual review, I'd like to give a general tip for solving programming challenges. With most of these challenges, the problem creator has a specific algorithmic solution in mind. This solution is usually not the brute force method, but something better. If you know the clues to look for, you can often deduce what kind of solution they are looking for.

In this problem, the maximum value of \$N\$ is given to be 1000000. The straightforward solution is \$O(n^2)\$ (as given by the OP). Notice that this causes the solution to run for approximately \10ドル^{12}\$ iterations. Even at a 3 GHz clock rate and 1 iteration per clock, this would take 333 seconds, which is very slow.

If the solution were \$O(n \log n)\,ドル then it would run for approximately 24 million iterations, which is much more reasonable. So it is likely that this is what the problem creator had in mind. Note that if the solution were meant to be linear, then the maximum \$N\$ would probably be pushed higher (probably by about 100 times) in order to make an \$O(n \log n)\$ solution not viable.

A quick review

Before delving into the problem and solution, I'll give a quick review of the code provided.

  • The code is relatively straightforward and solves the problem in the simplest way.
  • There are some indentation issues with one of the if statements.
  • I would have liked to see a function rather than have everything done in main().
  • The variable length arrays that you use for a and F are dangerous. In fact I had to change them to be allocated in order to not get stack overflow (because I was using very large values for N).

Splitting up the problem

The problem is actually composed of two parts. The first part is to compute the array F, as described by the problem. The second part is to compute the number of pairs in F such that F[i] + F[j] >= K.

The first part

To solve the first part on \$O(n \log n)\$ time, you can use a Binary Indexed Tree. A binary indexed tree can be used to track cumulative sums in \$O(\log n)\$ time. A cumulative sum is like this:

\$sum(n) = \sum_{i=0}^{n-1} Array[i]\$

Why do you need cumulative sums to compute F? Remember that F[i] contains the number of array elements A[j] (j > i) where A[j] > A[i]. Suppose you had an array of size N where the values were in the range 0..N-1. For example:

1 2 0 2 3

The F array we are trying to compute would be:

3 1 2 1 0

One way of computing F (without a binary indexed tree) would be to start at the end of the array and work backwards. For each array element, you take the value of the element and update a count array for each index lower than the value. For example, the last element of the array has value 3. This means you should do count[j]++ for each j in the range 0..2. As you reach each array location, you can compute F for that location by looking at the count array. For example, when you are at i=1, then array[1] is 2, and F[1] is simply count[2], because every time a number to its right was greater than 2, we added one to count[2]. The algorithm would work like this:

Numbers in parentheses are examined/changed on each step:
arr = 1 2 0 2 3 count = 0 0 0 0 0 F = 0 0 0 0 0
arr = 1 2 0 2 (3) count = (1)(1)(1) 0 0 F = 0 0 0 0 (0)
arr = 1 2 0 (2) 3 count = (2)(2) 1 0 0 F = 0 0 0 (1) 0
arr = 1 2 (0) 2 3 count = 2 2 1 0 0 F = 0 0 (2) 1 0
arr = 1 (2) 0 2 3 count = (3)(3) 1 0 0 F = 0 (1) 2 1 0
arr = (1) 2 0 2 3 count = (4) 3 1 0 0 F =(3) 1 2 1 0

And the code would look like this:

for (i=n-1; i>=0; i--) {
 int val = array[i];
 F[i] = count[val];
 for (j=0; j<val; j++)
 count[j]++;
}

Although the above solution works:

  1. If the maximum value in the array is \$M\,ドル you need a count array of size \$M\$. There was no limit to the array values in the question description.
  2. The above solution is still \$O(n^2)\$.

And the workarounds for these problems are:

  1. You can "normalize" an array of size N to contain only values 0..N-1 if you sort the array and then replace each array element with its index in the sorted array (after removing duplicate entries). The normalized array has the property that Norm[i] compares the same way to Norm[j] as A[i] compares to A[j].
  2. Instead of using a linear time j loop to add one to multiple array entries, you can use a binary indexed tree to do the same thing in logarithmic time. Retrieving the cumulative sum is also done in logarithmic time.

The second part

Once you have the F array, you must count how many pairs of indices (i,j) satisfy the condition F[i] + F[j] >= K. To do this in \$O(n \log n)\$ time, you first sort the F array. Then you can compute the number of pairs by working backwards through the array with one index (call it i), and working forwards through the array with another index (call it j). At each array index i working backwards, you move j forwards until F[i] + F[j] >= K. At that point, you can add i - j to the count, because F[i] + F[k] >= K for all j <= k < i. Since you never back up either index, and you stop when i and j touch, this part runs in linear time, after sorting the array in linearithmic time.

Putting it all together

The final code is quite a bit bigger than the OP's program, but it runs in linearithmic time. I wrote it in C instead of C++ just because I prefer C. As a comparison, I ran it with N being the maximum of 1000000 and compared to the original program:

JS1 : 0.31 sec
Original: 680.00 sec

And here is the program:

#include <stdio.h>
#include <stdlib.h>
#define SWAP(arr, i, j) \
 do { \
 int tmp; \
 tmp = arr[i]; \
 arr[i] = arr[j]; \
 arr[j] = tmp; \
 } while(0)
static void *my_calloc(size_t nelem, size_t elsize);
static int countPairs(const int *arr, int N, int K);
static int *createNormalizedArray(const int *arr, int N);
static void my_qsort(int *arr, int *indexArr, int low, int high);
static int *computeF(const int *arr, int N);
static int computeCount(int *F, int N, int K);
int main(int argc, char *argv[])
{
 int N;
 int K;
 int *arr;
 if (argc < 4) {
 printf("Usage: %s N K seed\n");
 return 0;
 }
 N = atoi(argv[1]);
 K = atoi(argv[2]);
 srand(atoi(argv[3]));
 arr = my_calloc(N, sizeof(*arr));
 for(int i=0;i<N;i++)
 arr[i] = rand();
 printf("Count = %d\n", countPairs(arr, N, K));
 free(arr);
 return 0;
}
static int countPairs(const int *arr, int N, int K)
{
 int i;
 int *normalizedArr;
 int *F;
 int count;
 normalizedArr = createNormalizedArray(arr, N);
 F = computeF(normalizedArr, N);
 free(normalizedArr);
 count = computeCount(F, N, K);
 free(F);
 return count;
}
/**
 * Given an array of size N, return a "normalized array" where each array
 * element is replaced with its index in a sorted array with duplicates
 * removed.
 *
 * Initial array: 6 4 4 7 2
 * Sorted array : 2 4 6 7 (duplicates removed)
 * Output array : 2 1 1 3 0 (index of each element in sorted array).
 */
static int *createNormalizedArray(const int *arr, int N)
{
 int i;
 int *sortedArr = my_calloc(N, sizeof(*sortedArr));
 int *sortedIndex = my_calloc(N, sizeof(*sortedIndex));
 int *outputArr = my_calloc(N, sizeof(*outputArr));
 int numDuplicates = 0;
 // Sort a copy of the array and the index array together.
 for (i=0;i<N;i++) {
 sortedArr[i] = arr[i];
 sortedIndex[i] = i;
 }
 my_qsort(sortedArr, sortedIndex, 0, N-1);
 // From the sorted indexes, we can create the output array.
 for (i=0;i<N-1;i++) {
 outputArr[sortedIndex[i]] = i - numDuplicates;
 if (sortedArr[i] == sortedArr[i+1])
 numDuplicates++;
 }
 outputArr[sortedIndex[N-1]] = N - 1 - numDuplicates;
 free(sortedArr);
 free(sortedIndex);
 return outputArr;
}
/**
 * Sorts arr in ascending order, while also swapping the same
 * elements in indexArr.
 */
static void my_qsort(int *arr, int *indexArr, int low, int high)
{
 if (low >= high)
 return;
 int pivot = arr[low];
 int i = low - 1;
 int j = high + 1;
 while (1) {
 int tmp;
 while (arr[++i] < pivot);
 while (arr[--j] > pivot);
 if (i >= j)
 break;
 SWAP(arr, i, j);
 if (indexArr != NULL)
 SWAP(indexArr, i, j);
 }
 my_qsort(arr, indexArr, low, j);
 my_qsort(arr, indexArr, j+1, high);
}
/**
 * Add val to binary indexed tree at index i.
 */
static inline void BIT_add(int *BIT, unsigned int N, unsigned int i, int val)
{
 i++;
 while (i <= N) {
 BIT[i] += val;
 i += i & -i;
 }
}
/**
 * Get cumulative sum from binary indexed tree for index i.
 */
static inline int BIT_get(int *BIT, unsigned int N, unsigned int i)
{
 int sum = 0;
 i++;
 while (i > 0) {
 sum += BIT[i];
 i -= i & -i;
 }
 return sum;
}
/**
 * Given a normalized array, compute F, where
 *
 * F[i] = sum(j = i+1 .. N) Val(i, j)
 * Val(i, j) = 1 if arr[i] < arr[j]
 * = 0 otherwise
 *
 * The input array must be a normalized array.
 * This function uses a binary indexed tree to compute cumulative sums.
 */
static int *computeF(const int *arr, int N)
{
 int i;
 int *F = my_calloc(N, sizeof(*F));
 int *BIT = my_calloc(N+1, sizeof(*BIT));
 for (i=N-1;i>=0;i--) {
 F[i] = BIT_get(BIT, N, N-1-arr[i]);
 BIT_add(BIT, N, (N-1-arr[i])+1, 1);
 }
 free(BIT);
 return F;
}
int intCompare(const void *a, const void *b)
{
 int x = *(int *) a;
 int y = *(int *) b;
 if (x < y)
 return -1;
 if (x > y)
 return 1;
 return 0;
}
/**
 * Given the F array and a value K, return the number of distinct unordered
 * pairs (i,j) such that F[i] + F[j] >= K.
 */
static int computeCount(int *F, int N, int K)
{
 int i = 0;
 int j = 0;
 int count = 0;
 my_qsort(F, NULL, 0, N-1);
 for (i = N-1; i > j; i--) {
 do {
 if (F[i] + F[j] >= K) {
 count += (i - j);
 break;
 }
 j++;
 } while (i > j);
 }
 return count;
}
static void *my_calloc(size_t nelem, size_t elsize)
{
 void *ret = calloc(nelem, elsize);
 if (ret == NULL)
 exit(1);
 return ret;
}
answered Jul 4, 2016 at 8:46
\$\endgroup\$
7
  • \$\begingroup\$ Thanks for the detailed explanation! How do I train myself to come up with something as elegant as your solution, instead of my naive one? \$\endgroup\$ Commented Jul 4, 2016 at 19:04
  • \$\begingroup\$ After your review I have tried to understand BIT, it is useful for getting prefix sum. I have below questions: 1. How does normalized array works in lieu of actual array for BIT? \$\endgroup\$ Commented Jul 9, 2016 at 21:14
  • \$\begingroup\$ The function F only cares about relative ordering between elements in A. The normalized array simply replaces each value in A by its index within a sorted A. For example, the smallest element in A becomes 0, and the second smallest becomes 1, etc. For the purposes of finding F, the normalized array is equivalent to the original because all relative orderings were preserved. Once you normalize the array, you can create a BIT of size N to help compute F. \$\endgroup\$ Commented Jul 9, 2016 at 22:06
  • \$\begingroup\$ 2. Why do we need to start from N-1-arr[i] in the below snippet? for (i=N-1;i>=0;i--) { F[i] = BIT_get(BIT, N, N-1-arr[i]); BIT_add(BIT, N, (N-1-arr[i])+1, 1); } \$\endgroup\$ Commented Jul 9, 2016 at 22:14
  • 1
    \$\begingroup\$ Normally you would do the range modification by doing two BIT operations such as BIT_add(0, 1); BIT_add(x, -1); This adds one to the range 0 .. x-1, because the first operation adds one to the whole array, and the second subtracts one from the range x .. N-1. I optimized the code to do only one BIT operation by reversing the direction of the BIT. So by doing BIT(N-x, 1), it adds one to the rangeN-x .. N-1. The BIT direction is now reversed, so a lookup of BIT[j] in the normal case becomes BIT[N-1-j] in the reversed case. \$\endgroup\$ Commented Jul 10, 2016 at 0:18
1
\$\begingroup\$

Scope

 int sum = 0, val = 0;
 for(unsigned int i = 0; i < N; i++)
 {
 for(unsigned int j = i+1; j < N; j++)
 {
 if(a[i] < a[j])
 val = 1;
 else 
 val = 0;
 sum = sum + val;
 }
 F[i] = sum;
 sum = 0;
 }

This won't affect speed, but this would be more simply written

 for (unsigned int i = 0; i < N; i++)
 {
 int sum = 0;
 for (unsigned int j = i+1; j < N; j++)
 {
 int val = 0;
 if (a[i] < a[j])
 val = 1;
 else 
 val = 0;
 sum += val;
 }
 F[i] = sum;
 }

Now we only declare variables in the scope in which they are used.

Don't create unnecessary variables

But we can actually do even better

 for (unsigned int i = 0; i < N; i++)
 {
 F[i] = 0;
 for (unsigned int j = i+1; j < N; j++)
 {
 if (a[i] < a[j])
 {
 F[i]++;
 }
 }
 }

We don't actually need val or sum at all.

There is the slight possibility that constantly calculating the memory location of F[i] would add time, but I would expect any decent compiler to only set it once. The rest of the time it will work with a register.

This also saves some time in that it only adds if we're incrementing by 1. If not, this does nothing. That has the same effect as adding 0 but with fewer instructions. Of course, your compiler may have already optimized that out.

Many recommend always using the block form of control structures, even when there's only a single statement.

Don't forget what you know

In this case, you are working with segments going from a point in the array to the end. But you are looping from the beginning to the end. What happens when you loop from the end to the beginning?

 for (int i = N - 1; i >= 0; i--)
 {
 F[i] = 0;
 for (unsigned int j = N - 1; j > i; j--)
 {
 if (a[i] == a[j])
 {
 F[i] += F[j];
 break;
 }
 else if (a[i] < a[j])
 {
 F[i]++;
 }
 }
 }

Now when we encounter two equal array entries we can stop looping.

 int count = 0;
 for(unsigned int i = 0; i < N; i++)
 {
 for(unsigned int j = i+1; j < N; j++)
 {
 if(F[i] + F[j] >= K)
 count++;
 }
 }

Same kind of insight here.

 int n = N - K / 2;
 unsigned int count = 0;
 for (unsigned int i = 0; i < n; i++)
 {
 int remainder = K - F[i];
 if (remainder <= 0)
 {
 // if F[i] >= K, every remaining entry will work
 count += N - i - 1;
 continue;
 }
 for (unsigned int j = N - remainder; j > i; j--)
 {
 if (F[j] >= remainder)
 {
 count++;
 }
 }
 }

By precalculating the remainder, we save some math in the inner loop. And we can save some iterations of the inner loop (if remainder is too high, there is no way to add up to it at the end of the loop). But more importantly, we can check remainder against 0. Because if the remainder is less than or equal to zero, we can immediately say that every remaining entry will work.

answered Jul 2, 2016 at 22:12
\$\endgroup\$

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.