2
\$\begingroup\$

For something I'm working on, I need a function that takes in a one-dimensional array vec of integers and returns a boolean array of the same shape indicating where the n largest entries of vec are located.

For example,

  • nlargest([0, 1, 3, 5], 2) should yield [0, 0, 1, 1].
  • nlargest([0, 0, 4, 0, 3, 6], 2) should yield [0, 0, 1, 0, 0, 1].
  • nlargest([2, 1, 4, 3], 1) should yield [0, 0, 1, 0].

I have decided it's worthwhile to write my own function to do this, because the input data in this case has a few "nice" properties:

  • n is always strictly smaller than the number of nonzero entries in vec and greater than zero
  • The nonzero entries are all positive integers and there are no ties among them

Here is a working Julia function that does this:

function nlargest(vec, n)
 # Boolean vector shaped like input indicating n largest entries
 inp = copy(vec)
 out = zeros(Bool, length(vec))
 for i in 1:n
 am = argmax(inp)
 out[am] = true
 inp[am] = -1
 end
 return out
end

I am more interested in the design of the algorithm itself than the specific Julia implementation, so here is equivalent pseudocode:

  • Initialize: out ← a array of boolean 0s shaped like vec
  • Repeat n times:
    • jargmax(vec)
    • Set the jth entry of out to 1
    • Set the jth entry of vec to -1 (*)
  • Return out

(*) is taking advantage of the fact that all the nonzero entries are positive, but it feels like a bit of a hack to me. An alternative idea I had for this step was to "pop" the jth entry out of vec so that argmax(vec) has shorter input on the next loop. But it seems to me that the time savings there are offset by then having to adjust subsequent j values based on the new length of vec. Is this reasoning sound, or is there something clever I can do here?

Have I taken full advantage of the structure of my input data? Can I do better than the default argmax() function?

Of course, feedback on the Julia implementation is also welcome.

asked Dec 5, 2020 at 13:31
\$\endgroup\$

1 Answer 1

3
\$\begingroup\$

The algorithm you described has complexity O(n * m), where m is the size of the input array, since you run a full argmax at every one of the n iterations.

You can achieve the same thing by collecting the indices of the n largest values once, and then set only those indices to 1 in the result array. For finding of the n largest values, you could use a heap, but Julia makes this easier with partialsortperm:

function nlargest(v, n; rev=true)
 result = falses(size(v))
 result[partialsortperm(v, 1:n; rev=rev)] .= true
 return result
end

This uses a BitArray for storage; if n << m, even a sparse array might pay off. It should have complexity O(m + n log n), unless I'm misjudging partialsortperm. Note that sorting algorithms in Julia work ascending by default, so you have to use rev=true to get the equivalent of your code.

Max
2071 silver badge7 bronze badges
answered Dec 19, 2020 at 14:30
\$\endgroup\$
1
  • \$\begingroup\$ Once again I am pleasantly surprised to find a Julia builtin for something I had been coding my way around. Thank you! \$\endgroup\$ Commented Dec 20, 2020 at 5:50

Your Answer

Draft saved
Draft discarded

Sign up or log in

Sign up using Google
Sign up using Email and Password

Post as a guest

Required, but never shown

Post as a guest

Required, but never shown

By clicking "Post Your Answer", you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.