#!/usr/bin/env python2.5


import sys
import logging
import cx_Oracle


def log(msg):
    logging.info(msg)


def get_table_columns(conn, schema, table):
    curs = conn.cursor()
    curs.execute('''
            select column_name
            from all_tab_columns
            where table_name = :t
            and (owner = :o or (:o is null and owner = user))
            order by column_id
        ''', {'t': table, 'o': schema})
    return [r[0] for r in curs]


def get_table_index_sql(conn, owner, table, index):
    curs = conn.cursor()
    curs.execute('''
            select uniqueness
            from all_indexes
            where
                table_name = :table_name
                and (table_owner = :table_owner or (:table_owner is null and table_owner = user))
                and index_name = :index_name
        ''', {'table_name': table, 'table_owner': owner, 'index_name': index})

    d = curs.fetchall()
    assert len(d) == 1
    uniqueness = d[0][0]

    if uniqueness == 'UNIQUE':
        uniqueness_str = 'UNIQUE'
    else:
        uniqueness_str = ''

    curs.execute('''
            select column_name
            from all_ind_columns
            where
                table_name = :table_name
                and (table_owner = :table_owner or (:table_owner is null and table_owner = user))
                and index_name = :index_name
            order by
                column_position
        ''', {'table_name': table, 'table_owner': owner, 'index_name': index})

    columns = [r[0] for r in curs]
    columns_str = ', '.join(columns)
    sql = "create %s index %s on %s (%s)" % (uniqueness_str, index, table, columns_str)
    return sql


def get_table_index_sqls(conn, owner, table):
    curs = conn.cursor()
    curs.execute('''
            select index_name
            from all_indexes
            where
                table_name = :table_name
                and (table_owner = :table_owner or (:table_owner is null and table_owner = user))
        ''', {'table_name': table, 'table_owner': owner})
    indexes = [r[0] for r in curs]
    return [get_table_index_sql(conn, owner, table, index) for index in indexes]


def get_table_create_sqls(conn, owner, table):
    def column_sql(row):
        name, data_type, data_length, nullable = row
        if nullable == 'N':
            nullable_str = 'not null'
        else:
            nullable_str = ''
        if data_type == 'NUMBER':
            return "%s NUMBER %s" % (name, nullable_str)
        elif data_type == 'VARCHAR2':
            return "%s VARCHAR(%d) %s" % (name, data_length, nullable_str)
        elif data_type == 'DATE':
            return "%s DATE %s" % (name, nullable_str)
        else:
            raise Exception("unknown data type: %s" % data_type)

    curs = conn.cursor()
    curs.execute('''
            select
                column_name, data_type, data_length,
                nullable
            from all_tab_columns
            where
                table_name = :table_name
                and (owner = :owner or (:owner is null and owner = user))
            order by column_id
        ''', {'owner': owner, 'table_name': table})
    columns_data = curs.fetchall()
    columns_sql = ',\n'.join(column_sql(row) for row in columns_data)
    create_sql = "create table %s\n(%s)" % (table, columns_sql)

    index_sqls = get_table_index_sqls(conn, owner, table)

    return [create_sql] + index_sqls


def table_exists(conn, schema, table):
    assert table == table.upper()
    curs = conn.cursor()
    curs.execute('''
            select 1 from all_tables
            where
                table_name = :table_name
                and (owner = :owner or (:owner is null and owner = user))
        ''', {'owner': schema, 'table_name': table})
    d = curs.fetchall()
    if len(d) == 1:
        return True
    elif len(d) == 0:
        return False
    else:
        raise Exception('expected at most one row')


def connect(connstr):
    import re
    # user/password@host:port:sid
    rs = r'^([^/]+)/([^@]+)@([^:]+):([\d]+):(.+)$'
    fields = re.compile(rs).findall(connstr)
    if fields:
        log("conn fields: %s" % fields)
        user, password, host, port, sid = fields[0]
        return cx_Oracle.connect(user, password, cx_Oracle.makedsn(host, int(port), sid))
    else:
        return cx_Oracle.connect(connstr)



def main(srcconn, dstconn, tablename):
    assert tablename == tablename.upper()
    log("main(%s, %s, %s)" % (srcconn, dstconn, tablename))
    srcconn = connect(srcconn)
    srccurs = srcconn.cursor()
    log("source conn: %s" % srcconn)

    dstconn = connect(dstconn)
    dstcurs = dstconn.cursor()
    log("destination conn: %s" % dstconn)

    srctable = tablename
    dsttable = tablename

    columns = get_table_columns(srcconn, None, srctable)

    column_list = ',\n'.join(columns)

    log("column list: %s" % column_list)


    srccurs.execute("select count(*) from %s" % srctable)
    r = srccurs.fetchone()
    srccnt = r[0]

    log("source table count: %d" % srccnt)

    srccurs.execute("""
            select %(column_list)s
                from %(srctable)s
        """ % {'column_list': column_list,
                'srctable': srctable})

    value_placeholders = [":v%d" % n for n in xrange(len(columns))]


    insertsql = """
            insert
            into %(dsttable)s (%(column_list)s)
            values (%(value_placeholders)s)
        """ % {
                'dsttable': dsttable,
                'column_list': column_list,
                'value_placeholders': ',\n'.join(value_placeholders),
            }


    if not table_exists(dstconn, None, dsttable):
        log("creating dest table")
        sqls = get_table_create_sqls(srcconn, None, srctable)
        for sql in sqls:
            log("execuing sql on dstconn: %s" % sql)
            dstcurs.execute(sql)


    log("insert sql is: %s" % insertsql)

    def maptype(v):
        if isinstance(v, cx_Oracle.LOB):
            return cx_Oracle.BLOB
        else:
            return None

    for n,r in enumerate(srccurs):
        if n % 1000 == 0:
            log("n: %d" % n)
            if n > 0:
                log("commit...")
                dstconn.commit()
                log("commit done.")

        values = [(x.read() if isinstance(x, cx_Oracle.LOB) else x) for x in r]
        #log("values: %s" % values)
        inserttypes = dict([tm for tm in zip(["v%d" % n for n in xrange(len(r))], [maptype(v) for v in r]) if tm[1] is not None])
        #log("inserttypes: %s" % repr(inserttypes))
        insertplaceholders = value_placeholders

        assert len(insertplaceholders) == len(value_placeholders)
        insertmap = dict(zip(["v%d" % x for x in xrange(len(r))], values))
        #log("insertmap: %s" % repr(insertmap))

        dstcurs.setinputsizes(**inserttypes)
        dstcurs.execute(insertsql, insertmap)

    log("commit...")
    dstconn.commit()
    log("commit done.")


if __name__ == '__main__':
    logging.basicConfig(level=logging.DEBUG)
    main(sys.argv[1], sys.argv[2], sys.argv[3])
