forked from Kyligence/spark
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsql_formatter.py
292 lines (239 loc) · 9.64 KB
/
sql_formatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import string
from typing import Any, Dict, Optional, Union, List, Sequence, Mapping, Tuple
import uuid
import warnings
import pandas as pd
from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.namespace import _get_index_map
from pyspark.sql.functions import lit
from pyspark import pandas as ps
from pyspark.sql import SparkSession
from pyspark.pandas.utils import default_session
from pyspark.pandas.frame import DataFrame
from pyspark.pandas.series import Series
__all__ = ["sql"]
# This is not used in this file. It's for legacy sql_processor.
_CAPTURE_SCOPES = 3
def sql(
query: str,
index_col: Optional[Union[str, List[str]]] = None,
args: Dict[str, str] = {},
**kwargs: Any,
) -> DataFrame:
"""
Execute a SQL query and return the result as a pandas-on-Spark DataFrame.
This function acts as a standard Python string formatter with understanding
the following variable types:
* pandas-on-Spark DataFrame
* pandas-on-Spark Series
* pandas DataFrame
* pandas Series
* string
Also the method can bind named parameters to SQL literals from `args`.
Parameters
----------
query : str
the SQL query
index_col : str or list of str, optional
Column names to be used in Spark to represent pandas-on-Spark's index. The index name
in pandas-on-Spark is ignored. By default, the index is always lost.
.. note:: If you want to preserve the index, explicitly use :func:`DataFrame.reset_index`,
and pass it to the SQL statement with `index_col` parameter.
For example,
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
>>> new_psdf = psdf.reset_index()
>>> ps.sql("SELECT * FROM {new_psdf}", index_col="index", new_psdf=new_psdf)
... # doctest: +NORMALIZE_WHITESPACE
A B
index
a 1 4
b 2 5
c 3 6
For MultiIndex,
>>> psdf = ps.DataFrame(
... {"A": [1, 2, 3], "B": [4, 5, 6]},
... index=pd.MultiIndex.from_tuples(
... [("a", "b"), ("c", "d"), ("e", "f")], names=["index1", "index2"]
... ),
... )
>>> new_psdf = psdf.reset_index()
>>> ps.sql(
... "SELECT * FROM {new_psdf}", index_col=["index1", "index2"], new_psdf=new_psdf)
... # doctest: +NORMALIZE_WHITESPACE
A B
index1 index2
a b 1 4
c d 2 5
e f 3 6
Also note that the index name(s) should be matched to the existing name.
args : dict
A dictionary of parameter names to string values that are parsed as SQL literal
expressions. For example, dict keys: "rank", "name", "birthdate"; dict values:
"1", "'Steven'", "DATE'2023-03-21'". The fragments of string values belonged to SQL
comments are skipped while parsing.
.. versionadded:: 3.4.0
kwargs
other variables that the user want to set that can be referenced in the query
Returns
-------
pandas-on-Spark DataFrame
Examples
--------
Calling a built-in SQL function.
>>> ps.sql("SELECT * FROM range(10) where id > 7")
id
0 8
1 9
>>> ps.sql("SELECT * FROM range(10) WHERE id > {bound1} AND id < {bound2}", bound1=7, bound2=9)
id
0 8
>>> mydf = ps.range(10)
>>> x = tuple(range(4))
>>> ps.sql("SELECT {ser} FROM {mydf} WHERE id IN {x}", ser=mydf.id, mydf=mydf, x=x)
id
0 0
1 1
2 2
3 3
Mixing pandas-on-Spark and pandas DataFrames in a join operation. Note that the index is
dropped.
>>> ps.sql('''
... SELECT m1.a, m2.b
... FROM {table1} m1 INNER JOIN {table2} m2
... ON m1.key = m2.key
... ORDER BY m1.a, m2.b''',
... table1=ps.DataFrame({"a": [1,2], "key": ["a", "b"]}),
... table2=pd.DataFrame({"b": [3,4,5], "key": ["a", "b", "b"]}))
a b
0 1 3
1 2 4
2 2 5
Also, it is possible to query using Series.
>>> psdf = ps.DataFrame({"A": [1, 2, 3], "B":[4, 5, 6]}, index=['a', 'b', 'c'])
>>> ps.sql("SELECT {mydf.A} FROM {mydf}", mydf=psdf)
A
0 1
1 2
2 3
And substitude named parameters with the `:` prefix by SQL literals.
>>> ps.sql("SELECT * FROM range(10) WHERE id > :bound1", args={"bound1":"7"})
id
0 8
1 9
"""
if os.environ.get("PYSPARK_PANDAS_SQL_LEGACY") == "1":
from pyspark.pandas import sql_processor
warnings.warn(
"Deprecated in 3.3.0, and the legacy behavior "
"will be removed in the future releases.",
FutureWarning,
)
return sql_processor.sql(query, index_col=index_col, **kwargs)
session = default_session()
formatter = PandasSQLStringFormatter(session)
try:
sdf = session.sql(formatter.format(query, **kwargs), args)
finally:
formatter.clear()
index_spark_columns, index_names = _get_index_map(sdf, index_col)
return DataFrame(
InternalFrame(
spark_frame=sdf, index_spark_columns=index_spark_columns, index_names=index_names
)
)
class PandasSQLStringFormatter(string.Formatter):
"""
A standard ``string.Formatter`` in Python that can understand pandas-on-Spark instances
with basic Python objects. This object must be clear after the use for single SQL
query; cannot be reused across multiple SQL queries without cleaning.
"""
def __init__(self, session: SparkSession) -> None:
self._session: SparkSession = session
self._temp_views: List[Tuple[DataFrame, str]] = []
self._ref_sers: List[Tuple[Series, str]] = []
def vformat(self, format_string: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> str:
ret = super(PandasSQLStringFormatter, self).vformat(format_string, args, kwargs)
for ref, n in self._ref_sers:
if not any((ref is v for v in df._pssers.values()) for df, _ in self._temp_views):
# If referred DataFrame does not hold the given Series, raise an error.
raise ValueError("The series in {%s} does not refer any dataframe specified." % n)
return ret
def get_field(self, field_name: str, args: Sequence[Any], kwargs: Mapping[str, Any]) -> Any:
obj, first = super(PandasSQLStringFormatter, self).get_field(field_name, args, kwargs)
return self._convert_value(obj, field_name), first
def _convert_value(self, val: Any, name: str) -> Optional[str]:
"""
Converts the given value into a SQL string.
"""
if isinstance(val, pd.Series):
# Return the column name from pandas Series directly.
return ps.from_pandas(val).to_frame()._to_spark().columns[0]
elif isinstance(val, Series):
# Return the column name of pandas-on-Spark Series iff its DataFrame was
# referred. The check will be done in `vformat` after we parse all.
self._ref_sers.append((val, name))
return val.to_frame()._to_spark().columns[0]
elif isinstance(val, (DataFrame, pd.DataFrame)):
df_name = "_pandas_api_%s" % str(uuid.uuid4()).replace("-", "")
if isinstance(val, pd.DataFrame):
# Don't store temp view for plain pandas instances
# because it is unable to know which pandas DataFrame
# holds which Series.
val = ps.from_pandas(val)
else:
for df, n in self._temp_views:
if df is val:
return n
self._temp_views.append((val, df_name))
val._to_spark().createOrReplaceTempView(df_name)
return df_name
elif isinstance(val, str):
return lit(val)._jc.expr().sql() # for escaped characters.
else:
return val
def clear(self) -> None:
for _, n in self._temp_views:
self._session.catalog.dropTempView(n)
self._temp_views = []
self._ref_sers = []
def _test() -> None:
import os
import doctest
import sys
from pyspark.sql import SparkSession
import pyspark.pandas.sql_formatter
os.chdir(os.environ["SPARK_HOME"])
globs = pyspark.pandas.sql_formatter.__dict__.copy()
globs["ps"] = pyspark.pandas
spark = (
SparkSession.builder.master("local[4]")
.appName("pyspark.pandas.sql_formatter tests")
.getOrCreate()
)
(failure_count, test_count) = doctest.testmod(
pyspark.pandas.sql_formatter,
globs=globs,
optionflags=doctest.ELLIPSIS | doctest.NORMALIZE_WHITESPACE,
)
spark.stop()
if failure_count:
sys.exit(-1)
if __name__ == "__main__":
_test()