Spread the love

Sometimes we need to override the process of model creation in the DB first approach. Maybe we want to create a base class for some common functionality of all the models. So, what we could do is, use the Design Time Services: Design-time services – EF Core | Microsoft Learn. We could use ‘Custom Reverse Engineering Templates – EF Core | Microsoft Learn‘ but I found them to be unfriendly for development. First of all, visual studio 2022 don’t have any t4 support and we had to learn a templating engine. So it’s better to remain in C# world as much as possible for faster development.

The docs don’t have enough information. So, I will be providing how to implement the IModelCodeGenerator . I took most of the implementation from efcore github repo: GitHub – dotnet/efcore: EF Core is a modern object-database mapper for .NET. It supports LINQ queries, change tracking, updates, and schema migrations.

It has the default implementation of these interfaces. So, we can modify the only needed parts. I have done it for .NET 8. So there may be changes in the source code for different versions of efcore.

Hope you have already installed the ‘Microsoft.EntityFrameworkCore.Design’ dependency (It should not be transitive) and have already updated the project file as stated in the Microsoft docs.

Keep all the code of design services in ‘DEBUG’ preprocessor directive.

First implement the IDesignTimeServices interface and register the services you wish to override.

#if DEBUG
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Scaffolding;
using System.CodeDom.Compiler;

namespace CustomDesignTimeServices
{
    public class MyDesignTimeServices : IDesignTimeServices
    {
        public void ConfigureDesignTimeServices(IServiceCollection serviceCollection)
        {
             serviceCollection.AddSingleton<IModelCodeGenerator, MyModelCodeGenerator>();
        }
    }
}
#endif

Next implement the services. Look for the ‘MyEntityTypeGenerator’ in the below code. We are only modifying the Entity generation (Model). If you wish to override the DbContext generation then you have to replace the ‘CSharpDbContextGenerator’.

#if DEBUG
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Design.Internal;
using Microsoft.EntityFrameworkCore.Diagnostics;
using Microsoft.EntityFrameworkCore.Infrastructure;
using Microsoft.EntityFrameworkCore.Internal;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Microsoft.EntityFrameworkCore.Scaffolding.Internal;
using System.CodeDom.Compiler;
using System.Diagnostics;

namespace CustomDesignTimeServices
{
    public class MyModelCodeGenerator : CSharpModelGenerator
    {
        private readonly IOperationReporter _reporter;
        private readonly IServiceProvider _serviceProvider;

        /// <summary>
        ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
        ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
        ///     any release. You should only use it directly in your code with extreme caution and knowing that
        ///     doing so can result in application failures when updating to a new Entity Framework Core release.
        /// </summary>
        public MyModelCodeGenerator(
            ModelCodeGeneratorDependencies dependencies,
            IOperationReporter reporter,
            IServiceProvider serviceProvider)
            : base(dependencies, reporter, serviceProvider)
        {
            // Debugger.Launch();
            _reporter = reporter;
            _serviceProvider = serviceProvider;
        }

        /// <summary>
        ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
        ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
        ///     any release. You should only use it directly in your code with extreme caution and knowing that
        ///     doing so can result in application failures when updating to a new Entity Framework Core release.
        /// </summary>
        public override string Language
            => "C#";

        /// <summary>
        ///     This is an internal API that supports the Entity Framework Core infrastructure and not subject to
        ///     the same compatibility standards as public APIs. It may be changed or removed without notice in
        ///     any release. You should only use it directly in your code with extreme caution and knowing that
        ///     doing so can result in application failures when updating to a new Entity Framework Core release.
        /// </summary>
        public override ScaffoldedModel GenerateModel(
            IModel model,
            ModelCodeGenerationOptions options)
        {
            if (options.ContextName == null)
            {
                throw new ArgumentException(
                    CoreStrings.ArgumentPropertyNull(nameof(options.ContextName), nameof(options)), nameof(options));
            }

            if (options.ConnectionString == null)
            {
                throw new ArgumentException(
                    CoreStrings.ArgumentPropertyNull(nameof(options.ConnectionString), nameof(options)), nameof(options));
            }

            var host = new TextTemplatingEngineHost(_serviceProvider);
            var contextTemplate = new CSharpDbContextGenerator { Host = host, Session = host.CreateSession() };
            contextTemplate.Session.Add("Model", model);
            contextTemplate.Session.Add("Options", options);
            contextTemplate.Session.Add("NamespaceHint", options.ContextNamespace ?? options.ModelNamespace);
            contextTemplate.Session.Add("ProjectDefaultNamespace", options.RootNamespace);
            contextTemplate.Initialize();

            var generatedCode = ProcessTemplate(contextTemplate);

            // output DbContext .cs file
            var dbContextFileName = options.ContextName + host.Extension;
            var resultingFiles = new ScaffoldedModel
            {
                ContextFile = new ScaffoldedFile
                {
                    Path = options.ContextDir != null
                        ? Path.Combine(options.ContextDir, dbContextFileName)
                        : dbContextFileName,
                    Code = generatedCode
                }
            };

            foreach (var entityType in model.GetEntityTypes())
            {
                host.Initialize();
                var entityTypeTemplate = new MyEntityTypeGenerator { Host = host, Session = host.CreateSession() };
                entityTypeTemplate.Session.Add("EntityType", entityType);
                entityTypeTemplate.Session.Add("Options", options);
                entityTypeTemplate.Session.Add("NamespaceHint", options.ModelNamespace);
                entityTypeTemplate.Session.Add("ProjectDefaultNamespace", options.RootNamespace);
                entityTypeTemplate.Initialize();
                generatedCode = ProcessTemplate(entityTypeTemplate);
                if (string.IsNullOrWhiteSpace(generatedCode))
                {
                    continue;
                }


                // output EntityType poco .cs file
                var entityTypeFileName = entityType.Name + host.Extension;
                resultingFiles.AdditionalFiles.Add(
                    new ScaffoldedFile { Path = entityTypeFileName, Code = generatedCode });
            }

            return resultingFiles;
        }

        private string ProcessTemplate(ITextTransformation transformation)
        {
            var output = transformation.TransformText();

            foreach (CompilerError error in transformation.Errors)
            {
                _reporter.Write(error);
            }

            if (transformation.Errors.HasErrors)
            {
                throw new OperationException(DesignStrings.ErrorGeneratingOutput(transformation.GetType().Name));
            }

            return output;
        }
    }
}
#endif

Implement the Entity generator. Look for #MODIFIED in the below code. What we have done is, removed some of the properties from the models because they are in a base class.

#if DEBUG
using Microsoft.CodeAnalysis;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Metadata;
using Microsoft.EntityFrameworkCore.Metadata.Internal;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Microsoft.EntityFrameworkCore.Scaffolding.Internal;
using Microsoft.Extensions.Options;
using Microsoft.VisualStudio.TextTemplating;
using System.ComponentModel.DataAnnotations;
using System.Reflection;
using System.Text;

namespace CustomDesignTimeServices
{
    public class MyEntityTypeGenerator : CSharpEntityTypeGenerator
    {
        /// <summary>
        /// Create the template output
        /// </summary>
        public override string TransformText()
        {

            if (EntityType.IsSimpleManyToManyJoinEntityType())
            {
                // Don't scaffold these
                return "";
            }

            var services = (IServiceProvider)Host;
            var annotationCodeGenerator = services.GetRequiredService<IAnnotationCodeGenerator>();
            var code = services.GetRequiredService<ICSharpHelper>();

            var usings = new List<string>
    {
        "System",
        "System.Collections.Generic"
    };

            if (Options.UseDataAnnotations)
            {
                usings.Add("System.ComponentModel.DataAnnotations");
                usings.Add("System.ComponentModel.DataAnnotations.Schema");
                usings.Add("Microsoft.EntityFrameworkCore");
            }

            if (!string.IsNullOrEmpty(NamespaceHint))
            {

                this.Write("namespace ");
                this.Write(this.ToStringHelper.ToStringWithCulture(NamespaceHint));
                this.Write(";\r\n\r\n");

            }

            if (!string.IsNullOrEmpty(EntityType.GetComment()))
            {

                this.Write("/// <summary>\r\n/// ");
                this.Write(this.ToStringHelper.ToStringWithCulture(code.XmlComment(EntityType.GetComment())));
                this.Write("\r\n/// </summary>\r\n");

            }

            if (Options.UseDataAnnotations)
            {
                foreach (var dataAnnotation in EntityType.GetDataAnnotations(annotationCodeGenerator))
                {

                    this.Write(this.ToStringHelper.ToStringWithCulture(code.Fragment(dataAnnotation)));
                    this.Write("\r\n");

                }
            }

// #MODIFIED from here
            var orderedProperties = EntityType.GetProperties().OrderBy(p => p.GetColumnOrder() ?? -1).ToList();

            this.Write("public partial class ");
            this.Write(this.ToStringHelper.ToStringWithCulture(EntityType.Name));


            var propertiesToSkip = new List<string>();
            propertiesToSkip.AddRange(typeof(BaseEntity<string>)
                                                .GetProperties(BindingFlags.Public | BindingFlags.Instance)
                                                .Select(p => p.Name)
                                                .ToList());

            var inheritBaseEntity = false;
            if (propertiesToSkip.All(a => orderedProperties.Any(p => string.Compare(p.Name, a, StringComparison.OrdinalIgnoreCase) == 0)))
            {
                var idProperty = orderedProperties.FirstOrDefault(f => f.Name == nameof(BaseEntity<string>.Id));
                var baseEntityType = this.ToStringHelper.ToStringWithCulture(code.Reference(idProperty.ClrType));
                this.Write($" : BaseEntity<{baseEntityType}>");
                inheritBaseEntity = true;
            }

            this.Write("\r\n{\r\n");

            var firstProperty = true;

            foreach (var property in orderedProperties)
            {

                if (inheritBaseEntity && propertiesToSkip.Contains(property.Name))
                {
                    continue;
                }
// #MODIFIED till here
                if (!firstProperty)
                {
                    WriteLine("");
                }

                if (!string.IsNullOrEmpty(property.GetComment()))
                {

                    this.Write("    /// <summary>\r\n    /// ");
                    this.Write(this.ToStringHelper.ToStringWithCulture(code.XmlComment(property.GetComment(), indent: 1)));
                    this.Write("\r\n    /// </summary>\r\n");

                }

                if (Options.UseDataAnnotations)
                {
                    var dataAnnotations = property.GetDataAnnotations(annotationCodeGenerator)
                        .Where(a => !(a.Type == typeof(RequiredAttribute) && Options.UseNullableReferenceTypes && !property.ClrType.IsValueType));
                    foreach (var dataAnnotation in dataAnnotations)
                    {

                        this.Write("    ");
                        this.Write(this.ToStringHelper.ToStringWithCulture(code.Fragment(dataAnnotation)));
                        this.Write("\r\n");

                    }
                }

                usings.AddRange(code.GetRequiredUsings(property.ClrType));

                var needsNullable = Options.UseNullableReferenceTypes && property.IsNullable && !property.ClrType.IsValueType;
                var needsInitializer = Options.UseNullableReferenceTypes && !property.IsNullable && !property.ClrType.IsValueType;

                this.Write("    public ");
                this.Write(this.ToStringHelper.ToStringWithCulture(code.Reference(property.ClrType)));
                this.Write(this.ToStringHelper.ToStringWithCulture(needsNullable ? "?" : ""));
                this.Write(" ");
                this.Write(this.ToStringHelper.ToStringWithCulture(property.Name));
                this.Write(" { get; set; }");
                this.Write(this.ToStringHelper.ToStringWithCulture(needsInitializer ? " = null!;" : ""));
                this.Write("\r\n");

                firstProperty = false;
            }

            foreach (var navigation in EntityType.GetNavigations())
            {
                WriteLine("");

                if (Options.UseDataAnnotations)
                {
                    foreach (var dataAnnotation in navigation.GetDataAnnotations(annotationCodeGenerator))
                    {

                        this.Write("    ");
                        this.Write(this.ToStringHelper.ToStringWithCulture(code.Fragment(dataAnnotation)));
                        this.Write("\r\n");

                    }
                }

                var targetType = navigation.TargetEntityType.Name;
                if (navigation.IsCollection)
                {

                    this.Write("    public virtual ICollection<");
                    this.Write(this.ToStringHelper.ToStringWithCulture(targetType));
                    this.Write("> ");
                    this.Write(this.ToStringHelper.ToStringWithCulture(navigation.Name));
                    this.Write(" { get; set; } = new List<");
                    this.Write(this.ToStringHelper.ToStringWithCulture(targetType));
                    this.Write(">();\r\n");

                }
                else
                {
                    var needsNullable = Options.UseNullableReferenceTypes && !(navigation.ForeignKey.IsRequired && navigation.IsOnDependent);
                    var needsInitializer = Options.UseNullableReferenceTypes && navigation.ForeignKey.IsRequired && navigation.IsOnDependent;

                    this.Write("    public virtual ");
                    this.Write(this.ToStringHelper.ToStringWithCulture(targetType));
                    this.Write(this.ToStringHelper.ToStringWithCulture(needsNullable ? "?" : ""));
                    this.Write(" ");
                    this.Write(this.ToStringHelper.ToStringWithCulture(navigation.Name));
                    this.Write(" { get; set; }");
                    this.Write(this.ToStringHelper.ToStringWithCulture(needsInitializer ? " = null!;" : ""));
                    this.Write("\r\n");

                }
            }

            foreach (var skipNavigation in EntityType.GetSkipNavigations())
            {
                WriteLine("");

                if (Options.UseDataAnnotations)
                {
                    foreach (var dataAnnotation in skipNavigation.GetDataAnnotations(annotationCodeGenerator))
                    {

                        this.Write("    ");
                        this.Write(this.ToStringHelper.ToStringWithCulture(code.Fragment(dataAnnotation)));
                        this.Write("\r\n");

                    }
                }

                this.Write("    public virtual ICollection<");
                this.Write(this.ToStringHelper.ToStringWithCulture(skipNavigation.TargetEntityType.Name));
                this.Write("> ");
                this.Write(this.ToStringHelper.ToStringWithCulture(skipNavigation.Name));
                this.Write(" { get; set; } = new List<");
                this.Write(this.ToStringHelper.ToStringWithCulture(skipNavigation.TargetEntityType.Name));
                this.Write(">();\r\n");

            }

            this.Write("}\r\n");

            var previousOutput = GenerationEnvironment;
            GenerationEnvironment = new StringBuilder();

            foreach (var ns in usings.Distinct().OrderBy(x => x, new NamespaceComparer()))
            {

                this.Write("using ");
                this.Write(this.ToStringHelper.ToStringWithCulture(ns));
                this.Write(";\r\n");

            }

            WriteLine("");

            GenerationEnvironment.Append(previousOutput);

            return this.GenerationEnvironment.ToString();
        }
        private ITextTemplatingEngineHost hostValue;
        /// <summary>
        /// The current host for the text templating engine
        /// </summary>
        public override ITextTemplatingEngineHost Host
        {
            get
            {
                return this.hostValue;
            }
            set
            {
                this.hostValue = value;
            }
        }

        private IEntityType _EntityTypeField;

        /// <summary>
        /// Access the EntityType parameter of the template.
        /// </summary>
        private IEntityType EntityType
        {
            get
            {
                return this._EntityTypeField;
            }
        }

        private ModelCodeGenerationOptions _OptionsField;

        /// <summary>
        /// Access the Options parameter of the template.
        /// </summary>
        private ModelCodeGenerationOptions Options
        {
            get
            {
                return this._OptionsField;
            }
        }

        private string _NamespaceHintField;

        /// <summary>
        /// Access the NamespaceHint parameter of the template.
        /// </summary>
        private string NamespaceHint
        {
            get
            {
                return this._NamespaceHintField;
            }
        }

        public override void Initialize()
        {
            if ((this.Errors.HasErrors == false))
            {
                bool EntityTypeValueAcquired = false;
                if (this.Session.ContainsKey("EntityType"))
                {
                    this._EntityTypeField = ((global::Microsoft.EntityFrameworkCore.Metadata.IEntityType)(this.Session["EntityType"]));
                    EntityTypeValueAcquired = true;
                }
                if ((EntityTypeValueAcquired == false))
                {
                    string parameterValue = this.Host.ResolveParameterValue("Property", "PropertyDirectiveProcessor", "EntityType");
                    if ((string.IsNullOrEmpty(parameterValue) == false))
                    {
                        global::System.ComponentModel.TypeConverter tc = global::System.ComponentModel.TypeDescriptor.GetConverter(typeof(global::Microsoft.EntityFrameworkCore.Metadata.IEntityType));
                        if (((tc != null)
                                    && tc.CanConvertFrom(typeof(string))))
                        {
                            this._EntityTypeField = ((global::Microsoft.EntityFrameworkCore.Metadata.IEntityType)(tc.ConvertFrom(parameterValue)));
                            EntityTypeValueAcquired = true;
                        }
                        else
                        {
                            this.Error("The type \'Microsoft.EntityFrameworkCore.Metadata.IEntityType\' of the parameter \'E" +
                                    "ntityType\' did not match the type of the data passed to the template.");
                        }
                    }
                }
                if ((EntityTypeValueAcquired == false))
                {
                    //object data = global::System.Runtime.Remoting.Messaging.CallContext.LogicalGetData("EntityType");
                    //if ((data != null))
                    //{
                    //    this._EntityTypeField = ((global::Microsoft.EntityFrameworkCore.Metadata.IEntityType)(data));
                    //}
                }
                bool OptionsValueAcquired = false;
                if (this.Session.ContainsKey("Options"))
                {
                    this._OptionsField = ((global::Microsoft.EntityFrameworkCore.Scaffolding.ModelCodeGenerationOptions)(this.Session["Options"]));
                    OptionsValueAcquired = true;
                }
                if ((OptionsValueAcquired == false))
                {
                    string parameterValue = this.Host.ResolveParameterValue("Property", "PropertyDirectiveProcessor", "Options");
                    if ((string.IsNullOrEmpty(parameterValue) == false))
                    {
                        global::System.ComponentModel.TypeConverter tc = global::System.ComponentModel.TypeDescriptor.GetConverter(typeof(global::Microsoft.EntityFrameworkCore.Scaffolding.ModelCodeGenerationOptions));
                        if (((tc != null)
                                    && tc.CanConvertFrom(typeof(string))))
                        {
                            this._OptionsField = ((global::Microsoft.EntityFrameworkCore.Scaffolding.ModelCodeGenerationOptions)(tc.ConvertFrom(parameterValue)));
                            OptionsValueAcquired = true;
                        }
                        else
                        {
                            this.Error("The type \'Microsoft.EntityFrameworkCore.Scaffolding.ModelCodeGenerationOptions\' o" +
                                    "f the parameter \'Options\' did not match the type of the data passed to the templ" +
                                    "ate.");
                        }
                    }
                }
                if ((OptionsValueAcquired == false))
                {
                    //object data = global::System.Runtime.Remoting.Messaging.CallContext.LogicalGetData("Options");
                    //if ((data != null))
                    //{
                    //    this._OptionsField = ((global::Microsoft.EntityFrameworkCore.Scaffolding.ModelCodeGenerationOptions)(data));
                    //}
                }
                bool NamespaceHintValueAcquired = false;
                if (this.Session.ContainsKey("NamespaceHint"))
                {
                    this._NamespaceHintField = ((string)(this.Session["NamespaceHint"]));
                    NamespaceHintValueAcquired = true;
                }
                if ((NamespaceHintValueAcquired == false))
                {
                    string parameterValue = this.Host.ResolveParameterValue("Property", "PropertyDirectiveProcessor", "NamespaceHint");
                    if ((string.IsNullOrEmpty(parameterValue) == false))
                    {
                        global::System.ComponentModel.TypeConverter tc = global::System.ComponentModel.TypeDescriptor.GetConverter(typeof(string));
                        if (((tc != null)
                                    && tc.CanConvertFrom(typeof(string))))
                        {
                            this._NamespaceHintField = ((string)(tc.ConvertFrom(parameterValue)));
                            NamespaceHintValueAcquired = true;
                        }
                        else
                        {
                            this.Error("The type \'System.String\' of the parameter \'NamespaceHint\' did not match the type " +
                                    "of the data passed to the template.");
                        }
                    }
                }
                if ((NamespaceHintValueAcquired == false))
                {
                    //object data = global::System.Runtime.Remoting.Messaging.CallContext.LogicalGetData("NamespaceHint");
                    //if ((data != null))
                    //{
                    //    this._NamespaceHintField = ((string)(data));
                    //}
                }


            }
        }
    }
}
#endif

I think it gives you a way ahead to modify more according to your needs. And you have all the source code baked into the GitHub repo of efcore. So, it’s the power of opensource (Not needed to be free :-)). We could look at the code and modify it according to our needs.

Now comes the debugging part. It’s not as direct as we run the project in debug mode. We have to use the Just in time debugging capabilities of the Visual Studio. First of all, make sure you have installed the desktop development module in the Visual Studio Installer, because it has the JIT debugger. Then you may need to enable it (Debug using the Just-In-Time Debugger – Visual Studio (Windows) | Microsoft Learn). Now you can add checkpoints and start the scaffolding from the nuget package manager console. You will be greeted with a screen which will ask you where you would like to debug. You can choose the current project or new window, whatever you like 🙂 I think there is nothing for to share from my end.

Cheers and Peace out!!!