aboutsummaryrefslogtreecommitdiff
path: root/yapf/yapflib/yapf_api.py
blob: 282dea3eb0b5cf3fe8ac34c7ff8c001e18f819c4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Entry points for YAPF.

The main APIs that YAPF exposes to drive the reformatting.

  FormatFile(): reformat a file.
  FormatCode(): reformat a string of code.

These APIs have some common arguments:

  style_config: (string) Either a style name or a path to a file that contains
    formatting style settings. If None is specified, use the default style
    as set in style.DEFAULT_STYLE_FACTORY
  lines: (list of tuples of integers) A list of tuples of lines, [start, end],
    that we want to format. The lines are 1-based indexed. It can be used by
    third-party code (e.g., IDEs) when reformatting a snippet of code rather
    than a whole file.
  print_diff: (bool) Instead of returning the reformatted source, return a
    diff that turns the formatted source into reformatter source.
  verify: (bool) True if reformatted code should be verified for syntax.
"""

import difflib
import re
import sys

from lib2to3.pgen2 import parse

from yapf.yapflib import blank_line_calculator
from yapf.yapflib import comment_splicer
from yapf.yapflib import continuation_splicer
from yapf.yapflib import file_resources
from yapf.yapflib import identify_container
from yapf.yapflib import py3compat
from yapf.yapflib import pytree_unwrapper
from yapf.yapflib import pytree_utils
from yapf.yapflib import reformatter
from yapf.yapflib import split_penalty
from yapf.yapflib import style
from yapf.yapflib import subtype_assigner


def FormatFile(filename,
               style_config=None,
               lines=None,
               print_diff=False,
               verify=False,
               in_place=False,
               logger=None):
  """Format a single Python file and return the formatted code.

  Arguments:
    filename: (unicode) The file to reformat.
    in_place: (bool) If True, write the reformatted code back to the file.
    logger: (io streamer) A stream to output logging.
    remaining arguments: see comment at the top of this module.

  Returns:
    Tuple of (reformatted_code, encoding, changed). reformatted_code is None if
    the file is successfully written to (having used in_place). reformatted_code
    is a diff if print_diff is True.

  Raises:
    IOError: raised if there was an error reading the file.
    ValueError: raised if in_place and print_diff are both specified.
  """
  _CheckPythonVersion()

  if in_place and print_diff:
    raise ValueError('Cannot pass both in_place and print_diff.')

  original_source, newline, encoding = ReadFile(filename, logger)
  reformatted_source, changed = FormatCode(
      original_source,
      style_config=style_config,
      filename=filename,
      lines=lines,
      print_diff=print_diff,
      verify=verify)
  if reformatted_source.rstrip('\n'):
    lines = reformatted_source.rstrip('\n').split('\n')
    reformatted_source = newline.join(line for line in lines) + newline
  if in_place:
    if original_source and original_source != reformatted_source:
      file_resources.WriteReformattedCode(filename, reformatted_source,
                                          encoding, in_place)
    return None, encoding, changed

  return reformatted_source, encoding, changed


def FormatCode(unformatted_source,
               filename='<unknown>',
               style_config=None,
               lines=None,
               print_diff=False,
               verify=False):
  """Format a string of Python code.

  This provides an alternative entry point to YAPF.

  Arguments:
    unformatted_source: (unicode) The code to format.
    filename: (unicode) The name of the file being reformatted.
    remaining arguments: see comment at the top of this module.

  Returns:
    Tuple of (reformatted_source, changed). reformatted_source conforms to the
    desired formatting style. changed is True if the source changed.
  """
  _CheckPythonVersion()
  style.SetGlobalStyle(style.CreateStyleFromConfig(style_config))
  if not unformatted_source.endswith('\n'):
    unformatted_source += '\n'

  try:
    tree = pytree_utils.ParseCodeToTree(unformatted_source)
  except parse.ParseError as e:
    e.msg = filename + ': ' + e.msg
    raise

  # Run passes on the tree, modifying it in place.
  comment_splicer.SpliceComments(tree)
  continuation_splicer.SpliceContinuations(tree)
  subtype_assigner.AssignSubtypes(tree)
  identify_container.IdentifyContainers(tree)
  split_penalty.ComputeSplitPenalties(tree)
  blank_line_calculator.CalculateBlankLines(tree)

  uwlines = pytree_unwrapper.UnwrapPyTree(tree)
  for uwl in uwlines:
    uwl.CalculateFormattingInformation()

  lines = _LineRangesToSet(lines)
  _MarkLinesToFormat(uwlines, lines)
  reformatted_source = reformatter.Reformat(
      _SplitSemicolons(uwlines), verify, lines)

  if unformatted_source == reformatted_source:
    return '' if print_diff else reformatted_source, False

  code_diff = _GetUnifiedDiff(
      unformatted_source, reformatted_source, filename=filename)

  if print_diff:
    return code_diff, code_diff.strip() != ''  # pylint: disable=g-explicit-bool-comparison

  return reformatted_source, True


def _CheckPythonVersion():  # pragma: no cover
  errmsg = 'yapf is only supported for Python 2.7 or 3.4+'
  if sys.version_info[0] == 2:
    if sys.version_info[1] < 7:
      raise RuntimeError(errmsg)
  elif sys.version_info[0] == 3:
    if sys.version_info[1] < 4:
      raise RuntimeError(errmsg)


def ReadFile(filename, logger=None):
  """Read the contents of the file.

  An optional logger can be specified to emit messages to your favorite logging
  stream. If specified, then no exception is raised. This is external so that it
  can be used by third-party applications.

  Arguments:
    filename: (unicode) The name of the file.
    logger: (function) A function or lambda that takes a string and emits it.

  Returns:
    The contents of filename.

  Raises:
    IOError: raised if there was an error reading the file.
  """
  try:
    encoding = file_resources.FileEncoding(filename)

    # Preserves line endings.
    with py3compat.open_with_encoding(
        filename, mode='r', encoding=encoding, newline='') as fd:
      lines = fd.readlines()

    line_ending = file_resources.LineEnding(lines)
    source = '\n'.join(line.rstrip('\r\n') for line in lines) + '\n'
    return source, line_ending, encoding
  except IOError as err:  # pragma: no cover
    if logger:
      logger(err)
    raise


def _SplitSemicolons(uwlines):
  res = []
  for uwline in uwlines:
    res.extend(uwline.Split())
  return res


DISABLE_PATTERN = r'^#.*\byapf:\s*disable\b'
ENABLE_PATTERN = r'^#.*\byapf:\s*enable\b'


def _LineRangesToSet(line_ranges):
  """Return a set of lines in the range."""

  if line_ranges is None:
    return None

  line_set = set()
  for low, high in sorted(line_ranges):
    line_set.update(range(low, high + 1))

  return line_set


def _MarkLinesToFormat(uwlines, lines):
  """Skip sections of code that we shouldn't reformat."""
  if lines:
    for uwline in uwlines:
      uwline.disable = not lines.intersection(
          range(uwline.lineno, uwline.last.lineno + 1))

  # Now go through the lines and disable any lines explicitly marked as
  # disabled.
  index = 0
  while index < len(uwlines):
    uwline = uwlines[index]
    if uwline.is_comment:
      if _DisableYAPF(uwline.first.value.strip()):
        index += 1
        while index < len(uwlines):
          uwline = uwlines[index]
          if uwline.is_comment and _EnableYAPF(uwline.first.value.strip()):
            break
          uwline.disable = True
          index += 1
    elif re.search(DISABLE_PATTERN, uwline.last.value.strip(), re.IGNORECASE):
      uwline.disable = True
    index += 1


def _DisableYAPF(line):
  return (re.search(DISABLE_PATTERN,
                    line.split('\n')[0].strip(), re.IGNORECASE) or
          re.search(DISABLE_PATTERN,
                    line.split('\n')[-1].strip(), re.IGNORECASE))


def _EnableYAPF(line):
  return (re.search(ENABLE_PATTERN,
                    line.split('\n')[0].strip(), re.IGNORECASE) or
          re.search(ENABLE_PATTERN,
                    line.split('\n')[-1].strip(), re.IGNORECASE))


def _GetUnifiedDiff(before, after, filename='code'):
  """Get a unified diff of the changes.

  Arguments:
    before: (unicode) The original source code.
    after: (unicode) The reformatted source code.
    filename: (unicode) The code's filename.

  Returns:
    The unified diff text.
  """
  before = before.splitlines()
  after = after.splitlines()
  return '\n'.join(
      difflib.unified_diff(
          before,
          after,
          filename,
          filename,
          '(original)',
          '(reformatted)',
          lineterm='')) + '\n'