Skip to content

gh-112720: Move dis's cache output code to the Formatter, labels lookup to the arg_resolver. Reduce the number of parameters passed around. #113108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 91 additions & 77 deletions Lib/dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,14 @@ def dis(x=None, *, file=None, depth=None, show_caches=False, adaptive=False,
elif hasattr(x, 'co_code'): # Code object
_disassemble_recursive(x, file=file, depth=depth, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
elif isinstance(x, (bytes, bytearray)): # Raw bytecode
_disassemble_bytes(x, file=file, show_caches=show_caches, show_offsets=show_offsets)
Copy link
Member

@markshannon markshannon Dec 14, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this need different handling now?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not different handling, just different argument types. If we pass around the arg_resolver and formatter then we can plug in different implementations of them.

labels_map = _make_labels_map(x)
label_width = 4 + len(str(len(labels_map)))
formatter = Formatter(file=file,
offset_width=len(str(max(len(x) - 2, 9999))) if show_offsets else 0,
label_width=label_width,
show_caches=show_caches)
arg_resolver = ArgResolver(labels_map=labels_map)
_disassemble_bytes(x, arg_resolver=arg_resolver, formatter=formatter)
elif isinstance(x, str): # Source code
_disassemble_str(x, file=file, depth=depth, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
else:
Expand Down Expand Up @@ -394,23 +401,41 @@ def __str__(self):
class Formatter:

def __init__(self, file=None, lineno_width=0, offset_width=0, label_width=0,
line_offset=0):
line_offset=0, show_caches=False):
"""Create a Formatter

*file* where to write the output
*lineno_width* sets the width of the line number field (0 omits it)
*offset_width* sets the width of the instruction offset field
*label_width* sets the width of the label field
*show_caches* is a boolean indicating whether to display cache lines

*line_offset* the line number (within the code unit)
"""
self.file = file
self.lineno_width = lineno_width
self.offset_width = offset_width
self.label_width = label_width

self.show_caches = show_caches

def print_instruction(self, instr, mark_as_current=False):
self.print_instruction_line(instr, mark_as_current)
if self.show_caches and instr.cache_info:
offset = instr.offset
for name, size, data in instr.cache_info:
for i in range(size):
offset += 2
# Only show the fancy argrepr for a CACHE instruction when it's
# the first entry for a particular cache value:
if i == 0:
argrepr = f"{name}: {int.from_bytes(data, sys.byteorder)}"
else:
argrepr = ""
self.print_instruction_line(
Instruction("CACHE", CACHE, 0, None, argrepr, offset, offset,
False, None, None, instr.positions),
False)

def print_instruction_line(self, instr, mark_as_current):
"""Format instruction details for inclusion in disassembly output."""
lineno_width = self.lineno_width
offset_width = self.offset_width
Expand Down Expand Up @@ -474,11 +499,14 @@ def print_exception_table(self, exception_entries):


class ArgResolver:
def __init__(self, co_consts, names, varname_from_oparg, labels_map):
def __init__(self, co_consts=None, names=None, varname_from_oparg=None, labels_map=None):
self.co_consts = co_consts
self.names = names
self.varname_from_oparg = varname_from_oparg
self.labels_map = labels_map
self.labels_map = labels_map or {}

def get_label_for_offset(self, offset):
return self.labels_map.get(offset, None)

def get_argval_argrepr(self, op, arg, offset):
get_name = None if self.names is None else self.names.__getitem__
Expand Down Expand Up @@ -547,8 +575,7 @@ def get_argval_argrepr(self, op, arg, offset):
argrepr = _intrinsic_2_descs[arg]
return argval, argrepr


def get_instructions(x, *, first_line=None, show_caches=False, adaptive=False):
def get_instructions(x, *, first_line=None, show_caches=None, adaptive=False):
"""Iterator for the opcodes in methods, functions or code

Generates a series of Instruction named tuples giving the details of
Expand All @@ -567,9 +594,10 @@ def get_instructions(x, *, first_line=None, show_caches=False, adaptive=False):
line_offset = 0

original_code = co.co_code
labels_map = _make_labels_map(original_code)
arg_resolver = ArgResolver(co.co_consts, co.co_names, co._varname_from_oparg,
labels_map)
arg_resolver = ArgResolver(co_consts=co.co_consts,
names=co.co_names,
varname_from_oparg=co._varname_from_oparg,
labels_map=_make_labels_map(original_code))
return _get_instructions_bytes(_get_code_array(co, adaptive),
linestarts=linestarts,
line_offset=line_offset,
Expand Down Expand Up @@ -648,7 +676,7 @@ def _is_backward_jump(op):
'ENTER_EXECUTOR')

def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=None,
original_code=None, labels_map=None, arg_resolver=None):
original_code=None, arg_resolver=None):
"""Iterate over the instructions in a bytecode string.

Generates a sequence of Instruction namedtuples giving the details of each
Expand All @@ -661,8 +689,6 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
original_code = original_code or code
co_positions = co_positions or iter(())

labels_map = labels_map or _make_labels_map(original_code)

starts_line = False
local_line_number = None
line_number = None
Expand All @@ -684,10 +710,6 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
else:
argval, argrepr = arg, repr(arg)

instr = Instruction(_all_opname[op], op, arg, argval, argrepr,
offset, start_offset, starts_line, line_number,
labels_map.get(offset, None), positions)

caches = _get_cache_size(_all_opname[deop])
# Advance the co_positions iterator:
for _ in range(caches):
Expand All @@ -701,23 +723,31 @@ def _get_instructions_bytes(code, linestarts=None, line_offset=0, co_positions=N
else:
cache_info = None

label = arg_resolver.get_label_for_offset(offset) if arg_resolver else None
yield Instruction(_all_opname[op], op, arg, argval, argrepr,
offset, start_offset, starts_line, line_number,
labels_map.get(offset, None), positions, cache_info)

label, positions, cache_info)


def disassemble(co, lasti=-1, *, file=None, show_caches=False, adaptive=False,
show_offsets=False):
"""Disassemble a code object."""
linestarts = dict(findlinestarts(co))
exception_entries = _parse_exception_table(co)
_disassemble_bytes(_get_code_array(co, adaptive),
lasti, co._varname_from_oparg,
co.co_names, co.co_consts, linestarts, file=file,
exception_entries=exception_entries,
co_positions=co.co_positions(), show_caches=show_caches,
original_code=co.co_code, show_offsets=show_offsets)
labels_map = _make_labels_map(co.co_code, exception_entries=exception_entries)
label_width = 4 + len(str(len(labels_map)))
formatter = Formatter(file=file,
lineno_width=_get_lineno_width(linestarts),
offset_width=len(str(max(len(co.co_code) - 2, 9999))) if show_offsets else 0,
label_width=label_width,
show_caches=show_caches)
arg_resolver = ArgResolver(co_consts=co.co_consts,
names=co.co_names,
varname_from_oparg=co._varname_from_oparg,
labels_map=labels_map)
_disassemble_bytes(_get_code_array(co, adaptive), lasti, linestarts,
exception_entries=exception_entries, co_positions=co.co_positions(),
original_code=co.co_code, arg_resolver=arg_resolver, formatter=formatter)

def _disassemble_recursive(co, *, file=None, depth=None, show_caches=False, adaptive=False, show_offsets=False):
disassemble(co, file=file, show_caches=show_caches, adaptive=adaptive, show_offsets=show_offsets)
Expand Down Expand Up @@ -764,60 +794,29 @@ def _get_lineno_width(linestarts):
return lineno_width


def _disassemble_bytes(code, lasti=-1, varname_from_oparg=None,
names=None, co_consts=None, linestarts=None,
*, file=None, line_offset=0, exception_entries=(),
co_positions=None, show_caches=False, original_code=None,
show_offsets=False):

offset_width = len(str(max(len(code) - 2, 9999))) if show_offsets else 0

labels_map = _make_labels_map(original_code or code, exception_entries)
label_width = 4 + len(str(len(labels_map)))
def _disassemble_bytes(code, lasti=-1, linestarts=None,
*, line_offset=0, exception_entries=(),
co_positions=None, original_code=None,
arg_resolver=None, formatter=None):

formatter = Formatter(file=file,
lineno_width=_get_lineno_width(linestarts),
offset_width=offset_width,
label_width=label_width,
line_offset=line_offset)
assert formatter is not None
assert arg_resolver is not None

arg_resolver = ArgResolver(co_consts, names, varname_from_oparg, labels_map)
instrs = _get_instructions_bytes(code, linestarts=linestarts,
line_offset=line_offset,
co_positions=co_positions,
original_code=original_code,
labels_map=labels_map,
arg_resolver=arg_resolver)

print_instructions(instrs, exception_entries, formatter,
show_caches=show_caches, lasti=lasti)
print_instructions(instrs, exception_entries, formatter, lasti=lasti)


def print_instructions(instrs, exception_entries, formatter, show_caches=False, lasti=-1):
def print_instructions(instrs, exception_entries, formatter, lasti=-1):
for instr in instrs:
if show_caches:
is_current_instr = instr.offset == lasti
else:
# Each CACHE takes 2 bytes
is_current_instr = instr.offset <= lasti \
<= instr.offset + 2 * _get_cache_size(_all_opname[_deoptop(instr.opcode)])
# Each CACHE takes 2 bytes
is_current_instr = instr.offset <= lasti \
<= instr.offset + 2 * _get_cache_size(_all_opname[_deoptop(instr.opcode)])
formatter.print_instruction(instr, is_current_instr)
deop = _deoptop(instr.opcode)
if show_caches and instr.cache_info:
offset = instr.offset
for name, size, data in instr.cache_info:
for i in range(size):
offset += 2
# Only show the fancy argrepr for a CACHE instruction when it's
# the first entry for a particular cache value:
if i == 0:
argrepr = f"{name}: {int.from_bytes(data, sys.byteorder)}"
else:
argrepr = ""
formatter.print_instruction(
Instruction("CACHE", CACHE, 0, None, argrepr, offset, offset,
False, None, None, instr.positions),
is_current_instr)

formatter.print_exception_table(exception_entries)

Expand Down Expand Up @@ -960,14 +959,15 @@ def __iter__(self):
co = self.codeobj
original_code = co.co_code
labels_map = _make_labels_map(original_code, self.exception_entries)
arg_resolver = ArgResolver(co.co_consts, co.co_names, co._varname_from_oparg,
labels_map)
arg_resolver = ArgResolver(co_consts=co.co_consts,
names=co.co_names,
varname_from_oparg=co._varname_from_oparg,
labels_map=labels_map)
return _get_instructions_bytes(_get_code_array(co, self.adaptive),
linestarts=self._linestarts,
line_offset=self._line_offset,
co_positions=co.co_positions(),
original_code=original_code,
labels_map=labels_map,
arg_resolver=arg_resolver)

def __repr__(self):
Expand Down Expand Up @@ -995,18 +995,32 @@ def dis(self):
else:
offset = -1
with io.StringIO() as output:
_disassemble_bytes(_get_code_array(co, self.adaptive),
varname_from_oparg=co._varname_from_oparg,
names=co.co_names, co_consts=co.co_consts,
code = _get_code_array(co, self.adaptive)
offset_width = len(str(max(len(code) - 2, 9999))) if self.show_offsets else 0


labels_map = _make_labels_map(co.co_code, self.exception_entries)
label_width = 4 + len(str(len(labels_map)))
formatter = Formatter(file=output,
lineno_width=_get_lineno_width(self._linestarts),
offset_width=offset_width,
label_width=label_width,
line_offset=self._line_offset,
show_caches=self.show_caches)

arg_resolver = ArgResolver(co_consts=co.co_consts,
names=co.co_names,
varname_from_oparg=co._varname_from_oparg,
labels_map=labels_map)
_disassemble_bytes(code,
linestarts=self._linestarts,
line_offset=self._line_offset,
file=output,
lasti=offset,
exception_entries=self.exception_entries,
co_positions=co.co_positions(),
show_caches=self.show_caches,
original_code=co.co_code,
show_offsets=self.show_offsets)
arg_resolver=arg_resolver,
formatter=formatter)
return output.getvalue()


Expand Down
17 changes: 13 additions & 4 deletions Lib/test/test_dis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import contextlib
import dis
import functools
import io
import re
import sys
Expand Down Expand Up @@ -1982,19 +1983,27 @@ def f(opcode, oparg, offset, *init_args):
self.assertEqual(f(opcode.opmap["BINARY_OP"], 3, *args), (3, '<<'))
self.assertEqual(f(opcode.opmap["CALL_INTRINSIC_1"], 2, *args), (2, 'INTRINSIC_IMPORT_STAR'))

def get_instructions(self, code):
return dis._get_instructions_bytes(code)

def test_start_offset(self):
# When no extended args are present,
# start_offset should be equal to offset

instructions = list(dis.Bytecode(_f))
for instruction in instructions:
self.assertEqual(instruction.offset, instruction.start_offset)

def last_item(iterable):
return functools.reduce(lambda a, b : b, iterable)

code = bytes([
opcode.opmap["LOAD_FAST"], 0x00,
opcode.opmap["EXTENDED_ARG"], 0x01,
opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
])
jump = list(dis._get_instructions_bytes(code))[-1]
labels_map = dis._make_labels_map(code)
jump = last_item(self.get_instructions(code))
self.assertEqual(4, jump.offset)
self.assertEqual(2, jump.start_offset)

Expand All @@ -2006,7 +2015,7 @@ def test_start_offset(self):
opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
opcode.opmap["CACHE"], 0x00,
])
jump = list(dis._get_instructions_bytes(code))[-1]
jump = last_item(self.get_instructions(code))
self.assertEqual(8, jump.offset)
self.assertEqual(2, jump.start_offset)

Expand All @@ -2021,7 +2030,7 @@ def test_start_offset(self):
opcode.opmap["POP_JUMP_IF_TRUE"], 0xFF,
opcode.opmap["CACHE"], 0x00,
])
instructions = list(dis._get_instructions_bytes(code))
instructions = list(self.get_instructions(code))
# 1st jump
self.assertEqual(4, instructions[2].offset)
self.assertEqual(2, instructions[2].start_offset)
Expand All @@ -2042,7 +2051,7 @@ def test_cache_offset_and_end_offset(self):
opcode.opmap["CACHE"], 0x00,
opcode.opmap["CACHE"], 0x00
])
instructions = list(dis._get_instructions_bytes(code))
instructions = list(self.get_instructions(code))
self.assertEqual(2, instructions[0].cache_offset)
self.assertEqual(10, instructions[0].end_offset)
self.assertEqual(12, instructions[1].cache_offset)
Expand Down