-
-
Notifications
You must be signed in to change notification settings - Fork 25.8k
[WIP] Implement general naive Bayes #16281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ikit-learn into general-naive-bayes
…ikit-learn into general-naive-bayes
Thanks @remykarem. You have a linter failing. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please see the linter build log for a list of errors. Do you need help resolving them?
|
||
# Subtract the class log prior from all the jlls | ||
# but add it back after the summation | ||
jlls = jlls - log_prior |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
log_prior is among a handful of variables you've used without definition.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @jnothman I'm still working on this (switched this PR back to WIP). Some refactoring needed because I'm trying to fit in the remainder
API.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately this kind of invalid code causes the linter to fail too...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay will fix this.
# convert to feature if callable | ||
self._cols = [] | ||
dict_col2model = {} | ||
if callable(cols): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cols is not defined here.
Yes please! I have been trying to figure this out. Sorry for not reaching out to you earlier. |
@remykarem have you abandoned this project or has it stalled for some other reason? Would you like someone to take over or chip in? I drafted my own wrapper for Naive Bayes a month ago and was thinking about contributing it, but now I discovered your work, which seems almost complete. |
@avm19 Sorry, I got busy after a while and didn't manage to complete this. I think it would be great for someone to take over this project :) |
take |
Reference Issues/PRs
Fixes #15077. See also #10856.
What does this implement/fix? Explain your changes.
This implements general naive Bayes (
GeneralNB
) in addition to the existing naive Bayes implementations likeGaussianNB
andBernoulliNB
.This implementation allows multiple assumptions on the features, namely the Bernoulli, Gaussian, Multinomial, and Categorical distributions. In the API, the user will be able to specify these distributions and their respective features.
I have divided this description into 3 sections as below:
1. Design and usage
The design of the API is similar to that of
ColumnTransformer
andPipeline
. To specify that columns 0-2 and 3-4 are to be modelled with Gaussian and categorical naive Bayes respectively, indicate these in theGeneralNB
constructor and fit accordingly:It also accepts a list of strings of column names if the data to be fitted are pandas DataFrames:
Lastly, similar to
ColumnTransformer
, it also accepts callables likemake_column_selector
to specify DataFrame columns:The attributes of the fitted estimators can be accessed using the
self.named_models_
attribute. For example, to access thetheta_
parameter of thebernoulli
model,2. Under the hood
For the
GeneralNB.predict()
function, we sum the_joint_log_likelihood()
for each naive Bayes estimator, then subtract (n-1) log P(c) from this sum. Here is a pseudocode:3. Runtime checks
Check
self.models
:X
.Checks on parameter consistency across naive Bayes estimators are performed to ensure that specific parameters across the estimators stay the same. Otherwise, the calculation of the joint log likelihood will be wrong. Such parameters are:
class_prior
* orpriors
^fit_prior
*class_log_prior
* orclass_prior
^*used in BernoulliNB, MultinomialNB, ComplementNB, CategoricalNB
^used in GaussianNB
Data checks:
Methods like
_check_X_y()
and_check_X()
check if the data type used during fitting is the same during prediction (i.e. NumPy array and pandas DataFrame).4. Others
Partial fitting is not supported.
Currently, it is not okay to leave some columns out.
Progress:
Any other comments?
PR submitted.