Sunday, June 10, 2012

Utilizing the command pattern to support CQRS on top of a domain framework

Building on top of the domain framework I posted last time, I wanted to share a command pattern implementation I've come up with to try and make enabling CQRS scenarios easier.

Disclaimer: The code below carries no guarantees-- heck, it may not even work-- so use at your own risk.



    public interface Executable
    {
        void Execute();
    }


    public interface Command : Executable
    { }


    public interface Query : Executable
    { }


    public abstract class UnitOfWorkBasedCommand : UnitOfWorkBasedObject, Command
    {
        protected abstract void ExecuteCommand();


        public void Execute()
        { ExecuteCommand(); }


        public UnitOfWorkBasedCommand(UnitOfWork unitOfWork) : base(unitOfWork)
        { }
    }


    public abstract class UnitOfWorkBasedQuery : UnitOfWorkBasedObject, Query
    {
        protected abstract void ExecuteQuery();


        protected bool ShouldExecuteAsReadOnly
        { get; private set; }


        public void Execute()
        { ExecuteQuery(); }
        

        public UnitOfWorkBasedQuery(UnitOfWork unitOfWork, bool shouldExecuteAsReadOnly = true) : base(unitOfWork)
        { ShouldExecuteAsReadOnly = shouldExecuteAsReadOnly; }
    }


    public abstract class RepositoryAccessingCommand : UnitOfWorkBasedCommand
    {
        private readonly RepositoryFactory repositories = null;


        protected Repository<TEntity> GetRepositoryAddedToUnitOfWorkContexts<TEntity>() where TEntity : class, new()
        {
            var repository = repositories.Get<TEntity>();
            UnitOfWork.AddContext(repository);
            return repository;
        }


        protected ReadOnlyRepository<TEntity> GetReadOnlyRepositoryAddedToUnitOfWorkContexts<TEntity>() where TEntity : class, new()
        {
            var repository = repositories.GetReadOnly<TEntity>();
            UnitOfWork.AddContext(repository);
            return repository;
        }


        public RepositoryAccessingCommand(UnitOfWork unitOfWork, RepositoryFactory repositoryFactory) : base(unitOfWork)
        {
            repositories = repositoryFactory;
        }
    }


    public abstract class RepositoryAccessingQuery : UnitOfWorkBasedQuery
    {
        private readonly RepositoryFactory repositories = null;


        protected Repository<TEntity> GetRepositoryAddedToUnitOfWorkContexts<TEntity>() where TEntity : class, new()
        {
            var repository = repositories.Get<TEntity>();
            UnitOfWork.AddContext(repository);
            return repository;
        }


        protected ReadOnlyRepository<TEntity> GetReadOnlyRepositoryAddedToUnitOfWorkContexts<TEntity>() where TEntity : class, new()
        {
            var repository = repositories.GetReadOnly<TEntity>();
            UnitOfWork.AddContext(repository);
            return repository;
        }


        public RepositoryAccessingQuery(
            UnitOfWork unitOfWork, 
            RepositoryFactory repositoryFactory,
            bool shouldExecuteAsReadOnly = true
        ) : base(unitOfWork, shouldExecuteAsReadOnly)
        {
            repositories = repositoryFactory;
        }
    }
    

Sunday, June 3, 2012

Decoupling your application domain from your ORM of choice

Domain-driven design and development (DDD/DDDD) is a very popular approach for making software.  There are a variety of ORM's that make implementation of a domain model easier, such as Entity Framework and Lightspeed.   You want to decouple your application from your chosen ORM, in case you need to switch to a different ORM later. Here is an example domain framework which will enable common domain design patterns like Repository and Unit of Work.   I use Entity Framework a lot so I've included adapters for it too. Note: the ExceptionManager is one that is similar to one I shared in an earlier post.

Disclaimer: The code below carries no guarantees-- heck, it may not even work-- so use at your own risk.

    public interface Committable
    {
        void Commit();
    }

    
    public interface ContextProvider
    {
        object GetContext();
    }


    public interface UnitOfWork : Committable, IDisposable
    {
        UnitOfWork AddContext(object context);
    }


    
    public abstract class UnitOfWorkBase : IDisposable, UnitOfWork
    {
        private readonly ICollection<object> contexts = new Collection<object>();

        protected IEnumerable<object> Contexts
        { get { return contexts; } }


        protected IEnumerable<object> GetUniqueContexts()
        {
            var uniqueContexts = new Collection<object>();

            foreach (var context in Contexts)
            {
                var rootContext = context;
                while (rootContext is ContextProvider) 
                    rootContext = ((ContextProvider)rootContext).GetContext();
                if (!uniqueContexts.Contains(rootContext))
                    uniqueContexts.Add(rootContext);
            }

            return uniqueContexts;
        }


        public abstract void Commit();        


        public UnitOfWork AddContext(object context)
        {
            if (!contexts.Contains(context))
                contexts.Add(context);
            return this;
        }

        
        public void Dispose()
        {
            Commit();
            contexts.Clear();
        }
    }


    public class ContextManagingUnitOfWork : UnitOfWorkBase
    {
        public override void Commit()
        {
            
            foreach (var context in GetUniqueContexts())
            {
                var committable = context as Committable;
                if (committable != null)
                    committable.Commit();
            }
        }

        protected override void OnDisposedExplicitAfterCommit()
        {
            foreach (var context in GetUniqueContexts())
            {
                var disposable = context as IDisposable;
                if (disposable != null)
                    disposable.Dispose();
            }

            base.OnDisposedExplicitAfterCommit();
        }
    }


    public abstract class UnitOfWorkBasedObject
    {
        protected UnitOfWork UnitOfWork
        { get; private set; }

        
        public UnitOfWorkBasedObject(UnitOfWork unitOfWork)
        {
            UnitOfWork = unitOfWork;
        }
    }


    public interface DomainContext : Committable, IDisposable
    {
        ObjectSet<TEntity> GetContextFor<TEntity>() where TEntity : class, new();
        void SetStateOf<TEntity>(TEntity entity, EntityState state) where TEntity : class, new();
        EntityState GetStateOf<TEntity>(TEntity entity) where TEntity : class, new();
        bool TryGetStateOf<TEntity>(TEntity entity, out EntityState state) where TEntity : class, new();

        int ExecuteCommand(string commandText, params object[] parameters);

        ObjectResult<TElement> ExecuteQuery<TElement>(string queryText, params object[] parameters) where TElement : new();
        ObjectResult<TEntity> ExecuteQuery<TEntity>(string queryText, string entitySetName, bool asReadOnly, params object[] parameters) where TEntity : class, new();

        int ExecuteFunction(string functionName, params ObjectParameter[] parameters);
        ObjectResult<TElement> ExecuteFunction<TElement>(string functionName, params ObjectParameter[] parameters) where TElement : new();
        ObjectResult<TEntity> ExecuteFunction<TEntity>(string functionName, bool asReadOnly, params ObjectParameter[] parameters) where TEntity : class, new();

        int ExecuteStoredProcedure(string storedProcedureName, params ObjectParameter[] parameters);
        ObjectResult<TElement> ExecuteStoredProcedure<TElement>(string storedProcedureName, params ObjectParameter[] parameters) where TElement : new();
        ObjectResult<TEntity> ExecuteStoredProcedure<TEntity>(string storedProcedureName, bool asReadOnly, params ObjectParameter[] parameters) where TEntity : class, new();
    }


    public interface Repository<TEntity> where TEntity : class, new()
    {
        IQueryable<TEntity> AsQueryable(bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);

        TEntity Get(IEnumerable<object> keyValues, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity Create();
        TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity, new();
        TEntity Add(TEntity entity);
        TEntity Remove(IEnumerable<object> keyValues);
        TEntity Remove(TEntity entity);
        TEntity Update(TEntity entity);

        IEnumerable<TEntity> Where(
            Expression<Func<TEntity, bool>> criteria,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            bool asReadOnly = false,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        );

        IEnumerable<TEntity> GetAll(
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            bool asReadOnly = false,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        );
        

        TEntity Single(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity SingleOrDefault(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity First(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity FirstOrDefault(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
    }


    public interface ReadOnlyRepository<TEntity> where TEntity : class, new()
    {
        IQueryable<TEntity> AsQueryable();

        TEntity Get(IEnumerable<object> keyValues, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        
        IEnumerable<TEntity> Where(
            Expression<Func<TEntity, bool>> criteria,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        );

        IEnumerable<TEntity> GetAll(
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        );


        TEntity Single(Expression<Func<TEntity, bool>> criteria = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity SingleOrDefault(Expression<Func<TEntity, bool>> criteria = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity First(Expression<Func<TEntity, bool>> criteria = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
        TEntity FirstOrDefault(Expression<Func<TEntity, bool>> criteria = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null);
    }


    public interface RepositoryFactory
    {
        Repository<T> Get<T>() where T : class, new();
        ReadOnlyRepository<T> GetReadOnly<T>() where T : class, new();
    }


    public class ReadOnlyRepositoryWrapper<TEntity> : ReadOnlyRepository<TEntity>, ContextProvider where TEntity : class, new()
    {
        private readonly Repository<TEntity> repository = null;


        public IQueryable<TEntity> AsQueryable()
        { return repository.AsQueryable(asReadOnly: true); }

        public TEntity Get(IEnumerable<object> keyValues, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.Get(keyValues, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public IEnumerable<TEntity> Where(Expression<Func<TEntity, bool>> criteria, Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.Where(criteria, orderBy: orderBy, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public IEnumerable<TEntity> GetAll(Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.GetAll(orderBy: orderBy, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public TEntity Single(Expression<Func<TEntity, bool>> criteria, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.Single(criteria: criteria, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public TEntity SingleOrDefault(Expression<Func<TEntity, bool>> criteria, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.SingleOrDefault(criteria: criteria, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public TEntity First(Expression<Func<TEntity, bool>> criteria, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.First(criteria: criteria, asReadOnly: true, eagerlyLoad: eagerlyLoad); }

        public TEntity FirstOrDefault(Expression<Func<TEntity, bool>> criteria, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return repository.FirstOrDefault(criteria: criteria, asReadOnly: true, eagerlyLoad: eagerlyLoad); }


        object ContextProvider.GetContext()
        { return repository; }


        public ReadOnlyRepositoryWrapper(Repository<TEntity> repository)
        {
            this.repository = repository;
        }
    }


    public class DomainContextBasedRepository<TEntity> : Repository<TEntity>, ContextProvider where TEntity : class, new()
    {
        protected static readonly object KeyPropertiesLock = new object();
        protected static IEnumerable<PropertyInfo> KeyProperties = null;

        protected DomainContext DomainContext
        { get; private set; }
        
        protected ExceptionManager ExceptionManager
        { get; private set; }

        protected string DefaultExceptionPolicy
        { get; set; }


        protected void ProcessWithinExceptionPolicy(Action action, string exceptionPolicy)
        { ExceptionManager.Handle(action, exceptionPolicy); }

        protected TResult ProcessedWithinExceptionPolicy<TResult>(Func<TResult> func, string exceptionPolicy)
        {
            var result = default(TResult);
            ExceptionManager.Handle(() => result = func(), exceptionPolicy);
            return result;
        }


        protected void ProcessWithinExceptionPolicy(Action action)
        { ProcessWithinExceptionPolicy(action, DefaultExceptionPolicy); }

        protected TResult ProcessedWithinExceptionPolicy<TResult>(Func<TResult> func)
        { return ProcessedWithinExceptionPolicy(func, DefaultExceptionPolicy); }
        

        protected Expression<Func<TEntity, bool>> GetKeyComparisonExpression(IEnumerable<object> keyValues)
        {
            var propertyEqualityExpressions = new Queue<Expression>();

            var keyPropertiesEnumerator = KeyProperties.GetEnumerator();
            var keyValuesEnumerator = keyValues.GetEnumerator();

            var entityParam = Expression.Parameter(typeof(TEntity), "entity");

            while (keyPropertiesEnumerator.MoveNext() && keyValuesEnumerator.MoveNext())
                propertyEqualityExpressions.Enqueue(
                    Expression.Equal(
                        Expression.Property(entityParam, keyPropertiesEnumerator.Current),
                        Expression.Constant(keyValuesEnumerator.Current)
                    )
                );

            var expression = propertyEqualityExpressions.Dequeue();
            while (propertyEqualityExpressions.Any())
                expression = Expression.AndAlso(expression, propertyEqualityExpressions.Dequeue());

            return Expression.Lambda<Func<TEntity, bool>>(expression, entityParam);
        }


        
        protected virtual ObjectSet<TEntity> GetObjectSet(bool asReadOnly = false)
        {
            var objectSet = DomainContext.GetContextFor<TEntity>();
            if (asReadOnly)
                objectSet.MergeOption = MergeOption.NoTracking;
            return objectSet;
        }

        protected virtual ObjectQuery<TEntity> GetQuery(bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        {
            var query = (ObjectQuery<TEntity>) GetObjectSet(asReadOnly);
            if (eagerlyLoad != null)
                foreach (var propertyToEagerLoad in eagerlyLoad)
                    query = query.Include(propertyToEagerLoad.AsPropertyPath());
            return query;
        }


        protected ObjectQuery<TEntity> GetQuery(Expression<Func<TEntity, bool>> criteria, bool asReadOnly, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad)
        {
            var query = GetQuery(asReadOnly, eagerlyLoad);
            return (ObjectQuery<TEntity>)(criteria != null ? query.Where(criteria) : query);
        }


        public virtual IQueryable<TEntity> AsQueryable(bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return ProcessedWithinExceptionPolicy(() => GetQuery(asReadOnly, eagerlyLoad)); }


        protected virtual TEntity Get(ObjectQuery<TEntity> objectSet, IEnumerable<object> keyValues)
        {
            return objectSet.Where(GetKeyComparisonExpression(keyValues)).FirstOrDefault();
        }


        public TEntity Get(IEnumerable<object> keyValues, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        {
            return ProcessedWithinExceptionPolicy(() => Where(GetKeyComparisonExpression(keyValues), asReadOnly: asReadOnly, eagerlyLoad: eagerlyLoad).FirstOrDefault());
        }

        public virtual TEntity Create()
        { return ProcessedWithinExceptionPolicy(() => GetObjectSet().CreateObject()); }

        public virtual TDerivedEntity Create<TDerivedEntity>() where TDerivedEntity : class, TEntity, new()
        { return ProcessedWithinExceptionPolicy(() => GetObjectSet().CreateObject<TDerivedEntity>()); }

        public virtual TEntity Add(TEntity entity)
        {
            return ProcessedWithinExceptionPolicy(() =>
            {
                GetObjectSet().AddObject(entity);
                return entity;
            });
        }

        public virtual TEntity Remove(IEnumerable<object> keyValues)
        {
            return ProcessedWithinExceptionPolicy(() =>
            {
                var objectSet = GetObjectSet();
                var entity = Get(objectSet, keyValues);
                objectSet.DeleteObject(entity);
                return entity;
            });
        }

        public virtual TEntity Remove(TEntity entity)
        {
            return ProcessedWithinExceptionPolicy(() =>
            {
                var objectSet = GetObjectSet();
                if (DomainContext.GetStateOf(entity) == EntityState.Detached)
                    objectSet.Attach(entity);
                objectSet.DeleteObject(entity);
                return entity;
            });
        }

        public virtual TEntity Update(TEntity entity)
        {
            return ProcessedWithinExceptionPolicy(() =>
            {
                try
                {
                    GetObjectSet().Attach(entity);
                    DomainContext.SetStateOf(entity, EntityState.Modified);
                }
                catch (InvalidOperationException)
                { entity = GetObjectSet().ApplyCurrentValues(entity); }

                return entity;
            });
        }
        

        public virtual IEnumerable<TEntity> Where(
            Expression<Func<TEntity, bool>> criteria,
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            bool asReadOnly = false,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        )
        {
            return ProcessedWithinExceptionPolicy(() =>
            {
                var query = GetQuery(criteria, asReadOnly, eagerlyLoad);
                return ((orderBy != null) ? orderBy(query) : query).ToArray();
            });
        }


        public virtual IEnumerable<TEntity> GetAll(
            Func<IQueryable<TEntity>, IOrderedQueryable<TEntity>> orderBy = null,
            bool asReadOnly = false,
            IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null
        )
        { return ProcessedWithinExceptionPolicy(() => Where(entity => true, orderBy, asReadOnly, eagerlyLoad)); }


        public virtual TEntity Single(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return ProcessedWithinExceptionPolicy(() => GetQuery(criteria, asReadOnly, eagerlyLoad).Single(criteria)); }


        public virtual TEntity SingleOrDefault(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return ProcessedWithinExceptionPolicy(() => GetQuery(criteria, asReadOnly, eagerlyLoad).SingleOrDefault(criteria)); }


        public virtual TEntity First(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return ProcessedWithinExceptionPolicy(() => GetQuery(criteria, asReadOnly, eagerlyLoad).First(criteria)); }


        public virtual TEntity FirstOrDefault(Expression<Func<TEntity, bool>> criteria = null, bool asReadOnly = false, IEnumerable<Expression<Func<TEntity, object>>> eagerlyLoad = null)
        { return ProcessedWithinExceptionPolicy(() => GetQuery(criteria, asReadOnly, eagerlyLoad).FirstOrDefault(criteria)); }
        

        object ContextProvider.GetContext()
        { return DomainContext; }



        public DomainContextBasedRepository(DomainContext domainContext, ExceptionManager exceptionManager)
        {
            this.DomainContext = domainContext;
            this.ExceptionManager = exceptionManager;
            this.DefaultExceptionPolicy = null;

            if (KeyProperties == null)
            {
                lock (KeyPropertiesLock)
                {
                    if (KeyProperties == null)
                    {
                        var defaultKeyOrder = 0;
                        var keyProperties = GetObjectSet().EntitySet.ElementType.KeyMembers
                            .Select(m => typeof(TEntity).GetProperty(m.Name))
                            .Select(property =>
                            {
                                var columnAttribute = property.GetCustomAttribute<ColumnAttribute>();
                                return new Tuple<int, PropertyInfo>((columnAttribute == null) ? defaultKeyOrder : columnAttribute.Order, property);
                            });


                        KeyProperties = keyProperties.OrderBy(tuple => tuple.Item1).Select(tuple => tuple.Item2);
                    }
                }
            }
        }
    }


    public class DomainContextBasedRepositoryFactory : RepositoryFactory
    {
        private readonly DomainContext domainContext = null;
        private readonly ExceptionManager exceptionManager = null;


        public Repository<T> Get<T>() where T : class, new()
        { return new DomainContextBasedRepository<T>(domainContext, exceptionManager); }

        public ReadOnlyRepository<T> GetReadOnly<T>() where T : class, new()
        { return new ReadOnlyRepositoryWrapper<T>(Get<T>()); }


        public DomainContextBasedRepositoryFactory(DomainContext domainContext, ExceptionManager exceptionManager)
        {
            this.domainContext = domainContext;
            this.exceptionManager = exceptionManager;
        }
    }
    

    public abstract class ObjectContextBasedDomainContext : ScopeLimitedObject, DomainContext
    {
        protected abstract ObjectContext ObjectContext { get; }


        public virtual ObjectSet<TEntity> GetContextFor<TEntity>() where TEntity : class, new()
        { return ObjectContext.CreateObjectSet<TEntity>(); }

        public virtual void SetStateOf<TEntity>(TEntity entity, EntityState state) where TEntity : class, new()
        { ObjectContext.ObjectStateManager.ChangeObjectState(entity, state); }

        public virtual EntityState GetStateOf<TEntity>(TEntity entity) where TEntity : class, new()
        { return ObjectContext.ObjectStateManager.GetObjectStateEntry(entity).State; }

        public bool TryGetStateOf<TEntity>(TEntity entity, out EntityState entityState) where TEntity : class, new()
        {
            var objectStateEntry = (ObjectStateEntry) null;
            var canOrNot = ObjectContext.ObjectStateManager.TryGetObjectStateEntry(entity, out objectStateEntry);
            entityState = canOrNot ? objectStateEntry.State : default(EntityState);
            return canOrNot;
        }
        

        public int ExecuteCommand(string commandText, params object[] parameters)
        { return ObjectContext.ExecuteStoreCommand(commandText, parameters); }


        public ObjectResult<TElement> ExecuteQuery<TElement>(string queryText, params object[] parameters) where TElement : new()
        { return ObjectContext.ExecuteStoreQuery<TElement>(queryText, parameters); }

        public ObjectResult<TEntity> ExecuteQuery<TEntity>(string queryText, string entitySetName, bool asReadOnly, params object[] parameters) where TEntity : class, new()
        { return ObjectContext.ExecuteStoreQuery<TEntity>(queryText, entitySetName, asReadOnly ? MergeOption.NoTracking : MergeOption.AppendOnly, parameters); }


        public int ExecuteFunction(string functionName, params ObjectParameter[] parameters)
        { return ObjectContext.ExecuteFunction(functionName, parameters); }

        public ObjectResult<TElement> ExecuteFunction<TElement>(string functionName, params ObjectParameter[] parameters) where TElement : new()
        { return ObjectContext.ExecuteFunction<TElement>(functionName, parameters); }

        public ObjectResult<TEntity> ExecuteFunction<TEntity>(string functionName, bool asReadOnly, params ObjectParameter[] parameters) where TEntity : class, new()
        { return ObjectContext.ExecuteFunction<TEntity>(functionName, asReadOnly ? MergeOption.NoTracking : MergeOption.AppendOnly, parameters); }


        public int ExecuteStoredProcedure(string storedProcedureName, params ObjectParameter[] parameters)
        { return ObjectContext.ExecuteFunction(storedProcedureName, parameters); }

        public ObjectResult<TElement> ExecuteStoredProcedure<TElement>(string storedProcedureName, params ObjectParameter[] parameters) where TElement : new()
        { return ObjectContext.ExecuteFunction<TElement>(storedProcedureName, parameters); }

        public ObjectResult<TEntity> ExecuteStoredProcedure<TEntity>(string storedProcedureName, bool asReadOnly, params ObjectParameter[] parameters) where TEntity : class, new()
        { return ObjectContext.ExecuteFunction<TEntity>(storedProcedureName, asReadOnly ? MergeOption.NoTracking : MergeOption.AppendOnly, parameters); }
        

        public virtual void Commit()
        { ObjectContext.SaveChanges(); }
    }
    

    // adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
    public class PropertyPathBuildingExpressionVisitor : ExpressionVisitor
    {
        public const string DefaultPropertyPathDelimiter = ".";

        private Stack<string> propertyNameStack = null;


        public Expression Expression
        { get; private set; }

        public string PropertyPath
        { get; private set; }


        private void BuildPropertyPath(string propertyPathDelimeter)
        {
            propertyNameStack = new Stack<string>();

            Visit(Expression);

            PropertyPath = propertyNameStack.Aggregate(
                new StringBuilder(),
                (sb, name) => (sb.Length > 0 ? sb.Append(propertyPathDelimeter) : sb).Append(name)
            ).ToString();

            propertyNameStack = null;
        }


        protected override Expression VisitMember(MemberExpression expression)
        {
            if (propertyNameStack != null)
                propertyNameStack.Push(expression.Member.Name);
            return base.VisitMember(expression);
        }


        protected override Expression VisitMethodCall(MethodCallExpression expression)
        {
            var visitedExpression = (Expression)expression;

            if (expression.Method.IsLinqOperator())
            {
                for (int i = 1; i < expression.Arguments.Count; i++)
                    Visit(expression.Arguments[i]);
                Visit(expression.Arguments[0]);
            }
            else
                visitedExpression = base.VisitMethodCall(expression);

            return visitedExpression;
        }


        public PropertyPathBuildingExpressionVisitor(Expression expression, string delimitedBy = null)
        {
            Expression = expression;
            BuildPropertyPath(delimitedBy ?? DefaultPropertyPathDelimiter);
        }
    }


    public static class ExpressionExtensions
    {
        public static string AsPropertyPath(this Expression expression, string delimitedBy = null)
        { return new PropertyPathBuildingExpressionVisitor(expression, delimitedBy).PropertyPath; }
    }
    

    public class UnableToRemoveEntityDueToRelationshipConstraintsException : Exception
    {
        public UnableToRemoveEntityDueToRelationshipConstraintsException(string message, Exception exception) : base(message, exception)
        { }

        public UnableToRemoveEntityDueToRelationshipConstraintsException(Exception exception) : this(Resources.CannotRemoveEntityDueToRelationshipConstraintsMessage, exception)
        { }

        public UnableToRemoveEntityDueToRelationshipConstraintsException(string message) : base(message)
        { }

        public UnableToRemoveEntityDueToRelationshipConstraintsException() : base()
        { }
    }


    public class DbContextBasedDomainContext : ObjectContextBasedDomainContext
    {
        private static readonly string DeleteStatementConflictedMessagePrefix = "The DELETE statement conflicted";


        private DbContext dbContext = null;

        
        protected override ObjectContext ObjectContext
        { get { return ((IObjectContextAdapter)dbContext).ObjectContext; } }


        protected ExceptionManager ExceptionManager
        { get; private set; }

        protected string DefaultExceptionPolicy
        { get; set; }


        protected override void OnDisposeExplicit()
        {
            base.OnDisposeExplicit();
            dbContext.Dispose();
            dbContext = null;
        }


        public override void Commit()
        {
            ExceptionManager.Process(
                () =>
                {
                    try
                    { base.Commit(); }
                    catch (UpdateException ex)
                    {
                        if ((ex.InnerException != null) && (ex.InnerException.Message.StartsWith(DeleteStatementConflictedMessagePrefix)))
                            throw new UnableToRemoveEntityDueToRelationshipConstraintsException(ex);
                        else
                            throw;
                    }
                },
                DefaultExceptionPolicy
            );
        }


        public DbContextBasedDomainContext(DbContext dbContext, ExceptionManager exceptionManager)
        {
            Guard.AgainstNull(dbContext, "dbContext");
            Guard.AgainstNull(exceptionManager, "exceptionManager");

            this.dbContext = dbContext;
            this.ExceptionManager = exceptionManager;

            this.DefaultExceptionPolicy = Rollins.Framework.ExceptionPolicy.ShieldFromDataAccessExceptions; 
        }
    }


    public static class DbQueryExtensions
    {
        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, IEnumerable<Expression<Func<TResult, object>>> propertyAccessingExpressions)
        {
            foreach (var propertyAccessingExpression in propertyAccessingExpressions)
                query = query.Include(propertyAccessingExpression.AsPropertyPath());
            return query;
        }

        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, params Expression<Func<TResult, object>>[] propertyAccessingExpressions)
        { return query.Include((IEnumerable<Expression<Func<TResult, object>>>) propertyAccessingExpressions); }
    }
    

Friday, June 1, 2012

Some useful Entity Framework extension methods

As I've been working with the latest version of Entity Framework, I wanted to go ahead and share some useful code I've found online and created for myself.

Disclaimer: The code below carries no guarantees-- heck, it may not even work-- so use at your own risk.

    public static class MethodInfoExtensions
    {
        // adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
        public static bool IsLinqOperator(this MethodInfo method)
        {
            return ((method.DeclaringType == typeof(Queryable)) || (method.DeclaringType == typeof(Enumerable)))
                && (Attribute.GetCustomAttribute(method, typeof(ExtensionAttribute)) != null);
        }
    }


    // adapted from: http://www.thomaslevesque.com/2010/10/03/entity-framework-using-include-with-lambda-expressions/
    public class PropertyPathBuildingExpressionVisitor : ExpressionVisitor
    {
        public const string DefaultPropertyPathDelimiter = ".";

        private Stack<string> propertyNameStack = null;


        public Expression Expression
        { get; private set; }

        public string PropertyPath
        { get; private set; }


        private void BuildPropertyPath(string propertyPathDelimeter)
        {
            propertyNameStack = new Stack<string>();

            Visit(Expression);

            PropertyPath = propertyNameStack.Aggregate(
                new StringBuilder(),
                (sb, name) => (sb.Length > 0 ? sb.Append(propertyPathDelimeter) : sb).Append(name)
            ).ToString();

            propertyNameStack = null;
        }


        protected override Expression VisitMember(MemberExpression expression)
        {
            if (propertyNameStack != null)
                propertyNameStack.Push(expression.Member.Name);
            return base.VisitMember(expression);
        }


        protected override Expression VisitMethodCall(MethodCallExpression expression)
        {
            var visitedExpression = (Expression)expression;

            if (expression.Method.IsLinqOperator())
            {
                for (int i = 1; i < expression.Arguments.Count; i++)
                    Visit(expression.Arguments[i]);
                Visit(expression.Arguments[0]);
            }
            else
                visitedExpression = base.VisitMethodCall(expression);

            return visitedExpression;
        }


        public PropertyPathBuildingExpressionVisitor(Expression expression, string delimitedBy = null)
        {
            Expression = expression;
            BuildPropertyPath(delimitedBy ?? DefaultPropertyPathDelimiter);
        }
    }


    public static class ExpressionExtensions
    {
        public static string AsPropertyPath(this Expression expression, string delimitedBy = null)
        { return new PropertyPathBuildingExpressionVisitor(expression, delimitedBy).PropertyPath; }

        public static string AsConventionalHtmlInputName(this Expression expression)
        { return expression.AsPropertyPath(delimitedBy: "__"); }      
    }


    public static class TypeExtensions
    {
        public static bool IsSubclassOfGenericTypeWithDefinition(this Type type, Type genericTypeDefinition)
        {
            Guard.AgainstNull(type, "type");
            Guard.Against(genericTypeDefinition, gtd => !gtd.IsGenericTypeDefinition, "Must be a generic type definition", "genericTypeDefinition");
            var isOrNot = false;
            var baseType = type.BaseType;

            while (baseType != null && !isOrNot)
            {
                if (baseType.IsGenericType)
                    isOrNot = (baseType.GetGenericTypeDefinition() == genericTypeDefinition);
                baseType = baseType.BaseType;
            }

            return isOrNot;
        }

        public static bool IsEntityTypeConfiguration(this Type type)
        { return type.IsSubclassOfGenericTypeWithDefinition(typeof(EntityTypeConfiguration<>)); }

        public static bool IsComplexTypeConfiguration(this Type type)
        { return type.IsSubclassOfGenericTypeWithDefinition(typeof(ComplexTypeConfiguration<>)); }
    }


    public static class DbQueryExtensions
    {
        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, IEnumerable<Expression<Func<TResult, object>>> propertyAccessingExpressions)
        {
            foreach (var propertyAccessingExpression in propertyAccessingExpressions)
                query = query.Include(propertyAccessingExpression.AsPropertyPath());
            return query;
        }

        public static DbQuery<TResult> Include<TResult>(this DbQuery<TResult> query, params Expression<Func<TResult, object>>[] propertyAccessingExpressions)
        { return query.Include((IEnumerable<Expression<Func<TResult, object>>>) propertyAccessingExpressions); }
    }


    public static class DbModelBuilderExtensions
    {
        private static MethodInfo GetConfigurationRegistrarAddMethodWithParameter(Type configurationType)
        {
            Guard.AgainstNull(configurationType, "configurationType");
            return typeof(ConfigurationRegistrar).GetMethods().First(method =>
            {
                var isOrNot = string.Equals(method.Name, "Add");
                if (isOrNot)
                {
                    var parameters = method.GetParameters();
                    isOrNot = parameters.First().ParameterType.GetGenericTypeDefinition() == configurationType;
                }
                return isOrNot;
            });
        }

        private static readonly MethodInfo AddEntityTypeConfigurationMethod =
            GetConfigurationRegistrarAddMethodWithParameter(typeof(EntityTypeConfiguration<>));

        private static readonly MethodInfo AddComplexTypeConfigurationMethod =
            GetConfigurationRegistrarAddMethodWithParameter(typeof(ComplexTypeConfiguration<>));


        public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, Type type)
        {
            Guard.AgainstNull(dbModelBuilder, "dbModelBuilder");
            Guard.AgainstNull(type, "type");

            foreach (var configurationType in type
                .GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
                .Where(nestedType => nestedType.IsComplexTypeConfiguration()))
            {
                AddComplexTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
                    .Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
            }


            foreach (var configurationType in type
                .GetNestedTypes(BindingFlags.Public | BindingFlags.NonPublic | BindingFlags.DeclaredOnly)
                .Where(nestedType => nestedType.IsEntityTypeConfiguration()))
            {
                AddEntityTypeConfigurationMethod.MakeGenericMethod(configurationType.BaseType.GenericTypeArguments.First())
                    .Invoke(dbModelBuilder.Configurations, new object[] { Activator.CreateInstance(configurationType) });
            }
        }


        public static void AddConfigurationsDefinedWithin(this DbModelBuilder dbModelBuilder, object @object)
        {
            Guard.AgainstNull(@object, "object");
            dbModelBuilder.AddConfigurationsDefinedWithin(@object.GetType());
        }
    }