-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
Migrate WeightVector to use tempita #13358
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Migrate WeightVector to use tempita #13358
Conversation
scikit-learn/sklearn/utils/weight_vector.pxd Lines 24 to 25 in 984871b
EDIT: number_iter comes from |
@@ -22,7 +22,7 @@ from numpy.math cimport INFINITY | |||
cdef extern from "sgd_fast_helpers.h": | |||
bint skl_isfinite(double) nogil | |||
|
|||
from sklearn.utils.weight_vector cimport WeightVector | |||
from sklearn.utils.weight_vector cimport WeightVector64 as WeightVector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should not be here but in the template. Maybe we have some sort of clash. I need to get back to that
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless you also template sgd_fast
, you'll have to deal with both, and switch with
if floating is float:
do something with WeightVector64
else:
do something with WeightVector32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this diff don't belong to this PR
ping: @NicolasHug can you review aswell?? Thx |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -22,7 +22,7 @@ from numpy.math cimport INFINITY | |||
cdef extern from "sgd_fast_helpers.h": | |||
bint skl_isfinite(double) nogil | |||
|
|||
from sklearn.utils.weight_vector cimport WeightVector | |||
from sklearn.utils.weight_vector cimport WeightVector64 as WeightVector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unless you also template sgd_fast
, you'll have to deal with both, and switch with
if floating is float:
do something with WeightVector64
else:
do something with WeightVector32
@@ -22,7 +22,7 @@ from numpy.math cimport INFINITY | |||
cdef extern from "sgd_fast_helpers.h": | |||
bint skl_isfinite(double) nogil | |||
|
|||
from sklearn.utils.weight_vector cimport WeightVector | |||
from sklearn.utils.weight_vector cimport WeightVector64 as WeightVector |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also this diff don't belong to this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure about the import but this lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpicks + questions.
Also:
- is
WeightVector32
(going to be) used somewhere else? - it'd be nice to have tests ensuring that the
.pyx
and.pxd
files that are generated by tempita are git-ignored. Or at least that we can't modify them accidentally. As @ogrisel noted these tests should also pass / be ignored on pre-compiled wheels where cython sources aren't available.
"""The L2 norm of the weight vector. """ | ||
return sqrt(self.sq_norm) | ||
|
||
{{endfor}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing newline
cdef void reset_wscale(self) nogil | ||
cdef {{c_type}} norm(self) nogil | ||
|
||
{{endfor}} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing newline
dtypes = [('64', 'double'), | ||
('32', 'float')] | ||
|
||
def get_dispatch(dtypes): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry not familiar with tempita (and the tempita doc links are broken): why do you need get_dispatch? You can't simply iterate over the dtypes
list?
@massich Could you solve the resolve the conflict and the address the comments |
Hi @massich, I saw that your PR has been stalled. Are you still interested in continuing this work? |
WeightVector is used in #13346 and has attributes that cannot be fused.
scikit-learn/sklearn/utils/weight_vector.pxd
Lines 12 to 20 in 984871b
This PR uses Tempita to allow float32 float64.
cross ref: #11000