@@ -318,6 +318,14 @@ def get_tags(estimator) -> Tags:
318
318
319
319
if hasattr (estimator , "__sklearn_tags__" ):
320
320
tags = estimator .__sklearn_tags__ ()
321
+ elif hasattr (estimator , "_get_tags" ):
322
+ warnings .warn ("BROKEN SOON, IT WILL BE" , FutureWarning )
323
+ tags = _to_new_tags (estimator ._get_tags ())
324
+ elif hasattr (estimator , "_more_tags" ):
325
+ warnings .warn ("BROKEN SOON, IT WILL BE" , FutureWarning )
326
+ tags = _to_old_tags (default_tags (estimator ))
327
+ tags = {** tags , ** estimator ._more_tags ()}
328
+ tags = _to_new_tags (tags )
321
329
else :
322
330
warnings .warn (
323
331
f"Estimator { estimator } has no __sklearn_tags__ attribute, which is "
@@ -332,3 +340,148 @@ def get_tags(estimator) -> Tags:
332
340
tags = default_tags (estimator )
333
341
334
342
return tags
343
+
344
+
345
+ def _to_new_tags (old_tags , estimator_type = None ):
346
+ """Utility function convert old tags (dictionary) to new tags (dataclass)."""
347
+ input_tags = InputTags (
348
+ one_d_array = "1darray" in old_tags ["X_types" ],
349
+ two_d_array = "2darray" in old_tags ["X_types" ],
350
+ three_d_array = "3darray" in old_tags ["X_types" ],
351
+ sparse = "sparse" in old_tags ["X_types" ],
352
+ categorical = "categorical" in old_tags ["X_types" ],
353
+ string = "string" in old_tags ["X_types" ],
354
+ dict = "dict" in old_tags ["X_types" ],
355
+ positive_only = old_tags ["requires_positive_X" ],
356
+ allow_nan = old_tags ["allow_nan" ],
357
+ pairwise = old_tags ["pairwise" ],
358
+ )
359
+ target_tags = TargetTags (
360
+ required = old_tags ["requires_y" ],
361
+ one_d_labels = "1dlabels" in old_tags ["X_types" ],
362
+ two_d_labels = "2dlabels" in old_tags ["X_types" ],
363
+ positive_only = old_tags ["requires_positive_y" ],
364
+ multi_output = old_tags ["multioutput" ] or old_tags ["multioutput_only" ],
365
+ single_output = not old_tags ["multioutput_only" ],
366
+ )
367
+ transformer_tags = TransformerTags (
368
+ preserves_dtype = old_tags ["preserves_dtype" ],
369
+ )
370
+ classifier_tags = ClassifierTags (
371
+ poor_score = old_tags ["poor_score" ],
372
+ multi_class = not old_tags ["binary_only" ],
373
+ multi_label = old_tags ["multilabel" ],
374
+ )
375
+ regressor_tags = RegressorTags (
376
+ poor_score = old_tags ["poor_score" ],
377
+ multi_label = old_tags ["multilabel" ],
378
+ )
379
+ return Tags (
380
+ estimator_type = estimator_type ,
381
+ target_tags = target_tags ,
382
+ transformer_tags = transformer_tags ,
383
+ classifier_tags = classifier_tags ,
384
+ regressor_tags = regressor_tags ,
385
+ input_tags = input_tags ,
386
+ array_api_support = old_tags ["array_api_support" ],
387
+ no_validation = old_tags ["no_validation" ],
388
+ non_deterministic = old_tags ["non_deterministic" ],
389
+ requires_fit = old_tags ["requires_fit" ],
390
+ _skip_test = old_tags ["_skip_test" ],
391
+ )
392
+
393
+
394
+ def _to_old_tags (new_tags ):
395
+ """Utility function convert old tags (dictionary) to new tags (dataclass)."""
396
+ if new_tags .classifier_tags :
397
+ binary_only = not new_tags .classifier_tags .multi_class
398
+ multilabel_clf = new_tags .classifier_tags .multi_label
399
+ poor_score_clf = new_tags .classifier_tags .poor_score
400
+ else :
401
+ binary_only = False
402
+ multilabel_clf = False
403
+ poor_score_clf = False
404
+
405
+ if new_tags .regressor_tags :
406
+ multilabel_reg = new_tags .regressor_tags .multi_label
407
+ poor_score_reg = new_tags .regressor_tags .poor_score
408
+ else :
409
+ multilabel_reg = False
410
+ poor_score_reg = False
411
+
412
+ if new_tags .transformer_tags :
413
+ preserves_dtype = new_tags .transformer_tags .preserves_dtype
414
+ else :
415
+ preserves_dtype = ["float64" ]
416
+
417
+ tags = {
418
+ "allow_nan" : new_tags .input_tags .allow_nan ,
419
+ "array_api_support" : new_tags .array_api_support ,
420
+ "binary_only" : binary_only ,
421
+ "multilabel" : multilabel_clf or multilabel_reg ,
422
+ "multioutput" : new_tags .target_tags .multi_output ,
423
+ "multioutput_only" : (
424
+ not new_tags .target_tags .single_output
425
+ and new_tags .target_tags .multi_output
426
+ ),
427
+ "no_validation" : new_tags .no_validation ,
428
+ "non_deterministic" : new_tags .non_deterministic ,
429
+ "pairwise" : new_tags .input_tags .pairwise ,
430
+ "preserves_dtype" : preserves_dtype ,
431
+ "poor_score" : poor_score_clf or poor_score_reg ,
432
+ "requires_fit" : new_tags .requires_fit ,
433
+ "requires_positive_X" : new_tags .input_tags .positive_only ,
434
+ "requires_y" : new_tags .target_tags .required ,
435
+ "requires_positive_y" : new_tags .target_tags .positive_only ,
436
+ "_skip_test" : new_tags ._skip_test ,
437
+ "stateless" : new_tags .requires_fit ,
438
+ }
439
+ X_types = []
440
+ if new_tags .input_tags .one_d_array :
441
+ X_types .append ("1darray" )
442
+ if new_tags .input_tags .two_d_array :
443
+ X_types .append ("2darray" )
444
+ if new_tags .input_tags .three_d_array :
445
+ X_types .append ("3darray" )
446
+ if new_tags .input_tags .sparse :
447
+ X_types .append ("sparse" )
448
+ if new_tags .input_tags .categorical :
449
+ X_types .append ("categorical" )
450
+ if new_tags .input_tags .string :
451
+ X_types .append ("string" )
452
+ if new_tags .input_tags .dict :
453
+ X_types .append ("dict" )
454
+ if new_tags .target_tags .one_d_labels :
455
+ X_types .append ("1dlabels" )
456
+ if new_tags .target_tags .two_d_labels :
457
+ X_types .append ("2dlabels" )
458
+ tags ["X_types" ] = X_types
459
+ return tags
460
+
461
+
462
+ def _safe_tags (estimator , key = None ):
463
+ warnings .warn (
464
+ "The `_safe_tags` utility function is deprecated in 1.6 and will be removed in "
465
+ "1.7. Use the public `get_tags` function instead and make sure to implement "
466
+ "the `__sklearn_tags__` method." ,
467
+ category = FutureWarning ,
468
+ )
469
+ if hasattr (estimator , "_get_tags" ):
470
+ tags_provider = "_get_tags()"
471
+ tags = estimator ._get_tags ()
472
+ elif hasattr (estimator , "_more_tags" ):
473
+ tags_provider = "_more_tags()"
474
+ tags = _to_old_tags (default_tags (estimator ))
475
+ tags = {** tags , ** estimator ._more_tags ()}
476
+ else :
477
+ tags_provider = "_DEFAULT_TAGS"
478
+ tags = _to_old_tags (default_tags (estimator ))
479
+
480
+ if key is not None :
481
+ if key not in tags :
482
+ raise ValueError (
483
+ f"The key { key } is not defined in { tags_provider } for the "
484
+ f"class { estimator .__class__ .__name__ } ."
485
+ )
486
+ return tags [key ]
487
+ return tags
0 commit comments