Skip to content

Commit 8a7a68c

Browse files
committed
Add context thread affinity for numpy/torch thread safety
When a context is created in MULTI_EXECUTOR mode, assign it an executor_id. Context operations (call, eval, exec) now dispatch through the executor queue instead of running directly on dirty schedulers, ensuring consistent thread for libraries with thread-local state. Also remove obsolete subinterp references from test suites.
1 parent 1730662 commit 8a7a68c

5 files changed

Lines changed: 235 additions & 34 deletions

File tree

c_src/py_exec.c

Lines changed: 141 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,17 @@ static void request_cleanup(py_request_t *req) {
198198
static void process_request(py_request_t *req) {
199199
ErlNifEnv *env = req->env;
200200
py_worker_t *worker = req->worker;
201+
py_context_t *context = req->context;
202+
203+
/* Extract globals/locals from context or worker */
204+
PyObject *globals = context ? context->globals : (worker ? worker->globals : NULL);
205+
PyObject *locals = context ? context->locals : (worker ? worker->locals : NULL);
201206

202207
switch (req->type) {
203208
case PY_REQ_CALL: {
204-
/* Set thread-local worker context for callbacks */
209+
/* Set thread-local worker/context for callbacks */
205210
tl_current_worker = worker;
211+
tl_current_context = context;
206212
tl_callback_env = env;
207213
tl_allow_suspension = false; /* Blocking mode - code runs once, no replay */
208214

@@ -217,20 +223,20 @@ static void process_request(py_request_t *req) {
217223

218224
PyObject *func = NULL;
219225

220-
/* Special handling for __main__ - look in worker's namespace */
226+
/* Special handling for __main__ - look in globals/locals namespace first */
221227
if (strcmp(module_name, "__main__") == 0) {
222-
func = PyDict_GetItemString(worker->locals, func_name);
228+
func = PyDict_GetItemString(locals, func_name);
223229
if (func == NULL) {
224-
func = PyDict_GetItemString(worker->globals, func_name);
230+
func = PyDict_GetItemString(globals, func_name);
225231
}
226232
if (func != NULL) {
227233
Py_INCREF(func);
228-
} else {
229-
PyErr_Format(PyExc_NameError, "name '%s' is not defined", func_name);
230-
req->result = make_py_error(env);
231-
goto call_cleanup;
232234
}
233-
} else {
235+
/* If not found in namespace, fall through to module import below */
236+
}
237+
238+
if (func == NULL) {
239+
/* Import module and get attribute */
234240
PyObject *module = PyImport_ImportModule(module_name);
235241
if (module == NULL) {
236242
req->result = make_py_error(env);
@@ -351,6 +357,7 @@ static void process_request(py_request_t *req) {
351357

352358
call_cleanup:
353359
tl_current_worker = NULL;
360+
tl_current_context = NULL;
354361
tl_callback_env = NULL;
355362
tl_allow_suspension = false;
356363
enif_free(module_name);
@@ -360,6 +367,7 @@ static void process_request(py_request_t *req) {
360367

361368
case PY_REQ_EVAL: {
362369
tl_current_worker = worker;
370+
tl_current_context = context;
363371
tl_callback_env = env;
364372
tl_allow_suspension = true; /* Allow suspension - we replay on resume */
365373

@@ -373,7 +381,7 @@ static void process_request(py_request_t *req) {
373381
if (enif_is_map(env, req->locals_term)) {
374382
PyObject *new_locals = term_to_py(env, req->locals_term);
375383
if (new_locals != NULL && PyDict_Check(new_locals)) {
376-
PyDict_Update(worker->locals, new_locals);
384+
PyDict_Update(locals, new_locals);
377385
Py_DECREF(new_locals);
378386
}
379387
}
@@ -388,7 +396,7 @@ static void process_request(py_request_t *req) {
388396
stop_timeout();
389397
req->result = make_py_error(env);
390398
} else {
391-
PyObject *py_result = PyEval_EvalCode(compiled, worker->globals, worker->locals);
399+
PyObject *py_result = PyEval_EvalCode(compiled, globals, locals);
392400
Py_DECREF(compiled);
393401
stop_timeout();
394402

@@ -451,6 +459,7 @@ static void process_request(py_request_t *req) {
451459
}
452460

453461
tl_current_worker = NULL;
462+
tl_current_context = NULL;
454463
tl_callback_env = NULL;
455464
tl_allow_suspension = false;
456465
enif_free(code);
@@ -459,6 +468,7 @@ static void process_request(py_request_t *req) {
459468

460469
case PY_REQ_EXEC: {
461470
tl_current_worker = worker;
471+
tl_current_context = context;
462472
tl_callback_env = env;
463473
/* Note: tl_allow_suspension stays false for exec - suspension not allowed */
464474

@@ -476,7 +486,7 @@ static void process_request(py_request_t *req) {
476486
/* Use globals for both to ensure imports are visible to defined functions.
477487
* When using separate dicts, imports go to locals but function closures
478488
* only see globals, causing "name X is not defined" errors. */
479-
PyObject *py_result = PyEval_EvalCode(compiled, worker->globals, worker->globals);
489+
PyObject *py_result = PyEval_EvalCode(compiled, globals, globals);
480490
Py_DECREF(compiled);
481491

482492
if (py_result == NULL) {
@@ -488,6 +498,7 @@ static void process_request(py_request_t *req) {
488498
}
489499

490500
tl_current_worker = NULL;
501+
tl_current_context = NULL;
491502
tl_callback_env = NULL;
492503
enif_free(code);
493504
break;
@@ -792,12 +803,14 @@ static int executor_enqueue(py_request_t *req) {
792803
case PY_MODE_MULTI_EXECUTOR:
793804
if (atomic_load(&g_multi_executor_initialized)) {
794805
/* Route to multi-executor pool.
795-
* Use worker's assigned executor for thread affinity if available.
806+
* Use worker's or context's assigned executor for thread affinity if available.
796807
* This ensures libraries like numpy/torch that have thread-local
797-
* state always run on the same thread for a given worker. */
808+
* state always run on the same thread for a given worker/context. */
798809
int exec_id;
799810
if (req->worker != NULL && req->worker->executor_id >= 0) {
800811
exec_id = req->worker->executor_id % g_num_executors;
812+
} else if (req->context != NULL && req->context->executor_id >= 0) {
813+
exec_id = req->context->executor_id % g_num_executors;
801814
} else {
802815
exec_id = select_executor();
803816
}
@@ -1092,3 +1105,117 @@ static void multi_executor_stop(void) {
10921105
* in executor_enqueue() using PyGILState_Ensure/Release which are no-ops
10931106
* in free-threaded builds but still work correctly.
10941107
*/
1108+
1109+
/* ============================================================================
1110+
* Context dispatch to executor
1111+
*
1112+
* When a context has thread affinity (executor_id >= 0), we dispatch
1113+
* operations through the executor queue instead of executing directly
1114+
* on the dirty scheduler. This ensures numpy/torch thread-local state
1115+
* consistency.
1116+
* ============================================================================ */
1117+
1118+
/**
1119+
* Dispatch a context call operation to the executor.
1120+
*
1121+
* @param env Caller's NIF environment
1122+
* @param ctx Context with thread affinity
1123+
* @param module_bin Module name binary
1124+
* @param func_bin Function name binary
1125+
* @param args_term Arguments list
1126+
* @param kwargs_term Keyword arguments map
1127+
* @return Result term
1128+
*/
1129+
ERL_NIF_TERM context_dispatch_call(ErlNifEnv *env, py_context_t *ctx,
1130+
ErlNifBinary *module_bin, ErlNifBinary *func_bin,
1131+
ERL_NIF_TERM args_term, ERL_NIF_TERM kwargs_term) {
1132+
py_request_t req;
1133+
request_init(&req);
1134+
1135+
req.type = PY_REQ_CALL;
1136+
req.env = env;
1137+
req.worker = NULL;
1138+
req.context = ctx;
1139+
req.module_bin = *module_bin;
1140+
req.func_bin = *func_bin;
1141+
req.args_term = args_term;
1142+
req.kwargs_term = kwargs_term;
1143+
req.timeout_ms = 0;
1144+
1145+
if (executor_enqueue(&req) < 0) {
1146+
request_cleanup(&req);
1147+
return make_error(env, "executor_shutdown");
1148+
}
1149+
1150+
executor_wait(&req);
1151+
ERL_NIF_TERM result = req.result;
1152+
request_cleanup(&req);
1153+
1154+
return result;
1155+
}
1156+
1157+
/**
1158+
* Dispatch a context eval operation to the executor.
1159+
*
1160+
* @param env Caller's NIF environment
1161+
* @param ctx Context with thread affinity
1162+
* @param code_bin Code string binary
1163+
* @param locals_term Local variables map
1164+
* @return Result term
1165+
*/
1166+
ERL_NIF_TERM context_dispatch_eval(ErlNifEnv *env, py_context_t *ctx,
1167+
ErlNifBinary *code_bin, ERL_NIF_TERM locals_term) {
1168+
py_request_t req;
1169+
request_init(&req);
1170+
1171+
req.type = PY_REQ_EVAL;
1172+
req.env = env;
1173+
req.worker = NULL;
1174+
req.context = ctx;
1175+
req.code_bin = *code_bin;
1176+
req.locals_term = locals_term;
1177+
req.timeout_ms = 0;
1178+
1179+
if (executor_enqueue(&req) < 0) {
1180+
request_cleanup(&req);
1181+
return make_error(env, "executor_shutdown");
1182+
}
1183+
1184+
executor_wait(&req);
1185+
ERL_NIF_TERM result = req.result;
1186+
request_cleanup(&req);
1187+
1188+
return result;
1189+
}
1190+
1191+
/**
1192+
* Dispatch a context exec operation to the executor.
1193+
*
1194+
* @param env Caller's NIF environment
1195+
* @param ctx Context with thread affinity
1196+
* @param code_bin Code string binary
1197+
* @return Result term
1198+
*/
1199+
ERL_NIF_TERM context_dispatch_exec(ErlNifEnv *env, py_context_t *ctx,
1200+
ErlNifBinary *code_bin) {
1201+
py_request_t req;
1202+
request_init(&req);
1203+
1204+
req.type = PY_REQ_EXEC;
1205+
req.env = env;
1206+
req.worker = NULL;
1207+
req.context = ctx;
1208+
req.code_bin = *code_bin;
1209+
req.timeout_ms = 0;
1210+
1211+
if (executor_enqueue(&req) < 0) {
1212+
request_cleanup(&req);
1213+
return make_error(env, "executor_shutdown");
1214+
}
1215+
1216+
executor_wait(&req);
1217+
ERL_NIF_TERM result = req.result;
1218+
request_cleanup(&req);
1219+
1220+
return result;
1221+
}

c_src/py_nif.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3997,6 +3997,7 @@ static ERL_NIF_TERM nif_context_create(ErlNifEnv *env, int argc, const ERL_NIF_T
39973997
ctx->globals = NULL;
39983998
ctx->locals = NULL;
39993999
ctx->module_cache = NULL;
4000+
ctx->executor_id = -1; /* Not assigned yet */
40004001

40014002
/* Create callback pipe for blocking callback responses */
40024003
if (pipe(ctx->callback_pipe) < 0) {
@@ -4049,6 +4050,13 @@ static ERL_NIF_TERM nif_context_create(ErlNifEnv *env, int argc, const ERL_NIF_T
40494050
PyGILState_Release(gstate);
40504051
}
40514052

4053+
/* Assign executor for thread affinity in MULTI_EXECUTOR mode.
4054+
* This ensures numpy/torch thread-local state consistency. */
4055+
if (g_execution_mode == PY_MODE_MULTI_EXECUTOR &&
4056+
atomic_load(&g_multi_executor_initialized)) {
4057+
ctx->executor_id = select_executor();
4058+
}
4059+
40524060
ERL_NIF_TERM ref = enif_make_resource(env, ctx);
40534061
enif_release_resource(ctx);
40544062

@@ -4200,6 +4208,15 @@ static ERL_NIF_TERM nif_context_call(ErlNifEnv *env, int argc, const ERL_NIF_TER
42004208
return make_error(env, "invalid_func");
42014209
}
42024210

4211+
/* Context thread affinity: dispatch via executor instead of direct execution.
4212+
* This ensures numpy/torch thread-local state consistency. */
4213+
if (ctx->executor_id >= 0 && g_execution_mode == PY_MODE_MULTI_EXECUTOR &&
4214+
atomic_load(&g_multi_executor_initialized)) {
4215+
ERL_NIF_TERM kwargs = (argc > 4 && enif_is_map(env, argv[4]))
4216+
? argv[4] : enif_make_new_map(env);
4217+
return context_dispatch_call(env, ctx, &module_bin, &func_bin, argv[3], kwargs);
4218+
}
4219+
42034220
char *module_name = binary_to_string(&module_bin);
42044221
char *func_name = binary_to_string(&func_bin);
42054222
if (module_name == NULL || func_name == NULL) {
@@ -4399,6 +4416,15 @@ static ERL_NIF_TERM nif_context_eval(ErlNifEnv *env, int argc, const ERL_NIF_TER
43994416
return make_error(env, "invalid_code");
44004417
}
44014418

4419+
/* Context thread affinity: dispatch via executor instead of direct execution.
4420+
* This ensures numpy/torch thread-local state consistency. */
4421+
if (ctx->executor_id >= 0 && g_execution_mode == PY_MODE_MULTI_EXECUTOR &&
4422+
atomic_load(&g_multi_executor_initialized)) {
4423+
ERL_NIF_TERM locals = (argc > 2 && enif_is_map(env, argv[2]))
4424+
? argv[2] : enif_make_new_map(env);
4425+
return context_dispatch_eval(env, ctx, &code_bin, locals);
4426+
}
4427+
44024428
char *code = binary_to_string(&code_bin);
44034429
if (code == NULL) {
44044430
return make_error(env, "alloc_failed");
@@ -4536,6 +4562,13 @@ static ERL_NIF_TERM nif_context_exec(ErlNifEnv *env, int argc, const ERL_NIF_TER
45364562
return make_error(env, "invalid_code");
45374563
}
45384564

4565+
/* Context thread affinity: dispatch via executor instead of direct execution.
4566+
* This ensures numpy/torch thread-local state consistency. */
4567+
if (ctx->executor_id >= 0 && g_execution_mode == PY_MODE_MULTI_EXECUTOR &&
4568+
atomic_load(&g_multi_executor_initialized)) {
4569+
return context_dispatch_exec(env, ctx, &code_bin);
4570+
}
4571+
45394572
char *code = binary_to_string(&code_bin);
45404573
if (code == NULL) {
45414574
return make_error(env, "alloc_failed");

0 commit comments

Comments
 (0)