Skip to content

Commit 4206d14

Browse files
MNT little refactor and doc improvement for metadata routing consumes() methods (scikit-learn#31703)
Co-authored-by: Lucy Liu <[email protected]>
1 parent 46f5423 commit 4206d14

File tree

2 files changed

+38
-26
lines changed

2 files changed

+38
-26
lines changed

sklearn/tests/test_metadata_routing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -638,7 +638,7 @@ class Consumer(BaseEstimator):
638638
@config_context(enable_metadata_routing=True)
639639
def test_metadata_request_consumes_method():
640640
"""Test that MetadataRequest().consumes() method works as expected."""
641-
request = MetadataRouter(owner="test")
641+
request = MetadataRequest(owner="test")
642642
assert request.consumes(method="fit", params={"foo"}) == set()
643643

644644
request = MetadataRequest(owner="test")

sklearn/utils/_metadata_requests.py

Lines changed: 37 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -501,26 +501,26 @@ def _route_params(self, params, parent, caller):
501501
return res
502502

503503
def _consumes(self, params):
504-
"""Check whether the given metadata are consumed by this method.
504+
"""Return subset of `params` consumed by the method that owns this instance.
505505
506506
Parameters
507507
----------
508508
params : iterable of str
509-
An iterable of parameters to check.
509+
An iterable of parameter names to test for consumption.
510510
511511
Returns
512512
-------
513-
consumed : set of str
514-
A set of parameters which are consumed by this method.
513+
consumed_params : set of str
514+
A subset of parameters from `params` which are consumed by this method.
515515
"""
516516
params = set(params)
517-
res = set()
518-
for prop, alias in self._requests.items():
519-
if alias is True and prop in params:
520-
res.add(prop)
517+
consumed_params = set()
518+
for metadata_name, alias in self._requests.items():
519+
if alias is True and metadata_name in params:
520+
consumed_params.add(metadata_name)
521521
elif isinstance(alias, str) and alias in params:
522-
res.add(alias)
523-
return res
522+
consumed_params.add(alias)
523+
return consumed_params
524524

525525
def _serialize(self):
526526
"""Serialize the object.
@@ -571,22 +571,27 @@ def __init__(self, owner):
571571
)
572572

573573
def consumes(self, method, params):
574-
"""Check whether the given metadata are consumed by the given method.
574+
"""Return params consumed as metadata in a :term:`consumer`.
575+
576+
This method returns the subset of given `params` that are consumed by the
577+
given `method`. It can be used to check if parameters are used as metadata in
578+
the specified method of the :term:`consumer` that owns this `MetadataRequest`
579+
instance.
575580
576581
.. versionadded:: 1.4
577582
578583
Parameters
579584
----------
580585
method : str
581-
The name of the method to check.
586+
The name of the method for which to determine consumed parameters.
582587
583588
params : iterable of str
584-
An iterable of parameters to check.
589+
An iterable of parameter names to test for consumption.
585590
586591
Returns
587592
-------
588-
consumed : set of str
589-
A set of parameters which are consumed by the given method.
593+
consumed_params : set of str
594+
A subset of parameters from `params` which are consumed by the given method.
590595
"""
591596
return getattr(self, method)._consumes(params=params)
592597

@@ -900,35 +905,42 @@ def add(self, *, method_mapping, **objs):
900905
return self
901906

902907
def consumes(self, method, params):
903-
"""Check whether the given metadata is consumed by the given method.
908+
"""Return params consumed as metadata in a :term:`router` or its sub-estimators.
909+
910+
This method returns the subset of `params` that are consumed by the
911+
`method`. A `param` is considered consumed if it is used in the specified
912+
method of the :term:`router` itself or any of its sub-estimators (or their
913+
sub-estimators).
904914
905915
.. versionadded:: 1.4
906916
907917
Parameters
908918
----------
909919
method : str
910-
The name of the method to check.
920+
The name of the method for which to determine consumed parameters.
911921
912922
params : iterable of str
913-
An iterable of parameters to check.
923+
An iterable of parameter names to test for consumption.
914924
915925
Returns
916926
-------
917-
consumed : set of str
918-
A set of parameters which are consumed by the given method.
927+
consumed_params : set of str
928+
A subset of parameters from `params` which are consumed by this method.
919929
"""
920-
res = set()
930+
consumed_params = set()
921931
if self._self_request:
922-
res = res | self._self_request.consumes(method=method, params=params)
932+
consumed_params.update(
933+
self._self_request.consumes(method=method, params=params)
934+
)
923935

924936
for _, route_mapping in self._route_mappings.items():
925937
for caller, callee in route_mapping.mapping:
926938
if caller == method:
927-
res = res | route_mapping.router.consumes(
928-
method=callee, params=params
939+
consumed_params.update(
940+
route_mapping.router.consumes(method=callee, params=params)
929941
)
930942

931-
return res
943+
return consumed_params
932944

933945
def _get_param_names(self, *, method, return_alias, ignore_self_request):
934946
"""Get names of all metadata that can be consumed or routed by specified \

0 commit comments

Comments
 (0)