import ast
from mkinit.util.orderedset import OrderedSet as oset
import sys
__all__ = [
"TopLevelVisitor",
]
_UNHANDLED = None
IS_PY_GE_308 = sys.version_info[0] >= 3 and sys.version_info[1] >= 8
IS_PY_GE_312 = sys.version_info[0] >= 3 and sys.version_info[1] >= 12
[docs]
class TopLevelVisitor(ast.NodeVisitor):
"""
Parses top-level attribute names
References:
# For other visit_<classname> values see
http://greentreesnakes.readthedocs.io/en/latest/nodes.html
CommandLine:
python ~/code/mkinit/mkinit/top_level_ast.py TopLevelVisitor:1
Example:
>>> from xdoctest import utils
>>> source = utils.codeblock(
... '''
... def foo():
... def subfunc():
... pass
... def bar():
... pass
... class Spam(object):
... def eggs(self):
... pass
... @staticmethod
... def hams():
... pass
... ''')
>>> self = TopLevelVisitor.parse(source)
>>> print('attrnames = {!r}'.format(sorted(self.attrnames)))
attrnames = ['Spam', 'bar', 'foo']
Example:
>>> # xdoctest: +REQUIRES(PY3)
>>> from mkinit.top_level_ast import * # NOQA
>>> from xdoctest import utils
>>> source = utils.codeblock(
... '''
... async def asyncfoo():
... var = 1
... def bar():
... pass
... class Spam(object):
... def eggs(self):
... pass
... @staticmethod
... def hams():
... pass
... ''')
>>> self = TopLevelVisitor.parse(source)
>>> print('attrnames = {!r}'.format(sorted(self.attrnames)))
attrnames = ['Spam', 'asyncfoo', 'bar']
Example:
>>> from xdoctest import utils
>>> source = utils.codeblock(
... '''
... a = True
... if a:
... b = True
... c = True
... else:
... b = False
... d = True
... del d
... ''')
>>> self = TopLevelVisitor.parse(source)
>>> print('attrnames = {!r}'.format(sorted(self.attrnames)))
attrnames = ['a', 'b']
Example:
>>> from xdoctest import utils
>>> source = utils.codeblock(
... '''
... try:
... d = True
... e = True
... except ImportError:
... raise
... except Exception:
... d = False
... f = False
... else:
... f = True
... ''')
>>> self = TopLevelVisitor.parse(source)
>>> print('attrnames = {!r}'.format(sorted(self.attrnames)))
attrnames = ['d', 'f']
"""
def __init__(self):
super(TopLevelVisitor, self).__init__()
self.attrnames = oset()
self.removed = oset() # keep track of which variables were deleted
[docs]
def _register(self, name):
if isinstance(name, (list, tuple, oset)):
for n in name:
self._register(n)
else:
if name not in self.attrnames:
self.attrnames.add(name)
self.removed.discard(name)
[docs]
def _unregister(self, name):
if name in self.attrnames:
self.attrnames.discard(name)
self.removed.add(name)
[docs]
@classmethod
def parse(TopLevelVisitor, source):
self = TopLevelVisitor()
source_utf8 = source.encode("utf8")
pt = ast.parse(source_utf8)
self.visit(pt)
return self
# def visit(self, node):
# super(TopLevelVisitor, self).visit(node)
[docs]
def visit_FunctionDef(self, node):
self._register(node.name)
[docs]
def visit_AsyncFunctionDef(self, node):
self._register(node.name)
[docs]
def visit_ClassDef(self, node):
self._register(node.name)
[docs]
def visit_Assign(self, node):
for target in node.targets:
if hasattr(target, "id"):
self._register(target.id)
# TODO: assign constants to self.const_lookup?
self.generic_visit(node)
[docs]
def visit_If(self, node):
"""
Note:
elif clauses don't have a special representation in the AST, but
rather appear as extra If nodes within the orelse section of the
previous one.
"""
if isinstance(node.test, ast.Compare): # pragma: nobranch
try:
if IS_PY_GE_312:
if all([
isinstance(node.test.ops[0], ast.Eq),
node.test.left.id == '__name__',
node.test.comparators[0].value == '__main__',
]):
# Ignore main block
return
else:
if all([
isinstance(node.test.ops[0], ast.Eq),
node.test.left.id == '__name__',
node.test.comparators[0].s == '__main__',
]):
# Ignore main block
return
except Exception: # nocover
pass
# TODO: handled deleted attributes?
# Find definitions from conditionals that always accept or
# that are defined in all possible non-rejecting branches (note this
# requires an else statment). A rejecting branch is one that is
# unconditionally false or unconditionally raises an exception
if_node, elif_nodes, else_body = unpack_if_nodes(node)
test_nodes = [if_node] + elif_nodes
has_unconditional = False
required = []
for item in test_nodes:
truth = static_truthiness(item.test)
# if any(isinstance(n, ast.Raise) for n in item.body):
# # Ignore branches that simply raise an error
# continue
if truth is _UNHANDLED:
names = get_conditional_attrnames(item.body)
required.append(names)
elif truth is True:
# Branch is unconditionally true, no need to check others
names = get_conditional_attrnames(item.body)
required.append(names)
has_unconditional = True
break
elif truth is False:
# Ignore branches that are unconditionally false
continue
else:
raise AssertionError("cannot happen")
if not has_unconditional and else_body:
# If we havent found an unconditional branch we need an else
if not any(isinstance(n, ast.Raise) for n in else_body):
# Ignore else branches that simply raise an error
names = get_conditional_attrnames(else_body)
required.append(names)
has_unconditional = True
if has_unconditional:
# We can only gaurentee that something will exist if there is at
# least one path that must be taken
if len(required) == 0:
common = oset()
elif len(required) == 1:
common = required[0]
else:
common = oset.intersection(*required)
# common = set.intersection(*map(set, required))
self._register(sorted(common))
[docs]
def visit_Try(self, node):
"""
We only care about checking if (a) a variable is defined in the main
body, and (b) that the variable is defined in all except blacks that
**don't** immediately re-raise.
"""
body_attrs = get_conditional_attrnames(node.body)
orelse_attrs = get_conditional_attrnames(node.orelse)
# body_attrs.extend(orelse_attrs)
body_attrs.update(orelse_attrs)
# Require that attributes are defined in all non-error branches
required = []
for handler in node.handlers:
# Ignore any handlers that will always reraise
if not any(isinstance(n, ast.Raise) for n in handler.body):
handler_attrs = get_conditional_attrnames(handler.body)
required.append(handler_attrs)
if len(required) == 0:
common = body_attrs
else:
common = oset.intersection(body_attrs, *required)
# common = set.intersection(set(body_attrs), *map(set, required))
self._register(sorted(common))
# for python2
visit_TryExcept = visit_Try
[docs]
def visit_Delete(self, node):
for item in node.targets:
if isinstance(item, ast.Name):
self._unregister(item.id)
self.generic_visit(node)
def unpack_if_nodes(if_node):
"""
Extract chain of `<if><elif>*<else>?` statements
"""
elif_nodes = []
else_body = None
curr = if_node
while curr:
if len(curr.orelse) == 1 and isinstance(curr.orelse[0], ast.If):
# The current node is followed by an else-if statement
elif_node = curr.orelse[0]
elif_nodes.append(elif_node)
curr = elif_node
elif curr.orelse:
# The current node is followed by an else statement
else_body = curr.orelse
curr = None
else:
curr = None
return if_node, elif_nodes, else_body
def static_truthiness(node):
"""
Extracts static truthiness of a node if possible
Args:
node (ast.Node)
Returns:
bool or None: True or False if a node can be statically bound to a
truthy value, otherwise returns None.
"""
if (isinstance(node, ast.Constant) and isinstance(node.value, str) if IS_PY_GE_308 else isinstance(node, ast.Str)):
return bool(node.value if IS_PY_GE_308 else node.s)
# if isinstance(node, ast.Str):
# return bool(node.s)
elif isinstance(node, ast.Tuple):
return bool(node.elts)
# elif isinstance(node, ast.Num):
# return bool(node.n)
elif (isinstance(node, ast.Constant) and isinstance(node.value, (int, float)) if IS_PY_GE_308 else isinstance(node, ast.Num)):
return bool(node.value if IS_PY_GE_308 else node.n)
# elif isinstance(node, ast.Bytes): # nocover
# return bool(node.s)
elif (isinstance(node, ast.Constant) and isinstance(node.value, bytes) if IS_PY_GE_308 else isinstance(node, ast.Bytes)):
return bool(node.value if IS_PY_GE_308 else node.s)
# elif isinstance(node, ast.NameConstant):
# return bool(node.value)
elif (isinstance(node, ast.Constant) if IS_PY_GE_308 else isinstance(node, ast.NameConstant)):
return bool(node.value)
else:
return _UNHANDLED
def get_conditional_attrnames(body):
"""
Gets attrnames within a list of nodes
"""
sub_visitor = TopLevelVisitor()
for node in body:
# Check the attributes defined on this branch
sub_visitor.visit(node)
return sub_visitor.attrnames
if __name__ == "__main__":
"""
CommandLine:
python -m mkinit.top_level_ast all
"""
import xdoctest
xdoctest.doctest_module(__file__)