aboutsummaryrefslogtreecommitdiff
path: root/pw_presubmit/py/pw_presubmit/format_code.py
diff options
context:
space:
mode:
Diffstat (limited to 'pw_presubmit/py/pw_presubmit/format_code.py')
-rwxr-xr-xpw_presubmit/py/pw_presubmit/format_code.py317
1 files changed, 225 insertions, 92 deletions
diff --git a/pw_presubmit/py/pw_presubmit/format_code.py b/pw_presubmit/py/pw_presubmit/format_code.py
index 9b073ea05..df81ad3f9 100755
--- a/pw_presubmit/py/pw_presubmit/format_code.py
+++ b/pw_presubmit/py/pw_presubmit/format_code.py
@@ -22,10 +22,12 @@ code. These tools must be available on the path when this script is invoked!
import argparse
import collections
import difflib
+import json
import logging
import os
from pathlib import Path
import re
+import shutil
import subprocess
import sys
import tempfile
@@ -44,26 +46,33 @@ from typing import (
Union,
)
-try:
- import pw_presubmit
-except ImportError:
- # Append the pw_presubmit package path to the module search path to allow
- # running this module without installing the pw_presubmit package.
- sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
- import pw_presubmit
-
import pw_cli.color
import pw_cli.env
-from pw_presubmit.presubmit import FileFilter
-from pw_presubmit import (
- cli,
+import pw_env_setup.config_file
+from pw_presubmit.presubmit import (
+ FileFilter,
+ filter_paths,
+)
+from pw_presubmit.presubmit_context import (
FormatContext,
FormatOptions,
+ PresubmitContext,
+ PresubmitFailure,
+)
+from pw_presubmit import (
+ cli,
git_repo,
owners_checks,
- PresubmitContext,
+ presubmit_context,
)
-from pw_presubmit.tools import exclude_paths, file_summary, log_run, plural
+from pw_presubmit.tools import (
+ exclude_paths,
+ file_summary,
+ log_run,
+ plural,
+ colorize_diff,
+)
+from pw_presubmit.rst_format import reformat_rst
_LOG: logging.Logger = logging.getLogger(__name__)
_COLOR = pw_cli.color.colors()
@@ -72,27 +81,15 @@ _DEFAULT_PATH = Path('out', 'format')
_Context = Union[PresubmitContext, FormatContext]
-def _colorize_diff_line(line: str) -> str:
- if line.startswith('--- ') or line.startswith('+++ '):
- return _COLOR.bold_white(line)
- if line.startswith('-'):
- return _COLOR.red(line)
- if line.startswith('+'):
- return _COLOR.green(line)
- if line.startswith('@@ '):
- return _COLOR.cyan(line)
- return line
-
-
-def colorize_diff(lines: Iterable[str]) -> str:
- """Takes a diff str or list of str lines and returns a colorized version."""
- if isinstance(lines, str):
- lines = lines.splitlines(True)
-
- return ''.join(_colorize_diff_line(line) for line in lines)
+def _ensure_newline(orig: bytes) -> bytes:
+ if orig.endswith(b'\n'):
+ return orig
+ return orig + b'\nNo newline at end of file\n'
def _diff(path, original: bytes, formatted: bytes) -> str:
+ original = _ensure_newline(original)
+ formatted = _ensure_newline(formatted)
return ''.join(
difflib.unified_diff(
original.decode(errors='replace').splitlines(True),
@@ -103,24 +100,31 @@ def _diff(path, original: bytes, formatted: bytes) -> str:
)
-Formatter = Callable[[str, bytes], bytes]
+FormatterT = Callable[[str, bytes], bytes]
-def _diff_formatted(path, formatter: Formatter) -> Optional[str]:
+def _diff_formatted(
+ path, formatter: FormatterT, dry_run: bool = False
+) -> Optional[str]:
"""Returns a diff comparing a file to its formatted version."""
with open(path, 'rb') as fd:
original = fd.read()
formatted = formatter(path, original)
+ if dry_run:
+ return None
+
return None if formatted == original else _diff(path, original, formatted)
-def _check_files(files, formatter: Formatter) -> Dict[Path, str]:
+def _check_files(
+ files, formatter: FormatterT, dry_run: bool = False
+) -> Dict[Path, str]:
errors = {}
for path in files:
- difference = _diff_formatted(path, formatter)
+ difference = _diff_formatted(path, formatter, dry_run)
if difference:
errors[path] = difference
@@ -138,7 +142,11 @@ def _clang_format(*args: Union[Path, str], **kwargs) -> bytes:
def clang_format_check(ctx: _Context) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
- return _check_files(ctx.paths, lambda path, _: _clang_format(path))
+ return _check_files(
+ ctx.paths,
+ lambda path, _: _clang_format(path),
+ ctx.dry_run,
+ )
def clang_format_fix(ctx: _Context) -> Dict[Path, str]:
@@ -147,6 +155,30 @@ def clang_format_fix(ctx: _Context) -> Dict[Path, str]:
return {}
+def _typescript_format(*args: Union[Path, str], **kwargs) -> bytes:
+ return log_run(
+ ['npx', 'prettier@3.0.1', *args],
+ stdout=subprocess.PIPE,
+ check=True,
+ **kwargs,
+ ).stdout
+
+
+def typescript_format_check(ctx: _Context) -> Dict[Path, str]:
+ """Checks formatting; returns {path: diff} for files with bad formatting."""
+ return _check_files(
+ ctx.paths,
+ lambda path, _: _typescript_format(path),
+ ctx.dry_run,
+ )
+
+
+def typescript_format_fix(ctx: _Context) -> Dict[Path, str]:
+ """Fixes formatting for the provided files in place."""
+ _typescript_format('--write', *ctx.paths)
+ return {}
+
+
def check_gn_format(ctx: _Context) -> Dict[Path, str]:
"""Checks formatting; returns {path: diff} for files with bad formatting."""
return _check_files(
@@ -157,6 +189,7 @@ def check_gn_format(ctx: _Context) -> Dict[Path, str]:
stdout=subprocess.PIPE,
check=True,
).stdout,
+ ctx.dry_run,
)
@@ -185,7 +218,7 @@ def check_bazel_format(ctx: _Context) -> Dict[Path, str]:
errors[Path(path)] = stderr
return build.read_bytes()
- result = _check_files(ctx.paths, _format_temp)
+ result = _check_files(ctx.paths, _format_temp, ctx.dry_run)
result.update(errors)
return result
@@ -215,6 +248,7 @@ def check_go_format(ctx: _Context) -> Dict[Path, str]:
lambda path, _: log_run(
['gofmt', path], stdout=subprocess.PIPE, check=True
).stdout,
+ ctx.dry_run,
)
@@ -224,7 +258,7 @@ def fix_go_format(ctx: _Context) -> Dict[Path, str]:
return {}
-# TODO(b/259595799) Remove yapf support.
+# TODO: b/259595799 - Remove yapf support.
def _yapf(*args, **kwargs) -> subprocess.CompletedProcess:
return log_run(
['python', '-m', 'yapf', '--parallel', *args],
@@ -268,6 +302,20 @@ def fix_py_format_yapf(ctx: _Context) -> Dict[Path, str]:
def _enumerate_black_configs() -> Iterable[Path]:
+ config = pw_env_setup.config_file.load()
+ black_config_file = (
+ config.get('pw', {})
+ .get('pw_presubmit', {})
+ .get('format', {})
+ .get('black_config_file', {})
+ )
+ if black_config_file:
+ explicit_path = Path(black_config_file)
+ if not explicit_path.is_file():
+ raise ValueError(f'Black config file not found: {explicit_path}')
+ yield explicit_path
+ return # If an explicit path is provided, don't try implicit paths.
+
if directory := os.environ.get('PW_PROJECT_ROOT'):
yield Path(directory, '.black.toml')
yield Path(directory, 'pyproject.toml')
@@ -335,6 +383,7 @@ def check_py_format_black(ctx: _Context) -> Dict[Path, str]:
result = _check_files(
[x for x in ctx.paths if str(x).endswith(paths)],
_format_temp,
+ ctx.dry_run,
)
result.update(errors)
return result
@@ -399,6 +448,47 @@ def _check_trailing_space(paths: Iterable[Path], fix: bool) -> Dict[Path, str]:
return errors
+def _format_json(contents: bytes) -> bytes:
+ return json.dumps(json.loads(contents), indent=2).encode() + b'\n'
+
+
+def _json_error(exc: json.JSONDecodeError, path: Path) -> str:
+ return f'{path}: {exc.msg} {exc.lineno}:{exc.colno}\n'
+
+
+def check_json_format(ctx: _Context) -> Dict[Path, str]:
+ errors = {}
+
+ for path in ctx.paths:
+ orig = path.read_bytes()
+ try:
+ formatted = _format_json(orig)
+ except json.JSONDecodeError as exc:
+ errors[path] = _json_error(exc, path)
+ continue
+
+ if orig != formatted:
+ errors[path] = _diff(path, orig, formatted)
+
+ return errors
+
+
+def fix_json_format(ctx: _Context) -> Dict[Path, str]:
+ errors = {}
+ for path in ctx.paths:
+ orig = path.read_bytes()
+ try:
+ formatted = _format_json(orig)
+ except json.JSONDecodeError as exc:
+ errors[path] = _json_error(exc, path)
+ continue
+
+ if orig != formatted:
+ path.write_bytes(formatted)
+
+ return errors
+
+
def check_trailing_space(ctx: _Context) -> Dict[Path, str]:
return _check_trailing_space(ctx.paths, fix=False)
@@ -408,6 +498,24 @@ def fix_trailing_space(ctx: _Context) -> Dict[Path, str]:
return {}
+def rst_format_check(ctx: _Context) -> Dict[Path, str]:
+ errors = {}
+ for path in ctx.paths:
+ result = reformat_rst(path, diff=True, in_place=False)
+ if result:
+ errors[path] = ''.join(result)
+ return errors
+
+
+def rst_format_fix(ctx: _Context) -> Dict[Path, str]:
+ errors = {}
+ for path in ctx.paths:
+ result = reformat_rst(path, diff=True, in_place=True)
+ if result:
+ errors[path] = ''.join(result)
+ return errors
+
+
def print_format_check(
errors: Dict[Path, str],
show_fix_commands: bool,
@@ -457,7 +565,7 @@ class CodeFormat(NamedTuple):
@property
def extensions(self):
- # TODO(b/23842636): Switch calls of this to using 'filter' and remove.
+ # TODO: b/23842636 - Switch calls of this to using 'filter' and remove.
return self.filter.endswith
@@ -467,7 +575,7 @@ CPP_SOURCE_EXTS = frozenset(
)
CPP_EXTS = CPP_HEADER_EXTS.union(CPP_SOURCE_EXTS)
CPP_FILE_FILTER = FileFilter(
- endswith=CPP_EXTS, exclude=(r'\.pb\.h$', r'\.pb\.c$')
+ endswith=CPP_EXTS, exclude=[r'\.pb\.h$', r'\.pb\.c$']
)
C_FORMAT = CodeFormat(
@@ -476,102 +584,133 @@ C_FORMAT = CodeFormat(
PROTO_FORMAT: CodeFormat = CodeFormat(
'Protocol buffer',
- FileFilter(endswith=('.proto',)),
+ FileFilter(endswith=['.proto']),
clang_format_check,
clang_format_fix,
)
JAVA_FORMAT: CodeFormat = CodeFormat(
'Java',
- FileFilter(endswith=('.java',)),
+ FileFilter(endswith=['.java']),
clang_format_check,
clang_format_fix,
)
JAVASCRIPT_FORMAT: CodeFormat = CodeFormat(
'JavaScript',
- FileFilter(endswith=('.js',)),
- clang_format_check,
- clang_format_fix,
+ FileFilter(endswith=['.js']),
+ typescript_format_check,
+ typescript_format_fix,
+)
+
+TYPESCRIPT_FORMAT: CodeFormat = CodeFormat(
+ 'TypeScript',
+ FileFilter(endswith=['.ts']),
+ typescript_format_check,
+ typescript_format_fix,
+)
+
+# TODO: b/308948504 - Add real code formatting support for CSS
+CSS_FORMAT: CodeFormat = CodeFormat(
+ 'css',
+ FileFilter(endswith=['.css']),
+ check_trailing_space,
+ fix_trailing_space,
)
GO_FORMAT: CodeFormat = CodeFormat(
- 'Go', FileFilter(endswith=('.go',)), check_go_format, fix_go_format
+ 'Go', FileFilter(endswith=['.go']), check_go_format, fix_go_format
)
PYTHON_FORMAT: CodeFormat = CodeFormat(
'Python',
- FileFilter(endswith=('.py',)),
+ FileFilter(endswith=['.py']),
check_py_format,
fix_py_format,
)
GN_FORMAT: CodeFormat = CodeFormat(
- 'GN', FileFilter(endswith=('.gn', '.gni')), check_gn_format, fix_gn_format
+ 'GN', FileFilter(endswith=['.gn', '.gni']), check_gn_format, fix_gn_format
)
BAZEL_FORMAT: CodeFormat = CodeFormat(
'Bazel',
- FileFilter(endswith=('BUILD', '.bazel', '.bzl'), name=('WORKSPACE')),
+ FileFilter(endswith=['.bazel', '.bzl'], name=['^BUILD$', '^WORKSPACE$']),
check_bazel_format,
fix_bazel_format,
)
COPYBARA_FORMAT: CodeFormat = CodeFormat(
'Copybara',
- FileFilter(endswith=('.bara.sky',)),
+ FileFilter(endswith=['.bara.sky']),
check_bazel_format,
fix_bazel_format,
)
-# TODO(b/234881054): Add real code formatting support for CMake
+# TODO: b/234881054 - Add real code formatting support for CMake
CMAKE_FORMAT: CodeFormat = CodeFormat(
'CMake',
- FileFilter(endswith=('CMakeLists.txt', '.cmake')),
+ FileFilter(endswith=['.cmake'], name=['^CMakeLists.txt$']),
check_trailing_space,
fix_trailing_space,
)
RST_FORMAT: CodeFormat = CodeFormat(
'reStructuredText',
- FileFilter(endswith=('.rst',)),
- check_trailing_space,
- fix_trailing_space,
+ FileFilter(endswith=['.rst']),
+ rst_format_check,
+ rst_format_fix,
)
MARKDOWN_FORMAT: CodeFormat = CodeFormat(
'Markdown',
- FileFilter(endswith=('.md',)),
+ FileFilter(endswith=['.md']),
check_trailing_space,
fix_trailing_space,
)
OWNERS_CODE_FORMAT = CodeFormat(
'OWNERS',
- filter=FileFilter(name=('OWNERS',)),
+ filter=FileFilter(name=['^OWNERS$']),
check=check_owners_format,
fix=fix_owners_format,
)
-CODE_FORMATS: Tuple[CodeFormat, ...] = (
- # keep-sorted: start
- BAZEL_FORMAT,
- CMAKE_FORMAT,
- COPYBARA_FORMAT,
- C_FORMAT,
- GN_FORMAT,
- GO_FORMAT,
- JAVASCRIPT_FORMAT,
- JAVA_FORMAT,
- MARKDOWN_FORMAT,
- OWNERS_CODE_FORMAT,
- PROTO_FORMAT,
- PYTHON_FORMAT,
- RST_FORMAT,
- # keep-sorted: end
+JSON_FORMAT: CodeFormat = CodeFormat(
+ 'JSON',
+ FileFilter(endswith=['.json']),
+ check=check_json_format,
+ fix=fix_json_format,
)
-# TODO(b/264578594) Remove these lines when these globals aren't referenced.
+CODE_FORMATS: Tuple[CodeFormat, ...] = tuple(
+ filter(
+ None,
+ (
+ # keep-sorted: start
+ BAZEL_FORMAT,
+ CMAKE_FORMAT,
+ COPYBARA_FORMAT,
+ CSS_FORMAT,
+ C_FORMAT,
+ GN_FORMAT,
+ GO_FORMAT,
+ JAVASCRIPT_FORMAT if shutil.which('npx') else None,
+ JAVA_FORMAT,
+ JSON_FORMAT,
+ MARKDOWN_FORMAT,
+ OWNERS_CODE_FORMAT,
+ PROTO_FORMAT,
+ PYTHON_FORMAT,
+ RST_FORMAT,
+ TYPESCRIPT_FORMAT if shutil.which('npx') else None,
+ # keep-sorted: end
+ ),
+ )
+)
+
+
+# TODO: b/264578594 - Remove these lines when these globals aren't referenced.
CODE_FORMATS_WITH_BLACK: Tuple[CodeFormat, ...] = CODE_FORMATS
CODE_FORMATS_WITH_YAPF: Tuple[CodeFormat, ...] = CODE_FORMATS
@@ -591,8 +730,9 @@ def presubmit_check(
file_filter = FileFilter(**vars(code_format.filter))
file_filter.exclude += tuple(re.compile(e) for e in exclude)
- @pw_presubmit.filter_paths(file_filter=file_filter)
- def check_code_format(ctx: pw_presubmit.PresubmitContext):
+ @filter_paths(file_filter=file_filter)
+ def check_code_format(ctx: PresubmitContext):
+ ctx.paths = presubmit_context.apply_exclusions(ctx)
errors = code_format.check(ctx)
print_format_check(
errors,
@@ -610,7 +750,7 @@ def presubmit_check(
file=outs,
)
- raise pw_presubmit.PresubmitFailure
+ raise PresubmitFailure
language = code_format.language.lower().replace('+', 'p').replace(' ', '_')
check_code_format.name = f'{language}_format'
@@ -646,10 +786,16 @@ class CodeFormatter:
package_root: Optional[Path] = None,
):
self.root = root
- self.paths = list(files)
self._formats: Dict[CodeFormat, List] = collections.defaultdict(list)
self.root_output_dir = output_dir
self.package_root = package_root or output_dir / 'packages'
+ self._format_options = FormatOptions.load()
+ raw_paths = files
+ self.paths: Tuple[Path, ...] = self._format_options.filter_paths(files)
+
+ filtered_paths = set(raw_paths) - set(self.paths)
+ for path in sorted(filtered_paths):
+ _LOG.debug('filtered out %s', path)
for path in self.paths:
for code_format in code_formats:
@@ -671,7 +817,7 @@ class CodeFormatter:
output_dir=outdir,
paths=tuple(self._formats[code_format]),
package_root=self.package_root,
- format_options=FormatOptions.load(),
+ format_options=self._format_options,
)
def check(self) -> Dict[Path, str]:
@@ -882,19 +1028,6 @@ def main() -> int:
return format_paths_in_repo(**vars(arguments(git_paths=True).parse_args()))
-def _pigweed_upstream_main() -> int:
- """Check and fix formatting for source files in upstream Pigweed.
-
- Excludes third party sources.
- """
- args = arguments(git_paths=True).parse_args()
-
- # Exclude paths with third party code from formatting.
- args.exclude.append(re.compile('^third_party/fuchsia/repo/'))
-
- return format_paths_in_repo(**vars(args))
-
-
if __name__ == '__main__':
try:
# If pw_cli is available, use it to initialize logs.