python中yield的实现

Python中的iter可是个好东西,时常都会要用到,不过《Python源码剖析》中只提到了for中的iter,对另一种形式,也就是包含yield語句的函数未做研究。这篇文章就讲讲yield是如何实现的。

其实最后的实现并不困难,但实现yield的部分从源码中找出来可费了我好大功夫。先来看看要研究的范例代码:

#!/usr/bin/python
#encoding:utf8

def iter_func():
  4           0 LOAD_CONST               0 (<code object iter_func at 0xb789a410, file "mod1.py", line 4>)
              3 MAKE_FUNCTION            0
              6 STORE_NAME               0 (iter_func)

    i = 5
  5           0 LOAD_CONST               1 (5)
              3 STORE_FAST               0 (i)
    yield i
  6           6 LOAD_FAST                0 (i)
              9 YIELD_VALUE
             10 POP_TOP
             11 LOAD_CONST               0 (None)
             14 RETURN_VALUE        

for i in iter_func():
  9           9 SETUP_LOOP              22 (to 34)
             12 LOAD_NAME                0 (iter_func)
             15 CALL_FUNCTION            0
             18 GET_ITER
        >>   19 FOR_ITER                11 (to 33)
             22 STORE_NAME               1 (i)

    print i
 10          25 LOAD_NAME                1 (i)
             28 PRINT_ITEM
             29 PRINT_NEWLINE
             30 JUMP_ABSOLUTE           19
        >>   33 POP_BLOCK
        >>   34 LOAD_CONST               1 (None)
             37 RETURN_VALUE

第一眼看到这个编译出来的字节码觉得没有什么滑头,只有一句是不了解的:yield_value,于是认为实现iter估计就是在这句里面了,于是看看yield_value中python虚拟机都做了什么:

...
		case YIELD_VALUE:
			retval = POP();
			f->f_stacktop = stack_pointer;
			why = WHY_YIELD;
			goto fast_yield;

...

fast_yield:
	if (tstate->use_tracing) {
		if (tstate->c_tracefunc) {
			if (why == WHY_RETURN || why == WHY_YIELD) {
				if (call_trace(tstate->c_tracefunc,
					       tstate->c_traceobj, f,
					       PyTrace_RETURN, retval)) {
					Py_XDECREF(retval);
					retval = NULL;
					why = WHY_EXCEPTION;
				}
			}
			else if (why == WHY_EXCEPTION) {
				call_trace_protected(tstate->c_tracefunc,
						     tstate->c_traceobj, f,
						     PyTrace_RETURN, NULL);
			}
		}
		if (tstate->c_profilefunc) {
			if (why == WHY_EXCEPTION)
				call_trace_protected(tstate->c_profilefunc,
						     tstate->c_profileobj, f,
						     PyTrace_RETURN, NULL);
			else if (call_trace(tstate->c_profilefunc,
					    tstate->c_profileobj, f,
					    PyTrace_RETURN, retval)) {
				Py_XDECREF(retval);
				retval = NULL;
				why = WHY_EXCEPTION;
			}
		}
	}

	if (tstate->frame->f_exc_type != NULL)
		reset_exc_info(tstate);
	else {
		assert(tstate->frame->f_exc_value == NULL);
		assert(tstate->frame->f_exc_traceback == NULL);
	}

	/* pop frame */
exit_eval_frame:
	Py_LeaveRecursiveCall();
	tstate->frame = f->f_back;

	return retval;
}

这代码和我想象中的可不太相同,yield_value好像就pop()了一个值(范例中是5),将它设为返回值,再设置why为why_yield,就路到了fast_yield段,而在fast_yield段,由于python初始化线程环境和帧时,tstate->use_tracing和frame->f_exc_type分别为0和NULL,也就是说基本上yield_value就像return一样,直接返回了一个值就跳出了那个大大的switch和外层的for(;;)循环。这就奇怪了,如果返回是5的话,那在call_function的后一句get_iter肯定会出错啊,不仅是返回5不行,返回一个函数也不对,因为函数和tp_iter也为空。

python总要知道iter_func是一个iter这程序才能正常进行下去啊,到底是在哪里标识的呢?make_function 0,call_function 0这两句创建和调用代码和平常的函数一模一样。我又把范例程序里的yield i改成了return i,也没有看出什么来。又一次地仔细检查call_function才发现问题(其实如果早点写个pycparser再查看编译出来的code对象应该就发现了)。

很容易看出来,范例中call_function会走入fast_function通道,而在fast_function中,co->co_flags并不满足条件,因而会调用到PyEval_EvalCodeEx(),如下:

static PyObject *
fast_function(PyObject *func, PyObject ***pp_stack, int n, int na, int nk)
{
	if (argdefs == NULL && co->co_argcount == n && nk==0 &&
	    co->co_flags == (CO_OPTIMIZED | CO_NEWLOCALS | CO_NOFREE)) {
	        /* do something */
	}

	return PyEval_EvalCodeEx(co, globals,
				 (PyObject *)NULL, (*pp_stack)-n, na,
				 (*pp_stack)-2*nk, nk, d, nd,
				 PyFunction_GET_CLOSURE(func));
}

奥秘就隐藏在这个PyEval_EvalCodeEx中。


PyObject *
PyEval_EvalCodeEx(PyCodeObject *co, PyObject *globals, PyObject *locals,
	   PyObject **args, int argcount, PyObject **kws, int kwcount,
	   PyObject **defs, int defcount, PyObject *closure)
{
	register PyFrameObject *f;
	register PyObject *retval = NULL;
	register PyObject **fastlocals, **freevars;
	PyThreadState *tstate = PyThreadState_GET();
	PyObject *x, *u;

	if (co->co_flags & CO_GENERATOR) {
		/* Don't need to keep the reference to f_back, it will be set
		 * when the generator is resumed. */
		Py_XDECREF(f->f_back);
		f->f_back = NULL;

		PCALL(PCALL_GENERATOR);

		/* Create a new generator that owns the ready to run frame
		 * and return that as the value. */
		return PyGen_New(f);
	}
}

这下终于发现了,原来python通过 code对象的co_flags来知道这是一个iter,然后根据已经创建的frame来得到一个iter对象。

PyObject *
PyGen_New(PyFrameObject *f)
{
	PyGenObject *gen = PyObject_GC_New(PyGenObject, &PyGen_Type);
	if (gen == NULL) {
		Py_DECREF(f);
		return NULL;
	}
	gen->gi_frame = f;
	Py_INCREF(f->f_code);
	gen->gi_code = (PyObject *)(f->f_code);
	gen->gi_running = 0;
	gen->gi_weakreflist = NULL;
	_PyObject_GC_TRACK(gen);
	return (PyObject *)gen;
}

总结一下,call_function 0这一句就得到一个iter对象,然后压栈,这样在后面的get_iter中就对这个iter对象进行处理。iter对象就是对 frame的一个简单的包装,回想一下get_iter,它要调用对象的type对象中的的tp_iter操作。而PyGen_Type中的tp_iter被设置为PyObject_SelfIter,简单地返回自身。而接下来的for_iter则会调用对象的type中的ip_iternext。

PyTypeObject PyGen_Type = {
	PyVarObject_HEAD_INIT(&PyType_Type, 0)
	"generator",				/* tp_name */
	sizeof(PyGenObject),			/* tp_basicsize */
	/* ... */
	PyObject_SelfIter,			/* tp_iter */
	(iternextfunc)gen_iternext,		/* tp_iternext */
	/* ... */
};

static PyObject *
gen_iternext(PyGenObject *gen)
{
	return gen_send_ex(gen, NULL, 0);
}

static PyObject *
gen_send_ex(PyGenObject *gen, PyObject *arg, int exc)
{
	PyThreadState *tstate = PyThreadState_GET();
	PyFrameObject *f = gen->gi_frame;
	PyObject *result;

	if (gen->gi_running) {
		PyErr_SetString(PyExc_ValueError,
				"generator already executing");
		return NULL;
	}

	/* Generators always return to their most recent caller, not
	 * necessarily their creator. */
	Py_XINCREF(tstate->frame);
	assert(f->f_back == NULL);
	f->f_back = tstate->frame;

	gen->gi_running = 1;
	result = PyEval_EvalFrameEx(f, exc);
	gen->gi_running = 0;

	/* If the generator just returned (as opposed to yielding), signal
	 * that the generator is exhausted. */
	if (result == Py_None && f->f_stacktop == NULL) {
		Py_DECREF(result);
		result = NULL;
		/* Set exception if not called by gen_iternext() */
		if (arg)
			PyErr_SetNone(PyExc_StopIteration);
	}

	return result;
}

调用iter的next函数时可能有很多情况,这里简化了代码只显示最基本的无参调用。从逻辑上考虑一下,在例子中调用iter_func().next()时,func肯定要从上次yield的地方开始执行起。回头再看看yield_value对应的代码我们可以发现一句

f->f_stacktop = stack_pointer;

这句话将iter对象所对应的frame的f_stacktop置为当前的栈顶。这个f_stacktop是干什么的呢?再在ceval.c里面往上找找就可以发现这样一句

f->f_stacktop = NULL;   /* remains NULL unless yield suspends frame */

因而猜测f_stacktop保存着一个指针,当frame执行的时候,如果它不为空,那将栈顶指针初始化为f_stacktop,而什么情况下需要frame需要在一个不为空的“环境”(就是不按正常的方式在一个空栈上运行)中开始执行呢,答案只有yield。再找找代码发现果然是这样

stack_pointer = f->f_stacktop;

到这里iter对象的next()函数基本就清楚了,而iter的开始、停止等其实也就是调用gen_send_ex。yield的实现基本也就清楚了。

发表评论?

1 条评论。

  1. :!:
    高手呀

发表评论


注意 - 你可以用以下 HTML tags and attributes:
<a href="" title=""> <abbr title=""> <acronym title=""> <b> <blockquote cite=""> <cite> <code> <del datetime=""> <em> <i> <q cite=""> <strike> <strong>