Skip to content

ENH ascii visualisation for metadata routing #31535

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 42 commits into
base: main
Choose a base branch
from

Conversation

adrinjalali
Copy link
Member

Adding visualisation for metadata routing.

Right now it gives such output:

def run_test_1():
    numeric_features = ["age", "fare"]
    numeric_transformer = Pipeline(
        steps=[
            ("imputer", SimpleImputer(strategy="median")),
            (
                "scaler",
                StandardScaler()
                .set_fit_request(sample_weight="inner_weights")
                .set_transform_request(copy=True),
            ),
        ]
    )

    categorical_features = ["embarked", "sex", "pclass"]
    categorical_transformer = Pipeline(
        steps=[
            ("encoder", OneHotEncoder(handle_unknown="ignore")),
            ("selector", SelectPercentile(chi2, percentile=50)),
        ]
    )
    preprocessor = ColumnTransformer(
        transformers=[
            ("num", numeric_transformer, numeric_features),
            ("cat", categorical_transformer, categorical_features),
        ]
    )

    # %%
    # Append classifier to preprocessing pipeline.
    # Now we have a full prediction pipeline.
    clf = Pipeline(
        steps=[
            ("preprocessor", preprocessor),
            ("classifier", LogisticRegression().set_fit_request(sample_weight=False)),
        ]
    )

    param_grid = {
        "preprocessor__num__imputer__strategy": ["mean", "median"],
        "preprocessor__cat__selector__percentile": [10, 30, 50, 70],
        "classifier__C": [0.1, 1.0, 10, 100],
    }

    scorer = get_scorer("accuracy").set_score_request(sample_weight=True)

    search_cv = RandomizedSearchCV(
        clf, param_grid, cv=GroupKFold(), scoring=scorer, random_state=0
    )

    # Get the routing information
    test = get_routing_for_object(search_cv)

    visualise_routing(test)


def run_test_2():
    est = make_pipeline(
        make_pipeline(StandardScaler().set_fit_request(sample_weight=True)),
        make_pipeline(StandardScaler().set_fit_request(sample_weight=False)),
        make_pipeline(StandardScaler()),
        make_pipeline(StandardScaler().set_fit_request(sample_weight=WARN)),
    )

    visualise_routing(get_routing_for_object(est))


def run_test_3():
    est = RandomizedSearchCV(estimator=LogisticRegression(), param_distributions={})
    visualise_routing(get_routing_for_object(est))

    est = RandomizedSearchCV(
        estimator=LogisticRegression(),
        param_distributions={},
        scoring=get_scorer("accuracy").set_score_request(sample_weight=True),
    )
    visualise_routing(get_routing_for_object(est))


if __name__ == "__main__":
    # Enable metadata routing
    set_config(enable_metadata_routing=True)
    run_test_1()
    run_test_2()
    run_test_3()
$ python sklearn/utils/tests/test_metadata_routing_visualise.py

=== METADATA ROUTING TREE ===
RandomizedSearchCV
├── estimator (Pipeline)
│   ├── preprocessor (ColumnTransformer)
│   │   ├── num (Pipeline)
│   │   │   ├── imputer (SimpleImputer(strategy='median'))
│   │   │   └── scaler (StandardScaler())
│   │   │           ➤ copy[transform✓]
│   │   │           ➤ inner_weights→sample_weight[fit↗]
│   │   ├── cat (Pipeline)
│   │   │   ├── encoder (OneHotEncoder(handle_unknown='ignore'))
│   │   │   └── selector (SelectPercentile(percentile=50, score_func=<function chi2 at 0x7f7b495dfb00>))
│   │   └── remainder (None)
│   └── classifier (LogisticRegression())
│           ➤ sample_weight[fit✗]
├── scorer (_Scorer)
│       ➤ sample_weight[score✓]
└── splitter (GroupKFold(n_splits=5, random_state=None, shuffle=False))
        ➤ groups[split✓]

Parameter summary:
fit
 ├─ ✓ copy
 │   • ✓ requested:
 │       - RandomizedSearchCV/estimator/preprocessor/num/scaler.transform
 ├─ ✓ groups
 │   • ✓ requested:
 │       - RandomizedSearchCV/splitter.split
 ├─ ✓ inner_weights
 │   • ✓ requested:
 │       - RandomizedSearchCV/estimator/preprocessor/num/scaler.fit
 ├─ ✓ sample_weight
 │   • ✓ requested:
 │       - RandomizedSearchCV/scorer.score
 │   • ✗ ignored:
 │       - RandomizedSearchCV/estimator/classifier.fit

score
 ├─ ✓ sample_weight
 │   • ✓ requested:
 │       - RandomizedSearchCV/scorer.score


=== METADATA ROUTING TREE ===
Pipeline
├── pipeline-1 (Pipeline)
│   └── standardscaler (StandardScaler())
│           ➤ copy[transform⛔,inverse_transform⛔]
│           ➤ sample_weight[fit✓]
├── pipeline-2 (Pipeline)
│   └── standardscaler (StandardScaler())
│           ➤ copy[transform⛔,inverse_transform⛔]
│           ➤ sample_weight[fit✗]
├── pipeline-3 (Pipeline)
│   └── standardscaler (StandardScaler())
│           ➤ copy[transform⛔,inverse_transform⛔]
│           ➤ sample_weight[fit⛔]
└── pipeline-4 (Pipeline)
    └── standardscaler (StandardScaler())
            ➤ copy[transform⛔,inverse_transform⛔]
            ➤ sample_weight[fit⚠]

Parameter summary:
fit
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform
 ├─ ⛔ sample_weight
 │   • ✓ requested:
 │       - Pipeline/pipeline-1/standardscaler.fit
 │   • ✗ ignored:
 │       - Pipeline/pipeline-2/standardscaler.fit
 │   • ⚠ warns:
 │       - Pipeline/pipeline-4/standardscaler.fit
 │   • ⛔ errors:
 │       - Pipeline/pipeline-3/standardscaler.fit

predict
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform

predict_proba
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform

predict_log_proba
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform

decision_function
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform

score
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform

transform
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.transform
 │       - Pipeline/pipeline-2/standardscaler.transform
 │       - Pipeline/pipeline-3/standardscaler.transform
 │       - Pipeline/pipeline-4/standardscaler.transform

inverse_transform
 ├─ ⛔ copy
 │   • ⛔ errors:
 │       - Pipeline/pipeline-1/standardscaler.inverse_transform
 │       - Pipeline/pipeline-2/standardscaler.inverse_transform
 │       - Pipeline/pipeline-3/standardscaler.inverse_transform
 │       - Pipeline/pipeline-4/standardscaler.inverse_transform


=== METADATA ROUTING TREE ===
RandomizedSearchCV
├── estimator (LogisticRegression())
│       ➤ sample_weight[fit⛔]
├── scorer (_PassthroughScorer)
│       ➤ sample_weight[score⛔]
└── splitter (None)

Parameter summary:
fit
 ├─ ⛔ sample_weight
 │   • ⛔ errors:
 │       - RandomizedSearchCV/estimator.fit
 │       - RandomizedSearchCV/scorer.score

score
 ├─ ⛔ sample_weight
 │   • ⛔ errors:
 │       - RandomizedSearchCV/scorer.score


=== METADATA ROUTING TREE ===
RandomizedSearchCV
├── estimator (LogisticRegression())
│       ➤ sample_weight[fit⛔]
├── scorer (_Scorer)
│       ➤ sample_weight[score✓]
└── splitter (None)

Parameter summary:
fit
 ├─ ⛔ sample_weight
 │   • ✓ requested:
 │       - RandomizedSearchCV/scorer.score
 │   • ⛔ errors:
 │       - RandomizedSearchCV/estimator.fit

score
 ├─ ✓ sample_weight
 │   • ✓ requested:
 │       - RandomizedSearchCV/scorer.score

Copy link

✔️ Linting Passed

All linting checks passed. Your pull request is in excellent shape! ☀️

Generated for commit: 5296b93. Link to the linter CI: here

@glemaitre
Copy link
Member

I like how it looks. I see that the symbols are defined with their mining in the parameter summary but I would be happy to get a small legend right at the start before to go into the tree and the summary. It should take a single line.

A second thoughts is about being able to filter the output: maybe I might only be interested in the tree but not the summary of vice-versa and maybe only on particular parameter. I think that the current view should be the default but having the flexibility to reduce the amount of info could be nice.

@StefanieSenger
Copy link
Contributor

StefanieSenger commented Jun 13, 2025

I've already gotten a demonstration, and had expressed by appreciation. I think that will be very useful. Now I have some suggestions:

  1. Is it possible to make that colourful? Maybe colours could even be used instead of ✓, ✗, ⛔, ⚠, it it is somehow possible to kind of force a colourful terminal output, also if users don't use shells like zsh?

  2. I am not sure about ⛔, since it is the most pointy and colourful symbol just means: implicitly not requested (set_fit_request(metadata=None)). But from the symbol it looks a bit like explicitly not requested (set_fit_request(metadata=False)).

  3. I also wonder if ✓, ✗, ⛔, ⚠ are safe to be used for any users or if they could look cryptic for some users? These are ascii and thus available for any user?

@adrinjalali
Copy link
Member Author

  • Legend: yes, it should have a small one, and a link to a page in the docs where things are better explained.
  • Filtering: yep, easy to implement, and makes sense to have.
  • Symbols: I'm happy to have alternative suggestions for symbols, and I'll try to improve what we have. The ⛔ sign is pretty nice since it's the only case where it results in an actual error. But I'm also not happy with these signs and what users understand intuitively from them.
  • Colors: This one is very tricky and I rather not have it to start with. It's relatively easy to have colors in a white background with black font color, but terminals have all sorts of color themes. Mine is green on black, others have white/bright gray on black/dark gray, some use the black on white, etc. Having a color setting for all of them which works and is easily visible to a half colorblind person like me is pretty tricky. So I usually tend to avoid colors in my own work. But I'd be happy to review such suggestions in another PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants