# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA
# 02110-1301  USA

from db_sql92_re_grt import Sql92ReverseEngineering
from wb import DefineModule
from workbench.utils import find_object_with_name
import grt

ModuleInfo = DefineModule(name= "DbPostgresqlRE", author= "Oracle Corp.", version="1.0")

class PostgresqlReverseEngineering(Sql92ReverseEngineering):

    @classmethod
    def getTargetDBMSName(cls):
        return 'Postgresql'

    @classmethod
    def serverVersion(cls, connection):
        return cls._connections[connection.__id__]["version"]

    @classmethod
    def connect(cls, connection, password):
        r = super(PostgresqlReverseEngineering, cls).connect(connection, password)
        if r:
            ver = cls.execute_query(connection, "select version()").fetchone()[0]
            grt.log_info("PostgreSQL RE", "Connected to %s, %s" % (connection.name, ver))
            ver_parts = [int(n) for n in ver.split()[1].rstrip(",").split(".")] + 4*[0]
            version = grt.classes.GrtVersion()
            version.majorNumber, version.minorNumber, version.releaseNumber, version.buildNumber = ver_parts[:4]
            cls._connections[connection.__id__]["version"] = version
            if version.majorNumber < 8:
                raise RuntimeError("PostgreSQL version %s is not a supported migration source.\nAt least version 8 is required." % ver)
        return r

    @classmethod
    def getSchemaNames(cls, connection, catalog_name):
        """Returns a list of schemata for the given connection object."""
        return [ schema_name for schema_name in super(PostgresqlReverseEngineering, cls).getSchemaNames(connection, catalog_name)
                 if schema_name.upper() not in ['INFORMATION_SCHEMA', 'PG_CATALOG'] ]

    ######### Reverse Engineering functions #########

    @classmethod
    def reverseEngineerSequences(cls, connection, schema):
        schema.sequences.remove_all()

        seq_names_query = """SELECT c.relname
                             FROM pg_catalog.pg_class c
                             JOIN pg_catalog.pg_namespace n ON (c.relnamespace = n.oid)
                             WHERE n.nspname = '%s' AND c.relkind in ('S', 's')""" % schema.name

        seq_details_query = """SELECT min_value, max_value, start_value, increment_by, last_value, is_cycled, cache_value
                               FROM %s.%s"""

        sequence_names = cls.execute_query(connection, seq_names_query).fetchall()
        for (seq_name, ) in sequence_names:
            min_value, max_value, start_value, increment_by, last_value, is_cycled, ncache = cls.execute_query(connection, seq_details_query % (schema.name, seq_name)).fetchone()
            sequence = grt.classes.db_Sequence()
            sequence.name = seq_name
            sequence.owner = schema
            sequence.minValue = str(min_value)
            sequence.maxValue = str(max_value)
            sequence.startValue = str(start_value)
            sequence.incrementBy = str(increment_by)
            sequence.lastNumber = str(last_value)
            sequence.cycleFlag = int(is_cycled)
            sequence.cacheSize = str(ncache)
            schema.sequences.append(sequence)

    @classmethod
    def reverseEngineerTableIndices(cls, connection, table):
        schema = table.owner
        catalog = schema.owner

        if len(table.columns) == 0:
            grt.send_error('%s: reverseEngineerTableIndices', 'Reverse engineer of table %s.%s was attempted but the table has no columns attribute' % (cls.getTargetDBMSName(), schema.name, table.name) )
            return 1  # Table must have columns reverse engineered before we can rev eng its indices

        all_indices_query = """SELECT c2.relname, i.indisunique::int, i.indisclustered::int, i.indnatts, i.indkey
                               FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, pg_catalog.pg_namespace n, pg_catalog.pg_index i
                               WHERE c.oid = i.indrelid AND i.indexrelid = c2.oid AND c.relnamespace = n.oid
                                     AND n.nspname = '%s' AND c.relname = '%s' AND i.indisprimary = False
                               ORDER BY c2.relname""" % (schema.name, table.name)

        index_columns_query = """SELECT a.attname
                                 FROM unnest(ARRAY%r) attrid
                                 JOIN pg_catalog.pg_attribute a ON attrid=a.attnum
                                 JOIN pg_catalog.pg_class c ON c.oid = a.attrelid
                                 JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
                                 WHERE n.nspname = '%s' AND c.relname = '%s'"""

        index_rows = cls.execute_query(connection, all_indices_query).fetchall()
        for index_name, is_unique, is_clustered, column_count, column_refs in index_rows:
            index = grt.classes.db_Index()
            index.name = index_name
            index.isPrimary = 0
            index.unique = is_unique
            index.indexType = ('UNIQUE' if is_unique else 'INDEX')
            #index.clustered = is_clustered

            # Get the columns for the index:
            cols = [ int(col) for col in column_refs.split() ]
            if column_count != len(cols):
                grt.send_warning('%s: reverseEngineerTableIndices' % cls.getTargetDBMSName(),
                                 'Reverse engineer of index %s.%s was attempted but the referenced columns count differs '
                                 'from the number of its referenced columns. Skipping index!' % (schema.name, index_name) ) continue for (column_name, ) in cls.execute_query(connection, index_columns_query % (cols, schema.name, table.name)): column = find_object_with_name(table.columns, column_name) if column: index_column = grt.classes.db_IndexColumn() index_column.name = index_name + '.' + column_name #index_column.descend = is_descending_key index_column.referencedColumn = column index.columns.append(index_column) else: grt.send_warning('%s: reverseEngineerTableIndices' % cls.getTargetDBMSName(), 'Reverse engineer of index %s.%s was attempted but the referenced column %s ' 'could not be found on table %s. Skipping index!' % (schema.name, index_name, column_name, table.name) ) continue table.addIndex(index) return 0 @classmethod def getColumnDatatype(cls, connection, table, column, type_name): if type_name == 'USER-DEFINED': query = """SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod) FROM pg_catalog.pg_attribute a LEFT JOIN pg_catalog.pg_class c ON a.attrelid = c.oid LEFT JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid WHERE n.nspname = '%s' AND c.relname = '%s' AND a.attname = '%s' AND NOT a.attisdropped; """ % (table.owner.name, table.name, column.name) udtype = cls.execute_query(connection, query).fetchall() if udtype: type_name = udtype[0][1] return super(PostgresqlReverseEngineering, cls).getColumnDatatype(connection, table, column, type_name) @classmethod def reverseEngineerUserDatatypes(cls, connection, catalog): """ There are several kinds of user datatypes in Postgres, including: - domains - tuples/composite (table like structure) - ranges (numeric ranges with fancy definition, only in 9.2+) - base types - enums - others As of now, we're only supporting domains and enums. Ranges can be migrated to their underlying type. Composite types should be migrated to StructuredTypes at some point. """ version = cls.serverVersion(connection) catalog.userDatatypes.remove_all() query_composite = """SELECT t.typname, at.attname, pg_catalog.format_type(at.atttypid, at.atttypmod) FROM pg_type t JOIN pg_class on (reltype = t.oid) JOIN pg_attribute at on (at.attrelid = pg_class.oid) JOIN pg_type a on (at.atttypid = a.oid) JOIN pg_namespace n on n.oid = t.typnamespace WHERE n.nspname NOT IN ('information_schema', 'pg_catalog') AND pg_class.relkind = 'c' """ # TODO query_domains = """SELECT t.typname, pg_catalog.format_type(t.typbasetype, t.typtypmod) FROM pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace WHERE n.nspname NOT IN ('information_schema', 'pg_catalog') AND t.typtype = 'd' """ domain_types = cls.execute_query(connection, query_domains) for type_name, type_def in domain_types: datatype = grt.classes.db_UserDatatype() datatype.name = type_name datatype.sqlDefinition = type_def if '(' in type_def: base_type = type_def[:type_def.find('(')] else: base_type = type_def up_type_name = base_type.upper() for stype in cls._rdbms.simpleDatatypes: if stype.name.upper() == up_type_name or up_type_name in [s.upper() for s in stype.synonyms]: datatype.actualType = stype break datatype.owner = catalog catalog.userDatatypes.append(datatype) query_ranges = """ """ # TODO query_enums = """SELECT t.typname, e.enumlabel FROM pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace LEFT JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid WHERE t.typrelid = 0 AND t.typtype = 'e' AND n.nspname NOT IN ('information_schema', 'pg_catalog') AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid) ORDER BY e.enumsortorder""" query_enums_80 = """SELECT t.typname, e.enumlabel FROM pg_catalog.pg_type t LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace LEFT JOIN pg_catalog.pg_enum e ON e.enumtypid = t.oid WHERE t.typrelid = 0 AND t.typtype = 'e' AND n.nspname NOT IN ('information_schema', 'pg_catalog') AND NOT EXISTS(SELECT 1 FROM pg_catalog.pg_type el WHERE el.oid = t.typelem AND el.typarray = t.oid) """ enum_types = cls.execute_query(connection, query_enums if version.majorNumber >= 9 else query_enums_80) ltype = None types = [] values = [] for type_name, enum_label in enum_types: if type_name != ltype: ltype = type_name values = [] types.append((type_name, values)) values.append(enum_label) enumType = None for type_name, enum_labels in types: datatype = grt.classes.db_UserDatatype() datatype.name = type_name datatype.sqlDefinition = 'enum(%s)' % (', '.join(["'%s'" % l.replace("'", "''") for l in enum_labels])) datatype.actualType = enumType datatype.owner = catalog catalog.userDatatypes.append(datatype) ########################################################################################### @ModuleInfo.export(grt.classes.db_mgmt_Rdbms) def initializeDBMSInfo(): return PostgresqlReverseEngineering.initializeDBMSInfo('postgresql_rdbms_info.xml') @ModuleInfo.export((grt.LIST, grt.STRING)) def getDataSourceNames(): return PostgresqlReverseEngineering.getDataSourceNames() @ModuleInfo.export(grt.STRING, grt.STRING) def quoteIdentifier(name): return PostgresqlReverseEngineering.quoteIdentifier(name) @ModuleInfo.export(grt.STRING, grt.classes.GrtNamedObject) def fullyQualifiedObjectName(obj): return PostgresqlReverseEngineering.fullyQualifiedObjectName(obj) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.STRING) def connect(connection, password): return PostgresqlReverseEngineering.connect(connection, password) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection) def disconnect(connection): return PostgresqlReverseEngineering.disconnect(connection) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection) def isConnected(connection): return PostgresqlReverseEngineering.isConnected(connection) @ModuleInfo.export(grt.STRING) def getTargetDBMSName(): return PostgresqlReverseEngineering.getTargetDBMSName() @ModuleInfo.export(grt.LIST) def getSupportedObjectTypes(): return PostgresqlReverseEngineering.getSupportedObjectTypes() @ModuleInfo.export(grt.classes.GrtVersion, grt.classes.db_mgmt_Connection) def getServerVersion(connection): return PostgresqlReverseEngineering.getServerVersion(connection) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection) def getCatalogNames(connection): return PostgresqlReverseEngineering.getCatalogNames(connection) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING) def getSchemaNames(connection, catalog_name): return PostgresqlReverseEngineering.getSchemaNames(connection, catalog_name) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING) def getTableNames(connection, catalog_name, schema_name): return PostgresqlReverseEngineering.getTableNames(connection, catalog_name, schema_name) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING) def getViewNames(connection, catalog_name, schema_name): return PostgresqlReverseEngineering.getViewNames(connection, catalog_name, schema_name) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING) def getTriggerNames(connection, catalog_name, schema_name): return PostgresqlReverseEngineering.getTriggerNames(connection, catalog_name, schema_name) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING) def getProcedureNames(connection, catalog_name, schema_name): return PostgresqlReverseEngineering.getProcedureNames(connection, catalog_name, schema_name) @ModuleInfo.export(grt.LIST, grt.classes.db_mgmt_Connection, grt.STRING, grt.STRING) def getFunctionNames(connection, catalog_name, schema_name): return PostgresqlReverseEngineering.getFunctionNames(connection, catalog_name, schema_name) @ModuleInfo.export(grt.classes.db_Catalog, grt.classes.db_mgmt_Connection, grt.STRING, (grt.LIST, grt.STRING), grt.DICT) def reverseEngineer(connection, catalog_name, schemata_list, context): return PostgresqlReverseEngineering.reverseEngineer(connection, catalog_name, schemata_list, context) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Catalog) def reverseEngineerUserDatatypes(connection, catalog): return PostgresqlReverseEngineering.reverseEngineerUserDatatypes(connection, catalog) @ModuleInfo.export(grt.classes.db_Catalog, grt.classes.db_mgmt_Connection, grt.STRING) def reverseEngineerCatalog(connection, catalog_name): return PostgresqlReverseEngineering.reverseEngineerCatalog(connection, catalog_name) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Schema) def reverseEngineerTables(connection, schema): return PostgresqlReverseEngineering.reverseEngineerTables(connection, schema) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Schema) def reverseEngineerViews(connection, schema): return PostgresqlReverseEngineering.reverseEngineerViews(connection, schema) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Schema) def reverseEngineerProcedures(connection, schema): return PostgresqlReverseEngineering.reverseEngineerProcedures(connection, schema) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Schema) def reverseEngineerFunctions(connection, schema): return PostgresqlReverseEngineering.reverseEngineerFunctions(connection, schema) @ModuleInfo.export(grt.INT, grt.classes.db_mgmt_Connection, grt.classes.db_Schema) def reverseEngineerTriggers(connection, schema): return PostgresqlReverseEngineering.reverseEngineerTriggers(connection, schema)