我如何检查看看一个列是否存在于一个SqlDataReader对象?在我的数据访问层,我创建了一个为多个存储过程调用构建相同对象的方法。其中一个存储过程具有其他存储过程不使用的附加列。我想修改方法以适应各种情况。

我的应用程序是用c#编写的。


当前回答

Use:

if (dr.GetSchemaTable().Columns.Contains("accounttype"))
   do something
else
   do something

在循环中它可能没有那么有效。

其他回答

在一行中,在DataReader检索后使用:

var fieldNames = Enumerable.Range(0, dr.FieldCount).Select(i => dr.GetName(i)).ToArray();

然后,

if (fieldNames.Contains("myField"))
{
    var myFieldValue = dr["myField"];
    ...

Edit

更高效的单行程序,不需要加载模式:

var exists = Enumerable.Range(0, dr.FieldCount).Any(i => string.Equals(dr.GetName(i), fieldName, StringComparison.OrdinalIgnoreCase));

正确的代码是:

public static bool HasColumn(DbDataReader Reader, string ColumnName) { 
    foreach (DataRow row in Reader.GetSchemaTable().Rows) { 
        if (row["ColumnName"].ToString() == ColumnName) 
            return true; 
    } //Still here? Column not found. 
    return false; 
}

这是一个相当老的帖子,但我想提供我的意见。

大多数提议的解决方案的挑战在于,它要求您每次都为检查的每一行和每一列枚举所有字段。

其他的则使用GetSchemaTable方法,该方法不受全局支持。

就我个人而言,我对抛出和捕获异常以检查字段是否存在没有问题。事实上,我认为从编程的角度来看,这可能是最直接的解决方案,也是最容易调试和创建扩展的解决方案。我注意到吞咽异常不会对性能造成负面影响,除非涉及到其他事务或奇怪的回滚逻辑。

使用try-catch块实现

using System;
using System.Collections.Generic;
using System.Data.SqlClient;

public class MyModel {
    public int ID { get; set; }
    public int UnknownColumn { get; set; }
}


public IEnumerable<MyModel> ReadData(SqlCommand command) {
    using (SqlDataReader reader = command.ExecuteReader()) {
        try {
            while (reader.Read()) {
                // init the row
                MyModel row = new MyModel();

                // bind the fields
                row.ID = reader.IfDBNull("ID", row.ID);
                row.UnknownColumn = reader.IfDBNull("UnknownColumn", row.UnknownColumn);

                // return the row and move forward
                yield return row;
            }
        } finally {
            // technically the disposer should handle this for you
            if (!reader.IsClosed) reader.Close();
        }
    }
}

// I use a variant of this class everywhere I go to help simplify data binding
public static class IDataReaderExtensions {
    // clearly separate name to ensure I don't accidentally use the wrong method
    public static T IfDBNull<T>(this IDataReader reader, string name, T defaultValue) {
        T value;
        try {
            // attempt to read the value
            // will throw IndexOutOfRangeException if not available
            object objValue = reader[name];

            // the value returned from SQL is NULL
            if (Convert.IsDBNull(objValue)) {
                // use the default value
                objValue = defaultValue;
            }
            else if (typeof(T) == typeof(char)) {
                // chars are returned from SQL as strings
                string strValue = Convert.ToString(objValue);

                if (strValue.Length > 0) objValue = strValue[0];
                else objValue = defaultValue;
            }

            value = (T)objValue;
        } catch (IndexOutOfRangeException) {
            // field does not exist
            value = @defaultValue;
        } catch (InvalidCastException, ex) {
            // The type we are attempting to bind to is not the same as the type returned from the database
            // Personally, I want to know the field name that has the problem
            throw new InvalidCastException(name, ex);
        }

        return value;
    }

    // clearly separate name to ensure I don't accidentally use the wrong method
    // just overloads the other method so I don't need to pass in a default
    public static T IfDBNull<T>(this IDataReader reader, string name) {
        return IfDBNull<T>(reader, name, default(T));
    }
}

如果您想避免异常处理,我建议在初始化阅读器时将结果保存到HashSet<string>中,然后再检查它以查找所需的列。或者,为了进行微优化,您可以将列实现为Dictionary<string, int>,以防止SqlDataReader对象从Name到ordinal的重复解析。

使用HashSet<string>实现

using System;
using System.Collections.Generic;
using System.Data.SqlClient;

public class MyModel {
    public int ID { get; set; }
    public int UnknownColumn { get; set; }
}

public IEnumerable<MyModel> ReadData(SqlCommand command) {
    using (SqlDataReader reader = command.ExecuteReader()) {
        try {
            // first read
            if (reader.Read()) {
                // use whatever *IgnoreCase comparer that you're comfortable with
                HashSet<string> columns = new HashSet<string>(StringComparer.OrdinalIgnoreCase);

                // init the columns HashSet<string, int>
                for (int i = 0; i < reader.FieldCount; i++) {
                    string fieldName = reader.GetName(i);
                    columns.Add(fieldName);
                }

                // implemented as a do/while since we already read the first row
                do {
                    // init a new instance of your class
                    MyModel row = new MyModel();

                    // check if column exists
                    if (columns.Contains("ID") &&
                        // ensure the value is not DBNull
                        !Convert.IsDBNull(reader["ID"])) {
                        // bind value
                        row.ID = (int)reader["ID"];
                    }

                    // check if column exists
                    if (columns.Contains("UnknownColumn") &&
                        // ensure the value is not DBNull
                        !Convert.IsDBNull(reader["UnknownColumn"])) {
                        // bind value
                        row.UnknownColumn = (int)reader["UnknownColumn"];
                    }

                    // return the row and move forward
                    yield return row;
                } while (reader.Read());
            }
        } finally {
            // technically the disposer should handle this for you
            if (!reader.IsClosed) reader.Close();
        }
    }
}

使用Dictionary<string, int>实现

using System;
using System.Collections.Generic;
using System.Data.SqlClient;

public class MyModel {
    public int ID { get; set; }
    public int UnknownColumn { get; set; }
}

public IEnumerable<MyModel> ReadData(SqlCommand command) {
    using (SqlDataReader reader = command.ExecuteReader()) {
        try {
            // first read
            if (reader.Read()) {
                // use whatever *IgnoreCase comparer that you're comfortable with
                Dictionary<string, int> columns = new Dictionary<string, int>(StringComparer.OrdinalIgnoreCase);

                // init the columns Dictionary<string, int>
                for (int i = 0; i < reader.FieldCount; i++) {
                    string fieldName = reader.GetName(i);
                    columns[fieldName] = i;
                }

                // implemented as a do/while since we already read the first row
                do {
                    // init a new instance of your class
                    MyModel row = new MyModel();

                    // stores the resolved ordinal from your dictionary
                    int ordinal;

                    // check if column exists
                    if (columns.TryGetValue("ID", out ordinal) &&
                        // ensure the value is not DBNull
                        !Convert.IsDBNull(reader[ordinal])) {
                        // bind value
                        row.ID = (int)reader[ordinal];
                    }

                    // check if column exists
                    if (columns.TryGetValue("UnknownColumn", out ordinal) &&
                        // ensure the value is not DBNull
                        !Convert.IsDBNull(reader[ordinal])) {
                        // bind value
                        row.UnknownColumn = (int)reader[ordinal];
                    }

                    // return the row and move forward
                    yield return row;
                } while (reader.Read());
            }
        } finally {
            // technically the disposer should handle this for you
            if (!reader.IsClosed) reader.Close();
        }
    }
}

这对我来说很管用:

bool hasColumnName = reader.GetSchemaTable().AsEnumerable().Any(c => c["ColumnName"] == "YOUR_COLUMN_NAME");

TLDR:

有很多关于性能和不良实践的说法,所以我在这里澄清一下。

对于返回的列数较多的情况,异常路由更快,对于返回的列数较低的情况,循环路由更快,交叉点在11列左右。滚动到底部以查看图形和测试代码。

完整的回答:

一些顶级答案的代码可以工作,但是这里存在一个潜在的争论,即基于在逻辑中接受异常处理及其相关性能的“更好的”答案。

为了澄清这一点,我不认为有太多关于捕获异常的指导。微软确实有一些关于抛出异常的指导。他们写道:

如果可能,不要对正常的控制流使用异常。

第一个注意事项是“如果可能”的宽大。更重要的是,描述给出了以下上下文:

框架设计者应该设计api,这样用户就可以编写不抛出异常的代码

这意味着,如果你正在编写一个可能被其他人使用的API,让他们能够在不使用try/catch的情况下导航异常。例如,使用抛出异常的Parse方法提供TryParse。但是这里并没有说不应该捕获异常。

而且,正如另一个用户指出的那样,catch一直允许按类型进行过滤,最近还允许通过when子句进行进一步过滤。这似乎是对语言特性的浪费,如果我们不应该使用它们的话。

可以说,抛出异常是有代价的,而这种代价可能会影响重循环中的性能。然而,也可以说异常的代价在“连接应用程序”中是可以忽略不计的。实际成本在十多年前就已经调查过了:c#中的异常有多昂贵?

换句话说,数据库连接和查询的成本可能会使抛出异常的成本相形见绌。

除此之外,我还想确定哪种方法确实更快。不出所料,没有具体的答案。

任何遍历列的代码都会随着列数的增加而变慢。也可以说,任何依赖于异常的代码都会变慢,这取决于查找查询失败的速率。

使用Chad Grant和Matt Hamilton的答案,我运行了两种方法,最多有20个列,错误率高达50% (OP表明他在不同的存储过程之间使用这个两个测试,所以我假设只有两个)。

以下是用LINQPad绘制的结果:

这里的锯齿形表示每个列计数中的错误率(未找到列)。

对于较窄的结果集,循环是一个不错的选择。然而,GetOrdinal/Exception方法对列数不太敏感,在11列左右开始优于循环方法。

也就是说,我并没有真正的偏好性能,因为11列作为整个应用程序返回的平均列数听起来很合理。无论哪种情况,我们在这里谈论的都是一毫秒的分数。

然而,从代码简单性和别名支持的角度来看,我可能会选择GetOrdinal路线。

下面是LINQPad形式的测试。请随意用你自己的方法转发:

void Main()
{
    var loopResults = new List<Results>();
    var exceptionResults = new List<Results>();
    var totalRuns = 10000;
    for (var colCount = 1; colCount < 20; colCount++)
    {
        using (var conn = new SqlConnection(@"Data Source=(localdb)\MSSQLLocalDb;Initial Catalog=master;Integrated Security=True;"))
        {
            conn.Open();

            //create a dummy table where we can control the total columns
            var columns = String.Join(",",
                (new int[colCount]).Select((item, i) => $"'{i}' as col{i}")
            );
            var sql = $"select {columns} into #dummyTable";
            var cmd = new SqlCommand(sql,conn);
            cmd.ExecuteNonQuery();

            var cmd2 = new SqlCommand("select * from #dummyTable", conn);

            var reader = cmd2.ExecuteReader();
            reader.Read();

            Func<Func<IDataRecord, String, Boolean>, List<Results>> test = funcToTest =>
            {
                var results = new List<Results>();
                Random r = new Random();
                for (var faultRate = 0.1; faultRate <= 0.5; faultRate += 0.1)
                {
                    Stopwatch stopwatch = new Stopwatch();
                    stopwatch.Start();
                    var faultCount=0;
                    for (var testRun = 0; testRun < totalRuns; testRun++)
                    {
                        if (r.NextDouble() <= faultRate)
                        {
                            faultCount++;
                            if(funcToTest(reader, "colDNE"))
                                throw new ApplicationException("Should have thrown false");
                        }
                        else
                        {
                            for (var col = 0; col < colCount; col++)
                            {
                                if(!funcToTest(reader, $"col{col}"))
                                    throw new ApplicationException("Should have thrown true");
                            }
                        }
                    }
                    stopwatch.Stop();
                    results.Add(new UserQuery.Results{
                        ColumnCount = colCount,
                        TargetNotFoundRate = faultRate,
                        NotFoundRate = faultCount * 1.0f / totalRuns,
                        TotalTime=stopwatch.Elapsed
                    });
                }
                return results;
            };
            loopResults.AddRange(test(HasColumnLoop));

            exceptionResults.AddRange(test(HasColumnException));

        }

    }
    "Loop".Dump();
    loopResults.Dump();

    "Exception".Dump();
    exceptionResults.Dump();

    var combinedResults = loopResults.Join(exceptionResults,l => l.ResultKey, e=> e.ResultKey, (l, e) => new{ResultKey = l.ResultKey, LoopResult=l.TotalTime, ExceptionResult=e.TotalTime});
    combinedResults.Dump();
    combinedResults
        .Chart(r => r.ResultKey, r => r.LoopResult.Milliseconds * 1.0 / totalRuns, LINQPad.Util.SeriesType.Line)
        .AddYSeries(r => r.ExceptionResult.Milliseconds * 1.0 / totalRuns, LINQPad.Util.SeriesType.Line)
        .Dump();
}
public static bool HasColumnLoop(IDataRecord dr, string columnName)
{
    for (int i = 0; i < dr.FieldCount; i++)
    {
        if (dr.GetName(i).Equals(columnName, StringComparison.InvariantCultureIgnoreCase))
            return true;
    }
    return false;
}

public static bool HasColumnException(IDataRecord r, string columnName)
{
    try
    {
        return r.GetOrdinal(columnName) >= 0;
    }
    catch (IndexOutOfRangeException)
    {
        return false;
    }
}

public class Results
{
    public double NotFoundRate { get; set; }
    public double TargetNotFoundRate { get; set; }
    public int ColumnCount { get; set; }
    public double ResultKey {get => ColumnCount + TargetNotFoundRate;}
    public TimeSpan TotalTime { get; set; }


}