diff --git a/mpy-cross/mpconfigport.h b/mpy-cross/mpconfigport.h index 0b07a5b442..464c9113d5 100644 --- a/mpy-cross/mpconfigport.h +++ b/mpy-cross/mpconfigport.h @@ -40,6 +40,7 @@ #define MICROPY_FLOAT_IMPL (MICROPY_FLOAT_IMPL_DOUBLE) #define MICROPY_CPYTHON_COMPAT (1) +#define MICROPY_PY_ASYNC_AWAIT (1) #define MICROPY_USE_INTERNAL_PRINTF (0) #define MICROPY_PY_BUILTINS_STR_UNICODE (1) diff --git a/py/compile.c b/py/compile.c index 4708110056..9b0d29998a 100644 --- a/py/compile.c +++ b/py/compile.c @@ -1890,7 +1890,7 @@ STATIC void compile_async_stmt(compiler_t *comp, mp_parse_node_struct_t *pns) { // async def compile_funcdef(comp, pns0); scope_t *fscope = (scope_t*)pns0->nodes[4]; - fscope->scope_flags |= MP_SCOPE_FLAG_GENERATOR; + fscope->scope_flags |= MP_SCOPE_FLAG_GENERATOR | MP_SCOPE_FLAG_ASYNC; } else if (MP_PARSE_NODE_STRUCT_KIND(pns0) == PN_for_stmt) { // async for compile_async_for_stmt(comp, pns0); diff --git a/py/emitglue.c b/py/emitglue.c index 3a3174b0f8..7635a73d6a 100644 --- a/py/emitglue.c +++ b/py/emitglue.c @@ -152,7 +152,7 @@ mp_obj_t mp_make_function_from_raw_code(const mp_raw_code_t *rc, mp_obj_t def_ar // check for generator functions and if so wrap in generator object if ((rc->scope_flags & MP_SCOPE_FLAG_GENERATOR) != 0) { - fun = mp_obj_new_gen_wrap(fun); + fun = mp_obj_new_gen_wrap(fun, (rc->scope_flags & MP_SCOPE_FLAG_ASYNC) != 0); } return fun; diff --git a/py/obj.h b/py/obj.h index 8536e33335..e603d4a496 100644 --- a/py/obj.h +++ b/py/obj.h @@ -665,7 +665,7 @@ mp_obj_t mp_obj_new_fun_bc(mp_obj_t def_args, mp_obj_t def_kw_args, const byte * mp_obj_t mp_obj_new_fun_native(mp_obj_t def_args_in, mp_obj_t def_kw_args, const void *fun_data, const mp_uint_t *const_table); mp_obj_t mp_obj_new_fun_viper(size_t n_args, void *fun_data, mp_uint_t type_sig); mp_obj_t mp_obj_new_fun_asm(size_t n_args, void *fun_data, mp_uint_t type_sig); -mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun); +mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun, bool is_coroutine); mp_obj_t mp_obj_new_closure(mp_obj_t fun, size_t n_closed, const mp_obj_t *closed); mp_obj_t mp_obj_new_tuple(size_t n, const mp_obj_t *items); mp_obj_t mp_obj_new_list(size_t n, mp_obj_t *items); diff --git a/py/objgenerator.c b/py/objgenerator.c index 6ffcfae46a..57d20e3db8 100644 --- a/py/objgenerator.c +++ b/py/objgenerator.c @@ -42,11 +42,13 @@ typedef struct _mp_obj_gen_wrap_t { mp_obj_base_t base; mp_obj_t *fun; + bool coroutine_generator; } mp_obj_gen_wrap_t; typedef struct _mp_obj_gen_instance_t { mp_obj_base_t base; mp_obj_dict_t *globals; + bool coroutine_generator; mp_code_state_t code_state; } mp_obj_gen_instance_t; @@ -64,6 +66,7 @@ STATIC mp_obj_t gen_wrap_call(mp_obj_t self_in, size_t n_args, size_t n_kw, cons n_state * sizeof(mp_obj_t) + n_exc_stack * sizeof(mp_exc_stack_t)); o->base.type = &mp_type_gen_instance; + o->coroutine_generator = self->coroutine_generator; o->globals = self_fun->globals; o->code_state.fun_bc = self_fun; o->code_state.ip = 0; @@ -78,10 +81,11 @@ const mp_obj_type_t mp_type_gen_wrap = { .unary_op = mp_generic_unary_op, }; -mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun) { +mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun, bool is_coroutine) { mp_obj_gen_wrap_t *o = m_new_obj(mp_obj_gen_wrap_t); o->base.type = &mp_type_gen_wrap; o->fun = MP_OBJ_TO_PTR(fun); + o->coroutine_generator = is_coroutine; return MP_OBJ_FROM_PTR(o); } @@ -91,7 +95,11 @@ mp_obj_t mp_obj_new_gen_wrap(mp_obj_t fun) { STATIC void gen_instance_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) { (void)kind; mp_obj_gen_instance_t *self = MP_OBJ_TO_PTR(self_in); - mp_printf(print, "", mp_obj_fun_get_name(MP_OBJ_FROM_PTR(self->code_state.fun_bc)), self); + if (self->coroutine_generator) { + mp_printf(print, "", mp_obj_fun_get_name(MP_OBJ_FROM_PTR(self->code_state.fun_bc)), self); + } else { + mp_printf(print, "", mp_obj_fun_get_name(MP_OBJ_FROM_PTR(self->code_state.fun_bc)), self); + } } mp_vm_return_kind_t mp_obj_gen_resume(mp_obj_t self_in, mp_obj_t send_value, mp_obj_t throw_value, mp_obj_t *ret_val) { @@ -194,6 +202,10 @@ STATIC mp_obj_t gen_resume_and_raise(mp_obj_t self_in, mp_obj_t send_value, mp_o } STATIC mp_obj_t gen_instance_iternext(mp_obj_t self_in) { + mp_obj_gen_instance_t *self = MP_OBJ_TO_PTR(self_in); + if (self->coroutine_generator) { + mp_raise_TypeError(translate("'coroutine' object is not an iterator")); + } return gen_resume_and_raise(self_in, mp_const_none, MP_OBJ_NULL); } diff --git a/py/runtime0.h b/py/runtime0.h index a8089ea646..fb35c8a9f4 100644 --- a/py/runtime0.h +++ b/py/runtime0.h @@ -33,6 +33,7 @@ #define MP_SCOPE_FLAG_VARKEYWORDS (0x02) #define MP_SCOPE_FLAG_GENERATOR (0x04) #define MP_SCOPE_FLAG_DEFKWARGS (0x08) +#define MP_SCOPE_FLAG_ASYNC (0x10) // types for native (viper) function signature #define MP_NATIVE_TYPE_OBJ (0x00) diff --git a/tests/basics/async_coroutine.py b/tests/basics/async_coroutine.py new file mode 100644 index 0000000000..791f6df14c --- /dev/null +++ b/tests/basics/async_coroutine.py @@ -0,0 +1,13 @@ +async def f(): + pass + +try: + f() # Should not crash +except Exception as e: + print('failed to invoke') + +try: + next(f()) + print('This should fail because async def returns a coroutine, and next() is not allowed') +except Exception as e: + print('pass') diff --git a/tests/basics/async_coroutine.py.exp b/tests/basics/async_coroutine.py.exp new file mode 100644 index 0000000000..2ae28399f5 --- /dev/null +++ b/tests/basics/async_coroutine.py.exp @@ -0,0 +1 @@ +pass