Source code for airflow.hooks.dbapi_hook

# -*- coding: utf-8 -*-
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from builtins import str
from past.builtins import basestring
from datetime import datetime
import numpy
import logging
import sys

from sqlalchemy import create_engine

from airflow.hooks.base_hook import BaseHook
from airflow.exceptions import AirflowException


[docs]class DbApiHook(BaseHook): """ Abstract base class for sql hooks. """ # Override to provide the connection name. conn_name_attr = None # Override to have a default connection id for a particular dbHook default_conn_name = 'default_conn_id' # Override if this db supports autocommit. supports_autocommit = False # Override with the object that exposes the connect method connector = None def __init__(self, *args, **kwargs): if not self.conn_name_attr: raise AirflowException("conn_name_attr is not defined") elif len(args) == 1: setattr(self, self.conn_name_attr, args[0]) elif self.conn_name_attr not in kwargs: setattr(self, self.conn_name_attr, self.default_conn_name) else: setattr(self, self.conn_name_attr, kwargs[self.conn_name_attr])
[docs] def get_conn(self): """Returns a connection object """ db = self.get_connection(getattr(self, self.conn_name_attr)) return self.connector.connect( host=db.host, port=db.port, username=db.login, schema=db.schema)
def get_uri(self): conn = self.get_connection(getattr(self, self.conn_name_attr)) login = '' if conn.login: login = '{conn.login}:{conn.password}@'.format(conn=conn) host = conn.host if conn.port is not None: host += ':{port}'.format(port=conn.port) return '{conn.conn_type}://{login}{host}/{conn.schema}'.format( conn=conn, login=login, host=host) def get_sqlalchemy_engine(self, engine_kwargs=None): if engine_kwargs is None: engine_kwargs = {} return create_engine(self.get_uri(), **engine_kwargs)
[docs] def get_pandas_df(self, sql, parameters=None): """ Executes the sql and returns a pandas dataframe :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list :param parameters: The parameters to render the SQL query with. :type parameters: mapping or iterable """ if sys.version_info[0] < 3: sql = sql.encode('utf-8') import pandas.io.sql as psql conn = self.get_conn() df = psql.read_sql(sql, con=conn, params=parameters) conn.close() return df
[docs] def get_records(self, sql, parameters=None): """ Executes the sql and returns a set of records. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list :param parameters: The parameters to render the SQL query with. :type parameters: mapping or iterable """ if sys.version_info[0] < 3: sql = sql.encode('utf-8') conn = self.get_conn() cur = self.get_cursor() if parameters is not None: cur.execute(sql, parameters) else: cur.execute(sql) rows = cur.fetchall() cur.close() conn.close() return rows
[docs] def get_first(self, sql, parameters=None): """ Executes the sql and returns the first resulting row. :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list :param parameters: The parameters to render the SQL query with. :type parameters: mapping or iterable """ if sys.version_info[0] < 3: sql = sql.encode('utf-8') conn = self.get_conn() cur = conn.cursor() if parameters is not None: cur.execute(sql, parameters) else: cur.execute(sql) rows = cur.fetchone() cur.close() conn.close() return rows
[docs] def run(self, sql, autocommit=False, parameters=None): """ Runs a command or a list of commands. Pass a list of sql statements to the sql parameter to get them to execute sequentially :param sql: the sql statement to be executed (str) or a list of sql statements to execute :type sql: str or list :param autocommit: What to set the connection's autocommit setting to before executing the query. :type autocommit: bool :param parameters: The parameters to render the SQL query with. :type parameters: mapping or iterable """ conn = self.get_conn() if isinstance(sql, basestring): sql = [sql] if self.supports_autocommit: self.set_autocommit(conn, autocommit) cur = conn.cursor() for s in sql: if sys.version_info[0] < 3: s = s.encode('utf-8') logging.info(s) if parameters is not None: cur.execute(s, parameters) else: cur.execute(s) cur.close() conn.commit() conn.close()
def set_autocommit(self, conn, autocommit): conn.autocommit = autocommit
[docs] def get_cursor(self): """ Returns a cursor """ return self.get_conn().cursor()
[docs] def insert_rows(self, table, rows, target_fields=None, commit_every=1000): """ A generic way to insert a set of tuples into a table, the whole set of inserts is treated as one transaction :param table: Name of the target table :type table: str :param rows: The rows to insert into the table :type rows: iterable of tuples :param target_fields: The names of the columns to fill in the table :type target_fields: iterable of strings :param commit_every: The maximum number of rows to insert in one transaction. Set to 0 to insert all rows in one transaction. :type commit_every: int """ if target_fields: target_fields = ", ".join(target_fields) target_fields = "({})".format(target_fields) else: target_fields = '' conn = self.get_conn() if self.supports_autocommit: self.set_autocommit(conn, False) conn.commit() cur = conn.cursor() i = 0 for row in rows: i += 1 l = [] for cell in row: l.append(self._serialize_cell(cell, conn)) values = tuple(l) sql = "INSERT INTO {0} {1} VALUES ({2});".format( table, target_fields, ",".join(values)) cur.execute(sql) if commit_every and i % commit_every == 0: conn.commit() logging.info( "Loaded {i} into {table} rows so far".format(**locals())) conn.commit() cur.close() conn.close() logging.info( "Done loading. Loaded a total of {i} rows".format(**locals()))
@staticmethod def _serialize_cell(cell, conn=None): """ Returns the SQL literal of the cell as a string. :param cell: The cell to insert into the table :type cell: object :param conn: The database connection :type conn: connection object :return: The serialized cell :rtype: str """ if isinstance(cell, basestring): return "'" + str(cell).replace("'", "''") + "'" elif cell is None: return 'NULL' elif isinstance(cell, numpy.datetime64): return "'" + str(cell) + "'" elif isinstance(cell, datetime): return "'" + cell.isoformat() + "'" else: return str(cell)
[docs] def bulk_dump(self, table, tmp_file): """ Dumps a database table into a tab-delimited file :param table: The name of the source table :type table: str :param tmp_file: The path of the target file :type tmp_file: str """ raise NotImplementedError()
[docs] def bulk_load(self, table, tmp_file): """ Loads a tab-delimited file into a database table :param table: The name of the target table :type table: str :param tmp_file: The path of the file to load into the table :type tmp_file: str """ raise NotImplementedError()