1+ using System ;
12using System . Collections . Generic ;
23using System . Collections . Immutable ;
34using System . Diagnostics . CodeAnalysis ;
@@ -290,6 +291,42 @@ static bool IsInitialized(
290291 return null ;
291292 }
292293
294+ static ITypeSymbol ? FindTypeArgumentFromIEnumerable (
295+ ITypeSymbol ? start ,
296+ Compilation compilation )
297+ {
298+ if ( start is not INamedTypeSymbol namedStart )
299+ return null ;
300+
301+ var iEnumerableOfT = TypeSymbolFactory . IEnumerableOfT ( compilation ) ;
302+ if ( iEnumerableOfT is null )
303+ return null ;
304+
305+ var current = namedStart ;
306+ while ( current != null )
307+ {
308+ // If this class implements IEnumerable<T>, return T type argument
309+ foreach ( var iface in current . AllInterfaces )
310+ if ( SymbolEqualityComparer . Default . Equals ( iface . OriginalDefinition , iEnumerableOfT ) )
311+ return iface . TypeArguments . FirstOrDefault ( ) ;
312+
313+ // Navigate to its Base types to investigate if they implement IEnumerable,
314+ // attempt to resolve the base type from the source syntax (handles `: ValidExamples`).
315+ var declaredBase = ResolveDeclaredBaseTypeFromSyntax ( current , compilation ) as INamedTypeSymbol ;
316+
317+ // Prefer declaredBase when available; otherwise fall back to the BaseType symbol.
318+ var next = declaredBase ?? current . BaseType ;
319+
320+ // Stop if we reached object
321+ if ( next ? . SpecialType == SpecialType . System_Object )
322+ break ;
323+
324+ current = next ;
325+ }
326+
327+ return null ;
328+ }
329+
293330 public static ( INamedTypeSymbol ? TestClass , ITypeSymbol ? MemberClass ) GetClassTypesForAttribute (
294331 AttributeArgumentListSyntax attributeList ,
295332 SemanticModel semanticModel ,
@@ -318,7 +355,7 @@ public static (INamedTypeSymbol? TestClass, ITypeSymbol? MemberClass) GetClassTy
318355 List < AttributeArgumentSyntax > arguments , SemanticModel semanticModel )
319356 {
320357 if ( arguments . Count > 1 )
321- return arguments . Select ( a => a . Expression ) . ToList ( ) ;
358+ return [ .. arguments . Select ( a => a . Expression ) ] ;
322359 if ( arguments . Count != 1 )
323360 return null ;
324361
@@ -672,6 +709,52 @@ static void ReportUseNameof(
672709 ) ;
673710 }
674711
712+ /// <summary>
713+ /// Resolve declared base type and find the first base class that implements IEnumerable
714+ /// (e.g. "ValidExample" from "class SubtypeValidExample : ValidExample {}")
715+ /// </summary>
716+ static ITypeSymbol ? ResolveDeclaredBaseTypeFromSyntax (
717+ ITypeSymbol type ,
718+ Compilation compilation )
719+ {
720+ if ( type is null )
721+ return null ;
722+
723+ // Look through source declarations for an explicit base type in the syntax.
724+ foreach ( var declaringReference in type . DeclaringSyntaxReferences )
725+ {
726+ var syntax = declaringReference . GetSyntax ( ) ;
727+ if ( syntax is ClassDeclarationSyntax cls && cls . BaseList is not null && cls . BaseList . Types . Count > 0 )
728+ {
729+ var baseTypeSyntax = cls . BaseList . Types . First ( ) . Type ;
730+ var baseTypeName = baseTypeSyntax . ToString ( ) ; // may be qualified
731+
732+ // Try direct metadata lookup first (works for namespace-qualified names)
733+ var byMetadata = compilation . GetTypeByMetadataName ( baseTypeName ) ;
734+ if ( byMetadata is not null )
735+ return byMetadata ;
736+
737+ // Fallback: search symbols by the simple name and try to match fully-qualified textual form.
738+ var simpleName = baseTypeName . Split ( '.' ) . Last ( ) ;
739+ var candidates =
740+ compilation
741+ . GetSymbolsWithName ( n => n == simpleName , SymbolFilter . Type )
742+ . OfType < INamedTypeSymbol > ( )
743+ . ToList ( ) ;
744+
745+ // Prefer exact textual match of declared name (handles namespace-qualified baseTypeSyntax)
746+ var exact = candidates . FirstOrDefault ( c => string . Equals ( c . ToDisplayString ( ) , baseTypeName , StringComparison . Ordinal ) ) ;
747+ if ( exact is not null )
748+ return exact ;
749+
750+ // Otherwise return the first candidate in the current compilation with the simple name
751+ return candidates . FirstOrDefault ( ) ;
752+ }
753+ }
754+
755+ return null ;
756+ }
757+
675758 static void VerifyDataMethodParameterUsage (
676759 SemanticModel semanticModel ,
677760 SyntaxNodeAnalysisContext context ,
@@ -801,6 +884,13 @@ static bool VerifyDataSourceReturnType(
801884
802885 var valid = iEnumerableOfObjectArrayType . IsAssignableFrom ( memberType ) ;
803886
887+ // Special‑case handling for IEnumerable<T> where T is not object[]. If T is any array type, it is assignable to object[] and therefore valid.
888+ var memberEnumerableType = memberType . GetEnumerableType ( ) ;
889+ if ( ! valid && memberEnumerableType is not null )
890+ {
891+ valid = memberEnumerableType . TypeKind == TypeKind . Array ;
892+ }
893+
804894 if ( ! valid && v3 && iAsyncEnumerableOfObjectArrayType is not null )
805895 valid = iAsyncEnumerableOfObjectArrayType . IsAssignableFrom ( memberType ) ;
806896
@@ -816,6 +906,15 @@ static bool VerifyDataSourceReturnType(
816906 if ( ! valid && v3 && iAsyncEnumerableOfTupleType is not null )
817907 valid = iAsyncEnumerableOfTupleType . IsAssignableFrom ( memberType ) ;
818908
909+ // If still invalid and the member is a class, check whether it implements IEnumerable<T> and verify that T is an array type.
910+ // This ensures the data can be safely converted to IEnumerable<object[]>.
911+ if ( ! valid && memberType . TypeKind == TypeKind . Class )
912+ {
913+ var resolvedIEnumerableTypeArgument = FindTypeArgumentFromIEnumerable ( memberType , compilation ) ;
914+ if ( resolvedIEnumerableTypeArgument is not null )
915+ valid = ( resolvedIEnumerableTypeArgument . TypeKind == TypeKind . Array ) ;
916+ }
917+
819918 if ( ! valid )
820919 ReportIncorrectReturnType (
821920 context ,
0 commit comments