diff --git a/src/Microsoft.AspNetCore.OData/Formatter/Attributes/ReplaceIllegalFieldNameCharactersAttribute.cs b/src/Microsoft.AspNetCore.OData/Formatter/Attributes/ReplaceIllegalFieldNameCharactersAttribute.cs new file mode 100644 index 000000000..ca41feb73 --- /dev/null +++ b/src/Microsoft.AspNetCore.OData/Formatter/Attributes/ReplaceIllegalFieldNameCharactersAttribute.cs @@ -0,0 +1,63 @@ +using System; +using System.Linq; + +namespace Microsoft.AspNetCore.OData.Formatter.Attributes +{ + [AttributeUsage(AttributeTargets.Class, Inherited = false, AllowMultiple = false)] + sealed class ReplaceIllegalFieldNameCharactersAttribute : Attribute + { + //constant collection of illegal characters + private static readonly string[] illegalChars = new string[] { "@", ":", ".", "#" }; + public string ReplaceAt { get; } + public string ReplaceColon { get; } + public string ReplaceDot { get; } + public string ReplaceHash { get; } + + public ReplaceIllegalFieldNameCharactersAttribute(string replaceAt, string replaceColon, string replaceDot, string replaceHash) + { + //check if the replacement characters are not null + if (replaceAt == null || replaceColon == null || replaceDot == null) + { + throw new ArgumentNullException("Replacement characters cannot be null"); + } + + // check if any of the the replacement characters provided contain any of the illegal characters + // ex. if replaceAt contains any of the illegal characters checked one by one + if (illegalChars.Any(illegalChar => replaceAt.Contains(illegalChar) || replaceColon.Contains(illegalChar) || replaceDot.Contains(illegalChar))) + { + throw new ArgumentException("Replacement character cannot be an illegal character"); + } + + ReplaceAt = replaceAt; + ReplaceColon = replaceColon; + ReplaceDot = replaceDot; + ReplaceHash = replaceHash; + } + + public ReplaceIllegalFieldNameCharactersAttribute(string replaceAnyIllegal) + { + if (illegalChars.Any(illegalChar => replaceAnyIllegal.Contains(illegalChar))) + { + throw new ArgumentException("Replacement character cannot be an illegal character"); + } + + ReplaceAt = replaceAnyIllegal; + ReplaceColon = replaceAnyIllegal; + ReplaceDot = replaceAnyIllegal; + ReplaceHash = replaceAnyIllegal; + } + + public ReplaceIllegalFieldNameCharactersAttribute() + { + ReplaceAt = "_"; + ReplaceColon = "_"; + ReplaceDot = "_"; + ReplaceHash = "_"; + } + + public string Replace(string fieldName) + { + return fieldName.Replace("@", ReplaceAt).Replace(":", ReplaceColon).Replace(".", ReplaceDot).Replace("#",ReplaceHash); + } + } +} diff --git a/src/Microsoft.AspNetCore.OData/Formatter/Serialization/ODataResourceSerializer.cs b/src/Microsoft.AspNetCore.OData/Formatter/Serialization/ODataResourceSerializer.cs index 883c62687..b7f02755f 100644 --- a/src/Microsoft.AspNetCore.OData/Formatter/Serialization/ODataResourceSerializer.cs +++ b/src/Microsoft.AspNetCore.OData/Formatter/Serialization/ODataResourceSerializer.cs @@ -25,6 +25,7 @@ using Microsoft.AspNetCore.OData.Common; using System.Threading.Tasks; using Microsoft.AspNetCore.OData.Deltas; +using Microsoft.AspNetCore.OData.Formatter.Attributes; namespace Microsoft.AspNetCore.OData.Formatter.Serialization { @@ -551,6 +552,20 @@ public virtual ODataResource CreateResource(SelectExpandNode selectExpandNode, R // Try to add the dynamic properties if the structural type is open. AppendDynamicProperties(resource, selectExpandNode, resourceContext); + // check if the type is annotated with ReplaceIllegalFieldNameCharactersAttribute and replace the illegal characters in the field names + var resourceInstance = resourceContext.ResourceInstance; + if (resourceInstance != null) + { + var replaceIllegalFieldNameCharactersAttribute = resourceInstance.GetType().GetCustomAttribute(); + if (replaceIllegalFieldNameCharactersAttribute != null) + { + foreach (var property in resource.Properties) + { + property.Name = replaceIllegalFieldNameCharactersAttribute.Replace(property.Name); + } + } + } + if (selectExpandNode.SelectedActions != null) { IEnumerable actions = CreateODataActions(selectExpandNode.SelectedActions, resourceContext); diff --git a/test/Microsoft.AspNetCore.OData.Tests/Formatter/ODataOutFormatterTests.cs b/test/Microsoft.AspNetCore.OData.Tests/Formatter/ODataOutFormatterTests.cs index a54852ec1..eb00121b0 100644 --- a/test/Microsoft.AspNetCore.OData.Tests/Formatter/ODataOutFormatterTests.cs +++ b/test/Microsoft.AspNetCore.OData.Tests/Formatter/ODataOutFormatterTests.cs @@ -7,12 +7,15 @@ using System; using System.Collections.Generic; +using System.ComponentModel.DataAnnotations; +using System.IO; using System.Linq; using System.Text; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Mvc.Formatters; using Microsoft.AspNetCore.OData.Extensions; using Microsoft.AspNetCore.OData.Formatter; +using Microsoft.AspNetCore.OData.Formatter.Attributes; using Microsoft.AspNetCore.OData.Formatter.Value; using Microsoft.AspNetCore.OData.Tests.Commons; using Microsoft.AspNetCore.OData.Tests.Extensions; @@ -232,9 +235,107 @@ public void TryGetContentHeaderODataOutputFormatter_ThrowsArgumentNull_Type() ExceptionAssert.ThrowsArgumentNull(() => ODataOutputFormatter.TryGetContentHeader(null, null, out _), "type"); } + [Fact] + public void SerializeIllegalUnannotatedObject_ThrowsInvalidOperationException() + { + // Arrange + var illegalObject = new IllegalUnannotatedObject + { + DynamicProperties = new Dictionary + { + { "Inv@l:d.", 1 } + } + }; + + ODataConventionModelBuilder builder = new ODataConventionModelBuilder(); + builder.EntitySet("IllegalUnannotatedObjects"); + IEdmModel model = builder.GetEdmModel(); + IEdmEntitySet entitySet = model.EntityContainer.FindEntitySet("IllegalUnannotatedObjects"); + EntitySetSegment entitySetSeg = new EntitySetSegment(entitySet); + HttpRequest request = RequestFactory.Create(opt => opt.AddRouteComponents("odata", model)); + request.ODataFeature().RoutePrefix = "odata"; + request.ODataFeature().Model = model; + request.ODataFeature().Path = new ODataPath(entitySetSeg); + + OutputFormatterWriteContext context = new OutputFormatterWriteContext( + request.HttpContext, + (s, e) => null, + objectType: typeof(IllegalUnannotatedObject), + @object: illegalObject); + + ODataOutputFormatter formatter = new ODataOutputFormatter(new[] { ODataPayloadKind.Resource }); + formatter.SupportedMediaTypes.Add("application/json"); + + // Act & Assert + Assert.Throws(() => formatter.WriteResponseBodyAsync(context, Encoding.UTF8).GetAwaiter().GetResult()); + } + + // positive test as above + [Fact] + public void SerializeIllegalAnnotatedObject_ReturnsFixedValidObject() + { + // Arrange + var illegalObject = new IllegalAnnotatedObject + { + DynamicProperties = new Dictionary + { + { "I#v@l:d.", 1 } + } + }; + + ODataConventionModelBuilder builder = new ODataConventionModelBuilder(); + builder.EntitySet("IllegalAnnotatedObject"); + IEdmModel model = builder.GetEdmModel(); + IEdmEntitySet entitySet = model.EntityContainer.FindEntitySet("IllegalAnnotatedObject"); + EntitySetSegment entitySetSeg = new EntitySetSegment(entitySet); + HttpRequest request = RequestFactory.Create(opt => opt.AddRouteComponents("odata", model)); + request.ODataFeature().RoutePrefix = "odata"; + request.ODataFeature().Model = model; + request.ODataFeature().Path = new ODataPath(entitySetSeg); + + OutputFormatterWriteContext context = new OutputFormatterWriteContext( + request.HttpContext, + (s, e) => null, + objectType: typeof(IllegalAnnotatedObject), + @object: illegalObject); + + ODataOutputFormatter formatter = new ODataOutputFormatter(new[] { ODataPayloadKind.Resource }); + formatter.SupportedMediaTypes.Add("application/json"); + + // Set the Response.Body to a new MemoryStream to capture the response + var memoryStream = new MemoryStream(); + context.HttpContext.Response.Body = memoryStream; + + // Act + formatter.WriteResponseBodyAsync(context, Encoding.UTF8).GetAwaiter().GetResult(); + + memoryStream.Position = 0; + var content = new StreamReader(memoryStream).ReadToEnd(); + var jd = System.Text.Json.JsonDocument.Parse(content); + var root = jd.RootElement; + + // Assert + // check that the JSON response contains the fixed property name and its value is 1 + Assert.Equal(1, root.GetProperty("I_v_l_d_").GetInt32()); + } + private class Customer { public int Id { get; set; } } + + private class IllegalUnannotatedObject + { + [Key] + public int Id { get; set; } + public IDictionary DynamicProperties { get; set; } + } + + [ReplaceIllegalFieldNameCharacters] + private class IllegalAnnotatedObject : IllegalUnannotatedObject + { + + } + } }