I am learning about Tasks and also writing PowerShell cmdlets in C#. I found that connecting to remote machines was very slow, so I wrote this bit of code to speed it up using Tasks. I am hoping to get some feedback regarding the parallel processing and any other pointers in general.
The code is committed to a Github repo.
TaskCmdlet.cs
using System;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.Threading;
using System.Threading.Tasks;
namespace PoshTasks.Cmdlets
{
public abstract class TaskCmdlet<TIn, TOut> : Cmdlet where TIn : class
where TOut : class
{
#region Parameters
[Parameter(ValueFromPipeline = true)]
public TIn[] InputObject { get; set; }
#endregion
#region Abstract methods
/// <summary>
/// Performs an action on <paramref name="server"/>
/// </summary>
/// <param name="input">The <see cref="object"/> to be processed; null if not processing input</param>
/// <returns>A <see cref="T"/></returns>
protected abstract TOut ProcessTask(TIn input = null);
#endregion
#region Virtual methods
/// <summary>
/// Generates a collection of tasks to be processed
/// </summary>
/// <returns>A collection of tasks</returns>
protected virtual IEnumerable<Task<TOut>> GenerateTasks()
{
List<Task<TOut>> tasks = new List<Task<TOut>>();
if (InputObject != null)
foreach (TIn input in InputObject)
tasks.Add(Task.Run(() => ProcessTask(input)));
else
tasks.Add(Task.Run(() => ProcessTask()));
return tasks;
}
/// <summary>
/// Performs the pipeline output for this cmdlet
/// </summary>
/// <param name="result"></param>
protected virtual void PostProcessTask(TOut result)
{
WriteObject(result, true);
}
#endregion
#region Processing
/// <summary>
/// Processes cmdlet operation
/// </summary>
protected override void ProcessRecord()
{
IEnumerable<Task<TOut>> tasks = GenerateTasks();
foreach (Task<Task<TOut>> bucket in Interleaved(tasks))
{
try
{
Task<TOut> task = bucket.Result;
TOut result = task.Result;
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
WriteError(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
}
/// <summary>
/// Interleaves the tasks
/// </summary>
/// <param name="tasks">The collection of <see cref="Task{TOut}"/></param>
/// <returns>An array of task tasks</returns>
protected Task<Task<TOut>>[] Interleaved(IEnumerable<Task<TOut>> tasks)
{
TaskCompletionSource<Task<TOut>>[] buckets = new TaskCompletionSource<Task<TOut>>[tasks.Count()];
Task<Task<TOut>>[] results = new Task<Task<TOut>>[buckets.Length];
for (int i = 0; i < buckets.Length; i++)
{
buckets[i] = new TaskCompletionSource<Task<TOut>>();
results[i] = buckets[i].Task;
}
int nextTaskIndex = -1;
foreach (Task<TOut> task in tasks)
task.ContinueWith(completed =>
{
TaskCompletionSource<Task<TOut>> bucket = buckets[Interlocked.Increment(ref nextTaskIndex)];
bucket.TrySetResult(completed);
},
CancellationToken.None,
TaskContinuationOptions.None,
TaskScheduler.Default);
return results;
}
#endregion
}
}
GetRemoteService.cs (a sample implementation)
using PoshTasks.Cmdlets;
using System.Collections.Generic;
using System.Linq;
using System.Management.Automation;
using System.ServiceProcess;
namespace PoshTasks.Sample
{
[Cmdlet(VerbsCommon.Get, "RemoteService")]
public class GetRemoteService : TaskCmdlet<string, ServiceController[]>
{
#region Parameters
/// <summary>
/// Gets or sets the collection of requested service names
/// </summary>
[Parameter]
public string[] Name { get; set; }
#endregion
#region Processing
/// <summary>
/// Processes a single remote service lookup
/// </summary>
/// <param name="server">The remote machine name</param>
/// <returns>A collection of <see cref="ServiceController"/>s from the remote machine</returns>
protected override ServiceController[] ProcessTask(string server)
{
ServiceController[] services = ServiceController.GetServices(server);
if (Name != null)
return services.Where(s => Name.Contains(s.DisplayName)).ToArray();
return services;
}
/// <summary>
/// Generates custom service object and outputs to pipeline
/// </summary>
/// <param name="result">The collection of remote services</param>
protected override void PostProcessTask(ServiceController[] result)
{
List<dynamic> services = new List<dynamic>();
foreach (ServiceController service in result)
services.Add(new
{
Name = service.DisplayName,
Status = service.Status,
ComputerName = service.MachineName,
CanPause = service.CanPauseAndContinue
});
WriteObject(services, true);
}
#endregion
}
}
If not cloning the repo you will need the Microsoft.PowerShell.5.ReferenceAssemblies
Nuget package and to reference System.ServiceProcess
.
1 Answer 1
General conventions
The old good #region
. Most people (including me) consider them rather bad then good. You should avoid them.
if (InputObject != null) foreach (TIn input in InputObject) tasks.Add(Task.Run(() => ProcessTask(input))); else tasks.Add(Task.Run(() => ProcessTask()));
There's not even single curly brace {}
;-) I wouldn't complain if it was python but in C# you should always use them. Omitting them can cause a real headache.
protected virtual IEnumerable<Task<TOut>> GenerateTasks() { List<Task<TOut>> tasks = new List<Task<TOut>>(); if (InputObject != null) foreach (TIn input in InputObject) tasks.Add(Task.Run(() => ProcessTask(input))); else tasks.Add(Task.Run(() => ProcessTask())); return tasks; }
In cases like this you can use yead return
which greatly simplifies the code:
protected virtual IEnumerable<Task<TOut>> CreateProcessTasks()
{
if (InputObject == null)
{
yield return Task.Run(() => ProcessTask());
yield break;
}
foreach (var input in InputObject)
{
yield return Task.Run(() => ProcessTask(input));
}
}
I think this method doesn't need to be virtual
. Generating tasks in not something you'd like to implement in each derived class. Consider changing its name to CreateProcessTasks
as this is what it does. Generate sounds like it would create some random tasks.
async/await
In order for the async/await
to work you need to actually await
something but I couldn't find it in your code. Let's try to fix that and introduce few other changes that make your code look better.
I start with the Interleaved
method... and you actually don't need it. Everything it does can be reduced to a single line:
var results = await Task.WhenAll(tasks.ToArray());
Where do I put it? I move this one to the ProcessRecordCore
method that after this adjustment now looks like this:
protected override void ProcessRecord()
{
var errorRecords = Task.Run(async () => await ProcessRecordCore()).Result;
foreach (var errorRecord in errorRecords)
{
WriteError(errorRecord);
}
}
private async Task<BlockingCollection<ErrorRecord>> ProcessRecordCore()
{
var tasks = CreateProcessTasks();
var results = await Task.WhenAll(tasks.ToArray());
var errorRecords = new BlockingCollection<ErrorRecord>();
foreach (var result in results)
{
try
{
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
errorRecords.Add(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
return errorRecords;
}
Notice that ProcessRecordCore
it's now marked as async
so you can await
for it to complete and the ProcessRecord
uses the .Wait()
method.
There are two more methods that can be simplified. The first one is the ProcessTask
method where you can use the ?:
ternary operator and don't need the if
.
protected override ServiceController[] ProcessTask(string server)
{
var services = ServiceController.GetServices(server);
return
Name == null
? services
: services.Where(s => Name.Contains(s.DisplayName)).ToArray();
}
or you can go crazy and make it a one-liner:
return services.Where(s => Name == null || Name.Contains(s.DisplayName)).ToArray();
The other one is the PostProcessTask
method that could use some var
s (like the rest of the code):
protected override void PostProcessTask(ServiceController[] result)
{
var services = new List<dynamic>();
foreach (var service in result)
{
services.Add(new
{
Name = service.DisplayName,
Status = service.Status,
ComputerName = service.MachineName,
CanPause = service.CanPauseAndContinue
});
}
WriteObject(services, true);
}
IProgress interface
To see the errors right away you may try another approach with the IProgress<ErrorRecord>
. Here's an example;
protected override void ProcessRecord()
{
var progress = new Progress<ErrorRecord>(errorRecord =>
{
WriteError(errorRecord);
});
var errorRecords = Task.Run(async () => await ProcessRecordCore(progress));
}
private async Task ProcessRecordCore(IProgress<ErrorRecord> progress)
{
var tasks = CreateProcessTasks();
var results = await Task.WhenAll(tasks.ToArray());
foreach (var result in results)
{
try
{
PostProcessTask(result);
}
catch (Exception e) when (e is PipelineStoppedException || e is PipelineClosedException)
{
// do nothing if pipeline stops
}
catch (Exception e)
{
progress.Report(new ErrorRecord(e, e.GetType().Name, ErrorCategory.NotSpecified, this));
}
}
}
One important aspect of this class is that it invokes ProgressChanged (and the Action) in the context in which it was constructed.
See Reporting Progress from Async Tasks for more information.
-
\$\begingroup\$ First of all, thank you for taking the time to review this. I'm not sure how review responses work here, so I am commenting. All your style points are valid, it's much cleaner code. I set
CreateProcessTasks()
to be virtual in case there is some pre-processing needed to be done by inheriting classes. In one example there are more PowerShell Parameters defined and that affects the behaviour of the tasks themselves. I suppose this could be moved to theProcessTask()
function - I will have a think about this. \$\endgroup\$hsimah– hsimah2016年11月16日 22:53:00 +00:00Commented Nov 16, 2016 at 22:53 -
\$\begingroup\$ I have an interesting problem using
await
/async
. A PowerShell cmdlet can only callWriteObject()
andWriteError()
on the parent thread. So, I found I couldn't useawait
inProcessRecordCore()
as a different thread might be executing the yielded object. I fully appreciate the improvement you made to my code, but this is a limitation I can't seem to figure a way around. If you have any suggestions, please do let me know. I found processing the tasks in parallel improved the performance significantly, so the delay in output is similar to existing Microsoft cmdlets and acceptable. \$\endgroup\$hsimah– hsimah2016年11月16日 22:59:58 +00:00Commented Nov 16, 2016 at 22:59 -
1\$\begingroup\$ @hsimah commenting is perfectly fine ;-) and you're absolutely right. I should have looked into the documentation. I've updated the code and changed the pattern. Now the
WriteError
is properly called from theProcessRecord
where it's allowed to. \$\endgroup\$t3chb0t– t3chb0t2016年11月17日 04:07:56 +00:00Commented Nov 17, 2016 at 4:07 -
1\$\begingroup\$ @hsimah I've added one more example with immediate progress (here error) reporting. \$\endgroup\$t3chb0t– t3chb0t2016年11月17日 04:18:04 +00:00Commented Nov 17, 2016 at 4:18
-
\$\begingroup\$ When I use the new code it's still executing on a different thread in the pool. The
Progress
code is very cool. What I need is something similar toSystem.Windows.Threading.Dispatcher
. I found a sample which uses events to write progress, but I feel this might be a bit much for what I am doing. The Microsoft cmdlet forGet-Service
does it all synchronously and only outputs at the end, so together we've already outdone them. \$\endgroup\$hsimah– hsimah2016年11月17日 23:58:12 +00:00Commented Nov 17, 2016 at 23:58