@@ -17,9 +17,9 @@ class Ruler
1717"""
1818from __future__ import annotations
1919
20- from collections .abc import Callable , Iterable
20+ from collections .abc import Iterable
2121from dataclasses import dataclass , field
22- from typing import TYPE_CHECKING , TypedDict
22+ from typing import TYPE_CHECKING , Generic , TypedDict , TypeVar
2323import warnings
2424
2525from markdown_it ._compat import DATACLASS_KWARGS
@@ -57,33 +57,30 @@ def srcCharCode(self) -> tuple[int, ...]:
5757 return self ._srcCharCode
5858
5959
60- # The first positional arg is always a subtype of `StateBase`. Other
61- # arguments may or may not exist, based on the rule's type (block,
62- # core, inline). Return type is either `None` or `bool` based on the
63- # rule's type.
64- RuleFunc = Callable # type: ignore
65-
66-
6760class RuleOptionsType (TypedDict , total = False ):
6861 alt : list [str ]
6962
7063
64+ RuleFuncTv = TypeVar ("RuleFuncTv" )
65+ """A rule function, whose signature is dependent on the state type."""
66+
67+
7168@dataclass (** DATACLASS_KWARGS )
72- class Rule :
69+ class Rule ( Generic [ RuleFuncTv ]) :
7370 name : str
7471 enabled : bool
75- fn : RuleFunc = field (repr = False )
72+ fn : RuleFuncTv = field (repr = False )
7673 alt : list [str ]
7774
7875
79- class Ruler :
76+ class Ruler ( Generic [ RuleFuncTv ]) :
8077 def __init__ (self ) -> None :
8178 # List of added rules.
82- self .__rules__ : list [Rule ] = []
79+ self .__rules__ : list [Rule [ RuleFuncTv ] ] = []
8380 # Cached rule chains.
8481 # First level - chain name, '' for default.
8582 # Second level - diginal anchor for fast filtering by charcodes.
86- self .__cache__ : dict [str , list [RuleFunc ]] | None = None
83+ self .__cache__ : dict [str , list [RuleFuncTv ]] | None = None
8784
8885 def __find__ (self , name : str ) -> int :
8986 """Find rule index by name"""
@@ -112,7 +109,7 @@ def __compile__(self) -> None:
112109 self .__cache__ [chain ].append (rule .fn )
113110
114111 def at (
115- self , ruleName : str , fn : RuleFunc , options : RuleOptionsType | None = None
112+ self , ruleName : str , fn : RuleFuncTv , options : RuleOptionsType | None = None
116113 ) -> None :
117114 """Replace rule by name with new function & options.
118115
@@ -133,7 +130,7 @@ def before(
133130 self ,
134131 beforeName : str ,
135132 ruleName : str ,
136- fn : RuleFunc ,
133+ fn : RuleFuncTv ,
137134 options : RuleOptionsType | None = None ,
138135 ) -> None :
139136 """Add new rule to chain before one with given name.
@@ -148,14 +145,16 @@ def before(
148145 options = options or {}
149146 if index == - 1 :
150147 raise KeyError (f"Parser rule not found: { beforeName } " )
151- self .__rules__ .insert (index , Rule (ruleName , True , fn , options .get ("alt" , [])))
148+ self .__rules__ .insert (
149+ index , Rule [RuleFuncTv ](ruleName , True , fn , options .get ("alt" , []))
150+ )
152151 self .__cache__ = None
153152
154153 def after (
155154 self ,
156155 afterName : str ,
157156 ruleName : str ,
158- fn : RuleFunc ,
157+ fn : RuleFuncTv ,
159158 options : RuleOptionsType | None = None ,
160159 ) -> None :
161160 """Add new rule to chain after one with given name.
@@ -171,12 +170,12 @@ def after(
171170 if index == - 1 :
172171 raise KeyError (f"Parser rule not found: { afterName } " )
173172 self .__rules__ .insert (
174- index + 1 , Rule (ruleName , True , fn , options .get ("alt" , []))
173+ index + 1 , Rule [ RuleFuncTv ] (ruleName , True , fn , options .get ("alt" , []))
175174 )
176175 self .__cache__ = None
177176
178177 def push (
179- self , ruleName : str , fn : RuleFunc , options : RuleOptionsType | None = None
178+ self , ruleName : str , fn : RuleFuncTv , options : RuleOptionsType | None = None
180179 ) -> None :
181180 """Push new rule to the end of chain.
182181
@@ -185,7 +184,9 @@ def push(
185184 :param options: new rule options (not mandatory).
186185
187186 """
188- self .__rules__ .append (Rule (ruleName , True , fn , (options or {}).get ("alt" , [])))
187+ self .__rules__ .append (
188+ Rule [RuleFuncTv ](ruleName , True , fn , (options or {}).get ("alt" , []))
189+ )
189190 self .__cache__ = None
190191
191192 def enable (
@@ -252,7 +253,7 @@ def disable(
252253 self .__cache__ = None
253254 return result
254255
255- def getRules (self , chainName : str ) -> list [RuleFunc ]:
256+ def getRules (self , chainName : str = "" ) -> list [RuleFuncTv ]:
256257 """Return array of active functions (rules) for given chain name.
257258 It analyzes rules configuration, compiles caches if not exists and returns result.
258259
0 commit comments