Skip to content

Assemblymanager thread safety #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 17, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 112 additions & 29 deletions src/runtime/assemblymanager.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
using System;
using System.IO;
using System.Collections;
using System.Collections.Specialized;
using System.IO;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Reflection.Emit;
using System.Threading;

namespace Python.Runtime
{
Expand All @@ -15,12 +15,16 @@ namespace Python.Runtime
/// </summary>
internal class AssemblyManager
{
static Dictionary<string, Dictionary<Assembly, string>> namespaces;
// modified from event handlers below, potentially triggered from different .NET threads
// therefore this should be a ConcurrentDictionary
static ConcurrentDictionary<string, ConcurrentDictionary<Assembly, string>> namespaces;
//static Dictionary<string, Dictionary<string, string>> generics;
static AssemblyLoadEventHandler lhandler;
static ResolveEventHandler rhandler;
// updated only under GIL?
static Dictionary<string, int> probed;
static List<Assembly> assemblies;
// modified from event handlers below, potentially triggered from different .NET threads
static AssemblyList assemblies;
internal static List<string> pypath;

private AssemblyManager()
Expand All @@ -36,10 +40,10 @@ private AssemblyManager()
internal static void Initialize()
{
namespaces = new
Dictionary<string, Dictionary<Assembly, string>>(32);
ConcurrentDictionary<string, ConcurrentDictionary<Assembly, string>>();
probed = new Dictionary<string, int>(32);
//generics = new Dictionary<string, Dictionary<string, string>>();
assemblies = new List<Assembly>(16);
assemblies = new AssemblyList(16);
pypath = new List<string>(16);

AppDomain domain = AppDomain.CurrentDomain;
Expand Down Expand Up @@ -105,9 +109,8 @@ static void AssemblyLoadHandler(Object ob, AssemblyLoadEventArgs args)
static Assembly ResolveHandler(Object ob, ResolveEventArgs args)
{
string name = args.Name.ToLower();
for (int i = 0; i < assemblies.Count; i++)
foreach (Assembly a in assemblies)
{
Assembly a = (Assembly)assemblies[i];
string full = a.FullName.ToLower();
if (full.StartsWith(name))
{
Expand Down Expand Up @@ -266,9 +269,8 @@ public static Assembly LoadAssemblyFullPath(string name)

public static Assembly FindLoadedAssembly(string name)
{
for (int i = 0; i < assemblies.Count; i++)
foreach (Assembly a in assemblies)
{
Assembly a = (Assembly)assemblies[i];
if (a.GetName().Name == name)
{
return a;
Expand All @@ -295,15 +297,15 @@ public static bool LoadImplicit(string name, bool warn = true)
bool loaded = false;
string s = "";
Assembly lastAssembly = null;
HashSet<Assembly> assemblies = null;
HashSet<Assembly> assembliesSet = null;
for (int i = 0; i < names.Length; i++)
{
s = (i == 0) ? names[0] : s + "." + names[i];
if (!probed.ContainsKey(s))
{
if (assemblies == null)
if (assembliesSet == null)
{
assemblies = new HashSet<Assembly>(AppDomain.CurrentDomain.GetAssemblies());
assembliesSet = new HashSet<Assembly>(AppDomain.CurrentDomain.GetAssemblies());
}
Assembly a = FindLoadedAssembly(s);
if (a == null)
Expand All @@ -314,7 +316,7 @@ public static bool LoadImplicit(string name, bool warn = true)
{
a = LoadAssembly(s);
}
if (a != null && !assemblies.Contains(a))
if (a != null && !assembliesSet.Contains(a))
{
loaded = true;
lastAssembly = a;
Expand Down Expand Up @@ -362,16 +364,13 @@ internal static void ScanAssembly(Assembly assembly)
for (int n = 0; n < names.Length; n++)
{
s = (n == 0) ? names[0] : s + "." + names[n];
if (!namespaces.ContainsKey(s))
{
namespaces.Add(s, new Dictionary<Assembly, string>());
}
namespaces.TryAdd(s, new ConcurrentDictionary<Assembly, string>());
}
}

if (ns != null && !namespaces[ns].ContainsKey(assembly))
if (ns != null)
{
namespaces[ns].Add(assembly, String.Empty);
namespaces[ns].TryAdd(assembly, String.Empty);
}

if (ns != null && t.IsGenericTypeDefinition)
Expand All @@ -383,14 +382,12 @@ internal static void ScanAssembly(Assembly assembly)

public static AssemblyName[] ListAssemblies()
{
AssemblyName[] names = new AssemblyName[assemblies.Count];
Assembly assembly;
for (int i = 0; i < assemblies.Count; i++)
List<AssemblyName> names = new List<AssemblyName>(assemblies.Count);
foreach (Assembly assembly in assemblies)
{
assembly = assemblies[i];
names.SetValue(assembly.GetName(), i);
names.Add(assembly.GetName());
}
return names;
return names.ToArray();
}

//===================================================================
Expand Down Expand Up @@ -471,9 +468,8 @@ public static List<string> GetNames(string nsname)

public static Type LookupType(string qname)
{
for (int i = 0; i < assemblies.Count; i++)
foreach (Assembly assembly in assemblies)
{
Assembly assembly = (Assembly)assemblies[i];
Type type = assembly.GetType(qname);
if (type != null)
{
Expand All @@ -482,5 +478,92 @@ public static Type LookupType(string qname)
}
return null;
}

/// <summary>
/// Wrapper around List<Assembly> for thread safe access
/// </summary>
private class AssemblyList : IEnumerable<Assembly>{
private readonly List<Assembly> _list;
private readonly ReaderWriterLockSlim _lock;

public AssemblyList(int capacity) {
_list = new List<Assembly>(capacity);
_lock = new ReaderWriterLockSlim();
}

public int Count
{
get
{
_lock.EnterReadLock();
try {
return _list.Count;
}
finally {
_lock.ExitReadLock();
}
}
}

public void Add(Assembly assembly) {
_lock.EnterWriteLock();
try
{
_list.Add(assembly);
}
finally
{
_lock.ExitWriteLock();
}
}

public IEnumerator GetEnumerator()
{
return ((IEnumerable<Assembly>) this).GetEnumerator();
}

/// <summary>
/// Enumerator wrapping around <see cref="AssemblyList._list"/>'s enumerator.
/// Acquires and releases a read lock on <see cref="AssemblyList._lock"/> during enumeration
/// </summary>
private class Enumerator : IEnumerator<Assembly>
{
private readonly AssemblyList _assemblyList;

private readonly IEnumerator<Assembly> _listEnumerator;

public Enumerator(AssemblyList assemblyList)
{
_assemblyList = assemblyList;
_assemblyList._lock.EnterReadLock();
_listEnumerator = _assemblyList._list.GetEnumerator();
}

public void Dispose()
{
_listEnumerator.Dispose();
_assemblyList._lock.ExitReadLock();
}

public bool MoveNext()
{
return _listEnumerator.MoveNext();
}

public void Reset()
{
_listEnumerator.Reset();
}

public Assembly Current { get { return _listEnumerator.Current; } }

object IEnumerator.Current { get { return Current; } }
}

IEnumerator<Assembly> IEnumerable<Assembly>.GetEnumerator()
{
return new Enumerator(this);
}
}
}
}
1 change: 1 addition & 0 deletions src/testing/Python.Test.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@
<Compile Include="indexertest.cs" />
<Compile Include="interfacetest.cs" />
<Compile Include="methodtest.cs" />
<Compile Include="moduletest.cs" />
<Compile Include="propertytest.cs" />
<Compile Include="threadtest.cs" />
<Compile Include="doctest.cs" />
Expand Down
25 changes: 25 additions & 0 deletions src/testing/moduletest.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
using System;
using System.Threading;

namespace Python.Test {
public class ModuleTest {
private static Thread _thread;

public static void RunThreads()
{
_thread = new Thread(() => {
var appdomain = AppDomain.CurrentDomain;
var assemblies = appdomain.GetAssemblies();
foreach (var assembly in assemblies) {
assembly.GetTypes();
}
});
_thread.Start();
}

public static void JoinThreads()
{
_thread.Join();
}
}
}
23 changes: 21 additions & 2 deletions src/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,11 @@ def testModuleInterface(self):
import System
self.assertEquals(type(System.__dict__), type({}))
self.assertEquals(System.__name__, 'System')
# the filename can be any module from the System namespace (eg System.Data.dll or System.dll)
self.assertTrue(fnmatch(System.__file__, "*System*.dll"))
# the filename can be any module from the System namespace
# (eg System.Data.dll or System.dll, but also mscorlib.dll)
system_file = System.__file__
self.assertTrue(fnmatch(system_file, "*System*.dll") or fnmatch(system_file, "*mscorlib.dll"),
"unexpected System.__file__: " + system_file)
self.assertTrue(System.__doc__.startswith("Namespace containing types from the following assemblies:"))
self.assertTrue(self.isCLRClass(System.String))
self.assertTrue(self.isCLRClass(System.Int32))
Expand Down Expand Up @@ -353,6 +356,22 @@ def test_ClrAddReference(self):
self.assertRaises(FileNotFoundException,
AddReference, "somethingtotallysilly")

def test_AssemblyLoadThreadSafety(self):
import time
from Python.Test import ModuleTest
# spin up .NET thread which loads assemblies and triggers AppDomain.AssemblyLoad event
ModuleTest.RunThreads()
time.sleep(1e-5)
for i in range(1, 100):
# call import clr, which in AssemblyManager.GetNames iterates through the loaded types
import clr
# import some .NET types
from System import DateTime
from System import Guid
from System.Collections.Generic import Dictionary
dict = Dictionary[Guid,DateTime]()
ModuleTest.JoinThreads()


def test_suite():
return unittest.makeSuite(ModuleTests)
Expand Down