Skip to content

Commit 8059c4a

Browse files
authored
Fix X1019: Added special-cases handling to correctly validate types convertible to IEnumerable<T[]> (#199)
1 parent 154d3f6 commit 8059c4a

File tree

2 files changed

+134
-1
lines changed

2 files changed

+134
-1
lines changed

src/xunit.analyzers.tests/Analyzers/X1000/MemberDataShouldReferenceValidMemberTests.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,44 @@ public async ValueTask V2_and_V3()
203203
#pragma warning disable xUnit1053
204204
205205
using System;
206+
using System.Collections;
206207
using System.Collections.Generic;
207208
using System.Threading.Tasks;
208209
using Xunit;
209210
211+
public class NamedTypeForIEnumerableStringArray : IEnumerable<string[]>
212+
{
213+
private readonly List<string[]> _items = new List<string[]>();
214+
215+
public void Add(string sdl)
216+
{
217+
_items.Add(new string[] { sdl });
218+
}
219+
220+
public IEnumerator<string[]> GetEnumerator()
221+
{
222+
return _items.GetEnumerator();
223+
}
224+
225+
IEnumerator IEnumerable.GetEnumerator()
226+
{
227+
return GetEnumerator();
228+
}
229+
}
230+
231+
public class NamedSubtypeForIEnumerableStringArray : NamedTypeForIEnumerableStringArray {}
232+
210233
public class TestClass {
211234
public static IEnumerable<object> ObjectSource;
212235
public static object NakedObjectSource;
213236
public static object[] NakedObjectArraySource;
237+
public static object[][] NakedObjectMatrixSource;
214238
215239
public static IEnumerable<object[]> ObjectArraySource;
240+
public static IEnumerable<string[]> StringArraySource;
241+
public static NamedTypeForIEnumerableStringArray NamedTypeForIEnumerableStringArraySource;
242+
public static NamedSubtypeForIEnumerableStringArray NamedSubtypeForIEnumerableStringArraySource;
243+
216244
public static Task<IEnumerable<object[]>> TaskObjectArraySource;
217245
public static ValueTask<IEnumerable<object[]>> ValueTaskObjectArraySource;
218246
@@ -243,8 +271,13 @@ public class TestClass {
243271
[{|#0:MemberData(nameof(ObjectSource))|}]
244272
[{|#1:MemberData(nameof(NakedObjectSource))|}]
245273
[{|#2:MemberData(nameof(NakedObjectArraySource))|}]
274+
[MemberData(nameof(NakedObjectMatrixSource))]
246275
247276
[MemberData(nameof(ObjectArraySource))]
277+
[MemberData(nameof(StringArraySource))]
278+
[MemberData(nameof(NamedTypeForIEnumerableStringArraySource))]
279+
[MemberData(nameof(NamedSubtypeForIEnumerableStringArraySource))]
280+
248281
[{|#10:MemberData(nameof(TaskObjectArraySource))|}]
249282
[{|#11:MemberData(nameof(ValueTaskObjectArraySource))|}]
250283
@@ -1397,6 +1430,7 @@ public void TestMethod1(int _) { }
13971430
[{|#1:MemberData(nameof(PropertyUntypedData))|}]
13981431
[{|#2:MemberData(nameof(MethodUntypedData))|}]
13991432
[{|#3:MemberData(nameof(MethodWithArgsUntypedData), 42)|}]
1433+
14001434
public void TestMethod2(int _) { }
14011435
}
14021436
""";

src/xunit.analyzers/X1000/MemberDataShouldReferenceValidMember.cs

Lines changed: 100 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using System;
12
using System.Collections.Generic;
23
using System.Collections.Immutable;
34
using System.Diagnostics.CodeAnalysis;
@@ -290,6 +291,42 @@ static bool IsInitialized(
290291
return null;
291292
}
292293

294+
static ITypeSymbol? FindTypeArgumentFromIEnumerable(
295+
ITypeSymbol? start,
296+
Compilation compilation)
297+
{
298+
if (start is not INamedTypeSymbol namedStart)
299+
return null;
300+
301+
var iEnumerableOfT = TypeSymbolFactory.IEnumerableOfT(compilation);
302+
if (iEnumerableOfT is null)
303+
return null;
304+
305+
var current = namedStart;
306+
while (current != null)
307+
{
308+
// If this class implements IEnumerable<T>, return T type argument
309+
foreach (var iface in current.AllInterfaces)
310+
if (SymbolEqualityComparer.Default.Equals(iface.OriginalDefinition, iEnumerableOfT))
311+
return iface.TypeArguments.FirstOrDefault();
312+
313+
// Navigate to its Base types to investigate if they implement IEnumerable,
314+
// attempt to resolve the base type from the source syntax (handles `: ValidExamples`).
315+
var declaredBase = ResolveDeclaredBaseTypeFromSyntax(current, compilation) as INamedTypeSymbol;
316+
317+
// Prefer declaredBase when available; otherwise fall back to the BaseType symbol.
318+
var next = declaredBase ?? current.BaseType;
319+
320+
// Stop if we reached object
321+
if (next?.SpecialType == SpecialType.System_Object)
322+
break;
323+
324+
current = next;
325+
}
326+
327+
return null;
328+
}
329+
293330
public static (INamedTypeSymbol? TestClass, ITypeSymbol? MemberClass) GetClassTypesForAttribute(
294331
AttributeArgumentListSyntax attributeList,
295332
SemanticModel semanticModel,
@@ -318,7 +355,7 @@ public static (INamedTypeSymbol? TestClass, ITypeSymbol? MemberClass) GetClassTy
318355
List<AttributeArgumentSyntax> arguments, SemanticModel semanticModel)
319356
{
320357
if (arguments.Count > 1)
321-
return arguments.Select(a => a.Expression).ToList();
358+
return [.. arguments.Select(a => a.Expression)];
322359
if (arguments.Count != 1)
323360
return null;
324361

@@ -672,6 +709,52 @@ static void ReportUseNameof(
672709
);
673710
}
674711

712+
/// <summary>
713+
/// Resolve declared base type and find the first base class that implements IEnumerable
714+
/// (e.g. "ValidExample" from "class SubtypeValidExample : ValidExample {}")
715+
/// </summary>
716+
static ITypeSymbol? ResolveDeclaredBaseTypeFromSyntax(
717+
ITypeSymbol type,
718+
Compilation compilation)
719+
{
720+
if (type is null)
721+
return null;
722+
723+
// Look through source declarations for an explicit base type in the syntax.
724+
foreach (var declaringReference in type.DeclaringSyntaxReferences)
725+
{
726+
var syntax = declaringReference.GetSyntax();
727+
if (syntax is ClassDeclarationSyntax cls && cls.BaseList is not null && cls.BaseList.Types.Count > 0)
728+
{
729+
var baseTypeSyntax = cls.BaseList.Types.First().Type;
730+
var baseTypeName = baseTypeSyntax.ToString(); // may be qualified
731+
732+
// Try direct metadata lookup first (works for namespace-qualified names)
733+
var byMetadata = compilation.GetTypeByMetadataName(baseTypeName);
734+
if (byMetadata is not null)
735+
return byMetadata;
736+
737+
// Fallback: search symbols by the simple name and try to match fully-qualified textual form.
738+
var simpleName = baseTypeName.Split('.').Last();
739+
var candidates =
740+
compilation
741+
.GetSymbolsWithName(n => n == simpleName, SymbolFilter.Type)
742+
.OfType<INamedTypeSymbol>()
743+
.ToList();
744+
745+
// Prefer exact textual match of declared name (handles namespace-qualified baseTypeSyntax)
746+
var exact = candidates.FirstOrDefault(c => string.Equals(c.ToDisplayString(), baseTypeName, StringComparison.Ordinal));
747+
if (exact is not null)
748+
return exact;
749+
750+
// Otherwise return the first candidate in the current compilation with the simple name
751+
return candidates.FirstOrDefault();
752+
}
753+
}
754+
755+
return null;
756+
}
757+
675758
static void VerifyDataMethodParameterUsage(
676759
SemanticModel semanticModel,
677760
SyntaxNodeAnalysisContext context,
@@ -801,6 +884,13 @@ static bool VerifyDataSourceReturnType(
801884

802885
var valid = iEnumerableOfObjectArrayType.IsAssignableFrom(memberType);
803886

887+
// Special‑case handling for IEnumerable<T> where T is not object[]. If T is any array type, it is assignable to object[] and therefore valid.
888+
var memberEnumerableType = memberType.GetEnumerableType();
889+
if (!valid && memberEnumerableType is not null)
890+
{
891+
valid = memberEnumerableType.TypeKind == TypeKind.Array;
892+
}
893+
804894
if (!valid && v3 && iAsyncEnumerableOfObjectArrayType is not null)
805895
valid = iAsyncEnumerableOfObjectArrayType.IsAssignableFrom(memberType);
806896

@@ -816,6 +906,15 @@ static bool VerifyDataSourceReturnType(
816906
if (!valid && v3 && iAsyncEnumerableOfTupleType is not null)
817907
valid = iAsyncEnumerableOfTupleType.IsAssignableFrom(memberType);
818908

909+
// If still invalid and the member is a class, check whether it implements IEnumerable<T> and verify that T is an array type.
910+
// This ensures the data can be safely converted to IEnumerable<object[]>.
911+
if (!valid && memberType.TypeKind == TypeKind.Class)
912+
{
913+
var resolvedIEnumerableTypeArgument = FindTypeArgumentFromIEnumerable(memberType, compilation);
914+
if (resolvedIEnumerableTypeArgument is not null)
915+
valid = (resolvedIEnumerableTypeArgument.TypeKind == TypeKind.Array);
916+
}
917+
819918
if (!valid)
820919
ReportIncorrectReturnType(
821920
context,

0 commit comments

Comments
 (0)