Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 92 additions & 69 deletions sdks/python/apache_beam/yaml/yaml_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,11 @@
#

"""This module defines the basic MapToFields operation."""
import datetime
import itertools
import json
Comment thread
derrickaw marked this conversation as resolved.
import re
import threading
from collections import abc
from collections.abc import Callable
from collections.abc import Collection
Expand Down Expand Up @@ -53,13 +56,11 @@
from apache_beam.yaml.yaml_errors import maybe_with_exception_handling_transform_fn
from apache_beam.yaml.yaml_provider import dicts_to_rows

# Import js2py package if it exists
# Import quickjs package if it exists
try:
import js2py
from js2py.base import JsObjectWrapper
import quickjs
except ImportError:
js2py = None
JsObjectWrapper = object
quickjs = None

_str_expression_fields = {
'AssignTimestamps': 'timestamp',
Expand Down Expand Up @@ -178,18 +179,41 @@ def _check_mapping_arguments(
raise ValueError(f'{transform_name} cannot specify "name" without "path"')


# js2py's JsObjectWrapper object has a self-referencing __dict__ property
# that cannot be pickled without implementing the __getstate__ and
# __setstate__ methods.
class _CustomJsObjectWrapper(JsObjectWrapper):
def __init__(self, js_obj):
super().__init__(js_obj.__dict__['_obj'])
class _QuickJsCallable:
"""A wrapper for QuickJS callables to ensure thread-safety and context reuse.

QuickJS contexts are not thread-safe. This class uses thread-local storage
to ensure each thread has its own QuickJS context, while reusing it across
multiple calls on the same thread.
"""
def __init__(self, source, name=None):
self.source = source
self.name = name
self._local = threading.local()

def _get_func(self):
if not hasattr(self._local, 'func'):
if quickjs is None:
raise ValueError("quickjs is not installed.")
context = quickjs.Context()
if self.name:
context.eval(self.source)
self._local.func = context.get(self.name)
else:
self._local.func = context.eval(self.source)
self._local.context = context # Keep context alive
return self._local.func

def __call__(self, *args):
return self._get_func()(*args)

def __getstate__(self):
return self.__dict__.copy()
return {'source': self.source, 'name': self.name}

def __setstate__(self, state):
self.__dict__.update(state)
self.source = state['source']
self.name = state['name']
self._local = threading.local()


# TODO(yaml) Improve type inferencing for JS UDF's
Expand All @@ -210,78 +234,77 @@ def py_value_to_js_dict(py_value):
def _expand_javascript_mapping_func(
original_fields, expression=None, callable=None, path=None, name=None):

# Check for installed js2py package
if js2py is None:
# Check for installed quickjs package
if quickjs is None:
raise ValueError(
"Javascript mapping functions are not supported on"
" Python 3.12 or later.")

# import remaining js2py objects
from js2py import base
from js2py.constructors import jsdate
from js2py.internals import simplex

js_array_type = (
base.PyJsArray,
base.PyJsArrayBuffer,
base.PyJsInt8Array,
base.PyJsUint8Array,
base.PyJsUint8ClampedArray,
base.PyJsInt16Array,
base.PyJsUint16Array,
base.PyJsInt32Array,
base.PyJsUint32Array,
base.PyJsFloat32Array,
base.PyJsFloat64Array)

def _js_object_to_py_object(obj):
if isinstance(obj, (base.PyJsNumber, base.PyJsString, base.PyJsBoolean)):
return base.to_python(obj)
elif isinstance(obj, js_array_type):
return [_js_object_to_py_object(value) for value in obj.to_list()]
elif isinstance(obj, jsdate.PyJsDate):
return obj.to_utc_dt()
elif isinstance(obj, (base.PyJsNull, base.PyJsUndefined)):
return None
elif isinstance(obj, base.PyJsError):
raise RuntimeError(obj['message'])
elif isinstance(obj, base.PyJsObject):
return {
key: _js_object_to_py_object(value['value'])
for (key, value) in obj.own.items()
}
elif isinstance(obj, base.JsObjectWrapper):
return _js_object_to_py_object(obj._obj)
"Javascript mapping functions require the 'quickjs' package.")

return obj
def make_bridge_source(func_name, call_expr):
# The bridge function facilitates data transfer from Python to QuickJS by
# parsing a JSON string representing the row object.
return (
f"function {func_name}(row_json) {{ "
f" const row = JSON.parse(row_json); "
f" const result = {call_expr}; "
f" if (result instanceof Date) "
f"return {{__type__: 'date', value: result.toISOString()}}; "
f" return result; "
f"}}")

if expression:
source = '\n'.join(['function(__row__) {'] + [
f' {name} = __row__.{name}'
for name in original_fields if name in expression
] + [' return (' + expression + ')'] + ['}'])
js_func = _CustomJsObjectWrapper(js2py.eval_js(source))
args = [
name for name in original_fields
if name.isidentifier() and name in expression
]

row_var_name = "row"
while row_var_name in args:
row_var_name += "_"

source = f"""
function fn(row_json) {{
const {row_var_name} = JSON.parse(row_json);
{chr(10).join([f" const {name} = {row_var_name}.{name};" for name in args])}
const result = ({expression});
if (result instanceof Date) return {{__type__: 'date', value: result.toISOString()}};
return result;
}}
"""
js_func = _QuickJsCallable(source, "fn")

elif callable:
js_func = _CustomJsObjectWrapper(js2py.eval_js(callable))
source = make_bridge_source("fn", f"({callable})(row)")
js_func = _QuickJsCallable(source, "fn")

else:
if not path.endswith('.js'):
raise ValueError(f'File "{path}" is not a valid .js file.')
udf_code = FileSystems.open(path).read().decode()
js = js2py.EvalJs()
js.eval(udf_code)
js_func = _CustomJsObjectWrapper(getattr(js, name))
bridge_source = udf_code + "\n" + make_bridge_source(
"bridge_fn", f"{name}(row)")
js_func = _QuickJsCallable(bridge_source, "bridge_fn")

def js_wrapper(row):
# Serialize the entire row to JSON to pass to QuickJS.
row_as_dict = py_value_to_js_dict(row)
row_json = json.dumps(row_as_dict)

try:
js_result = js_func(row_as_dict)
except simplex.JsException as exn:
js_result = js_func(row_json)
except Exception as exn:
raise RuntimeError(
f"Error evaluating javascript expression: "
f"{exn.mes['message']}") from exn
return dicts_to_rows(_js_object_to_py_object(js_result))
f"Error evaluating javascript expression: {exn}") from exn

if isinstance(js_result, quickjs.Object):
# Use native json() method to transfer complex types from JS to Python.
obj = json.loads(js_result.json())
# Handle special tagged types like Date
if isinstance(obj, dict) and obj.get('__type__') == 'date':
js_result = datetime.datetime.fromisoformat(obj['value'])
else:
js_result = obj

return dicts_to_rows(js_result)

return js_wrapper

Expand Down
64 changes: 57 additions & 7 deletions sdks/python/apache_beam/yaml/yaml_udf_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@
from apache_beam.yaml.yaml_transform import YamlTransform

try:
import js2py
import quickjs
except ImportError:
js2py = None
logging.warning('js2py is not installed; some tests will be skipped.')
quickjs = None
logging.warning('quickjs is not installed; some tests will be skipped.')


def as_rows():
Expand Down Expand Up @@ -63,7 +63,7 @@ def setUp(self):
def tearDown(self):
shutil.rmtree(self.tmpdir)

@unittest.skipIf(js2py is None, 'js2py not installed.')
@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_map_to_fields_filter_inline_js(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
Expand Down Expand Up @@ -109,6 +109,56 @@ def test_map_to_fields_filter_inline_js(self):
row=beam.Row(rank=2, values=[7, 8, 9, 12])),
]))

@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_map_to_fields_date_js(self):
import datetime
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
])) as p:
elements = p | beam.Create([beam.Row(val=1)])
result = elements | YamlTransform(
'''
type: MapToFields
config:
language: javascript
fields:
date:
callable: |
function get_date(x) {
return new Date("2026-04-16T12:00:00.000Z")
}
''')
assert_that(
result,
equal_to([
beam.Row(
date=datetime.datetime(
2026, 4, 16, 12, 0, 0, tzinfo=datetime.timezone.utc)),
]))

@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_map_to_fields_new_complex_types_js(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
])) as p:
elements = p | beam.Create([beam.Row(val=1)])
result = elements | YamlTransform(
'''
type: MapToFields
config:
language: javascript
fields:
arr:
callable: "function(x) { return [1, 2, 3]; }"
obj:
callable: "function(x) { return {a: 1, b: 'two'}; }"
''')
assert_that(
result,
equal_to([
beam.Row(arr=[1, 2, 3], obj=beam.Row(a=1, b='two')),
]))

def test_map_to_fields_filter_inline_py(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle')) as p:
Expand Down Expand Up @@ -197,7 +247,7 @@ def test_map_to_fields_sql_reserved_keyword_append():
beam.Row(label='389a', timestamp=2, label_copy="389a"),
]))

@unittest.skipIf(js2py is None, 'js2py not installed.')
@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_filter_inline_js(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
Expand Down Expand Up @@ -252,7 +302,7 @@ def test_filter_inline_py(self):
row=beam.Row(rank=2, values=[7, 8, 9])),
]))

@unittest.skipIf(js2py is None, 'js2py not installed.')
@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_filter_expression_js(self):
with beam.Pipeline(options=beam.options.pipeline_options.PipelineOptions(
pickle_library='cloudpickle', yaml_experimental_features=['javascript'
Expand Down Expand Up @@ -296,7 +346,7 @@ def test_filter_expression_py(self):
row=beam.Row(rank=0, values=[1, 2, 3])),
]))

@unittest.skipIf(js2py is None, 'js2py not installed.')
@unittest.skipIf(quickjs is None, 'quickjs not installed.')
def test_filter_inline_js_file(self):
data = '''
function f(x) {
Expand Down
3 changes: 2 additions & 1 deletion sdks/python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,8 @@ def get_portability_package_data():
'jinja2>=3.0,<3.2',
'virtualenv-clone>=0.5,<1.0',
# https://github.com/PiotrDabkowski/Js2Py/issues/317
'js2py>=0.74,<1; python_version<"3.12"',
'quickjs; '
'python_version < "3.13" or platform_system != "Windows"',
'jsonschema>=4.0.0,<5.0.0',
] + dataframe_dependency,
# Keep the following dependencies in line with what we test against
Expand Down
Loading