Token claims enrichment Azure AD B2C

Spread the love

Sometimes we need to add claims while user is being signed in using Azure AD B2C identity provider.

For that we can use API connectors.

First add the API connector. There will be two options for authentication:

  1. Basic
  2. Certificate

For this post we will be using Basic authentication

So, add your endpoint, username and password and save. Go to user sign in flow and select the API connector. The API connector is only called on first authentication and for silent token requests it will not be called. Also, your endpoint must meet some requirements: TLS and cipher suite requirements – Azure AD B2C | Microsoft Learn

Now in you .net core app add the Basic authentication scheme and an authorization policy if you need to validate need some requirements to be fulfilled while authorizing the user.

Create a basic authentication handler

public class AzureAdConnectorBasicAuthenticationSchemeOptions : AuthenticationSchemeOptions
{
}

public class AzureAdConnectorBasicAuthenticationHandler(IOptionsMonitor<AzureAdConnectorBasicAuthenticationSchemeOptions> options, 
    ILoggerFactory logger, 
    UrlEncoder encoder, 
    ISystemClock clock,
    IOptionsMonitor<AzureAdConnectorBasicAuthOptions> azureAdOptions) : 
    AuthenticationHandler<AzureAdConnectorBasicAuthenticationSchemeOptions>(options, logger, encoder, clock)
{
    private readonly IOptionsMonitor<AzureAdConnectorBasicAuthOptions> _azureAdOptionsMonitor = azureAdOptions;

    protected override Task<AuthenticateResult> HandleAuthenticateAsync()
    {
        if (Request.Headers.TryGetValue("Authorization", out Microsoft.Extensions.Primitives.StringValues value) && value.Count > 0)
        {
            var authHeader = value.FirstOrDefault(h => h?.StartsWith(AuthenticationSchemes.Basic) == true);

            if (authHeader == null)
            {
                return Task.FromResult(AuthenticateResult.NoResult());
            }

            var usernamePassword = value.FirstOrDefault()?.Replace("Basic ", string.Empty);
            if (!string.IsNullOrWhiteSpace(usernamePassword))
            {
                // decode info and check if username password is correct
                var decodedInfo = Encoding.UTF8.GetString(Convert.FromBase64String(usernamePassword));
                var decodedInfoSplit = decodedInfo.Split(':');

                if (decodedInfoSplit.Length == 2)
                {
                    var userName = decodedInfo.Split(':')[0];
                    var password = decodedInfo.Split(":")[1];
                    if (string.Compare(userName, _azureAdOptionsMonitor.CurrentValue.UserName) == 0
                                && string.Compare(password, _azureAdOptionsMonitor.CurrentValue.Password) == 0)
                    {
                        var identity = new ClaimsIdentity(new List<Claim>(), Scheme.Name);

                        return Task.FromResult(AuthenticateResult.Success(new AuthenticationTicket(new ClaimsPrincipal(identity), Scheme.Name)));
                    }
                }
            }
        }
        _logger.Error($"AzureADConnectorAuthenticationFailure. Auth Header: {Request.Headers.Authorization}");
        return Task.FromResult(AuthenticateResult.Fail("Token is not provided or is invalid"));
    }
}

AzureAdConnectorBasicAuthOptions contains the username and password.

You can add an Authorization policy too. You can have various checks here e.g. claims check or any header check etc.

e.g.

public class AzureAdConnectorBasicAuthorizationRequirement : IAuthorizationRequirement
{
    
}

public class AzureAdConnectorBasicAuthorizationHandler(
    ILogger logger,
    IOptionsMonitor<AzureAdConnectorBasicAuthOptions> azureAdBasicOptions) 
    : AuthorizationHandler<AzureAdConnectorBasicAuthorizationRequirement>
{
    private readonly ILogger _logger = logger;
    private readonly AzureAdConnectorBasicAuthOptions _azureAdBasicOptions = azureAdBasicOptions.CurrentValue;

    protected override async Task HandleRequirementAsync(AuthorizationHandlerContext context, AzureAdConnectorBasicAuthorizationRequirement requirement)
    {
        if (context.Resource != null)
        {
            var httpContext = (Microsoft.AspNetCore.Http.DefaultHttpContext)context.Resource;
            var request = httpContext.Request;

            if (httpContext.User.Identities
                                    .FirstOrDefault(i => i.IsAuthenticated 
                                                            && string.Compare(i.AuthenticationType, AuthenticationSchemes.Basic) == 0) 
                                    != null)
            {
                context.Succeed(requirement);
            }
        } 
    }
}

Register them:

 builder.Services.AddAuthentication(opt =>
 {
     opt.DefaultScheme = null;
 })
 .AddScheme<AzureAdConnectorBasicAuthenticationSchemeOptions, AzureAdConnectorBasicAuthenticationHandler>("Basic", null);

 builder.Services.AddSingleton<IAuthorizationHandler, AzureAdConnectorBasicAuthorizationHandler>();

builder.Services.AddAuthorizationBuilder()
    .AddPolicy("polictyname", new AuthorizationPolicy(new List<IAuthorizationRequirement>() { new AzureAdConnectorBasicAuthorizationRequirement() },
                                new List<string>() { AuthenticationSchemes.Basic }
                                ));

Now the endpoint example:

Use the authorize attribute on controller/method:

[Authorize(Policy = "Your policy name", AuthenticationSchemes = AuthenticationSchemes.Basic)]

Method:

 public async Task<IActionResult> GetUserAttributes([FromBody] Dictionary<string, object> currentInformation)
 {
     // do some validations
     currentInformation.TryGetValue("client_id", out object? clientIdJSON);

     if (clientIdJSON == null) {

         return Ok(null);
     }

     var clientId = ((JsonElement)clientIdJSON).GetString();

     if (string.Compare(clientId, _azureAdOptions.Applications.Web.Id, StringComparison.OrdinalIgnoreCase) != 0)
     {
         return Ok(null);
     }


     currentInformation.TryGetValue("objectId", out object? objectIdJSON);

     if(objectIdJSON == null || string.IsNullOrWhiteSpace(((JsonElement)objectIdJSON).GetString()))
     {
         return Ok(null);
     }

     var objectId = ((JsonElement)objectIdJSON).GetString();


     
     var extensionAppId = _azureAdOptions.Applications.Extension.Id.Replace("-", "");
     
     var propertyName = $"extension_{extensionAppId}_claimName";

     // create response
     var response = new Dictionary<string, Object?>
     {
         [propertyName] = JsonSerializer.Serialize(new List<string>()),
         ["Version"] = HttpContext.GetRequestedApiVersion()?.ToString(),
         ["Action"] = "Continue"
     };

     return Ok(response);
 }

Now unit test it:

using Asp.Versioning;
using AutoFixture;
using AutoMapper;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Http.Features;
using Microsoft.AspNetCore.Mvc;
using Microsoft.Extensions.Options;
using Moq;
using System.Text.Json;

namespace Tests.Controllers
{

    public class AzureADControllerTest
    {
        private readonly IFixture _fixture;
        private readonly Mock<IService> _serviceMock;
        private readonly Mock<ILogger> _loggerMock;
        private readonly AzureADClaimsController _sut;
        private readonly IOptions<AzureADOptions> _azureADOptions;
        public AzureADControllerTest()
        {
            _fixture = new Fixture();
            _serviceMock = _fixture.Freeze<Mock<IService>>();
            _azureADOptions = Options.Create(new AzureADOptions());
            _loggerMock = new Mock<ILogger>();
            _sut = new AzureADClaimsController(_azureADOptions, _serviceMock.Object, _loggerMock.Object);
        }

        [Fact]
        public async Task GetUserClaims()
        {
            #region Arrange
            var users = new List<User>();

            // do the setup of your service functions
            
            // setup httpContext
            _sut.ControllerContext = new ControllerContext();
            var httpContext = new Mock<HttpContext>();
            var features = new FeatureCollection();
            IApiVersioningFeature feature = new ApiVersioningFeature(httpContext.Object);
            feature.RequestedApiVersion = new ApiVersion(2, 0);
            features.Set(feature);
            _sut.ControllerContext.HttpContext = new DefaultHttpContext(features);

            #endregion Arrange

            #region Act
            var result1 = await _sut.GetUserAttributes(new Dictionary<string, object> { 
                { "objectId", JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize("You user id")) }, 
                { "client_id", JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize("your app id")) } 
            });

            var result2 = await _sut.GetUserAttributes(new Dictionary<string, object> {
                { "objectId", JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize("3800baa3-92aa-4438-b473-7e829f4b6631")) },
                { "client_id", JsonSerializer.Deserialize<dynamic>(JsonSerializer.Serialize("4f8d7eed-1006-44c2-8776-8db2e682d816")) }
            });

            #endregion Act

            #region Assert
            Assert.NotNull(((OkObjectResult)result1).Value);
            Assert.Null(((OkObjectResult)result2).Value);
            #endregion Assert
        }
    }
}

I think it would help you get you clear idea about it. And you will be able to implement it in no time.

Above code is created with .Net 8.

Cheers and Peace out!!!