Purpose
ObservableCollection<T>
is not strictly a WPF class, but its intended purpose seems to be for use in WPF. It's the standard implementation of INotifyCollectionChanged
, an interface which is special-cased by some WPF widgets to provide efficient UI updates and to maintain state effectively. For example, if a ListView
is showing an INotifyCollectionChanged
then it updates its SelectedIndex
automatically when the underlying collection changes.
I wrote the following class as part of maintenance of an application which was exposing various IEnumerable<T>
in its view-model and using the much less powerful IPropertyChanged
to notify the ListView
of changes to the model. It had to do manual updates of SelectedIndex
, and this was a source of bugs.
CollectionView
pretty much supports Where
filters, but it doesn't support Select
and is messy to chain. It might be possible to rewrite the application to use CollectionView
, but it would be a more major change than a plugin replacement which mimics the Linq queries used in the old code to map the model to the view-model. I can replace List
in the model with ObservableCollection
, replace Linq Select
in the view-model (to map model classes to view-model classes) with SelectObservable
, and remove some PropertyChanged
event dispatches and manual SelectedIndex
tracking.
Code
using System;
using System.Collections;
using System.Collections.Generic;
using System.Collections.Specialized;
using System.Linq;
using System.Numerics;
using System.Windows;
namespace Org.Cheddarmonk.Utils
{
// The .Net standard library should have included some interface like this. ObservableCollection<T> "implements" it.
public interface IObservableEnumerable<T> : IReadOnlyList<T>, INotifyCollectionChanged
{
}
public static class ObservableEnumerable
{
public static IObservableEnumerable<TResult> SelectObservable<TSource, TResult, TCollection>(this TCollection collection, Func<TSource, TResult> selector)
where TCollection : IReadOnlyList<TSource>, INotifyCollectionChanged
{
if (collection == null) throw new ArgumentNullException(nameof(collection));
if (selector == null) throw new ArgumentNullException(nameof(selector));
return new ObservableSelectIterator<TSource, TResult>(collection, selector);
}
public static IObservableEnumerable<TElement> WhereObservable<TElement, TCollection>(this TCollection collection, Func<TElement, bool> predicate)
where TCollection : IReadOnlyList<TElement>, INotifyCollectionChanged
{
if (collection == null) throw new ArgumentNullException(nameof(collection));
if (predicate == null) throw new ArgumentNullException(nameof(predicate));
return new ObservableWhereIterator<TElement>(collection, predicate);
}
public static IObservableEnumerable<TCast> OfTypeObservable<TSource, TCast, TCollection>(this TCollection collection)
where TCollection : IReadOnlyList<TSource>, INotifyCollectionChanged
{
if (collection == null) throw new ArgumentNullException(nameof(collection));
return WhereObservable<TSource, TCollection>(collection, elt => elt is TCast).
SelectObservable<TSource, TCast, IObservableEnumerable<TSource>>(elt => (TCast)(object)elt);
}
private class ObservableSelectIterator<TSource, TResult> : IObservableEnumerable<TResult>
{
private readonly INotifyCollectionChanged source;
private readonly List<TResult> results;
private readonly Func<TSource, TResult> selector;
internal ObservableSelectIterator(IReadOnlyList<TSource> wrapped, Func<TSource, TResult> selector)
{
source = (INotifyCollectionChanged)wrapped; // Just to keep a hard reference around, lest an intermediate object in a chain get GC'd
this.results = wrapped.Select(selector).ToList();
this.selector = selector;
WeakEventManager<INotifyCollectionChanged, NotifyCollectionChangedEventArgs>.AddHandler(
(INotifyCollectionChanged)wrapped,
nameof(INotifyCollectionChanged.CollectionChanged),
(sender, evt) =>
{
var mangled = Mangle(evt);
CollectionChanged?.Invoke(this, mangled);
});
}
public int Count => results.Count;
public TResult this[int index] => results[index];
public IEnumerator<TResult> GetEnumerator() => results.GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
#region Event handler
public event NotifyCollectionChangedEventHandler CollectionChanged;
private NotifyCollectionChangedEventArgs Mangle(NotifyCollectionChangedEventArgs evt)
{
int oldIndex = evt.OldStartingIndex;
int newIndex = evt.NewStartingIndex;
TResult[] removedItems = null;
if (evt.OldItems != null)
{
removedItems = new TResult[evt.OldItems.Count];
results.CopyTo(oldIndex, removedItems, 0, evt.OldItems.Count);
}
TResult[] addedItems = evt.NewItems != null && evt.Action != NotifyCollectionChangedAction.Move ? evt.NewItems.Cast<TSource>().Select(selector).ToArray() : null;
// Unfortunately, as with so many WPF-related classes in the standard library, the useful constructor is internal-only.
switch (evt.Action)
{
case NotifyCollectionChangedAction.Reset:
results.Clear();
return evt;
case NotifyCollectionChangedAction.Add:
results.InsertRange(newIndex, addedItems);
return new NotifyCollectionChangedEventArgs(evt.Action, addedItems, newIndex);
case NotifyCollectionChangedAction.Remove:
results.RemoveRange(oldIndex, evt.OldItems.Count);
return new NotifyCollectionChangedEventArgs(evt.Action, removedItems, oldIndex);
case NotifyCollectionChangedAction.Replace:
results.RemoveRange(oldIndex, evt.OldItems.Count);
results.InsertRange(newIndex, addedItems);
return new NotifyCollectionChangedEventArgs(evt.Action, addedItems, removedItems, newIndex);
case NotifyCollectionChangedAction.Move:
results.RemoveRange(oldIndex, evt.OldItems.Count);
results.InsertRange(newIndex, removedItems);
return new NotifyCollectionChangedEventArgs(evt.Action, removedItems, newIndex, oldIndex);
default:
throw new NotImplementedException();
}
}
#endregion
}
private class ObservableWhereIterator<TElement> : IObservableEnumerable<TElement>
{
private readonly IReadOnlyList<TElement> wrapped;
private readonly Func<TElement, bool> predicate;
// For reasonably efficient lookups we cache the indices of the elements which meet the predicate.
private BigInteger indices;
internal ObservableWhereIterator(IReadOnlyList<TElement> wrapped, Func<TElement, bool> predicate)
{
this.wrapped = wrapped;
this.predicate = predicate;
indices = _Index(wrapped);
WeakEventManager<INotifyCollectionChanged, NotifyCollectionChangedEventArgs>.AddHandler(
(INotifyCollectionChanged)wrapped,
nameof(INotifyCollectionChanged.CollectionChanged),
(sender, evt) =>
{
var mangled = Mangle(evt);
if (mangled != null) CollectionChanged?.Invoke(this, mangled);
});
}
private BigInteger _Index(IEnumerable elts) => elts.Cast<TElement>().Aggregate((BigInteger.Zero, BigInteger.One), (accum, elt) => (accum.Item1 + (predicate(elt) ? accum.Item2 : 0), accum.Item2 << 1)).Item1;
public int Count => indices.PopCount();
public TElement this[int index]
{
get
{
if (index < 0) throw new IndexOutOfRangeException($"Index {index} is invalid");
// We need to find the index in wrapped at which we have (index + 1) elements which meet the predicate.
// For maximum efficiency we would have to rewrite to use a tree structure instead of BigInteger, but
// I'm not convinced that it's worthwhile.
int toSkip = index + 1;
int wrappedIndex = 0;
foreach (var b in indices.ToByteArray())
{
int sliceCount = b.PopCount();
if (sliceCount < toSkip)
{
toSkip -= sliceCount;
wrappedIndex += 8;
}
else
{
for (byte slice = b; ; wrappedIndex++, slice >>= 1)
{
if ((slice & 1) == 1)
{
toSkip--;
if (toSkip == 0) return wrapped[wrappedIndex];
}
}
}
}
throw new IndexOutOfRangeException($"Index {index} is invalid; Count = {index + 1 - toSkip}");
}
}
public IEnumerator<TElement> GetEnumerator() => wrapped.Where(predicate).GetEnumerator();
IEnumerator IEnumerable.GetEnumerator() => GetEnumerator();
#region Event handler
public event NotifyCollectionChangedEventHandler CollectionChanged;
private NotifyCollectionChangedEventArgs Mangle(NotifyCollectionChangedEventArgs evt)
{
IList liftItems(IList items) => items?.Cast<TElement>().Where(predicate).ToArray();
var newItems = liftItems(evt.NewItems);
var oldItems = liftItems(evt.OldItems);
var newMask = (BigInteger.One << evt.NewStartingIndex) - 1;
var oldMask = (BigInteger.One << evt.OldStartingIndex) - 1;
var newStartingIndex = (indices & newMask).PopCount();
var oldStartingIndex = (indices & oldMask).PopCount();
switch (evt.Action)
{
case NotifyCollectionChangedAction.Reset:
indices = 0;
return evt;
case NotifyCollectionChangedAction.Add:
indices = ((indices & ~newMask) << evt.NewItems.Count) | (_Index(evt.NewItems) << evt.NewStartingIndex) | (indices & newMask);
return newItems.Count > 0 ? new NotifyCollectionChangedEventArgs(evt.Action, newItems, newStartingIndex) : null;
case NotifyCollectionChangedAction.Remove:
indices = ((indices >> evt.OldItems.Count) & ~oldMask) | (indices & oldMask);
return oldItems.Count > 0 ? new NotifyCollectionChangedEventArgs(evt.Action, oldItems, oldStartingIndex) : null;
case NotifyCollectionChangedAction.Replace:
indices = (((indices >> evt.OldItems.Count) & ~newMask) << evt.NewItems.Count) |
(_Index(evt.NewItems) << evt.NewStartingIndex) |
(indices & newMask);
if (oldItems.Count > 0)
{
if (newItems.Count > 0) return new NotifyCollectionChangedEventArgs(evt.Action, newItems, oldItems, newStartingIndex);
return new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Remove, oldItems, oldStartingIndex);
}
if (newItems.Count > 0)
{
return new NotifyCollectionChangedEventArgs(NotifyCollectionChangedAction.Add, newItems, newStartingIndex);
}
return null;
case NotifyCollectionChangedAction.Move:
// Update indices in two steps, for the removal and then the insertion.
var movedIndices = (indices >> evt.OldStartingIndex) & ((BigInteger.One << evt.OldItems.Count) - 1);
indices = ((indices >> evt.OldItems.Count) & ~oldMask) | (indices & oldMask);
indices = ((indices & ~newMask) << evt.NewItems.Count) | (movedIndices << evt.NewStartingIndex) | (indices & newMask);
return oldItems.Count > 0 ? new NotifyCollectionChangedEventArgs(evt.Action, oldItems, newStartingIndex, oldStartingIndex) : null;
default:
throw new NotImplementedException();
}
}
#endregion
}
}
}
There is a dependency on some bit-twiddling code:
/// <summary>Population count: how many bits are 1?</summary>
public static int PopCount(this byte v)
{
int x = v - ((v >> 1) & 0x55);
x = (x & 0x33) + ((x >> 2) & 0x33);
return (x + (x >> 4)) & 0x0f;
}
/// <summary>Population count: how many bits are 1?</summary>
public static int PopCount(this uint v)
{
v = v - ((v >> 1) & 0x55555555);
v = (v & 0x33333333) + ((v >> 2) & 0x33333333);
v = (v + (v >> 4) & 0x0f0f0f0f) * 0x01010101;
return (int)v >> 24;
}
/// <summary>Population count: how many bits differ from the sign bit?</summary>
public static int PopCount(this BigInteger n)
{
uint invert = (uint)(n.Sign >> 1);
ReadOnlySpan<byte> rawBytes = n.ToByteArray();
var rawUints = System.Runtime.InteropServices.MemoryMarshal.Cast<byte, uint>(rawBytes);
// 4 bytes to a uint.
System.Diagnostics.Debug.Assert(rawUints.Length == rawBytes.Length >> 2);
int popCount = 0;
foreach (var u in rawUints) popCount += PopCount(u ^ invert);
for (int off = rawUints.Length << 2; off < rawBytes.Length; off++) popCount += PopCount((rawBytes[off] ^ invert) & 0xffu);
return popCount;
}
Tests
using NUnit.Framework;
using System;
using System.Collections.Generic;
using System.Collections.ObjectModel;
using System.Collections.Specialized;
using System.Linq;
namespace Org.Cheddarmonk.Utils.Tests
{
[TestFixture]
public class TestObservableEnumerable
{
[Test]
public void ValidateTracker()
{
// This is to ensure that the tracker we use for testing the main classes isn't itself buggy.
var raw = new ObservableCollection<int>();
var tracker = new ObservableTestTracker<int>(raw);
for (int i = 0; i < 5; i++)
{
raw.Add(i);
tracker.AssertTrackingCorrect();
}
// [0, 1, 2, 3, 4]
raw.RemoveAt(2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4]
raw.Move(2, 0);
tracker.AssertTrackingCorrect();
// [3, 0, 1, 4]
raw.Move(0, 2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4]
raw[3] = 5;
tracker.AssertTrackingCorrect();
// [0, 1, 3, 5]
Assert.IsTrue(new int[] { 0, 1, 3, 5 }.SequenceEqual(raw));
raw.Clear();
tracker.AssertTrackingCorrect();
}
[Test]
public void TestSelect()
{
var raw = new ObservableCollection<int>();
var select = raw.SelectObservable<int, int, ObservableCollection<int>>(x => 3 * x + 1);
var tracker = new ObservableTestTracker<int>(select);
for (int i = 0; i < 5; i++)
{
raw.Add(i);
tracker.AssertTrackingCorrect();
}
// [0, 1, 2, 3, 4] => [1, 4, 7, 10, 13]
raw.RemoveAt(2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4] => [1, 4, 10, 13]
raw.Move(2, 0);
tracker.AssertTrackingCorrect();
// [3, 0, 1, 4] => [10, 1, 4, 13]
raw.Move(0, 2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4] => [1, 4, 10, 13]
raw[3] = 5;
tracker.AssertTrackingCorrect();
// [0, 1, 3, 5] => [1, 4, 10, 16]
Assert.IsTrue(new int[] { 0, 1, 3, 5 }.SequenceEqual(raw));
Assert.IsTrue(new int[] { 1, 4, 10, 16 }.SequenceEqual(select));
raw.Clear();
tracker.AssertTrackingCorrect();
}
[Test]
public void TestWhere()
{
var raw = new ObservableCollection<int>();
var where = raw.WhereObservable<int, ObservableCollection<int>>(x => (x & 1) == 0);
var tracker = new ObservableTestTracker<int>(where);
for (int i = 0; i < 5; i++)
{
raw.Add(i);
tracker.AssertTrackingCorrect();
}
// [0, 1, 2, 3, 4] => [0, 2, 4]
raw.RemoveAt(2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4] => [0, 4]
raw.Move(2, 0);
tracker.AssertTrackingCorrect();
// [3, 0, 1, 4] => [0, 4]
raw.Move(0, 2);
tracker.AssertTrackingCorrect();
// [0, 1, 3, 4] => [0, 4]
raw[3] = 5;
tracker.AssertTrackingCorrect();
// [0, 1, 3, 5] => [0]
raw[3] = 1;
tracker.AssertTrackingCorrect();
// [0, 1, 3, 1] => [0]
raw[2] = 6;
tracker.AssertTrackingCorrect();
// [0, 1, 6, 1] => [0, 6]
raw[2] = 4;
tracker.AssertTrackingCorrect();
// [0, 1, 4, 1] => [0, 4]
Assert.IsTrue(new int[] { 0, 1, 4, 1 }.SequenceEqual(raw));
Assert.IsTrue(new int[] { 0, 4 }.SequenceEqual(where));
raw.Clear();
tracker.AssertTrackingCorrect();
}
}
class ObservableTestTracker<T>
{
private readonly IReadOnlyList<T> source;
private readonly IList<T> changeTracker;
internal ObservableTestTracker(IReadOnlyList<T> source)
{
this.source = source;
this.changeTracker = new ObservableCollection<T>(source);
(source as INotifyCollectionChanged).CollectionChanged += source_CollectionChanged;
}
private void source_CollectionChanged(object sender, NotifyCollectionChangedEventArgs e)
{
switch (e.Action)
{
case NotifyCollectionChangedAction.Reset:
changeTracker.Clear();
break;
case NotifyCollectionChangedAction.Add:
int i = e.NewStartingIndex;
foreach (T obj in e.NewItems) changeTracker.Insert(i++, obj);
break;
case NotifyCollectionChangedAction.Remove:
foreach (T obj in e.OldItems)
{
Assert.AreEqual(obj, changeTracker[e.OldStartingIndex]);
changeTracker.RemoveAt(e.OldStartingIndex);
}
break;
case NotifyCollectionChangedAction.Replace:
case NotifyCollectionChangedAction.Move:
// This is a remove followed by an add
foreach (T obj in e.OldItems)
{
Assert.AreEqual(obj, changeTracker[e.OldStartingIndex]);
changeTracker.RemoveAt(e.OldStartingIndex);
}
int j = e.NewStartingIndex;
foreach (T obj in e.NewItems) changeTracker.Insert(j++, obj);
break;
default:
throw new NotImplementedException();
}
}
public void AssertTrackingCorrect()
{
// Direct comparison as IEnumerable<T>.
Assert.IsTrue(source.SequenceEqual(changeTracker));
// Assert that the elements returned by source[int] correspond to the elements returned by source.GetEnumerator().
{
var byIndex = new List<T>();
for (int i = 0; i < changeTracker.Count; i++) byIndex.Add(source[i]);
// Assert that we can't get an extra item.
try
{
byIndex.Add(source[changeTracker.Count]);
Assert.Fail("Expected IndexOutOfRangeException or ArgumentOutOfRangeException");
}
catch (ArgumentOutOfRangeException)
{
// This is what's specified in the MSDN for IList<T>. IReadOnlyList<T> doesn't document any exceptions at all.
}
catch (IndexOutOfRangeException)
{
// This makes more sense, and is what the documentation for IndexOutOfRangeException claims should be thrown.
}
catch (Exception ex)
{
Assert.Fail($"Expected IndexOutOfRangeException or ArgumentOutOfRangeException, caught {ex}");
}
Assert.IsTrue(byIndex.SequenceEqual(changeTracker));
}
}
}
}
1 Answer 1
Move
events in ObservableSelectIterator
could be expensive if elements are shifted around near the start of the list, as Remove
and Insert
will end up shifting the whole collection. There might be a significant performance advantage to manually moving the elements 'in between' the start and end indices, if that is at-all a concern.
Similarly with Replace
, there is need only to perform a single removal or insertion, which would halve the amount of copying to be done. If the lists are the same length, then there is no need to move any of the other elements at all: this in particular is important, because it covers index assignment, which everyone expects to be constant-time.
I figure the same is true of ObservableWhereIteractor
, but I shan't pretend I've fully reviewed the big-twiddling in its Mangle
method.
Giving Item1
and Item2
half-meaningful names in _Index
wouldn't hurt.
private BigInteger _Index(IEnumerable elts) => elts.Cast<TElement>().Aggregate((indices : BigInteger.Zero, bit : BigInteger.One), (accum, elt) => (accum.indices + (predicate(elt) ? accum.bit : 0), accum.bit << 1)).indices;
Should you be checking for -1
indices in the NotifyCollectionChangedEventArgs
? Presumably this constructor is only provided to support unordered collections, but I can't work out if there are any meaningful guarantees.
-
\$\begingroup\$ Fair observations. I'm not too concerned about performance at present: my use case has on the order of ten or twenty items. The last question is a good one: in practice, AIUI, WPF never calls that constructor. \$\endgroup\$Peter Taylor– Peter Taylor2019年03月25日 11:02:57 +00:00Commented Mar 25, 2019 at 11:02
var
? Outch! The one-liner loops :-] \$\endgroup\$ObservableCollection
exactly? \$\endgroup\$Select
orWhere
toObservableCollection
, but instead of just getting anIEnumerable
you get anINotifyCollectionChanged
. The standard library sort-of providesWhere
withCollectionView
, but doesn't (AFAIK) provideSelect
. \$\endgroup\$