Python on-the-fly function update

I recently decided to learn a little about the CPython implementation in preparation for a future project. During the learning process I realised you have much more control over the internals of Python than I originally believed, and decided to see what I could get away with.

As is usual in these circumstances, I started thinking about hypothetical debugging tools, and more specifically how one would go about implementing some of the features of DTrace in Python.

One of my favourite features of DTrace is the absolute minimal overhead of the tool. As an example, if you want to just trace one function there is zero overhead on calling any other functions. This sits in contrast to the sys.gettrace()/sys.settrace() feature in Python, that calls the global trace function (which would now have to check whether this was the function we're interested in or not) every time any function is called.

To be clear: This post talks about a proof-of-concept (read: hackish, unfinished, and with unavoidable edge cases) implementation of one use of DTrace. It's pretty much useless to anyone, and has nothing like the power of DTrace. I do think it's quite cool nevertheless.

Introduction to the problem

The aim of this post is to be able to dynamically tell a running python process to print out the arguments passed and return value on every call to a given function. The caveat is that this python program should have nothing special about it to help make this happen, so we could do this on any process we come across.

We use the Pyrasite tool to handle executing python code in a running process, so that all we have to do is write the code to ensure a function gets traced.

You might be thinking this is an easy task, just write a wrapper function and reassign the function definition using this wrapper.


def wrapper_func(fn):
    def return_func(*args, **kwargs):
        print(args)
        print(kwargs)
        return_value = fn(*args, **kwargs)
        print(return_value)
        return return_value
    return return_func


def trace_this(fn):
    fn.__trace_orig_fn = fn
    orig_name = fn.__name__
    globals()[orig_name] = wrapper_func(fn)


if __name__ == "__main__":
    trace_this(test_function)
      

This is insufficient: it updates the binding in the namespace you found it, but not any other references that may be about.

As a demonstration of this problem we write a simple test program to work with. We will use this test program to demonstrate the benefits and problems with each approach we take in the post.

      
#!/usr/bin/env python
import sys

def test_function(a, b):
  print('Calling inner function')
  if a == 10:
      return 'Never'
  return a * b

def make_closure(fn):
    def returned_function(a):
        return fn(a, 2)
    return returned_function

called_via_closure = make_closure(test_function)

class LocalCopy:
    local_copy = test_function
    def __init__(self):
        type(self).local_copy(3, 2)

dictionary_container = {'fn': test_function}

def main():
    test_function(1, 2)
    dictionary_container['fn'](2, 2)
    LocalCopy()
    called_via_closure(5)
    test_function(10, 20)
    print('\n')

if __name__ == "__main__":
    import time
    while True:
        time.sleep(3)
        main()
      

This test program demonstrates some of the ways that a program can end up using a local reference to a function instead of accessing it via the module namespace. Any of these can stymie redefinition in the naive way, as running the commands below demonstrate.

To run the test, save the naive pyrasite payload as naive_trace_payload.py, run the test script above in one terminal, and inject the payload in another.

The injection command:


vshcmd: > pyrasite $(pgrep -a python | grep trace_problems | cut -d' ' -f1) naive_trace_payload.py
(tracing) python_trace [12:18:35] $
      

What is seen as the output of the test program:


vshcmd: > ./trace_problems.py
Calling inner function
Calling inner function
Calling inner function
Calling inner function
Calling inner function


(1, 2)
{}
Calling inner function
2
Calling inner function
Calling inner function
Calling inner function
(10, 20)
{}
Calling inner function
Never


Traceback (most recent call last):
  File "./trace_problems.py", line 39, in <module>
    time.sleep(3)
KeyboardInterrupt
python_trace [12:18:38] $
      

We can see that only the direct calls through the module namespace were traced.

Modification Of Code Object

The observation that pushed me off on this tangent was that you can simply reassign the fn.__code__ object to something else, and this is what defines the behaviour of fn.

This object contains the bytecode that is evaluated by the Python VM, and the relevant constants etc that the bytecode needs. Replacing it with a __code__ object that does something different changes the behaviour when that function is called.

As this is modification of the existing function object, rather than changing what is referenced by a namespace, any alternate references to the function will immediately see the new behaviour. You can see this in action by injecting the code below into a running trace_problems.py process:


def replace_function(fn, replaced_with):
    fn.__code__ = replaced_with.__code__

def print_wrapper(*args, **kwargs):
    '''Wrap the print builtin.
    As it's defined in C it doesn't have a __code__ attribute for us to use.'''
    print(*args, **kwargs)

replace_function(test_function, print_wrapper)
     

Every call now simply prints its arguments.

So how about we create our own __code__ attribute to behave the same as the original, but also print out the arguments and return value. We can then replace the __code__ attribute on the function we want to trace, and see the effects take hold.

In order to create a __code__ object we need to understand the attributes that make one up. This is documented with the Inspect module but I recommend using the Dis module with a number of test functions to really get the hang of it.

For our first attempt, we'll simply insert the bytecode to print all arguments at the start of the function, and insert a print statement just before the return at the end of the function.

We can figure out the bytecode required by disassembling a few choice functions and correlating what we see with the documentation.

In order to print arguments on entry we need to know the number of arguments which can be read off the original __code__ object, and we need to make sure we can access the print() global by putting it in the __code__.co_names tuple.


import itertools as itt
import opcode

def assemble(instructions):
    '''Hackish assemble function.

    I expect there are many problems with this definition as I'm just a
    beginner in python bytecode.

    '''
    return bytes(itt.chain(*[(opcode.opmap[instruction], argument)
                             for instruction, argument in instructions]))


def trace_this(fn):
    orig_code = fn.__code__
    new_names = (orig_code.co_names
                 if 'print' in orig_code.co_names else
                 orig_code.co_names + ('print',))
    num_args = orig_code.co_argcount + orig_code.co_kwonlyargcount
    print_index = new_names.index('print')

    init_instructions = [('LOAD_GLOBAL', print_index)]
    for local_index in range(0, num_args):
        init_instructions.append(('LOAD_FAST', local_index))

    init_instructions.extend([
              ('CALL_FUNCTION', num_args),
              ('POP_TOP', 0),
    ])

    entry_code = assemble(init_instructions)

     

After this we need to print the return value of the function before it exits. For now I'll just assume the last instruction of the original function is the only RETURN_VALUE instruction. This won't be correct, but we can fix it later.


    exit_code = assemble([
        # Note: The following appears to work instead of calling print().
        # ('DUP_TOP', 0),
        # ('PRINT_EXPR', 0),
        # but this feels like a hack that's not *needed* to implement the hack
        # I want, and seeing as I'm not a Python expert it could have unforseen
        # consequences that would complicate checking the tracers.
        ('DUP_TOP', 0),
        ('LOAD_GLOBAL', print_index),
        ('ROT_TWO', 0),
        ('CALL_FUNCTION', 1),
        ('POP_TOP', 0),
        ('RETURN_VALUE', 0),
    ])

    full_instructions = (entry_code
                         + orig_code.co_code[:-2]
                         + exit_code)
     

We use the new bytecode to create a types.CodeType object and put that in the fn.__code__ attribute.


    new_code = types.CodeType(
        fn.__code__.co_argcount,
        fn.__code__.co_kwonlyargcount,
        fn.__code__.co_nlocals,
        # We push num_args onto the stack in the initialisation code, and we
        # push an extra return value and print statement onto the stack in the
        # exit code.
        # Hence the stack can be up to max(2, 1 + num_args) greater than the
        # previous greatest depth.
        # n.b. From those functions I've disassembled it appears that each
        # python function ensures no extra stack variables are left on the
        # stack when it returns, in which case the greatest depth would be
        # max(3, fn.__code__.co_stacksize, 1 + num_args).
        # I don't know for certain that this happens (though the definition of
        # RETURN_VALUE in dis.rst doesn't say anything about clearing the
        # current stack frame, so it looks like it does), so I'm
        # playing it safe.
        fn.__code__.co_stacksize + num_args + 2,
        fn.__code__.co_flags,
        full_instructions,
        fn.__code__.co_consts,
        new_names,
        fn.__code__.co_varnames,
        fn.__code__.co_filename,
        fn.__code__.co_name,
        fn.__code__.co_firstlineno,
        fn.__code__.co_lnotab,
        fn.__code__.co_freevars,
        fn.__code__.co_cellvars
    )

    if getattr(fn, '__trace_orig_code', None) is None:
        fn.__trace_orig_code = orig_code
    else:
        raise ValueError('Tracing function with __trace_orig_code attribute.\n'
                         'Probably an already traced function.')

    fn.__code__ = new_code
    return
     

If we give this a go in the terminal, we see ...


vshcmd: > ./trace_problems.py
Calling inner function
Calling inner function
Calling inner function
Calling inner function
Calling inner function


1 2
Calling inner function
Segmentation fault (core dumped)
python_trace [14:21:00] $
     

... oh.

It turns out we forgot to account for the if clause in test_function(). This if clause is encoded as a POP_JUMP_IF_FALSE which jumps to an absolute bytecode position. Since we added code at the start of this function, the position of the jump is now incorrect.

In this particular case we get a segfault as new jump target attempts to run a binary operator when only one value is on the stack. This results in the python evaluator attempting to access a non-existant value on the stack, and hence going out-of-bounds.

You can see the disassembly of the function before and after tracing in the output below.


vshcmd: > python
vshcmd: > import dis
vshcmd: > import trace_problems
vshcmd: > import segfault_tracer
vshcmd: > print('Before Tracing')
vshcmd: > dis.dis(trace_problems.test_function)
vshcmd: > segfault_tracer.trace_this(trace_problems.test_function)
vshcmd: > print('After Tracing')
vshcmd: > dis.dis(trace_problems.test_function)
Python 3.6.1 (default, Mar 27 2017, 00:27:06)
[GCC 6.3.1 20170306] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> >>> >>> >>> Before Tracing
>>>   5           0 LOAD_GLOBAL              0 (print)
              2 LOAD_CONST               1 ('Calling inner function')
              4 CALL_FUNCTION            1
              6 POP_TOP

  6           8 LOAD_FAST                0 (a)
             10 LOAD_CONST               2 (10)
             12 COMPARE_OP               2 (==)
             14 POP_JUMP_IF_FALSE       20

  7          16 LOAD_CONST               3 ('Never')
             18 RETURN_VALUE

  8     >>   20 LOAD_FAST                0 (a)
             22 LOAD_FAST                1 (b)
             24 BINARY_MULTIPLY
             26 RETURN_VALUE
>>> >>> After Tracing
>>>   5           0 LOAD_GLOBAL              0 (print)
              2 LOAD_FAST                0 (a)
              4 LOAD_FAST                1 (b)
              6 CALL_FUNCTION            2

  6           8 POP_TOP
             10 LOAD_GLOBAL              0 (print)
             12 LOAD_CONST               1 ('Calling inner function')
             14 CALL_FUNCTION            1

  7          16 POP_TOP
             18 LOAD_FAST                0 (a)

  8     >>   20 LOAD_CONST               2 (10)
             22 COMPARE_OP               2 (==)
             24 POP_JUMP_IF_FALSE       20
             26 LOAD_CONST               3 ('Never')
             28 RETURN_VALUE
             30 LOAD_FAST                0 (a)
             32 LOAD_FAST                1 (b)
             34 BINARY_MULTIPLY
             36 DUP_TOP
             38 LOAD_GLOBAL              0 (print)
             40 ROT_TWO
             42 CALL_FUNCTION            1
             44 POP_TOP
             46 RETURN_VALUE
>>>
     

Avoiding shifts

In order to avoid the affects of shifting the bytecode around we take heed from the way that DTrace goes about tracing a function. As is shown in this tutorial, DTrace replaces the first instruction of the function to trace with an interrupt into a debugging process, which means that none of the following code need be modified.

Accordingly, we replace the first instruction with a jump to some code to be executed on entry and replace all RETURN_VALUE instructions to a jump to code that must be executed on exit.

There's a little complication ensuring the first instruction gets handled correctly, but other than that we can do with reasonably minor modifications on the previous attempt.


    # Print all arguments.
    entry_code = assemble(init_instructions)
    # Do whatever the first instruction of the original function was.
    # Have to account for the EXTENDED_ARG opcode.
    first_instr = first_full_instruction(orig_code.co_code)
    entry_code += first_instr
    # Jump to the start of the second instruction
    # Doesn't matter if there's a jump as the first instruction, we just would
    # not execute this one.
    # XXX Unless it's a JUMP_FORWARD instruction.
    entry_code += assemble([('JUMP_ABSOLUTE', len(first_instr))])

    # Place the entry and exit code after the functions bytecode.
    # That way we don't change the offset into the function of any existing
    # bytecode.
    exit_address = len(orig_code.co_code)
    initmes_address = exit_address + len(exit_code)

    # Replace first two bytes with a jump to initmes_address, and all
    # RETURN_VALUE instructions with a jump to exit_address.
    alternate_code = modified_code(orig_code.co_code,
                                   return_address=exit_address,
                                   entry_address=initmes_address)
    full_instructions = alternate_code + exit_code + entry_code
      

vshcmd: > pyrasite $(pgrep -a python | grep trace_problems | cut -d' ' -f1) modification_payload.py
(tracing) python_trace [15:17:35] $
      

vshcmd: > ./trace_problems.py
Calling inner function
Calling inner function
Calling inner function
Calling inner function
Calling inner function


1 2
Calling inner function
2
2 2
Calling inner function
4
3 2
Calling inner function
6
5 2
Calling inner function
10
10 20
Calling inner function
Never


Traceback (most recent call last):
  File "./trace_problems.py", line 39, in <module>
    time.sleep(3)
KeyboardInterrupt
python_trace [15:17:39] $
      

This works pretty well with our test program, but is still missing quite a few edge cases.

For example, if the first instruction has a relative jump then we're in trouble, and there are quite a few instructions with a relative jump.


vshcmd: > import opcode
vshcmd: > import pprint
vshcmd: > pprint.pprint([opcode.opname[x] for x in opcode.hasjrel])
>>> >>> ['FOR_ITER',
 'JUMP_FORWARD',
 'SETUP_LOOP',
 'SETUP_EXCEPT',
 'SETUP_FINALLY',
 'SETUP_WITH',
 'SETUP_ASYNC_WITH']
>>>
      

Moreover, if the initial function is large enough, we're going to need a jump to an absolute position greater than 255, which will need an EXTENDED_ARG instruction and we'll have to shift the remaining jumps anyway.

I guess we try a different approach.

Special wrapper function

Instead of trying to write python bytecode ourselves, another tack is to get Python to do this for us, and modify the function to execute this alternate code once we're done.

To do this we go back to our first attempt, where we created a closure that simply called the original function after printing our arguments.

If we make this closure, but instead of storing it in the global namespace we modify the original function object to point to it we shouldn't have to worry about updating python bytecodes, which would be nice.

There are two complications to worry about:

  1. The closure from the wrapper function contains a reference to the function we're modifying.
  2. The __closure__ and __globals__ attributes of a function are read-only.

The first problem means that when we update the function in place, our closure will call that modified version, which will result in an infinite loop. This can be avoided by creating a new function with all the same attributes of the old one.


copy_fn = types.FunctionType(fn.__code__, fn.__globals__, fn.__name__,
                             fn.__defaults__, fn.__closure__)
      

The second problem is more awkward. We need the wrapper function to have access to our copy_fn so it can emulate the traced code. After definition, the wrapper has its reference through its __closure__ attribute, but we can't simply put that __closure__ attribute onto test_function as it's a read-only attribute.

To get around these problems we cheat. Instead of storing the closed over variable in the function __closure__ attribute, we put it in the code objects co_consts attribute. We then have to change the LOAD_DEREF instruction that loads from the closure environment into a LOAD_CONST instruction.

This doesn't have the same problems we came across when modifying bytecode before, as we're always acting on the same wrapper function, which means we know the bytecode we're operating on. It happens that the number of constants isn't greater than 255, so we don't have to use an EXTENDED_ARG instruction.


def trace_this(fn):
    '''Modify `fn` to print arguments and return value on execution.'''
    # Implementation is done by modifying the function *in-place*.
    # This is necessary to ensure that all local references to the function
    # (e.g. in closures, different namespaces, or collections) are traced too.

    # TODO
    #   Known problems:
    #       sys._getframe(1) now returns the frame from inside the closure.

    # Create a wrapper function to print arguments and return value.
    wrap_function = wrapper_func(fn)
    orig_wrap_code = wrap_function.__code__

    # Happen to know the current bytecode of my wrapper function.
    # Probably should parse it programmatically, but for demonstration purposes
    # this is fine.
    # Change the bytecode instruction that loads a freevar into an instruction
    # to load a constant.
    # Happen to know I'll want the 5'th constant in co_consts (because I know
    # how many constants the wrapper function uses.
    #
    # Bytecode is version specific.
    # Am *very* apprehensive about this, I can easily imagine the LOAD_CONST
    # bytecode implementation assuming it's loading something that's constant.
    # It appears to work ...
    alt = bytes([opcode.opmap['LOAD_CONST'], 4])
    alternate_code = orig_wrap_code.co_code[:20] + alt + orig_wrap_code.co_code[22:]

    # Create a copy of the original function.
    # This has the same functionality as the original one, but is a different
    # object.
    # We need a different object so that the modification done below doesn't
    # change the behaviour of our wrapper function.
    copy_fn = types.FunctionType(fn.__code__, fn.__globals__, fn.__name__,
                                 fn.__defaults__, fn.__closure__)

    new_codeobj = types.CodeType(
        orig_wrap_code.co_argcount,
        orig_wrap_code.co_kwonlyargcount,
        orig_wrap_code.co_nlocals,
        orig_wrap_code.co_stacksize,
        orig_wrap_code.co_flags,
        alternate_code,
        orig_wrap_code.co_consts + (copy_fn,),
        orig_wrap_code.co_names,
        orig_wrap_code.co_varnames,
        orig_wrap_code.co_filename,
        orig_wrap_code.co_name,
        orig_wrap_code.co_firstlineno,
        orig_wrap_code.co_lnotab,
        # Take freevars from the original function.
        # This is so the code object matches the __closure__ object from the
        # original function.
        # If they don't match, python raises an exception upon assignment at
        # the end of this function.
        # We can't change the __closure__ member, as this is a read-only
        # attribute enforced in the C core.
        # This shouldn't matter either way, the closed over variables aren't
        # used in the wrapper code.
        fn.__code__.co_freevars,
        orig_wrap_code.co_cellvars
    )

    if getattr(fn, '__trace_orig_code', None) is None:
        fn.__trace_orig_code = fn.__code__
    else:
        raise ValueError('Tracing function with __trace_orig_code attribute.\n'
                         'Probably an already traced function.')

    fn.__code__ = new_codeobj
    return

      

This works just fine.


vshcmd: > pyrasite $(pgrep -a python | grep trace_problems | cut -d' ' -f1) wrapper_payload.py
(tracing) python_trace [16:22:52] $
      

vshcmd: > ./trace_problems.py
Calling inner function
Calling inner function
Calling inner function
Calling inner function
Calling inner function


Calling with args (1, 2)
Calling with kwargs {}
Calling inner function
Returned: 2
Calling with args (2, 2)
Calling with kwargs {}
Calling inner function
Returned: 4
Calling with args (3, 2)
Calling with kwargs {}
Calling inner function
Returned: 6
Calling with args (5, 2)
Calling with kwargs {}
Calling inner function
Returned: 10
Calling with args (10, 20)
Calling with kwargs {}
Calling inner function
Returned: Never


Traceback (most recent call last):
  File "./trace_problems.py", line 39, in <module>
    time.sleep(3)
KeyboardInterrupt
python_trace [16:22:54] $
      

And it can be untraced with a simple function.


def untrace_this(fn):
    '''Remove tracing on `fn`.'''
    if not hasattr(fn, '__trace_orig_code'):
        return
    fn.__code__ = fn.__trace_orig_code
    delattr(fn, '__trace_orig_code')
      

Conclusion

The above is a very hacky approach, but I think it's a pretty cool method of adding tracing without adding overhead to every function invokation. For an approach that uses the more conventional sys.settrace() function, we would probably want to do something like lptrace does, by attaching to the process and installing a trace handler.

I haven't looked into what problems we might have with multiple threads or processes. I expect multiple processes would have to be updated separately, and the fact that pyrasite holds the GIL around running our code means we're safe there, but I wouldn't bet much on that.

Future work

One thing that came to mind while working on this post is the possibility of completely redefining a function, not just adding tracing to it. This brings to mind one of the (many) cool and (probably not 100%) unique things about common lisp: the ability to completely redefine a function in a running process.

If it weren't for the problems with the readonly __globals__ and __closure__ attributes redefinition would be simple.

Modifying the CPython interpreter to allow redefinition of these attributes, and to allow creation of cell objects from python code allows this functionality.

The patch below appears to work, but being a complete novice to the CPython implementation and only spending a few minutes on it, I'd be surprised if there weren't any bugs.


diff --git a/Objects/cellobject.c b/Objects/cellobject.c
index 6af93b0030..60b15409ce 100644
--- a/Objects/cellobject.c
+++ b/Objects/cellobject.c
@@ -2,6 +2,13 @@

 #include "Python.h"

+/*[clinic input]
+class cell "PyCellObject *" "&PyCell_Type"
+[clinic start generated code]*/
+/*[clinic end generated code: output=da39a3ee5e6b4b0d input=3d6baba8cf810af0]*/
+
+#include "clinic/cellobject.c.h"
+
 PyObject *
 PyCell_New(PyObject *obj)
 {
@@ -154,6 +161,27 @@ static PyGetSetDef cell_getsetlist[] = {
     {NULL} /* sentinel */
 };

+/*[clinic input]
+@classmethod
+cell.__new__ as cell_new
+    contents: object
+        an object to be stored in the cell
+
+Create a cell object.
+[clinic start generated code]*/
+
+static PyObject *
+cell_new_impl(PyTypeObject *type, PyObject *contents)
+/*[clinic end generated code: output=0b7db264f0621bb4 input=1b029ce829d4061b]*/
+{
+    PyCellObject *newcell;
+    newcell = (PyCellObject *)PyCell_New(contents);
+    if (newcell == NULL)
+        return NULL;
+    // Already tracking this from the PyCell_New() function
+    return (PyObject *)newcell;
+}
+
 PyTypeObject PyCell_Type = {
     PyVarObject_HEAD_INIT(&PyType_Type, 0)
     "cell",
@@ -185,4 +213,12 @@ PyTypeObject PyCell_Type = {
     0,                                          /* tp_methods */
     0,                                          /* tp_members */
     cell_getsetlist,                            /* tp_getset */
+    0,                                          /* tp_base */
+    0,                                          /* tp_dict */
+    0,                                          /* tp_descr_get */
+    0,                                          /* tp_descr_set */
+    0,                                          /* tp_dictoffset */
+    0,                                          /* tp_init */
+    0,                                          /* tp_alloc */
+    cell_new,                                   /* tp_new */
 };
diff --git a/Objects/funcobject.c b/Objects/funcobject.c
index e440258d7d..b2e4fb26d7 100644
--- a/Objects/funcobject.c
+++ b/Objects/funcobject.c
@@ -230,11 +230,9 @@ PyFunction_SetAnnotations(PyObject *op, PyObject *annotations)
 #define OFF(x) offsetof(PyFunctionObject, x)

 static PyMemberDef func_memberlist[] = {
-    {"__closure__",   T_OBJECT,     OFF(func_closure),
-     RESTRICTED|READONLY},
+    {"__closure__",   T_OBJECT,     OFF(func_closure), RESTRICTED},
     {"__doc__",       T_OBJECT,     OFF(func_doc), PY_WRITE_RESTRICTED},
-    {"__globals__",   T_OBJECT,     OFF(func_globals),
-     RESTRICTED|READONLY},
+    {"__globals__",   T_OBJECT,     OFF(func_globals), RESTRICTED},
     {"__module__",    T_OBJECT,     OFF(func_module), PY_WRITE_RESTRICTED},
     {NULL}  /* Sentinel */
 };
      

With that patch, function redefinition is as simple as defining your new function, then copying its attributes onto the original function object. Obviously, some sort of lock would be needed to avoid the function being called in an inconsistent state, but nevertheless it's a step in the right direction.