Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlalchemy_bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
FLOAT64,
INT64,
INTEGER,
JSON,
NUMERIC,
RECORD,
STRING,
Expand All @@ -58,6 +59,7 @@
"FLOAT64",
"INT64",
"INTEGER",
"JSON",
"NUMERIC",
"RECORD",
"STRING",
Expand Down
135 changes: 135 additions & 0 deletions sqlalchemy_bigquery/_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from enum import auto, Enum
import sqlalchemy
from sqlalchemy.sql import sqltypes
import json


class JSON(sqltypes.JSON):
# Default JSON serializer/deserializer
_json_deserializer = json.loads

def bind_expression(self, bindvalue):
# JSON query parameters are STRINGs
return sqlalchemy.func.PARSE_JSON(bindvalue, type_=self)

def literal_processor(self, dialect):
super_proc = self.bind_processor(dialect)

def process(value):
value = super_proc(value)
return repr(value)

return process

def get_col_spec(self):
return "JSON"

def _compiler_dispatch(self, visitor, **kw):
# Handle struct_field parameter for STRUCT field types
if kw.get("struct_field", False):
return "JSON"
# For DDL statements
if "type_expression" in kw:
return "JSON"
# For DBAPI parameter binding, use STRING
return "STRING"

def result_processor(self, dialect, coltype):
json_deserializer = dialect._json_deserializer or self._json_deserializer

def process(value):
if value is None:
return None
# Handle case where BigQuery already returns a dictionary
if isinstance(value, dict):
return value
return json_deserializer(value)

return process

class Comparator(sqltypes.JSON.Comparator):
def _generate_converter(self, name, lax):
prefix = "LAX_" if lax else ""
func_ = getattr(sqlalchemy.func, f"{prefix}{name}")
return func_

def as_boolean(self, lax=False):
func_ = self._generate_converter("BOOL", lax)
return func_(self.expr, type_=sqltypes.Boolean)

def as_string(self, lax=False):
func_ = self._generate_converter("STRING", lax)
return func_(self.expr, type_=sqltypes.String)

def as_integer(self, lax=False):
func_ = self._generate_converter("INT64", lax)
return func_(self.expr, type_=sqltypes.Integer)

def as_float(self, lax=False):
func_ = self._generate_converter("FLOAT64", lax)
return func_(self.expr, type_=sqltypes.Float)

def as_numeric(self, precision, scale, asdecimal=True):
# No converter available in BigQuery
raise NotImplementedError()

comparator_factory = Comparator

class JSONPathMode(Enum):
LAX = auto()
LAX_RECURSIVE = auto()


# Patch the SQLAlchemy JSONStrIndexType class to add _compiler_dispatch
sqltypes.JSON.JSONStrIndexType._compiler_dispatch = lambda self, visitor, **kw: "STRING"


class JSONPathType(sqltypes.JSON.JSONPathType):
def _mode_prefix(self, mode):
if mode == JSON.JSONPathMode.LAX:
mode_prefix = "lax"
elif mode == JSON.JSONPathMode.LAX_RECURSIVE:
mode_prefix = "lax recursive"
else:
raise NotImplementedError(f"Unhandled JSONPathMode: {mode}")
return mode_prefix

def _format_value(self, value):
if isinstance(value[0], JSON.JSONPathMode):
mode = value[0]
mode_prefix = self._mode_prefix(mode)
value = value[1:]
else:
mode_prefix = ""

return "%s$%s" % (
mode_prefix + " " if mode_prefix else "",
"".join(
[
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
for elem in value
]
),
)

def bind_processor(self, dialect):
super_proc = self.string_bind_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process

def literal_processor(self, dialect):
super_proc = self.string_literal_processor(dialect)

def process(value):
value = self._format_value(value)
if super_proc:
value = super_proc(value)
return value

return process
149 changes: 143 additions & 6 deletions sqlalchemy_bigquery/_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,11 @@ def _get_subtype_col_spec(type_):

type_compiler = base.dialect.type_compiler(base.dialect())
_get_subtype_col_spec = type_compiler.process

# Pass struct_field=True for JSON types in STRUCT fields
if hasattr(type_, "__class__") and type_.__class__.__name__ == "JSON":
return type_compiler.process(type_, struct_field=True)

return _get_subtype_col_spec(type_)


Expand Down Expand Up @@ -77,14 +82,136 @@ def __repr__(self):
return f"STRUCT({fields})"

def get_col_spec(self, **kw):
fields = ", ".join(
f"{name} {_get_subtype_col_spec(type_)}"
for name, type_ in self._STRUCT_fields
)
return f"STRUCT<{fields}>"
fields = []
for name, type_ in self._STRUCT_fields:
# Special handling for JSON types in STRUCT fields
if hasattr(type_, "__class__") and type_.__class__.__name__ == "JSON":
fields.append(f"{name} JSON")
else:
fields.append(f"{name} {_get_subtype_col_spec(type_)}")

return f"STRUCT<{', '.join(fields)}>"

def bind_processor(self, dialect):
return dict
import json

# Check if any field in the STRUCT is a JSON type
has_json_fields = any(
hasattr(type_, "__class__") and type_.__class__.__name__ == "JSON"
for _, type_ in self._STRUCT_fields
)

# If no JSON fields, return dict for backward compatibility
if not has_json_fields:
return dict

def process_value(value, struct_type):
if value is None:
return None

result = {}
for key, val in value.items():
# Find the field type by case-insensitive lookup
field_type = struct_type._STRUCT_byname.get(key.lower())

if field_type is None:
# Field not found in schema, pass through unchanged
result[key] = val
continue

# Check if this is a nested STRUCT
if hasattr(field_type, "__class__") and field_type.__class__.__name__ == "STRUCT":
if isinstance(val, dict):
# Process nested STRUCT recursively
result[key] = process_value(val, field_type)
else:
result[key] = val
# Check if this field is a JSON type
elif hasattr(field_type, "__class__") and field_type.__class__.__name__ == "JSON":
# Serialize JSON data
if val is not None and not isinstance(val, str):
result[key] = json.dumps(val)
else:
result[key] = val
else:
result[key] = val

return result

def process(value):
if value is None:
return None

return process_value(value, self)

return process

def result_processor(self, dialect, coltype):
import json

# Check if any field in the STRUCT is a JSON type
has_json_fields = any(
hasattr(type_, "__class__") and type_.__class__.__name__ == "JSON"
for _, type_ in self._STRUCT_fields
)

# If no JSON fields, return None for backward compatibility
if not has_json_fields:
return None

def process_value(value, struct_type):
if value is None:
return None

# Handle case where value is a string (happens in some test cases)
if isinstance(value, str):
try:
value = json.loads(value)
except (ValueError, TypeError):
return value

if not isinstance(value, dict):
return value

result = {}
for key, val in value.items():
# Find the field type by case-insensitive lookup
field_type = struct_type._STRUCT_byname.get(key.lower())

if field_type is None:
# Field not found in schema, pass through unchanged
result[key] = val
continue

# Check if this is a nested STRUCT
if hasattr(field_type, "__class__") and field_type.__class__.__name__ == "STRUCT":
if isinstance(val, dict):
# Process nested STRUCT recursively
result[key] = process_value(val, field_type)
else:
result[key] = val
# Check if this field is a JSON type
elif hasattr(field_type, "__class__") and field_type.__class__.__name__ == "JSON":
# Deserialize JSON string
if val is not None and isinstance(val, str):
try:
result[key] = json.loads(val)
except (ValueError, TypeError):
result[key] = val # Keep as is if not valid JSON
else:
result[key] = val
else:
result[key] = val

return result

def process(value):
if value is None:
return None

return process_value(value, self)

return process

class Comparator(sqlalchemy.sql.sqltypes.Indexable.Comparator):
def _setup_getitem(self, name):
Expand Down Expand Up @@ -137,10 +264,20 @@ def struct_getitem_op(a, b):
raise NotImplementedError()


def json_getitem_op(a, b):
# This is a placeholder function that will be handled by the compiler
# The actual implementation is in visit_json_getitem_op_binary
return None


sqlalchemy.sql.default_comparator.operator_lookup[
struct_getitem_op.__name__
] = sqlalchemy.sql.default_comparator.operator_lookup["json_getitem_op"]

sqlalchemy.sql.default_comparator.operator_lookup[
json_getitem_op.__name__
] = sqlalchemy.sql.default_comparator.operator_lookup["json_getitem_op"]


class SQLCompiler:
def visit_struct_getitem_op_binary(self, binary, operator_, **kw):
Expand Down
3 changes: 3 additions & 0 deletions sqlalchemy_bigquery/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
except ImportError: # pragma: NO COVER
pass

from ._json import JSON
from ._struct import STRUCT

_type_map = {
Expand All @@ -41,6 +42,7 @@
"FLOAT": sqlalchemy.types.Float,
"INT64": sqlalchemy.types.Integer,
"INTEGER": sqlalchemy.types.Integer,
"JSON": JSON,
"NUMERIC": sqlalchemy.types.Numeric,
"RECORD": STRUCT,
"STRING": sqlalchemy.types.String,
Expand All @@ -61,6 +63,7 @@
FLOAT = _type_map["FLOAT"]
INT64 = _type_map["INT64"]
INTEGER = _type_map["INTEGER"]
JSON = _type_map["JSON"]
NUMERIC = _type_map["NUMERIC"]
RECORD = _type_map["RECORD"]
STRING = _type_map["STRING"]
Expand Down
Loading