Friday, June 1, 2012

Some useful Entity Framework extension methods

As I've been working with the latest version of Entity Framework, I wanted to go ahead and share some useful code I've found online and created for myself.

Disclaimer: The code below carries no guarantees-- heck, it may not even work-- so use at your own risk.

    public static class MethodInfoExtensions
    {
        // adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
        public static bool IsLinqOperator(this MethodInfo method)
        {
            return ((method.DeclaringType == typeof(Queryable)) || (method.DeclaringType == typeof(Enumerable)))
                && (Attribute.GetCustomAttribute(method, typeof(ExtensionAttribute)) != null);
        }
    }


    // adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
    public class PropertyPathBuildingExpressionVisitor : ExpressionVisitor
    {
        public const string DefaultPropertyPathDelimiter = ".";

        private Stack<string> propertyNameStack = null;


        public Expression Expression
        { get; private set; }

        public string PropertyPath
        { get; private set; }


        private void BuildPropertyPath(string propertyPathDelimeter)
        {
            propertyNameStack = new Stack<string>();

            Visit(Expression);

            PropertyPath = propertyNameStack.Aggregate(
                new StringBuilder(),
                (sb, name) => (sb.Length > 0 ? sb.Append(propertyPathDelimeter) : sb).Append(name)
            ).ToString();

            propertyNameStack = null;
        }


        protected override Expression VisitMember(MemberExpression expression)
        {
            if (propertyNameStack != null)
                propertyNameStack.Push(expression.Member.Name);
            return base.VisitMember(expression);
        }


        protected override Expression VisitMethodCall(MethodCallExpression expression)
        {
            var visitedExpression = (Expression)expression;

            if (expression.Method.IsLinqOperator())
            {
                for (int i = 1; i < expression.Arguments.Count; i++)
                    Visit(expression.Arguments[i]);
                Visit(expression.Arguments[0]);
            }
            else
                visitedExpression = base.VisitMethodCall(expression);

            return visitedExpression;
        }


        public PropertyPathBuildingExpressionVisitor(Expression expression, string delimitedBy = null)
        {
            Expression = expression;
            BuildPropertyPath(delimitedBy ?? DefaultPropertyPathDelimiter);
        }
    }


    public static class ExpressionExtensions
    {
        public static string AsPropertyPath(this Expression expression, string delimitedBy = null)
        { return new PropertyPathBuildingExpressionVisitor(expression, delimitedBy).PropertyPath; }

        public static string AsConventionalHtmlInputName(this Expression expression)
        { return expression.AsPropertyPath(delimitedBy: "__"); }      
    }


    public static class TypeExtensions
    {
        public static bool IsSubclassOfGenericTypeWithDefinition(this Type type, Type genericTypeDefinition)
        {
            Guard.AgainstNull(type, "type");
            Guard.Against(genericTypeDefinition, gtd => !gtd.IsGenericTypeDefinition, "Must be a generic type definition", "genericTypeDefinition");
            var isOrNot = false;
            var baseType = type.BaseType;

            while (baseType != null && !isOrNot)
            {
                if (baseType.IsGenericType)
                    isOrNot = (baseType.GetGenericTypeDefinition() == genericTypeDefinition);
                baseType = baseType.BaseType;
            }

            return isOrNot;
        }

        public static bool IsEntityTypeConfiguration(this Type type)
        { return type.IsSubclassOfGenericTypeWithDefinition(typeof(EntityTypeConfiguration<>)); }

        public static bool IsComplexTypeConfiguration(this Type type)
        { return type.IsSubclassOfGenericTypeWithDefinition(typeof(ComplexTypeConfiguration<>)); }
    }


    public static class DbQueryExtensions
    {
        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, IEnumerable<Expression<Func<TResult, object>>> propertyAccessingExpressions)
        {
            foreach (var propertyAccessingExpression in propertyAccessingExpressions)
                query = query.Include(propertyAccessingExpression.AsPropertyPath());
            return query;
        }

        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, params Expression<Func<TResult, object>>[] propertyAccessingExpressions)
        { return query.Include((IEnumerable<Expression<Func<TResult, object>>>) propertyAccessingExpressions); }
    }


    public static class DbModelBuilderExtensions
    {
        private static MethodInfo GetConfigurationRegistrarAddMethodWithParameter(Type configurationType)
        {
            Guard.AgainstNull(configurationType, "configurationType");
            return typeof(ConfigurationRegistrar).GetMethods().First(method =>
            {
                var isOrNot = string.Equals(method.Name, "Add");
                if (isOrNot)
                {
                    var parameters = method.GetParameters();
                    isOrNot = parameters.First().ParameterType.GetGenericTypeDefinition() == configurationType;
                }
                return isOrNot;
            });
        }

        private static readonly MethodInfo AddEntityTypeConfigurationMethod =
            GetConfigurationRegistrarAddMethodWithParameter(typeof(EntityTypeConfiguration<>));

        private static readonly MethodInfo AddComplexTypeConfigurationMethod =
            GetConfigurationRegistrarAddMethodWithParameter(typeof(ComplexTypeConfiguration<>));


        public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, Type type)
        {
            Guard.AgainstNull(dbModelBuilder, "dbModelBuilder");
            Guard.AgainstNull(type, "type");

            foreach (var configurationType in type
                .GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
                .Where(nestedType => nestedType.IsComplexTypeConfiguration()))
            {
                AddComplexTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
                    .Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
            }


            foreach (var configurationType in type
                .GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
                .Where(nestedType => nestedType.IsEntityTypeConfiguration()))
            {
                AddEntityTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
                    .Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
            }
        }


        public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, object @object)
        {
            Guard.AgainstNull(@object, "object");
            dbModelBuilder.AddConfigurationsDefinedWithin(@object.GetType());
        }
    }

No comments:

Post a Comment