Skip to content

Assemblymanager thread safety #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Nov 17, 2016
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Wrap assemblies in thread-safe class, make sure forked thread joins i…
…n test_AssemblyLoadThreadSafety
  • Loading branch information
abessen committed Oct 27, 2016
commit 2742b1287035c97a3cf5d1321bd74f7259604c79
141 changes: 98 additions & 43 deletions src/runtime/assemblymanager.cs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
using System;
using System.Collections;
using System.IO;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Reflection;
using System.Threading;

namespace Python.Runtime
{
Expand All @@ -22,8 +24,7 @@ internal class AssemblyManager
// updated only under GIL?
static Dictionary<string, int> probed;
// modified from event handlers below, potentially triggered from different .NET threads
// we guard access to assemblies via lock(assemblies) { ... } blocks
static List<Assembly> assemblies;
static AssemblyList assemblies;
internal static List<string> pypath;

private AssemblyManager()
Expand All @@ -42,7 +43,7 @@ internal static void Initialize()
ConcurrentDictionary<string, ConcurrentDictionary<Assembly, string>>();
probed = new Dictionary<string, int>(32);
//generics = new Dictionary<string, Dictionary<string, string>>();
assemblies = new List<Assembly>(16);
assemblies = new AssemblyList(16);
pypath = new List<string>(16);

AppDomain domain = AppDomain.CurrentDomain;
Expand Down Expand Up @@ -92,10 +93,7 @@ internal static void Shutdown()
static void AssemblyLoadHandler(Object ob, AssemblyLoadEventArgs args)
{
Assembly assembly = args.LoadedAssembly;
lock (assemblies)
{
assemblies.Add(assembly);
}
assemblies.Add(assembly);
ScanAssembly(assembly);
}

Expand All @@ -111,16 +109,12 @@ static void AssemblyLoadHandler(Object ob, AssemblyLoadEventArgs args)
static Assembly ResolveHandler(Object ob, ResolveEventArgs args)
{
string name = args.Name.ToLower();
lock (assemblies)
foreach (Assembly a in assemblies)
{
for (int i = 0; i < assemblies.Count; i++)
string full = a.FullName.ToLower();
if (full.StartsWith(name))
{
Assembly a = (Assembly) assemblies[i];
string full = a.FullName.ToLower();
if (full.StartsWith(name))
{
return a;
}
return a;
}
}
return LoadAssemblyPath(args.Name);
Expand Down Expand Up @@ -275,15 +269,11 @@ public static Assembly LoadAssemblyFullPath(string name)

public static Assembly FindLoadedAssembly(string name)
{
lock (assemblies)
foreach (Assembly a in assemblies)
{
for (int i = 0; i < assemblies.Count; i++)
if (a.GetName().Name == name)
{
Assembly a = (Assembly) assemblies[i];
if (a.GetName().Name == name)
{
return a;
}
return a;
}
}
return null;
Expand All @@ -307,15 +297,15 @@ public static bool LoadImplicit(string name, bool warn = true)
bool loaded = false;
string s = "";
Assembly lastAssembly = null;
HashSet<Assembly> assemblies = null;
HashSet<Assembly> assembliesSet = null;
for (int i = 0; i < names.Length; i++)
{
s = (i == 0) ? names[0] : s + "." + names[i];
if (!probed.ContainsKey(s))
{
if (assemblies == null)
if (assembliesSet == null)
{
assemblies = new HashSet<Assembly>(AppDomain.CurrentDomain.GetAssemblies());
assembliesSet = new HashSet<Assembly>(AppDomain.CurrentDomain.GetAssemblies());
}
Assembly a = FindLoadedAssembly(s);
if (a == null)
Expand All @@ -326,7 +316,7 @@ public static bool LoadImplicit(string name, bool warn = true)
{
a = LoadAssembly(s);
}
if (a != null && !assemblies.Contains(a))
if (a != null && !assembliesSet.Contains(a))
{
loaded = true;
lastAssembly = a;
Expand Down Expand Up @@ -392,17 +382,12 @@ internal static void ScanAssembly(Assembly assembly)

public static AssemblyName[] ListAssemblies()
{
lock (assemblies)
List<AssemblyName> names = new List<AssemblyName>(assemblies.Count);
foreach (Assembly assembly in assemblies)
{
AssemblyName[] names = new AssemblyName[assemblies.Count];
Assembly assembly;
for (int i = 0; i < assemblies.Count; i++)
{
assembly = assemblies[i];
names.SetValue(assembly.GetName(), i);
}
return names;
names.Add(assembly.GetName());
}
return names.ToArray();
}

//===================================================================
Expand Down Expand Up @@ -483,19 +468,89 @@ public static List<string> GetNames(string nsname)

public static Type LookupType(string qname)
{
lock (assemblies)
foreach (Assembly assembly in assemblies)
{
for (int i = 0; i < assemblies.Count; i++)
Type type = assembly.GetType(qname);
if (type != null)
{
Assembly assembly = (Assembly) assemblies[i];
Type type = assembly.GetType(qname);
if (type != null)
{
return type;
}
return type;
}
}
return null;
}

/// <summary>
/// Wrapper around List<Assembly> for thread safe access
/// </summary>
private class AssemblyList : IEnumerable<Assembly>{
private readonly List<Assembly> _list;
private readonly ReaderWriterLockSlim _lock;

public AssemblyList(int capacity) {
_list = new List<Assembly>(capacity);
_lock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
}

public int Count { get { return _list.Count; } }

public void Add(Assembly assembly) {
_lock.EnterWriteLock();
try
{
_list.Add(assembly);
}
finally
{
_lock.ExitWriteLock();
}
}

public IEnumerator GetEnumerator()
{
return ((IEnumerable<Assembly>) this).GetEnumerator();
}

/// <summary>
/// Enumerator wrapping around <see cref="AssemblyList._list"/>'s enumerator.
/// Acquires and releases a read lock on <see cref="AssemblyList._lock"/> during enumeration
/// </summary>
private class Enumerator : IEnumerator<Assembly>
{
private readonly AssemblyList _assemblyList;

private readonly IEnumerator<Assembly> _listEnumerator;

public Enumerator(AssemblyList assemblyList)
{
_assemblyList = assemblyList;
_listEnumerator = _assemblyList._list.GetEnumerator();
_assemblyList._lock.EnterReadLock();
}

public void Dispose()
{
_assemblyList._lock.ExitReadLock();
}

public bool MoveNext()
{
return _listEnumerator.MoveNext();
}

public void Reset()
{
_listEnumerator.Reset();
}

public Assembly Current { get { return _listEnumerator.Current; } }

object IEnumerator.Current { get { return Current; } }
}

IEnumerator<Assembly> IEnumerable<Assembly>.GetEnumerator()
{
return new Enumerator(this);
}
}
}
}
14 changes: 11 additions & 3 deletions src/testing/moduletest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,23 @@

namespace Python.Test {
public class ModuleTest {
public static void RunThreads() {
var thread = new Thread(() => {
private static Thread _thread;

public static void RunThreads()
{
_thread = new Thread(() => {
var appdomain = AppDomain.CurrentDomain;
var assemblies = appdomain.GetAssemblies();
foreach (var assembly in assemblies) {
assembly.GetTypes();
}
});
thread.Start();
_thread.Start();
}

public static void JoinThreads()
{
_thread.Join();
}
}
}
1 change: 1 addition & 0 deletions src/tests/test_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,7 @@ def test_AssemblyLoadThreadSafety(self):
# call import clr, which in AssemblyManager.GetNames iterates through the loaded types
for i in range(1, 100):
import clr
ModuleTest.JoinThreads()


def test_suite():
Expand Down