Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions dataframe_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from unittest import TestCase
from e6data_python_connector import Connection

import logging

logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

class TestDataFrame(TestCase):
def setUp(self) -> None:
self._host = "localhost"
self._catalog = "demogluecatalog"
self._database = "tpcds_1000_delta"
logging.debug('Trying to connect to engine')
self.e6x_connection = Connection(
host=self._host,
port=9001,
username='[email protected]',
password='Dummy@123',
database=self._database,
catalog=self._catalog
)
logging.debug('Successfully connect to engine.')

def disconnect(self):
self.e6x_connection.close()
self.assertFalse(self.e6x_connection.check_connection())

def tearDown(self) -> None:
self.disconnect()

def test_table_creation(self):
try:
self._dataframe = self.e6x_connection.load_parquet('<filepath>')
rows = self._dataframe.show()
for row in rows:
print(row)
except BaseException as e:
print(f"Exception :{e}")
4 changes: 2 additions & 2 deletions e6data_python_connector/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from e6data_python_connector.e6data_grpc import Connection, Cursor
from e6data_python_connector.e6data_grpc import Connection, Cursor, DataFrame

__all__ = ['Connection', 'Cursor']
__all__ = ['Connection', 'Cursor','DataFrame']
269 changes: 269 additions & 0 deletions e6data_python_connector/e6data_grpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from decimal import Decimal
from io import BytesIO
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED
from typing import overload

import grpc
from grpc._channel import _InactiveRpcError
Expand All @@ -24,6 +25,7 @@
from e6data_python_connector.constants import *
from e6data_python_connector.datainputstream import get_query_columns_info, read_rows_from_chunk
from e6data_python_connector.server import e6x_engine_pb2_grpc, e6x_engine_pb2
from e6data_python_connector.server.e6x_engine_pb2 import AggregateFunction
from e6data_python_connector.typeId import *

apilevel = '2.0'
Expand Down Expand Up @@ -186,6 +188,8 @@ def __init__(
self.cluster_uuid = cluster_uuid
self._session_id = None
self._host = host
# engine ip for stickiness
self._engine_ip = None
self._port = port

self._secure_channel = secure
Expand All @@ -206,6 +210,9 @@ def __init__(
self.grpc_auto_resume_timeout_seconds = self._grpc_options.pop('grpc_auto_resume_timeout_seconds')
self._create_client()

# initialize session for dataframe
self._dataframe_session = DataFrameSession(self)

@property
def _get_grpc_options(self):
"""
Expand Down Expand Up @@ -324,6 +331,7 @@ def get_session_id(self):
metadata=_get_grpc_header(cluster=self.cluster_uuid)
)
self._session_id = authenticate_response.sessionId
self._engine_ip = authenticate_response.engineIP
else:
raise e
else:
Expand Down Expand Up @@ -357,6 +365,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
exc_val (BaseException): The exception instance raised (if any).
exc_tb (Traceback): The traceback object of the exception (if any).
"""
self._dataframe_session.terminate()
self.close()

def close(self):
Expand All @@ -365,6 +374,7 @@ def close(self):

This method ensures that the gRPC channel is properly closed and the session ID is reset to None.
"""
self._dataframe_session.terminate()
if self._channel is not None:
self._channel.close()
self._channel = None
Expand Down Expand Up @@ -533,6 +543,28 @@ def cursor(self, catalog_name=None, db_name=None):
"""
return Cursor(self, database=db_name, catalog_name=catalog_name)

def load_parquet(self, parquet_path) -> "DataFrame":
dataframe = DataFrame(self,
file_path=parquet_path,
dataframe_number=self._dataframe_session.get_dataframe_number,
table_name=None,
engine_ip=self._engine_ip)

self._dataframe_session.update_dataframe_map(dataframe=dataframe)
return dataframe

def load_table(self, table_name, database = None, catalog = None) -> "DataFrame":
dataframe = DataFrame(self,
file_path=None,
dataframe_number=self._dataframe_session.get_dataframe_number,
table_name=table_name,
engine_ip=self._engine_ip,
database=database,
catalog=catalog)

self._dataframe_session.update_dataframe_map(dataframe=dataframe)
return dataframe

def rollback(self):
"""
Rolls back the current transaction.
Expand All @@ -552,6 +584,10 @@ def client(self):
"""
return self._client

@property
def host(self):
return self._host


class Cursor(DBAPICursor):
"""
Expand Down Expand Up @@ -1032,6 +1068,239 @@ def explain_analyse(self):
)


class DataFrame:

def __init__(self, connection: Connection, file_path, dataframe_number, table_name, engine_ip, database = None, catalog = None):
self._dataframe_number = dataframe_number
self._connection = connection
self._catalog = self._connection.catalog_name if catalog is None else catalog
self._database = self._connection.database if database is None else database
self._table_name = table_name
self._file_path = file_path
self._sessionId = connection.get_session_id
self._engine_ip = engine_ip
self._is_metadata_updated = False
self._query_id = None
self._data = None
self._batch = None
self._create_dataframe(self._file_path is not None)

def _create_dataframe(self, create_dataframe_from_parquet : bool):
client = self._connection.client

create_dataframe_request = e6x_engine_pb2.CreateDataFrameRequest(
parquetFilePath=self._file_path,
catalog=self._catalog,
schema=self._database,
table=self._table_name,
sessionId=self._sessionId,
engineIP=self._engine_ip,
dataframeNumber=self._dataframe_number,
createFromParquet=create_dataframe_from_parquet
)

create_dataframe_response = client.createDataFrame(
create_dataframe_request
)
self._query_id = create_dataframe_response.queryId

def select(self, *fields : str) -> "DataFrame":
projection_fields = []
for field in fields:
projection_fields.append(field)

client = self._connection.client
projection_on_dataframe_request = e6x_engine_pb2.ProjectionOnDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
field=projection_fields
)

client.projectionOnDataFrame(projection_on_dataframe_request)

return self

def aggregate(self, agg_function : dict[str, str], group_by : list[str] = None) -> "DataFrame":
def get_agg_enum(function_name : str) -> AggregateFunction | None:
match function_name.lower():
case 'sum':
return e6x_engine_pb2.AggregateFunction.SUM
case 'count':
return e6x_engine_pb2.AggregateFunction.COUNT
case 'count_star':
return e6x_engine_pb2.AggregateFunction.COUNT_STAR
case 'count_distinct':
return e6x_engine_pb2.AggregateFunction.COUNT_DISTINCT
case _:
return None

agg_function_map = {}

for column in agg_function.keys():
fun = get_agg_enum(agg_function.get(column))
if fun is not None:
agg_function_map.update({column : fun})

client = self._connection.client
aggregate_on_dataframe_request = e6x_engine_pb2.AggregateOnDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
aggregateFunctionMap=agg_function_map,
groupBy=group_by
)

client.aggregateOnDataFrame(aggregate_on_dataframe_request)

return self

def where(self, where_clause : str) -> "DataFrame":
client = self._connection.client
filter_on_dataframe_request = e6x_engine_pb2.FilterOnDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
whereClause=where_clause
)

client.filterOnDataFrame(filter_on_dataframe_request)

return self

def order_by(self, *field_list : str) -> "DataFrame":
order_by_map = dict()

# default sorting in ASCENDING order
for column in field_list:
order_by_map.update({ column : e6x_engine_pb2.SortDirection.ASC})

client = self._connection.client

orderby_on_dataframe_request = e6x_engine_pb2.OrderByOnDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
orderByFieldMap=order_by_map
)

client.orderByOnDataFrame(orderby_on_dataframe_request)
return self

def limit(self, fetch_limit : int) -> "DataFrame":
client = self._connection.client
limit_on_dataframe_request = e6x_engine_pb2.LimitOnDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
fetchLimit=fetch_limit
)

client.limitOnDataFrame(limit_on_dataframe_request)

return self

def show(self):
self.execute()
return self.fetchall()

def execute(self):
client = self._connection.client
execute_dataframe_request = e6x_engine_pb2.ExecuteDataFrameRequest(
queryId=self._query_id,
dataframeNumber=self._dataframe_number,
sessionId=self._sessionId,
engineIP=self._engine_ip,
)
client.executeDataFrame(execute_dataframe_request)

def _update_meta_data(self):
result_meta_data_request = e6x_engine_pb2.GetResultMetadataRequest(
engineIP=self._engine_ip,
sessionId=self._sessionId,
queryId=self._query_id
)
get_result_metadata_response = self._connection.client.getResultMetadata(
result_meta_data_request,
)
buffer = BytesIO(get_result_metadata_response.resultMetaData)
self._rowcount, self._query_columns_description = get_query_columns_info(buffer)
self._is_metadata_updated = True

def _fetch_batch(self):
client = self._connection.client
get_next_result_batch_request = e6x_engine_pb2.GetNextResultBatchRequest(
engineIP=self._engine_ip,
sessionId=self._sessionId,
queryId=self._query_id
)
get_next_result_batch_response = client.getNextResultBatch(
get_next_result_batch_request,
)
buffer = get_next_result_batch_response.resultBatch
if not self._is_metadata_updated:
self._update_meta_data()
if not buffer or len(buffer) == 0:
return None
# one batch retrieves the predefined set of rows
return read_rows_from_chunk(self._query_columns_description, buffer)

def fetchall(self):
self._data = list()
while True:
rows = self._fetch_batch()
if rows is None:
break
self._data = self._data + rows
rows = self._data
self._data = None
return rows

class DataFrameSession:
def __init__(self, connection: Connection, planner_ip):
self._connection = connection
self._dataframe_count = 0
self._dataframe_map = dict()
self._is_terminated = False
self._session_id = connection.get_session_id
self._planner_ip = planner_ip

def __exit__(self, exc_type, exc_val, exc_tb):
self.terminate()

def update_dataframe_map(self, dataframe : "DataFrame"):
self._dataframe_map.update({self._dataframe_count : dataframe})
self._dataframe_count = self._dataframe_count + 1

@property
def get_dataframe_number(self) -> int:
return self._dataframe_count

@property
def is_terminated(self) -> bool:
return self._is_terminated

@property
def planner_ip(self):
return self._planner_ip

def terminate(self):
if not self._is_terminated:
drop_user_context_request = e6x_engine_pb2.DropUserContextRequest(
sessionId=self._session_id,
engineIP=self._planner_ip
)

self._connection.client.dropUserContext(drop_user_context_request)
self._is_terminated = True



def poll(self, get_progress_update=True):
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
:returns: ``ttypes.TGetOperationStatusResp``
Expand Down
Loading