diff --git a/Lib/test/test_traceback.py b/Lib/test/test_traceback.py index bffc03e663..e84f9a0394 100644 --- a/Lib/test/test_traceback.py +++ b/Lib/test/test_traceback.py @@ -18,6 +18,12 @@ test_code = namedtuple('code', ['co_filename', 'co_name']) test_frame = namedtuple('frame', ['f_code', 'f_globals', 'f_locals']) test_tb = namedtuple('tb', ['tb_frame', 'tb_lineno', 'tb_next']) +def count_frames(tb): + count = 1 + while tb.tb_next is not None: + count += 1 + tb = tb.tb_next + return count class TracebackCases(unittest.TestCase): # For now, a very minimal set of tests. I want to be sure that @@ -1092,6 +1098,65 @@ class TestTracebackException(unittest.TestCase): exc = traceback.TracebackException(Exception, Exception("haven"), None) self.assertEqual(list(exc.format()), ["Exception: haven\n"]) + def test_traceback_filter_nop(self): + def recurse(n): + current_n = n + if n: + recurse(n-1) + else: + 1/0 + try: + recurse(10) + except Exception: + ex_type, ex, tb = sys.exc_info() + + def filter_none(tb): + return False + + n_original = count_frames(tb) + tb.filter(filter_none) + self.assertEqual(count_frames(tb), n_original) + + def test_traceback_filter_all(self): + def recurse(n): + current_n = n + if n: + recurse(n-1) + else: + 1/0 + try: + recurse(10) + except Exception: + ex_type, ex, tb = sys.exc_info() + + def filter_all(tb): + return True + + n_original = count_frames(tb) + tb.filter(filter_all) + self.assertEqual(count_frames(tb), n_original - 11) + + + + def test_traceback_filter_some(self): + def recurse(n): + current_n = n + if n: + recurse(n-1) + else: + 1/0 + try: + recurse(10) + except Exception: + ex_type, ex, tb = sys.exc_info() + + def filter_odd(tb): + return True if tb.tb_frame.f_locals["current_n"] % 2 else False + + n_original = count_frames(tb) + tb.filter(filter_odd) + self.assertEqual(count_frames(tb), n_original - 5) + class MiscTest(unittest.TestCase): diff --git a/Python/traceback.c b/Python/traceback.c index 21b36b1471..3498cbc9a6 100644 --- a/Python/traceback.c +++ b/Python/traceback.c @@ -34,8 +34,43 @@ tb_dir(PyTracebackObject *self) "tb_lasti", "tb_lineno"); } +static PyObject * +tb_filter(PyTracebackObject *self, PyObject* callback) +{ + if (!PyCallable_Check(callback)) { + PyErr_SetString(PyExc_TypeError, "parameter must be callable"); + return NULL; + } + + + PyTracebackObject* current = self; + PyTracebackObject* last = NULL; + while(current->tb_next != NULL){ + Py_XDECREF(last); + last = current; + Py_INCREF(last); + current = current->tb_next; + Py_INCREF(current); + PyObject* args = PyTuple_Pack(1,current); + PyObject* result = PyEval_CallObject(callback,args); + + Py_INCREF(result); + Py_INCREF(Py_True); + if(Py_True == result){ + last->tb_next = current->tb_next; + Py_DECREF(current); + current = last; + } + Py_DECREF(result); + Py_DECREF(Py_True); + + } + Py_RETURN_NONE; +} + static PyMethodDef tb_methods[] = { {"__dir__", (PyCFunction)tb_dir, METH_NOARGS}, + {"filter", (PyCFunction)tb_filter, METH_O}, {NULL, NULL, 0, NULL}, };