From a10de60f71f38edbc9e8849db04d6c66f3d0dece Mon Sep 17 00:00:00 2001 From: Daniel Abrahamsson Date: Mon, 12 Oct 2020 13:52:39 +0200 Subject: [PATCH] Return interface objects when iterating over interface collections --- src/runtime/classbase.cs | 19 +++++++++++++++++-- src/runtime/iterator.cs | 6 ++++-- src/tests/test_interface.py | 16 ++++++++++++++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/src/runtime/classbase.cs b/src/runtime/classbase.cs index 66153fbe1..09adf5afe 100644 --- a/src/runtime/classbase.cs +++ b/src/runtime/classbase.cs @@ -1,5 +1,6 @@ using System; using System.Collections; +using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; using System.Runtime.Serialization; @@ -184,7 +185,6 @@ public static IntPtr tp_iter(IntPtr ob) var e = co.inst as IEnumerable; IEnumerator o; - if (e != null) { o = e.GetEnumerator(); @@ -199,7 +199,22 @@ public static IntPtr tp_iter(IntPtr ob) } } - return new Iterator(o).pyHandle; + var elemType = typeof(object); + var iterType = co.inst.GetType(); + foreach(var ifc in iterType.GetInterfaces()) + { + if (ifc.IsGenericType) + { + var genTypeDef = ifc.GetGenericTypeDefinition(); + if (genTypeDef == typeof(IEnumerable<>) || genTypeDef == typeof(IEnumerator<>)) + { + elemType = ifc.GetGenericArguments()[0]; + break; + } + } + } + + return new Iterator(o, elemType).pyHandle; } diff --git a/src/runtime/iterator.cs b/src/runtime/iterator.cs index f9cf10178..089e8538a 100644 --- a/src/runtime/iterator.cs +++ b/src/runtime/iterator.cs @@ -10,10 +10,12 @@ namespace Python.Runtime internal class Iterator : ExtensionType { private IEnumerator iter; + private Type elemType; - public Iterator(IEnumerator e) + public Iterator(IEnumerator e, Type elemType) { iter = e; + this.elemType = elemType; } @@ -41,7 +43,7 @@ public static IntPtr tp_iternext(IntPtr ob) return IntPtr.Zero; } object item = self.iter.Current; - return Converter.ToPythonImplicit(item); + return Converter.ToPython(item, self.elemType); } public static IntPtr tp_iter(IntPtr ob) diff --git a/src/tests/test_interface.py b/src/tests/test_interface.py index e6c6ba64b..4546471f2 100644 --- a/src/tests/test_interface.py +++ b/src/tests/test_interface.py @@ -120,3 +120,19 @@ def test_implementation_access(): assert 100 == i.__implementation__ assert clrVal == i.__raw_implementation__ assert i.__implementation__ != i.__raw_implementation__ + + +def test_interface_collection_iteration(): + """Test interface type is used when iterating over interface collection""" + import System + from System.Collections.Generic import List + elem = System.IComparable(System.Int32(100)) + typed_list = List[System.IComparable]() + typed_list.Add(elem) + for e in typed_list: + assert type(e).__name__ == "IComparable" + + untyped_list = System.Collections.ArrayList() + untyped_list.Add(elem) + for e in untyped_list: + assert type(e).__name__ == "int"