gh-104050: Add more annotations to `Tools/clinic.py` (#104544)

This commit is contained in:
Nikita Sobolev 2023-05-16 20:18:28 +03:00 committed by GitHub
parent 1163782868
commit a454a6651b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 59 additions and 29 deletions

View File

@ -28,7 +28,7 @@ import traceback
from collections.abc import Callable from collections.abc import Callable
from types import FunctionType, NoneType from types import FunctionType, NoneType
from typing import Any, NamedTuple from typing import Any, NamedTuple, NoReturn, Literal, overload
# TODO: # TODO:
# #
@ -59,21 +59,21 @@ CLINIC_PREFIXED_ARGS = {
} }
class Unspecified: class Unspecified:
def __repr__(self): def __repr__(self) -> str:
return '<Unspecified>' return '<Unspecified>'
unspecified = Unspecified() unspecified = Unspecified()
class Null: class Null:
def __repr__(self): def __repr__(self) -> str:
return '<Null>' return '<Null>'
NULL = Null() NULL = Null()
class Unknown: class Unknown:
def __repr__(self): def __repr__(self) -> str:
return '<Unknown>' return '<Unknown>'
unknown = Unknown() unknown = Unknown()
@ -81,15 +81,15 @@ unknown = Unknown()
sig_end_marker = '--' sig_end_marker = '--'
Appender = Callable[[str], None] Appender = Callable[[str], None]
Outputter = Callable[[None], str] Outputter = Callable[[], str]
class _TextAccumulator(NamedTuple): class _TextAccumulator(NamedTuple):
text: list[str] text: list[str]
append: Appender append: Appender
output: Outputter output: Outputter
def _text_accumulator(): def _text_accumulator() -> _TextAccumulator:
text = [] text: list[str] = []
def output(): def output():
s = ''.join(text) s = ''.join(text)
text.clear() text.clear()
@ -98,10 +98,10 @@ def _text_accumulator():
class TextAccumulator(NamedTuple): class TextAccumulator(NamedTuple):
text: list[str]
append: Appender append: Appender
output: Outputter
def text_accumulator(): def text_accumulator() -> TextAccumulator:
""" """
Creates a simple text accumulator / joiner. Creates a simple text accumulator / joiner.
@ -115,8 +115,28 @@ def text_accumulator():
text, append, output = _text_accumulator() text, append, output = _text_accumulator()
return TextAccumulator(append, output) return TextAccumulator(append, output)
@overload
def warn_or_fail(
*args: object,
fail: Literal[True],
filename: str | None = None,
line_number: int | None = None,
) -> NoReturn: ...
def warn_or_fail(fail=False, *args, filename=None, line_number=None): @overload
def warn_or_fail(
*args: object,
fail: Literal[False] = False,
filename: str | None = None,
line_number: int | None = None,
) -> None: ...
def warn_or_fail(
*args: object,
fail: bool = False,
filename: str | None = None,
line_number: int | None = None,
) -> None:
joined = " ".join([str(a) for a in args]) joined = " ".join([str(a) for a in args])
add, output = text_accumulator() add, output = text_accumulator()
if fail: if fail:
@ -139,14 +159,22 @@ def warn_or_fail(fail=False, *args, filename=None, line_number=None):
sys.exit(-1) sys.exit(-1)
def warn(*args, filename=None, line_number=None): def warn(
return warn_or_fail(False, *args, filename=filename, line_number=line_number) *args: object,
filename: str | None = None,
line_number: int | None = None,
) -> None:
return warn_or_fail(*args, filename=filename, line_number=line_number, fail=False)
def fail(*args, filename=None, line_number=None): def fail(
return warn_or_fail(True, *args, filename=filename, line_number=line_number) *args: object,
filename: str | None = None,
line_number: int | None = None,
) -> NoReturn:
warn_or_fail(*args, filename=filename, line_number=line_number, fail=True)
def quoted_for_c_string(s): def quoted_for_c_string(s: str) -> str:
for old, new in ( for old, new in (
('\\', '\\\\'), # must be first! ('\\', '\\\\'), # must be first!
('"', '\\"'), ('"', '\\"'),
@ -155,13 +183,13 @@ def quoted_for_c_string(s):
s = s.replace(old, new) s = s.replace(old, new)
return s return s
def c_repr(s): def c_repr(s: str) -> str:
return '"' + s + '"' return '"' + s + '"'
is_legal_c_identifier = re.compile('^[A-Za-z_][A-Za-z0-9_]*$').match is_legal_c_identifier = re.compile('^[A-Za-z_][A-Za-z0-9_]*$').match
def is_legal_py_identifier(s): def is_legal_py_identifier(s: str) -> bool:
return all(is_legal_c_identifier(field) for field in s.split('.')) return all(is_legal_c_identifier(field) for field in s.split('.'))
# identifiers that are okay in Python but aren't a good idea in C. # identifiers that are okay in Python but aren't a good idea in C.
@ -174,7 +202,7 @@ register return short signed sizeof static struct switch
typedef typeof union unsigned void volatile while typedef typeof union unsigned void volatile while
""".strip().split()) """.strip().split())
def ensure_legal_c_identifier(s): def ensure_legal_c_identifier(s: str) -> str:
# for now, just complain if what we're given isn't legal # for now, just complain if what we're given isn't legal
if not is_legal_c_identifier(s): if not is_legal_c_identifier(s):
fail("Illegal C identifier: {}".format(s)) fail("Illegal C identifier: {}".format(s))
@ -183,7 +211,7 @@ def ensure_legal_c_identifier(s):
return s + "_value" return s + "_value"
return s return s
def rstrip_lines(s): def rstrip_lines(s: str) -> str:
text, add, output = _text_accumulator() text, add, output = _text_accumulator()
for line in s.split('\n'): for line in s.split('\n'):
add(line.rstrip()) add(line.rstrip())
@ -191,14 +219,14 @@ def rstrip_lines(s):
text.pop() text.pop()
return output() return output()
def format_escape(s): def format_escape(s: str) -> str:
# double up curly-braces, this string will be used # double up curly-braces, this string will be used
# as part of a format_map() template later # as part of a format_map() template later
s = s.replace('{', '{{') s = s.replace('{', '{{')
s = s.replace('}', '}}') s = s.replace('}', '}}')
return s return s
def linear_format(s, **kwargs): def linear_format(s: str, **kwargs: str) -> str:
""" """
Perform str.format-like substitution, except: Perform str.format-like substitution, except:
* The strings substituted must be on lines by * The strings substituted must be on lines by
@ -242,7 +270,7 @@ def linear_format(s, **kwargs):
return output()[:-1] return output()[:-1]
def indent_all_lines(s, prefix): def indent_all_lines(s: str, prefix: str) -> str:
""" """
Returns 's', with 'prefix' prepended to all lines. Returns 's', with 'prefix' prepended to all lines.
@ -263,7 +291,7 @@ def indent_all_lines(s, prefix):
final.append(last) final.append(last)
return ''.join(final) return ''.join(final)
def suffix_all_lines(s, suffix): def suffix_all_lines(s: str, suffix: str) -> str:
""" """
Returns 's', with 'suffix' appended to all lines. Returns 's', with 'suffix' appended to all lines.
@ -283,7 +311,7 @@ def suffix_all_lines(s, suffix):
return ''.join(final) return ''.join(final)
def version_splitter(s): def version_splitter(s: str) -> tuple[int, ...]:
"""Splits a version string into a tuple of integers. """Splits a version string into a tuple of integers.
The following ASCII characters are allowed, and employ The following ASCII characters are allowed, and employ
@ -294,7 +322,7 @@ def version_splitter(s):
(This permits Python-style version strings such as "1.4b3".) (This permits Python-style version strings such as "1.4b3".)
""" """
version = [] version = []
accumulator = [] accumulator: list[str] = []
def flush(): def flush():
if not accumulator: if not accumulator:
raise ValueError('Unsupported version string: ' + repr(s)) raise ValueError('Unsupported version string: ' + repr(s))
@ -314,7 +342,7 @@ def version_splitter(s):
flush() flush()
return tuple(version) return tuple(version)
def version_comparitor(version1, version2): def version_comparitor(version1: str, version2: str) -> Literal[-1, 0, 1]:
iterator = itertools.zip_longest(version_splitter(version1), version_splitter(version2), fillvalue=0) iterator = itertools.zip_longest(version_splitter(version1), version_splitter(version2), fillvalue=0)
for i, (a, b) in enumerate(iterator): for i, (a, b) in enumerate(iterator):
if a < b: if a < b:

View File

@ -1,6 +1,7 @@
import re import re
import sys import sys
from collections.abc import Callable from collections.abc import Callable
from typing import NoReturn
TokenAndCondition = tuple[str, str] TokenAndCondition = tuple[str, str]
@ -30,7 +31,7 @@ class Monitor:
is_a_simple_defined: Callable[[str], re.Match[str] | None] is_a_simple_defined: Callable[[str], re.Match[str] | None]
is_a_simple_defined = re.compile(r'^defined\s*\(\s*[A-Za-z0-9_]+\s*\)$').match is_a_simple_defined = re.compile(r'^defined\s*\(\s*[A-Za-z0-9_]+\s*\)$').match
def __init__(self, filename=None, *, verbose: bool = False): def __init__(self, filename: str | None = None, *, verbose: bool = False) -> None:
self.stack: TokenStack = [] self.stack: TokenStack = []
self.in_comment = False self.in_comment = False
self.continuation: str | None = None self.continuation: str | None = None
@ -55,7 +56,7 @@ class Monitor:
""" """
return " && ".join(condition for token, condition in self.stack) return " && ".join(condition for token, condition in self.stack)
def fail(self, *a): def fail(self, *a: object) -> NoReturn:
if self.filename: if self.filename:
filename = " " + self.filename filename = " " + self.filename
else: else:
@ -64,7 +65,7 @@ class Monitor:
print(" ", ' '.join(str(x) for x in a)) print(" ", ' '.join(str(x) for x in a))
sys.exit(-1) sys.exit(-1)
def close(self): def close(self) -> None:
if self.stack: if self.stack:
self.fail("Ended file while still in a preprocessor conditional block!") self.fail("Ended file while still in a preprocessor conditional block!")

View File

@ -8,4 +8,5 @@ strict_concatenate = True
warn_redundant_casts = True warn_redundant_casts = True
warn_unused_ignores = True warn_unused_ignores = True
warn_unused_configs = True warn_unused_configs = True
warn_unreachable = True
files = Tools/clinic/ files = Tools/clinic/