diff options
Diffstat (limited to 'pw_presubmit/py/pw_presubmit/format_code.py')
-rwxr-xr-x | pw_presubmit/py/pw_presubmit/format_code.py | 317 |
1 files changed, 225 insertions, 92 deletions
diff --git a/pw_presubmit/py/pw_presubmit/format_code.py b/pw_presubmit/py/pw_presubmit/format_code.py index 9b073ea05..df81ad3f9 100755 --- a/pw_presubmit/py/pw_presubmit/format_code.py +++ b/pw_presubmit/py/pw_presubmit/format_code.py @@ -22,10 +22,12 @@ code. These tools must be available on the path when this script is invoked! import argparse import collections import difflib +import json import logging import os from pathlib import Path import re +import shutil import subprocess import sys import tempfile @@ -44,26 +46,33 @@ from typing import ( Union, ) -try: - import pw_presubmit -except ImportError: - # Append the pw_presubmit package path to the module search path to allow - # running this module without installing the pw_presubmit package. - sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - import pw_presubmit - import pw_cli.color import pw_cli.env -from pw_presubmit.presubmit import FileFilter -from pw_presubmit import ( - cli, +import pw_env_setup.config_file +from pw_presubmit.presubmit import ( + FileFilter, + filter_paths, +) +from pw_presubmit.presubmit_context import ( FormatContext, FormatOptions, + PresubmitContext, + PresubmitFailure, +) +from pw_presubmit import ( + cli, git_repo, owners_checks, - PresubmitContext, + presubmit_context, ) -from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural +from pw_presubmit.tools import ( + exclude_paths, + file_summary, + log_run, + plural, + colorize_diff, +) +from pw_presubmit.rst_format import reformat_rst _LOG: logging.Logger = logging.getLogger(__name__) _COLOR = pw_cli.color.colors() @@ -72,27 +81,15 @@ _DEFAULT_PATH = Path('out', 'format') _Context = Union[PresubmitContext, FormatContext] -def _colorize_diff_line(line: str) -> str: - if line.startswith('--- ') or line.startswith('+++ '): - return _COLOR.bold_white(line) - if line.startswith('-'): - return _COLOR.red(line) - if line.startswith('+'): - return _COLOR.green(line) - if line.startswith('@@ '): - return _COLOR.cyan(line) - return line - - -def colorize_diff(lines: Iterable[str]) -> str: - """Takes a diff str or list of str lines and returns a colorized version.""" - if isinstance(lines, str): - lines = lines.splitlines(True) - - return ''.join(_colorize_diff_line(line) for line in lines) +def _ensure_newline(orig: bytes) -> bytes: + if orig.endswith(b'\n'): + return orig + return orig + b'\nNo newline at end of file\n' def _diff(path, original: bytes, formatted: bytes) -> str: + original = _ensure_newline(original) + formatted = _ensure_newline(formatted) return ''.join( difflib.unified_diff( original.decode(errors='replace').splitlines(True), @@ -103,24 +100,31 @@ def _diff(path, original: bytes, formatted: bytes) -> str: ) -Formatter = Callable[[str, bytes], bytes] +FormatterT = Callable[[str, bytes], bytes] -def _diff_formatted(path, formatter: Formatter) -> Optional[str]: +def _diff_formatted( + path, formatter: FormatterT, dry_run: bool = False +) -> Optional[str]: """Returns a diff comparing a file to its formatted version.""" with open(path, 'rb') as fd: original = fd.read() formatted = formatter(path, original) + if dry_run: + return None + return None if formatted == original else _diff(path, original, formatted) -def _check_files(files, formatter: Formatter) -> Dict[Path, str]: +def _check_files( + files, formatter: FormatterT, dry_run: bool = False +) -> Dict[Path, str]: errors = {} for path in files: - difference = _diff_formatted(path, formatter) + difference = _diff_formatted(path, formatter, dry_run) if difference: errors[path] = difference @@ -138,7 +142,11 @@ def _clang_format(*args: Union[Path, str], **kwargs) -> bytes: def clang_format_check(ctx: _Context) -> Dict[Path, str]: """Checks formatting; returns {path: diff} for files with bad formatting.""" - return _check_files(ctx.paths, lambda path, _: _clang_format(path)) + return _check_files( + ctx.paths, + lambda path, _: _clang_format(path), + ctx.dry_run, + ) def clang_format_fix(ctx: _Context) -> Dict[Path, str]: @@ -147,6 +155,30 @@ def clang_format_fix(ctx: _Context) -> Dict[Path, str]: return {} +def _typescript_format(*args: Union[Path, str], **kwargs) -> bytes: + return log_run( + ['npx', 'prettier@3.0.1', *args], + stdout=subprocess.PIPE, + check=True, + **kwargs, + ).stdout + + +def typescript_format_check(ctx: _Context) -> Dict[Path, str]: + """Checks formatting; returns {path: diff} for files with bad formatting.""" + return _check_files( + ctx.paths, + lambda path, _: _typescript_format(path), + ctx.dry_run, + ) + + +def typescript_format_fix(ctx: _Context) -> Dict[Path, str]: + """Fixes formatting for the provided files in place.""" + _typescript_format('--write', *ctx.paths) + return {} + + def check_gn_format(ctx: _Context) -> Dict[Path, str]: """Checks formatting; returns {path: diff} for files with bad formatting.""" return _check_files( @@ -157,6 +189,7 @@ def check_gn_format(ctx: _Context) -> Dict[Path, str]: stdout=subprocess.PIPE, check=True, ).stdout, + ctx.dry_run, ) @@ -185,7 +218,7 @@ def check_bazel_format(ctx: _Context) -> Dict[Path, str]: errors[Path(path)] = stderr return build.read_bytes() - result = _check_files(ctx.paths, _format_temp) + result = _check_files(ctx.paths, _format_temp, ctx.dry_run) result.update(errors) return result @@ -215,6 +248,7 @@ def check_go_format(ctx: _Context) -> Dict[Path, str]: lambda path, _: log_run( ['gofmt', path], stdout=subprocess.PIPE, check=True ).stdout, + ctx.dry_run, ) @@ -224,7 +258,7 @@ def fix_go_format(ctx: _Context) -> Dict[Path, str]: return {} -# TODO(b/259595799) Remove yapf support. +# TODO: b/259595799 - Remove yapf support. def _yapf(*args, **kwargs) -> subprocess.CompletedProcess: return log_run( ['python', '-m', 'yapf', '--parallel', *args], @@ -268,6 +302,20 @@ def fix_py_format_yapf(ctx: _Context) -> Dict[Path, str]: def _enumerate_black_configs() -> Iterable[Path]: + config = pw_env_setup.config_file.load() + black_config_file = ( + config.get('pw', {}) + .get('pw_presubmit', {}) + .get('format', {}) + .get('black_config_file', {}) + ) + if black_config_file: + explicit_path = Path(black_config_file) + if not explicit_path.is_file(): + raise ValueError(f'Black config file not found: {explicit_path}') + yield explicit_path + return # If an explicit path is provided, don't try implicit paths. + if directory := os.environ.get('PW_PROJECT_ROOT'): yield Path(directory, '.black.toml') yield Path(directory, 'pyproject.toml') @@ -335,6 +383,7 @@ def check_py_format_black(ctx: _Context) -> Dict[Path, str]: result = _check_files( [x for x in ctx.paths if str(x).endswith(paths)], _format_temp, + ctx.dry_run, ) result.update(errors) return result @@ -399,6 +448,47 @@ def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]: return errors +def _format_json(contents: bytes) -> bytes: + return json.dumps(json.loads(contents), indent=2).encode() + b'\n' + + +def _json_error(exc: json.JSONDecodeError, path: Path) -> str: + return f'{path}: {exc.msg} {exc.lineno}:{exc.colno}\n' + + +def check_json_format(ctx: _Context) -> Dict[Path, str]: + errors = {} + + for path in ctx.paths: + orig = path.read_bytes() + try: + formatted = _format_json(orig) + except json.JSONDecodeError as exc: + errors[path] = _json_error(exc, path) + continue + + if orig != formatted: + errors[path] = _diff(path, orig, formatted) + + return errors + + +def fix_json_format(ctx: _Context) -> Dict[Path, str]: + errors = {} + for path in ctx.paths: + orig = path.read_bytes() + try: + formatted = _format_json(orig) + except json.JSONDecodeError as exc: + errors[path] = _json_error(exc, path) + continue + + if orig != formatted: + path.write_bytes(formatted) + + return errors + + def check_trailing_space(ctx: _Context) -> Dict[Path, str]: return _check_trailing_space(ctx.paths, fix=False) @@ -408,6 +498,24 @@ def fix_trailing_space(ctx: _Context) -> Dict[Path, str]: return {} +def rst_format_check(ctx: _Context) -> Dict[Path, str]: + errors = {} + for path in ctx.paths: + result = reformat_rst(path, diff=True, in_place=False) + if result: + errors[path] = ''.join(result) + return errors + + +def rst_format_fix(ctx: _Context) -> Dict[Path, str]: + errors = {} + for path in ctx.paths: + result = reformat_rst(path, diff=True, in_place=True) + if result: + errors[path] = ''.join(result) + return errors + + def print_format_check( errors: Dict[Path, str], show_fix_commands: bool, @@ -457,7 +565,7 @@ class CodeFormat(NamedTuple): @property def extensions(self): - # TODO(b/23842636): Switch calls of this to using 'filter' and remove. + # TODO: b/23842636 - Switch calls of this to using 'filter' and remove. return self.filter.endswith @@ -467,7 +575,7 @@ CPP_SOURCE_EXTS = frozenset( ) CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS) CPP_FILE_FILTER = FileFilter( - endswith=CPP_EXTS, exclude=(r'\.pb\.h$', r'\.pb\.c$') + endswith=CPP_EXTS, exclude=[r'\.pb\.h$', r'\.pb\.c$'] ) C_FORMAT = CodeFormat( @@ -476,102 +584,133 @@ C_FORMAT = CodeFormat( PROTO_FORMAT: CodeFormat = CodeFormat( 'Protocol buffer', - FileFilter(endswith=('.proto',)), + FileFilter(endswith=['.proto']), clang_format_check, clang_format_fix, ) JAVA_FORMAT: CodeFormat = CodeFormat( 'Java', - FileFilter(endswith=('.java',)), + FileFilter(endswith=['.java']), clang_format_check, clang_format_fix, ) JAVASCRIPT_FORMAT: CodeFormat = CodeFormat( 'JavaScript', - FileFilter(endswith=('.js',)), - clang_format_check, - clang_format_fix, + FileFilter(endswith=['.js']), + typescript_format_check, + typescript_format_fix, +) + +TYPESCRIPT_FORMAT: CodeFormat = CodeFormat( + 'TypeScript', + FileFilter(endswith=['.ts']), + typescript_format_check, + typescript_format_fix, +) + +# TODO: b/308948504 - Add real code formatting support for CSS +CSS_FORMAT: CodeFormat = CodeFormat( + 'css', + FileFilter(endswith=['.css']), + check_trailing_space, + fix_trailing_space, ) GO_FORMAT: CodeFormat = CodeFormat( - 'Go', FileFilter(endswith=('.go',)), check_go_format, fix_go_format + 'Go', FileFilter(endswith=['.go']), check_go_format, fix_go_format ) PYTHON_FORMAT: CodeFormat = CodeFormat( 'Python', - FileFilter(endswith=('.py',)), + FileFilter(endswith=['.py']), check_py_format, fix_py_format, ) GN_FORMAT: CodeFormat = CodeFormat( - 'GN', FileFilter(endswith=('.gn', '.gni')), check_gn_format, fix_gn_format + 'GN', FileFilter(endswith=['.gn', '.gni']), check_gn_format, fix_gn_format ) BAZEL_FORMAT: CodeFormat = CodeFormat( 'Bazel', - FileFilter(endswith=('BUILD', '.bazel', '.bzl'), name=('WORKSPACE')), + FileFilter(endswith=['.bazel', '.bzl'], name=['^BUILD$', '^WORKSPACE$']), check_bazel_format, fix_bazel_format, ) COPYBARA_FORMAT: CodeFormat = CodeFormat( 'Copybara', - FileFilter(endswith=('.bara.sky',)), + FileFilter(endswith=['.bara.sky']), check_bazel_format, fix_bazel_format, ) -# TODO(b/234881054): Add real code formatting support for CMake +# TODO: b/234881054 - Add real code formatting support for CMake CMAKE_FORMAT: CodeFormat = CodeFormat( 'CMake', - FileFilter(endswith=('CMakeLists.txt', '.cmake')), + FileFilter(endswith=['.cmake'], name=['^CMakeLists.txt$']), check_trailing_space, fix_trailing_space, ) RST_FORMAT: CodeFormat = CodeFormat( 'reStructuredText', - FileFilter(endswith=('.rst',)), - check_trailing_space, - fix_trailing_space, + FileFilter(endswith=['.rst']), + rst_format_check, + rst_format_fix, ) MARKDOWN_FORMAT: CodeFormat = CodeFormat( 'Markdown', - FileFilter(endswith=('.md',)), + FileFilter(endswith=['.md']), check_trailing_space, fix_trailing_space, ) OWNERS_CODE_FORMAT = CodeFormat( 'OWNERS', - filter=FileFilter(name=('OWNERS',)), + filter=FileFilter(name=['^OWNERS$']), check=check_owners_format, fix=fix_owners_format, ) -CODE_FORMATS: Tuple[CodeFormat, ...] = ( - # keep-sorted: start - BAZEL_FORMAT, - CMAKE_FORMAT, - COPYBARA_FORMAT, - C_FORMAT, - GN_FORMAT, - GO_FORMAT, - JAVASCRIPT_FORMAT, - JAVA_FORMAT, - MARKDOWN_FORMAT, - OWNERS_CODE_FORMAT, - PROTO_FORMAT, - PYTHON_FORMAT, - RST_FORMAT, - # keep-sorted: end +JSON_FORMAT: CodeFormat = CodeFormat( + 'JSON', + FileFilter(endswith=['.json']), + check=check_json_format, + fix=fix_json_format, ) -# TODO(b/264578594) Remove these lines when these globals aren't referenced. +CODE_FORMATS: Tuple[CodeFormat, ...] = tuple( + filter( + None, + ( + # keep-sorted: start + BAZEL_FORMAT, + CMAKE_FORMAT, + COPYBARA_FORMAT, + CSS_FORMAT, + C_FORMAT, + GN_FORMAT, + GO_FORMAT, + JAVASCRIPT_FORMAT if shutil.which('npx') else None, + JAVA_FORMAT, + JSON_FORMAT, + MARKDOWN_FORMAT, + OWNERS_CODE_FORMAT, + PROTO_FORMAT, + PYTHON_FORMAT, + RST_FORMAT, + TYPESCRIPT_FORMAT if shutil.which('npx') else None, + # keep-sorted: end + ), + ) +) + + +# TODO: b/264578594 - Remove these lines when these globals aren't referenced. CODE_FORMATS_WITH_BLACK: Tuple[CodeFormat, ...] = CODE_FORMATS CODE_FORMATS_WITH_YAPF: Tuple[CodeFormat, ...] = CODE_FORMATS @@ -591,8 +730,9 @@ def presubmit_check( file_filter = FileFilter(**vars(code_format.filter)) file_filter.exclude += tuple(re.compile(e) for e in exclude) - @pw_presubmit.filter_paths(file_filter=file_filter) - def check_code_format(ctx: pw_presubmit.PresubmitContext): + @filter_paths(file_filter=file_filter) + def check_code_format(ctx: PresubmitContext): + ctx.paths = presubmit_context.apply_exclusions(ctx) errors = code_format.check(ctx) print_format_check( errors, @@ -610,7 +750,7 @@ def presubmit_check( file=outs, ) - raise pw_presubmit.PresubmitFailure + raise PresubmitFailure language = code_format.language.lower().replace('+', 'p').replace(' ', '_') check_code_format.name = f'{language}_format' @@ -646,10 +786,16 @@ class CodeFormatter: package_root: Optional[Path] = None, ): self.root = root - self.paths = list(files) self._formats: Dict[CodeFormat, List] = collections.defaultdict(list) self.root_output_dir = output_dir self.package_root = package_root or output_dir / 'packages' + self._format_options = FormatOptions.load() + raw_paths = files + self.paths: Tuple[Path, ...] = self._format_options.filter_paths(files) + + filtered_paths = set(raw_paths) - set(self.paths) + for path in sorted(filtered_paths): + _LOG.debug('filtered out %s', path) for path in self.paths: for code_format in code_formats: @@ -671,7 +817,7 @@ class CodeFormatter: output_dir=outdir, paths=tuple(self._formats[code_format]), package_root=self.package_root, - format_options=FormatOptions.load(), + format_options=self._format_options, ) def check(self) -> Dict[Path, str]: @@ -882,19 +1028,6 @@ def main() -> int: return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args())) -def _pigweed_upstream_main() -> int: - """Check and fix formatting for source files in upstream Pigweed. - - Excludes third party sources. - """ - args = arguments(git_paths=True).parse_args() - - # Exclude paths with third party code from formatting. - args.exclude.append(re.compile('^third_party/fuchsia/repo/')) - - return format_paths_in_repo(**vars(args)) - - if __name__ == '__main__': try: # If pw_cli is available, use it to initialize logs. |