diff --git a/AUTHORS.md b/AUTHORS.md index 39c2eb180..9e13ca569 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -35,6 +35,7 @@ - Jeff Reback ([@jreback](https://github.com/jreback)) - Joe Frayne ([@jfrayne](https://github.com/jfrayne)) - Joe Lidbetter ([@jmlidbetter](https://github.com/jmlidbetter)) +- Joe Savage ([@s4v4g3](https://github.com/s4v4g3)) - John Burnett ([@johnburnett](https://github.com/johnburnett)) - John Wilkes ([@jbw3](https://github.com/jbw3)) - Luke Stratman ([@lstratman](https://github.com/lstratman)) diff --git a/CHANGELOG.md b/CHANGELOG.md index a545f335c..5cb0ea96c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ This document follows the conventions laid out in [Keep a CHANGELOG][]. - Removes PyLong_GetMax and PyClass_New when targetting Python3 - Added support for converting python iterators to C# arrays - Changed usage of obselete function GetDelegateForFunctionPointer(IntPtr, Type) to GetDelegateForFunctionPointer<TDelegate>(IntPtr) +- Added support for kwarg parameters when calling .NET methods from Python ### Fixed diff --git a/src/runtime/methodbinder.cs b/src/runtime/methodbinder.cs index 95b953555..8a7fc1930 100644 --- a/src/runtime/methodbinder.cs +++ b/src/runtime/methodbinder.cs @@ -2,6 +2,8 @@ using System.Collections; using System.Reflection; using System.Text; +using System.Collections.Generic; +using System.Linq; namespace Python.Runtime { @@ -280,6 +282,22 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth { // loop to find match, return invoker w/ or /wo error MethodBase[] _methods = null; + + var kwargDict = new Dictionary<string, IntPtr>(); + if (kw != IntPtr.Zero) + { + var pynkwargs = (int)Runtime.PyDict_Size(kw); + IntPtr keylist = Runtime.PyDict_Keys(kw); + IntPtr valueList = Runtime.PyDict_Values(kw); + for (int i = 0; i < pynkwargs; ++i) + { + var keyStr = Runtime.GetManagedString(Runtime.PyList_GetItem(keylist, i)); + kwargDict[keyStr] = Runtime.PyList_GetItem(valueList, i); + } + Runtime.XDecref(keylist); + Runtime.XDecref(valueList); + } + var pynargs = (int)Runtime.PyTuple_Size(args); var isGeneric = false; if (info != null) @@ -303,11 +321,12 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth ArrayList defaultArgList; bool paramsArray; - if (!MatchesArgumentCount(pynargs, pi, out paramsArray, out defaultArgList)) { + if (!MatchesArgumentCount(pynargs, pi, kwargDict, out paramsArray, out defaultArgList)) + { continue; } var outs = 0; - var margs = TryConvertArguments(pi, paramsArray, args, pynargs, defaultArgList, + var margs = TryConvertArguments(pi, paramsArray, args, pynargs, kwargDict, defaultArgList, needsResolution: _methods.Length > 1, outs: out outs); @@ -351,19 +370,21 @@ internal Binding Bind(IntPtr inst, IntPtr args, IntPtr kw, MethodBase info, Meth } /// <summary> - /// Attempts to convert Python argument tuple into an array of managed objects, - /// that can be passed to a method. + /// Attempts to convert Python positional argument tuple and keyword argument table + /// into an array of managed objects, that can be passed to a method. /// </summary> /// <param name="pi">Information about expected parameters</param> /// <param name="paramsArray"><c>true</c>, if the last parameter is a params array.</param> /// <param name="args">A pointer to the Python argument tuple</param> /// <param name="pyArgCount">Number of arguments, passed by Python</param> + /// <param name="kwargDict">Dictionary of keyword argument name to python object pointer</param> /// <param name="defaultArgList">A list of default values for omitted parameters</param> /// <param name="needsResolution"><c>true</c>, if overloading resolution is required</param> /// <param name="outs">Returns number of output parameters</param> /// <returns>An array of .NET arguments, that can be passed to a method.</returns> static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray, IntPtr args, int pyArgCount, + Dictionary<string, IntPtr> kwargDict, ArrayList defaultArgList, bool needsResolution, out int outs) @@ -374,7 +395,10 @@ static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray, for (int paramIndex = 0; paramIndex < pi.Length; paramIndex++) { - if (paramIndex >= pyArgCount) + var parameter = pi[paramIndex]; + bool hasNamedParam = kwargDict.ContainsKey(parameter.Name); + + if (paramIndex >= pyArgCount && !hasNamedParam) { if (defaultArgList != null) { @@ -384,12 +408,19 @@ static object[] TryConvertArguments(ParameterInfo[] pi, bool paramsArray, continue; } - var parameter = pi[paramIndex]; - IntPtr op = (arrayStart == paramIndex) - // map remaining Python arguments to a tuple since - // the managed function accepts it - hopefully :] - ? Runtime.PyTuple_GetSlice(args, arrayStart, pyArgCount) - : Runtime.PyTuple_GetItem(args, paramIndex); + IntPtr op; + if (hasNamedParam) + { + op = kwargDict[parameter.Name]; + } + else + { + op = (arrayStart == paramIndex) + // map remaining Python arguments to a tuple since + // the managed function accepts it - hopefully :] + ? Runtime.PyTuple_GetSlice(args, arrayStart, pyArgCount) + : Runtime.PyTuple_GetItem(args, paramIndex); + } bool isOut; if (!TryConvertArgument(op, parameter.ParameterType, needsResolution, out margs[paramIndex], out isOut)) @@ -505,7 +536,8 @@ static Type TryComputeClrArgumentType(Type parameterType, IntPtr argument, bool return clrtype; } - static bool MatchesArgumentCount(int argumentCount, ParameterInfo[] parameters, + static bool MatchesArgumentCount(int positionalArgumentCount, ParameterInfo[] parameters, + Dictionary<string, IntPtr> kwargDict, out bool paramsArray, out ArrayList defaultArgList) { @@ -513,21 +545,40 @@ static bool MatchesArgumentCount(int argumentCount, ParameterInfo[] parameters, var match = false; paramsArray = false; - if (argumentCount == parameters.Length) + if (positionalArgumentCount == parameters.Length) { match = true; - } else if (argumentCount < parameters.Length) + } + else if (positionalArgumentCount < parameters.Length) { + // every parameter past 'positionalArgumentCount' must have either + // a corresponding keyword argument or a default parameter match = true; defaultArgList = new ArrayList(); - for (var v = argumentCount; v < parameters.Length; v++) { - if (parameters[v].DefaultValue == DBNull.Value) { + for (var v = positionalArgumentCount; v < parameters.Length; v++) + { + if (kwargDict.ContainsKey(parameters[v].Name)) + { + // we have a keyword argument for this parameter, + // no need to check for a default parameter, but put a null + // placeholder in defaultArgList + defaultArgList.Add(null); + } + else if (parameters[v].IsOptional) + { + // IsOptional will be true if the parameter has a default value, + // or if the parameter has the [Optional] attribute specified. + // The GetDefaultValue() extension method will return the value + // to be passed in as the parameter value + defaultArgList.Add(parameters[v].GetDefaultValue()); + } + else + { match = false; - } else { - defaultArgList.Add(parameters[v].DefaultValue); } } - } else if (argumentCount > parameters.Length && parameters.Length > 0 && + } + else if (positionalArgumentCount > parameters.Length && parameters.Length > 0 && Attribute.IsDefined(parameters[parameters.Length - 1], typeof(ParamArrayAttribute))) { // This is a `foo(params object[] bar)` style method @@ -722,4 +773,33 @@ internal Binding(MethodBase info, object inst, object[] args, int outs) this.outs = outs; } } + + + static internal class ParameterInfoExtensions + { + public static object GetDefaultValue(this ParameterInfo parameterInfo) + { + // parameterInfo.HasDefaultValue is preferable but doesn't exist in .NET 4.0 + bool hasDefaultValue = (parameterInfo.Attributes & ParameterAttributes.HasDefault) == + ParameterAttributes.HasDefault; + + if (hasDefaultValue) + { + return parameterInfo.DefaultValue; + } + else + { + // [OptionalAttribute] was specified for the parameter. + // See https://stackoverflow.com/questions/3416216/optionalattribute-parameters-default-value + // for rules on determining the value to pass to the parameter + var type = parameterInfo.ParameterType; + if (type == typeof(object)) + return Type.Missing; + else if (type.IsValueType) + return Activator.CreateInstance(type); + else + return null; + } + } + } } diff --git a/src/testing/methodtest.cs b/src/testing/methodtest.cs index cf653f9f9..91836b727 100644 --- a/src/testing/methodtest.cs +++ b/src/testing/methodtest.cs @@ -1,5 +1,6 @@ using System; using System.IO; +using System.Runtime.InteropServices; namespace Python.Test { @@ -651,6 +652,38 @@ public static string Casesensitive() { return "Casesensitive"; } + + public static string DefaultParams(int a=0, int b=0, int c=0, int d=0) + { + return string.Format("{0}{1}{2}{3}", a, b, c, d); + } + + public static string OptionalParams([Optional]int a, [Optional]int b, [Optional]int c, [Optional] int d) + { + return string.Format("{0}{1}{2}{3}", a, b, c, d); + } + + public static bool OptionalParams_TestMissing([Optional]object a) + { + return a == Type.Missing; + } + + public static bool OptionalParams_TestReferenceType([Optional]string a) + { + return a == null; + } + + public static string OptionalAndDefaultParams([Optional]int a, [Optional]int b, int c=0, int d=0) + { + return string.Format("{0}{1}{2}{3}", a, b, c, d); + } + + public static string OptionalAndDefaultParams2([Optional]int a, [Optional]int b, [Optional, DefaultParameterValue(1)]int c, int d = 2) + { + return string.Format("{0}{1}{2}{3}", a, b, c, d); + } + + } diff --git a/src/tests/test_method.py b/src/tests/test_method.py index ad678611b..34f460d59 100644 --- a/src/tests/test_method.py +++ b/src/tests/test_method.py @@ -776,6 +776,9 @@ def test_no_object_in_param(): res = MethodTest.TestOverloadedNoObject(5) assert res == "Got int" + + res = MethodTest.TestOverloadedNoObject(i=7) + assert res == "Got int" with pytest.raises(TypeError): MethodTest.TestOverloadedNoObject("test") @@ -787,9 +790,15 @@ def test_object_in_param(): res = MethodTest.TestOverloadedObject(5) assert res == "Got int" + + res = MethodTest.TestOverloadedObject(i=7) + assert res == "Got int" res = MethodTest.TestOverloadedObject("test") assert res == "Got object" + + res = MethodTest.TestOverloadedObject(o="test") + assert res == "Got object" def test_object_in_multiparam(): @@ -813,6 +822,42 @@ def test_object_in_multiparam(): res = MethodTest.TestOverloadedObjectTwo(7.24, 7.24) assert res == "Got object-object" + res = MethodTest.TestOverloadedObjectTwo(a=5, b=5) + assert res == "Got int-int" + + res = MethodTest.TestOverloadedObjectTwo(5, b=5) + assert res == "Got int-int" + + res = MethodTest.TestOverloadedObjectTwo(a=5, b="foo") + assert res == "Got int-string" + + res = MethodTest.TestOverloadedObjectTwo(5, b="foo") + assert res == "Got int-string" + + res = MethodTest.TestOverloadedObjectTwo(a="foo", b=7.24) + assert res == "Got string-object" + + res = MethodTest.TestOverloadedObjectTwo("foo", b=7.24) + assert res == "Got string-object" + + res = MethodTest.TestOverloadedObjectTwo(a="foo", b="bar") + assert res == "Got string-string" + + res = MethodTest.TestOverloadedObjectTwo("foo", b="bar") + assert res == "Got string-string" + + res = MethodTest.TestOverloadedObjectTwo(a="foo", b=5) + assert res == "Got string-int" + + res = MethodTest.TestOverloadedObjectTwo("foo", b=5) + assert res == "Got string-int" + + res = MethodTest.TestOverloadedObjectTwo(a=7.24, b=7.24) + assert res == "Got object-object" + + res = MethodTest.TestOverloadedObjectTwo(7.24, b=7.24) + assert res == "Got object-object" + def test_object_in_multiparam_exception(): """Test method with object multiparams behaves""" @@ -966,3 +1011,120 @@ def test_getting_overloaded_constructor_binding_does_not_leak_ref_count(): # simple test refCount = sys.getrefcount(PlainOldClass.Overloads[int]) assert refCount == 1 + + +def test_default_params(): + # all positional parameters + res = MethodTest.DefaultParams(1,2,3,4) + assert res == "1234" + + res = MethodTest.DefaultParams(1, 2, 3) + assert res == "1230" + + res = MethodTest.DefaultParams(1, 2) + assert res == "1200" + + res = MethodTest.DefaultParams(1) + assert res == "1000" + + res = MethodTest.DefaultParams(a=2) + assert res == "2000" + + res = MethodTest.DefaultParams(b=3) + assert res == "0300" + + res = MethodTest.DefaultParams(c=4) + assert res == "0040" + + res = MethodTest.DefaultParams(d=7) + assert res == "0007" + + res = MethodTest.DefaultParams(a=2, c=5) + assert res == "2050" + + res = MethodTest.DefaultParams(1, d=7, c=3) + assert res == "1037" + + with pytest.raises(TypeError): + MethodTest.DefaultParams(1,2,3,4,5) + +def test_optional_params(): + res = MethodTest.OptionalParams(1, 2, 3, 4) + assert res == "1234" + + res = MethodTest.OptionalParams(1, 2, 3) + assert res == "1230" + + res = MethodTest.OptionalParams(1, 2) + assert res == "1200" + + res = MethodTest.OptionalParams(1) + assert res == "1000" + + res = MethodTest.OptionalParams(a=2) + assert res == "2000" + + res = MethodTest.OptionalParams(b=3) + assert res == "0300" + + res = MethodTest.OptionalParams(c=4) + assert res == "0040" + + res = MethodTest.OptionalParams(d=7) + assert res == "0007" + + res = MethodTest.OptionalParams(a=2, c=5) + assert res == "2050" + + res = MethodTest.OptionalParams(1, d=7, c=3) + assert res == "1037" + + res = MethodTest.OptionalParams_TestMissing() + assert res == True + + res = MethodTest.OptionalParams_TestMissing(None) + assert res == False + + res = MethodTest.OptionalParams_TestMissing(a = None) + assert res == False + + res = MethodTest.OptionalParams_TestMissing(a='hi') + assert res == False + + res = MethodTest.OptionalParams_TestReferenceType() + assert res == True + + res = MethodTest.OptionalParams_TestReferenceType(None) + assert res == True + + res = MethodTest.OptionalParams_TestReferenceType(a=None) + assert res == True + + res = MethodTest.OptionalParams_TestReferenceType('hi') + assert res == False + + res = MethodTest.OptionalParams_TestReferenceType(a='hi') + assert res == False + +def test_optional_and_default_params(): + + res = MethodTest.OptionalAndDefaultParams() + assert res == "0000" + + res = MethodTest.OptionalAndDefaultParams(1) + assert res == "1000" + + res = MethodTest.OptionalAndDefaultParams(1, c=4) + assert res == "1040" + + res = MethodTest.OptionalAndDefaultParams(b=4, c=7) + assert res == "0470" + + res = MethodTest.OptionalAndDefaultParams2() + assert res == "0012" + + res = MethodTest.OptionalAndDefaultParams2(a=1,b=2,c=3,d=4) + assert res == "1234" + + res = MethodTest.OptionalAndDefaultParams2(b=2, c=3) + assert res == "0232"