I have a real usecase from a fish farm where the growth of a farm from being fed depends on the average size of the fish in the farm when the fish are fed. I have reduced this problem to what I believe to be the core of what I am unable to express in PostgreSQL: an aggregate function with a condition in it that depends on the value of the previous calculation of that aggregate.
The data operated on is a series of transactions.
create table transactions (
id bigserial primary key,
feed_g bigint
);
insert into transactions
(feed_g)
values
(50),
(50),
(50),
(50);
Calculating a sum over these rows is simple.
select
id,
feed_g,
sum(feed_g) over (order by id) as simple_sum
from transactions;
-- id | feed_g | simple_sum
-- ----+--------+------------
-- 1 | 50 | 50
-- 2 | 50 | 100
-- 3 | 50 | 150
-- 4 | 50 | 200
Calculating a sum with a conditional that depends on the value of the input row is also simple. In the below query the second case will always be used.
select
id,
feed_g,
sum(
case when feed_g > 75 then feed_g
else feed_g * 0.5
end
) over (order by id) as row_weighted_sum
from transactions;
-- id | feed_g | row_weighted_sum
-- ----+--------+------------------
-- 1 | 50 | 25.0
-- 2 | 50 | 50.0
-- 3 | 50 | 75.0
-- 4 | 50 | 100.0
What I cannot figure out how to do is to write a query where the conditional in the aggregate function depends on the output calculated by the same aggregate function for the previous row.
Below is some non-working pseudo-SQL for that.
select
id,
feed_g,
sum(
case when lag(recursive_sum) + feed_g > 75 then feed_g
else feed_g * 0.5
end
) over (order by id) as recursive_sum
from transactions;
-- The imagined output would be the following:
-- id | feed_g | row_weighted_sum
-- ----+--------+------------------
-- 1 | 50 | 25.0
-- 2 | 50 | 50.0
-- 3 | 50 | 100.0
-- 4 | 50 | 150.0
Using the simple_sum
as the input to the recursive_sum
does not seem like a viable solution as they will drift apart over time. In the given small example dataset this drift has affect on row two where the simple_sum
crosses the threshold on row 2 when it should not occur until row 3.
with estimate as (
select
id,
feed_g,
sum(feed_g) over (order by id) as simple_sum
from transactions
)
select
id,
feed_g,
simple_sum,
sum(
case when simple_sum > 75 then feed_g
else feed_g * 0.5
end
) over (order by id) as simple_sum_weighted_sum
from estimate;
-- id | feed_g | simple_sum | simple_sum_weighted_sum
-- ----+--------+------------+-------------------------
-- 1 | 50 | 50 | 25.0
-- 2 | 50 | 100 | 75.0
-- 3 | 50 | 150 | 125.0
-- 4 | 50 | 200 | 175.0
A third step that uses the simple_sum_weighted_sum
as input in a call to lag
does not work out either as it "forgets" the weighting of everything but the last row.
with estimate as (
select
id,
feed_g,
sum(feed_g) over (order by id) as simple_sum
from transactions
),
est2 as (
select
id,
feed_g,
simple_sum,
sum(
case when simple_sum > 75 then feed_g
else feed_g * 0.5
end
) over (order by id) as simple_sum_weighted_sum
from estimate)
select
id,
feed_g,
simple_sum,
simple_sum_weighted_sum,
coalesce(lag(simple_sum_weighted_sum) over (order by id), 0)
+ case when simple_sum_weighted_sum > 75 then feed_g
else feed_g * 0.5
end as row_weighted_sum
from est2;
-- id | feed_g | simple_sum | simple_sum_weighted_sum | row_weighted_sum
-- ----+--------+------------+-------------------------+------------------
-- 1 | 50 | 50 | 25.0 | 25.0
-- 2 | 50 | 100 | 75.0 | 50.0
-- 3 | 50 | 150 | 125.0 | 125.0
-- 4 | 50 | 200 | 175.0 | 175.0
I wrote two working implementations of the algorithm in Python for reference. This first one in imperative style.
data = (50, 50, 50, 50)
sum = 0
for value in data:
if sum + value > 75:
sum = sum + value
else:
sum = sum + value * 0.5
print(value, sum)
# 50 25.0
# 50 50.0
# 50 100.0
# 50 150.0
This second one in a somewhat stunted functional style.
data = (50, 50, 50, 50)
def data_dependant_recursive_sum(iterator, last_sum):
try:
value = next(iterator)
except StopIteration:
return
recursively_weighted_value = value if last_sum + value > 75 else value * 0.5
recursive_sum = recursively_weighted_value + last_sum
print(value, recursive_sum)
data_dependant_recursive_sum(iterator, recursive_sum)
data_dependant_recursive_sum(iter(data), 0)
# 50 25.0
# 50 50.0
# 50 100.0
# 50 150.0
If this exercise feels contrived and nonsensical, a much more complicated but complete version of this question can be found here: https://stackoverflow.com/questions/70158295
I am currently using Postgres 12 but an upgrade to 14 would be easy if that is required.
3 Answers 3
This requires a recursive CTE. Here is an example in TSQL (postgres should be similar):
declare @transactions as table (
id integer primary key identity(1,1),
feed_g integer
);
insert into @transactions
values (50), (50), (50), (50);
with indexed_transactions as (
select *, row_number() over (order by id) as rn
from @transactions
),
cte as (
select cast(0 as bigint) as rn, 0 as id, 0 as feed_g, cast(0.0 as float) as row_weighted_sum
union all
select
a.rn,
a.id,
a.feed_g,
case when cte.row_weighted_sum + a.feed_g > 75 then cte.row_weighted_sum + a.feed_g
else cte.row_weighted_sum + a.feed_g * 0.5 end as row_weighted_sum
from indexed_transactions a
join cte on cte.rn = a.rn - 1
)
select * from cte where id > 0
Results:
rn id feed_g row_weighted_sum
1 1 50 25
2 2 50 50
3 3 50 100
4 4 50 150
For posterity, here is Isak's answer translated to PostgreSQL.
with recursive indexed_transactions as (
select *, row_number() over (order by id)
from transactions
),
cte as (
select 0::bigint as row_number, 0::bigint as id, 0::bigint as feed_g, 0::float as row_weighted_sum
union all
select
a.row_number,
a.id,
a.feed_g,
case when cte.row_weighted_sum + a.feed_g > 75
then cte.row_weighted_sum + a.feed_g
else cte.row_weighted_sum + a.feed_g * 0.5
end as row_weighted_sum
from indexed_transactions a
join cte on cte.row_number = a.row_number - 1
)
select * from cte where id > 0;
If you have a working implementation in Python, you could simply convert it into a PL/Python function in PostgreSQL. While it should be possible to come up with a pure SQL solution, the task is effectively a procedural one, so a procedural solution might be the best fit.
SUM()
window function first, then calculate your adjustedrecursive_sum
with theLAG()
window function second.