我怎样才能做得快呢?
当然我可以这样做:
static bool ByteArrayCompare(byte[] a1, byte[] a2)
{
if (a1.Length != a2.Length)
return false;
for (int i=0; i<a1.Length; i++)
if (a1[i]!=a2[i])
return false;
return true;
}
但我正在寻找一个BCL函数或一些高度优化的已证明的方法来做到这一点。
java.util.Arrays.equals((sbyte[])(Array)a1, (sbyte[])(Array)a2);
工作得很好,但这似乎不适用于x64。
注意我的快速回答。
你可以使用Enumerable。SequenceEqual方法。
using System;
using System.Linq;
...
var a1 = new int[] { 1, 2, 3};
var a2 = new int[] { 1, 2, 3};
var a3 = new int[] { 1, 2, 4};
var x = a1.SequenceEqual(a2); // true
var y = a1.SequenceEqual(a3); // false
如果你因为某些原因不能使用. net 3.5,你的方法是可以的。
编译器运行时环境会优化你的循环,所以你不需要担心性能。
编辑:现代的快速方法是使用a1.SequenceEquals(a2)
用户gil提出了不安全的代码,产生了这个解决方案:
// Copyright (c) 2008-2013 Hafthor Stefansson
// Distributed under the MIT/X11 software license
// Ref: http://www.opensource.org/licenses/mit-license.php.
static unsafe bool UnsafeCompare(byte[] a1, byte[] a2) {
unchecked {
if(a1==a2) return true;
if(a1==null || a2==null || a1.Length!=a2.Length)
return false;
fixed (byte* p1=a1, p2=a2) {
byte* x1=p1, x2=p2;
int l = a1.Length;
for (int i=0; i < l/8; i++, x1+=8, x2+=8)
if (*((long*)x1) != *((long*)x2)) return false;
if ((l & 4)!=0) { if (*((int*)x1)!=*((int*)x2)) return false; x1+=4; x2+=4; }
if ((l & 2)!=0) { if (*((short*)x1)!=*((short*)x2)) return false; x1+=2; x2+=2; }
if ((l & 1)!=0) if (*((byte*)x1) != *((byte*)x2)) return false;
return true;
}
}
}
它对尽可能多的数组进行基于64位的比较。这依赖于数组以qword对齐开始的事实。它会工作,如果不是qword对齐,只是没有那么快,如果它是。
它比简单的“for”循环快了大约7个计时器。使用j#库执行相当于原来的' for '循环。使用.SequenceEqual会慢7倍左右;我想只是因为它使用了ienumerator。movenext。我认为基于linq的解决方案至少会这么慢,甚至更糟。
这与其他方法类似,但这里的不同之处在于,不存在我可以一次检查的下一个最高字节数,例如,如果我有63个字节(在我的SIMD示例中),我可以检查前32个字节的相等性,然后是后32个字节,这比检查32个字节、16个字节、8个字节等等要快。您输入的第一个检查是比较所有字节所需要的唯一检查。
这确实在我的测试中名列前茅,但仅以微弱之差。
下面的代码正是我在airbreather/ArrayComparePerf.cs中测试它的方式。
public unsafe bool SIMDNoFallThrough() #requires System.Runtime.Intrinsics.X86
{
if (a1 == null || a2 == null)
return false;
int length0 = a1.Length;
if (length0 != a2.Length) return false;
fixed (byte* b00 = a1, b01 = a2)
{
byte* b0 = b00, b1 = b01, last0 = b0 + length0, last1 = b1 + length0, last32 = last0 - 31;
if (length0 > 31)
{
while (b0 < last32)
{
if (Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(b0), Avx.LoadVector256(b1))) != -1)
return false;
b0 += 32;
b1 += 32;
}
return Avx2.MoveMask(Avx2.CompareEqual(Avx.LoadVector256(last0 - 32), Avx.LoadVector256(last1 - 32))) == -1;
}
if (length0 > 15)
{
if (Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(b0), Sse2.LoadVector128(b1))) != 65535)
return false;
return Sse2.MoveMask(Sse2.CompareEqual(Sse2.LoadVector128(last0 - 16), Sse2.LoadVector128(last1 - 16))) == 65535;
}
if (length0 > 7)
{
if (*(ulong*)b0 != *(ulong*)b1)
return false;
return *(ulong*)(last0 - 8) == *(ulong*)(last1 - 8);
}
if (length0 > 3)
{
if (*(uint*)b0 != *(uint*)b1)
return false;
return *(uint*)(last0 - 4) == *(uint*)(last1 - 4);
}
if (length0 > 1)
{
if (*(ushort*)b0 != *(ushort*)b1)
return false;
return *(ushort*)(last0 - 2) == *(ushort*)(last1 - 2);
}
return *b0 == *b1;
}
}
如果没有首选的SIMD,与现有的longpointer算法相同的方法:
public unsafe bool LongPointersNoFallThrough()
{
if (a1 == null || a2 == null || a1.Length != a2.Length)
return false;
fixed (byte* p1 = a1, p2 = a2)
{
byte* x1 = p1, x2 = p2;
int l = a1.Length;
if ((l & 8) != 0)
{
for (int i = 0; i < l / 8; i++, x1 += 8, x2 += 8)
if (*(long*)x1 != *(long*)x2) return false;
return *(long*)(x1 + (l - 8)) == *(long*)(x2 + (l - 8));
}
if ((l & 4) != 0)
{
if (*(int*)x1 != *(int*)x2) return false; x1 += 4; x2 += 4;
return *(int*)(x1 + (l - 4)) == *(int*)(x2 + (l - 4));
}
if ((l & 2) != 0)
{
if (*(short*)x1 != *(short*)x2) return false; x1 += 2; x2 += 2;
return *(short*)(x1 + (l - 2)) == *(short*)(x2 + (l - 2));
}
return *x1 == *x2;
}
}
我发布了一个类似的关于检查byte[]是否全是0的问题。(SIMD代码被打败了,所以我从这个答案中删除了它。)下面是我比较过的最快的代码:
static unsafe bool EqualBytesLongUnrolled (byte[] data1, byte[] data2)
{
if (data1 == data2)
return true;
if (data1.Length != data2.Length)
return false;
fixed (byte* bytes1 = data1, bytes2 = data2) {
int len = data1.Length;
int rem = len % (sizeof(long) * 16);
long* b1 = (long*)bytes1;
long* b2 = (long*)bytes2;
long* e1 = (long*)(bytes1 + len - rem);
while (b1 < e1) {
if (*(b1) != *(b2) || *(b1 + 1) != *(b2 + 1) ||
*(b1 + 2) != *(b2 + 2) || *(b1 + 3) != *(b2 + 3) ||
*(b1 + 4) != *(b2 + 4) || *(b1 + 5) != *(b2 + 5) ||
*(b1 + 6) != *(b2 + 6) || *(b1 + 7) != *(b2 + 7) ||
*(b1 + 8) != *(b2 + 8) || *(b1 + 9) != *(b2 + 9) ||
*(b1 + 10) != *(b2 + 10) || *(b1 + 11) != *(b2 + 11) ||
*(b1 + 12) != *(b2 + 12) || *(b1 + 13) != *(b2 + 13) ||
*(b1 + 14) != *(b2 + 14) || *(b1 + 15) != *(b2 + 15))
return false;
b1 += 16;
b2 += 16;
}
for (int i = 0; i < rem; i++)
if (data1 [len - 1 - i] != data2 [len - 1 - i])
return false;
return true;
}
}
测量两个256MB字节数组:
UnsafeCompare : 86,8784 ms
EqualBytesSimd : 71,5125 ms
EqualBytesSimdUnrolled : 73,1917 ms
EqualBytesLongUnrolled : 39,8623 ms
找不到一个我完全满意的解决方案(合理的性能,但没有不安全的代码/pinvoke),所以我想出了这个,没有真正的原创,但工作:
/// <summary>
///
/// </summary>
/// <param name="array1"></param>
/// <param name="array2"></param>
/// <param name="bytesToCompare"> 0 means compare entire arrays</param>
/// <returns></returns>
public static bool ArraysEqual(byte[] array1, byte[] array2, int bytesToCompare = 0)
{
if (array1.Length != array2.Length) return false;
var length = (bytesToCompare == 0) ? array1.Length : bytesToCompare;
var tailIdx = length - length % sizeof(Int64);
//check in 8 byte chunks
for (var i = 0; i < tailIdx; i += sizeof(Int64))
{
if (BitConverter.ToInt64(array1, i) != BitConverter.ToInt64(array2, i)) return false;
}
//check the remainder of the array, always shorter than 8 bytes
for (var i = tailIdx; i < length; i++)
{
if (array1[i] != array2[i]) return false;
}
return true;
}
与本页上的其他解决方案相比,性能:
简单循环:19837滴答,1.00
*位收敛器:4886 ticks, 4.06
unsafcompare: 1636 ticks, 12.12
EqualBytesLongUnrolled: 637 tick, 31.09
P/Invoke memcmp: 369 ticks, 53.67
在linqpad上测试,1000000字节的相同数组(最坏的情况),每个数组500次迭代。