如何为PLINQ重新实现Where()运算符

本文关键字:实现 Where 运算符 新实现 PLINQ | 更新日期: 2023-09-27 18:21:12

我想重写我的一个LINQ扩展以利用并行性。不管怎样,我不知道从哪里开始。

作为一个教学示例,我想知道如何重写Where()运算符的实现,但这适用于ParallelQuery

public static ParallelQuery<TSource> Where<TSource>(
   this ParallelQuery<TSource> source, 
   Func<TSource, bool> predicate)
{
    //implementation
}

可以写:

someList.AsParallel().Where(...)

写入串行执行的Where是微不足道的:

public static IEnumerable<TSource> Where<TSource>( 
    this IEnumerable<TSource> source, 
    Func<TSource, bool> predicate) 
{ 
    foreach (TSource item in source) 
    { 
        if (predicate(item)) 
        { 
            yield return item; 
        } 
    } 
}

我想简单地将谓词包装在Parallel.ForEach()上(并将结果推送到List/Array中),但我认为这不是办法。

我不知道写起来是琐碎的(所以它可以作为so答案)还是非常复杂。如果是这样的话,从哪里开始给出一些提示也是很好的。可能有几种方法可以实现这一点,由于特定的优化,它可能会变得非常复杂,但一个简单的可行实现是可以的(这意味着它提供了正确的结果,并且比上面的非多线程实现更快)


正如Scott Chamberlain所建议的,以下是我想重写的LINQ方法的实现:

public static IEnumerable<TSource> WhereContains<TSource, TKey>(
     this IEnumerable<TSource> source, 
     IEnumerable<TKey> values,
     Func<TSource, TKey> keySelector)
{
    HashSet<TKey> elements = new HashSet<TKey>(values);
    foreach (TSource item in source)
    {
        if (elements.Contains(keySelector(item)))
        {
            yield return item;
        }
    }
}

如何为PLINQ重新实现Where()运算符

不幸的是,您无法创建自己的基于ParallelQuery<T>的类,因为虽然ParallelQuery<T>是公共的,但它没有任何公共构造函数。

您可以使用现有的PLINQ基础架构来执行您想要的操作。你真正想做的就是做一个WhereContains是谓词。。。那就这么做吧。

public static ParallelQuery<TSource> WhereContains<TSource, TKey>(
    this ParallelQuery<TSource> source,
    IEnumerable<TKey> values,
    Func<TSource, TKey> keySelector)
{
    HashSet<TKey> elements = new HashSet<TKey>(values);
    return source.Where(item => elements.Contains(keySelector(item)));
}

这将并行执行Where子句,并且(虽然没有文档记录)Contains是线程安全的,只要您不执行任何写操作,并且因为您正在创建一个本地HashSet来执行查找,所以您不需要担心会发生写操作。


这里有一个示例项目,它向控制台打印出它正在处理的线程和项目,您可以看到它正在使用多个线程。

class Program
{
    static void Main(string[] args)
    {
        List<int> items = new List<int>(Enumerable.Range(0,100));
        int[] values = {5, 12, 25, 17, 0};
        Console.WriteLine("thread: {0}", Environment.CurrentManagedThreadId);
        var result = items.AsParallel().WhereContains(values, x=>x).ToList();
        Console.WriteLine("Done");
        Console.ReadLine();
    }
}
static class Extensions
{
    public static ParallelQuery<TSource> WhereContains<TSource, TKey>(
        this ParallelQuery<TSource> source,
        IEnumerable<TKey> values,
        Func<TSource, TKey> keySelector)
    {
        HashSet<TKey> elements = new HashSet<TKey>(values);
        return source.Where(item =>
        {
            Console.WriteLine("item:{0} thread: {1}", item, Environment.CurrentManagedThreadId);
            return elements.Contains(keySelector(item));
        });
    }
}

你不能就这么做吗?

public static ParallelQuery<TSource> Where<TSource>(
    this ParallelQuery<TSource> source, 
    Func<TSource, bool> predicate)
{
    return
        source
            .SelectMany(x =>
                predicate(x)
                ? new TSource[] { x } 
                : Enumerable.Empty<TSource>());
}