Disclaimer: The code below carries no guarantees-- heck, it may not even work-- so use at your own risk.
public static class MethodInfoExtensions
{
// adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
public static bool IsLinqOperator(this MethodInfo method)
{
return ((method.DeclaringType == typeof(Queryable)) || (method.DeclaringType == typeof(Enumerable)))
&& (Attribute.GetCustomAttribute(method, typeof(ExtensionAttribute)) != null);
}
}
// adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
public class PropertyPathBuildingExpressionVisitor : ExpressionVisitor
{
public const string DefaultPropertyPathDelimiter = ".";
private Stack<string> propertyNameStack = null;
public Expression Expression
{ get; private set; }
public string PropertyPath
{ get; private set; }
private void BuildPropertyPath(string propertyPathDelimeter)
{
propertyNameStack = new Stack<string>();
Visit(Expression);
PropertyPath = propertyNameStack.Aggregate(
new StringBuilder(),
(sb, name) => (sb.Length > 0 ? sb.Append(propertyPathDelimeter) : sb).Append(name)
).ToString();
propertyNameStack = null;
}
protected override Expression VisitMember(MemberExpression expression)
{
if (propertyNameStack != null)
propertyNameStack.Push(expression.Member.Name);
return base.VisitMember(expression);
}
protected override Expression VisitMethodCall(MethodCallExpression expression)
{
var visitedExpression = (Expression)expression;
if (expression.Method.IsLinqOperator())
{
for (int i = 1; i < expression.Arguments.Count; i++)
Visit(expression.Arguments[i]);
Visit(expression.Arguments[0]);
}
else
visitedExpression = base.VisitMethodCall(expression);
return visitedExpression;
}
public PropertyPathBuildingExpressionVisitor(Expression expression, string delimitedBy = null)
{
Expression = expression;
BuildPropertyPath(delimitedBy ?? DefaultPropertyPathDelimiter);
}
}
public static class ExpressionExtensions
{
public static string AsPropertyPath(this Expression expression, string delimitedBy = null)
{ return new PropertyPathBuildingExpressionVisitor(expression, delimitedBy).PropertyPath; }
public static string AsConventionalHtmlInputName(this Expression expression)
{ return expression.AsPropertyPath(delimitedBy: "__"); }
}
public static class TypeExtensions
{
public static bool IsSubclassOfGenericTypeWithDefinition(this Type type, Type genericTypeDefinition)
{
Guard.AgainstNull(type, "type");
Guard.Against(genericTypeDefinition, gtd => !gtd.IsGenericTypeDefinition, "Must be a generic type definition", "genericTypeDefinition");
var isOrNot = false;
var baseType = type.BaseType;
while (baseType != null && !isOrNot)
{
if (baseType.IsGenericType)
isOrNot = (baseType.GetGenericTypeDefinition() == genericTypeDefinition);
baseType = baseType.BaseType;
}
return isOrNot;
}
public static bool IsEntityTypeConfiguration(this Type type)
{ return type.IsSubclassOfGenericTypeWithDefinition(typeof(EntityTypeConfiguration<>)); }
public static bool IsComplexTypeConfiguration(this Type type)
{ return type.IsSubclassOfGenericTypeWithDefinition(typeof(ComplexTypeConfiguration<>)); }
}
public static class DbQueryExtensions
{
public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, IEnumerable<Expression<Func<TResult, object>>> propertyAccessingExpressions)
{
foreach (var propertyAccessingExpression in propertyAccessingExpressions)
query = query.Include(propertyAccessingExpression.AsPropertyPath());
return query;
}
public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, params Expression<Func<TResult, object>>[] propertyAccessingExpressions)
{ return query.Include((IEnumerable<Expression<Func<TResult, object>>>) propertyAccessingExpressions); }
}
public static class DbModelBuilderExtensions
{
private static MethodInfo GetConfigurationRegistrarAddMethodWithParameter(Type configurationType)
{
Guard.AgainstNull(configurationType, "configurationType");
return typeof(ConfigurationRegistrar).GetMethods().First(method =>
{
var isOrNot = string.Equals(method.Name, "Add");
if (isOrNot)
{
var parameters = method.GetParameters();
isOrNot = parameters.First().ParameterType.GetGenericTypeDefinition() == configurationType;
}
return isOrNot;
});
}
private static readonly MethodInfo AddEntityTypeConfigurationMethod =
GetConfigurationRegistrarAddMethodWithParameter(typeof(EntityTypeConfiguration<>));
private static readonly MethodInfo AddComplexTypeConfigurationMethod =
GetConfigurationRegistrarAddMethodWithParameter(typeof(ComplexTypeConfiguration<>));
public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, Type type)
{
Guard.AgainstNull(dbModelBuilder, "dbModelBuilder");
Guard.AgainstNull(type, "type");
foreach (var configurationType in type
.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
.Where(nestedType => nestedType.IsComplexTypeConfiguration()))
{
AddComplexTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
.Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
}
foreach (var configurationType in type
.GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
.Where(nestedType => nestedType.IsEntityTypeConfiguration()))
{
AddEntityTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
.Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
}
}
public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, object @object)
{
Guard.AgainstNull(@object, "object");
dbModelBuilder.AddConfigurationsDefinedWithin(@object.GetType());
}
}
No comments:
Post a Comment