# Copyright (c) 2012, 2019, Oracle and/or its affiliates. All rights reserved. # # This program is free software; you can redistribute it and/or modify # it under the terms of the GNU General Public License, version 2.0, # as published by the Free Software Foundation. # # This program is also distributed with certain software (including # but not limited to OpenSSL) that is licensed under separate terms, as # designated in a particular file or component or in included license # documentation. The authors of MySQL hereby grant you an additional # permission to link the program and your derivative works with the # separately licensed software that they have included with MySQL. # This program is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See # the GNU General Public License, version 2.0, for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software Foundation, Inc., # 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA # import the wb module from wb import DefineModule, wbinputs # import the grt module import grt # import the mforms module for GUI stuff import mforms import os from workbench.log import log_error from workbench.notifications import NotificationCenter from sql_reformatter import formatter_for_statement_ast from text_output import TextOutputTab from run_script import RunScriptForm from sqlide_catalogman_ext import show_schema_manager from sqlide_tableman_ext import show_table_inspector from sqlide_resultset_ext import handleResultsetContextMenu import sqlide_catalogman_ext import sqlide_tableman_ext import sqlide_schematree_ext import sqlide_import_spatial import sqlide_power_import_wizard import sqlide_power_export_wizard # define this Python module as a GRT module ModuleInfo = DefineModule(name= "SQLIDEUtils", author= "Oracle Corp.", version="1.1") @ModuleInfo.export(grt.INT) def initialize0(): nc = NotificationCenter() nc.add_observer(handleResultsetContextMenu, name = "GRNSQLResultsetMenuWillShow") sqlide_schematree_ext.init() # register a handler for when the SQLIDE live schema tree context menu is about to be shown nc.add_observer(sqlide_schematree_ext.handleLiveTreeContextMenu, name = "GRNLiveDBObjectMenuWillShow") # must be 1st nc.add_observer(sqlide_catalogman_ext.handleLiveTreeContextMenu, name = "GRNLiveDBObjectMenuWillShow") nc.add_observer(sqlide_tableman_ext.handleLiveTreeContextMenu, name = "GRNLiveDBObjectMenuWillShow") nc.add_observer(sqlide_import_spatial.handleContextMenu, name = "GRNLiveDBObjectMenuWillShow") nc.add_observer(sqlide_power_import_wizard.handleContextMenu, name = "GRNLiveDBObjectMenuWillShow") nc.add_observer(sqlide_power_export_wizard.handleContextMenu, name = "GRNLiveDBObjectMenuWillShow") @ModuleInfo.export(grt.INT, grt.classes.db_query_Editor) def launchPowerImport(editor): sqlide_power_import_wizard.showPowerImport(editor, {'table': None, 'schema': editor.defaultSchema}) return 0 @ModuleInfo.export(grt.INT, grt.classes.db_query_EditableResultset) def importRecordsetDataFromFile(resultset): file_chooser = mforms.newFileChooser(None, mforms.OpenFile) file_chooser.set_title('Import Recordset From CSV File') file_chooser.set_directory(os.path.expanduser('~')) file_chooser.set_extensions('CSV Files (*.csv)|*.csv', 'import') if file_chooser.run_modal(): with open(file_chooser.get_path(), 'r') as import_file: ext = os.path.splitext(import_file.name)[1].lower() import_module = None if ext == '.csv': import csv as import_module elif ext == '.sql': pass # Here will go our not yet written .sql reader else: import csv as import_module if import_module: reader = import_module.reader(import_file) column_count = len(resultset.columns) type_classes = { 'string':str, 'int':int, 'real':float, 'blob':str, 'date':str, 'time':str, 'datetime':str, 'geo':str, } converters = tuple(type_classes[column.columnType] for column in resultset.columns) for row in reader: if len(row) < column_count: # Fill with default values row.extend(converter() for converter in converters[len(row):]) try: converted_values = [ converter(value) for converter, value in zip(converters, row) ] except ValueError: continue # TODO: log a warning here resultset.addNewRow() for column, value in enumerate(converted_values): if isinstance(value, str): resultset.setStringFieldValue(column, value) elif isinstance(value, int): resultset.setIntFieldValue(column, value) elif isinstance(value, float): resultset.setFloatFieldValue(column, value) else: resultset.setFieldNull(column) resultset.addNewRow() # needed in Windows to refresh display for last row return 0 @ModuleInfo.plugin("wb.sqlide.executeToTextOutput", caption= "Execute Query Into Text Output", input= [wbinputs.currentQueryEditor()], accessibilityName="Execute Into Text Output") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def executeQueryAsText(qbuffer): editor = qbuffer.owner sql = qbuffer.selectedText or qbuffer.script resultsets = editor.executeScript(sql) if resultsets: view = TextOutputTab("") dock = mforms.fromgrt(qbuffer.resultDockingPoint) dock.dock_view(view, "", 0) view.set_title("Query Output") dock.select_view(view) for result in resultsets: output = ["Execute:"] output.append("> %s\n" % result.sql) line = [] column_lengths = [] ncolumns = len(result.columns) for column in result.columns: line.append(column.name + " "*5) column_lengths.append(len(column.name)+5) separator = [] for c in column_lengths: separator.append("-"*c) separator = " + ".join(separator) output.append("+ "+separator+" +") line = " | ".join(line) output.append("| "+line+" |") output.append("+ "+separator+" +\n") ok = result.goToFirstRow() if ok: view.textbox.append_text('\n'.join(output)) import time last_flush = 0 rows = [] while ok: line = [] for i in range(ncolumns): value = result.stringFieldValue(i) if value is None: value = "NULL" line.append(value.ljust(column_lengths[i])) line= " | ".join(line) rows.append("| "+line+" |") # flush text every 1/2s if time.time() - last_flush >= 0.5: last_flush = time.time() view.textbox.append_text("\n".join(rows)+"\n") rows = [] ok = result.nextRow() if rows: view.textbox.append_text("\n".join(rows)+"\n") view.textbox.append_text("+ "+separator+" +\n") view.textbox.append_text("%i rows\n\n" % (result.currentRow + 1)) return 0 @ModuleInfo.plugin('wb.sqlide.verticalOutput', caption='Vertical Output', input=[wbinputs.currentQueryEditor()], accessibilityName="Execute Into Vertical Output") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def verticalOutput(editor): statement = editor.currentStatement if statement: rsets = editor.owner.executeScript(statement) output = [ '> %s\n' % statement ] for idx, rset in enumerate(rsets): if len(rsets) > 1: output.append('Result set %i' % (idx+1)) column_name_length = max(len(col.name) for col in rset.columns) ok = rset.goToFirstRow() while ok: output.append('******************** %s. row *********************' % (rset.currentRow + 1)) for i, column in enumerate(rset.columns): col_name, col_value = column.name.rjust(column_name_length), rset.stringFieldValue(i) output.append('%s: %s' % (col_name, col_value if col_value is not None else 'NULL')) ok = rset.nextRow() output.append('%d rows in set' % (rset.currentRow + 1)) rset.reset_references() if len(rsets) > 1: output.append('') view = TextOutputTab('\n'.join(output) + '\n') dock = mforms.fromgrt(editor.resultDockingPoint) dock.dock_view(view, '', 0) dock.select_view(view) view.set_title('Vertical Output') return 0 def doReformatSQLStatement(text, return_none_if_unsupported): ast_list = grt.modules.MysqlSqlFacade.parseAstFromSqlScript(text) if len(ast_list) != 1: raise Exception("Error parsing statement") if type(ast_list[0]) is str: raise Exception("Error parsing statement: %s" % ast_list[0]) ast = ast_list[0] def trim_ast_fix_bq(text, node, add_rollup): s = node[0] v = node[1] c = node[2] # put back backquotes to identifiers, if there's any if s in ("ident", "ident_or_text"): begin = node[3] + node[4] end = node[3] + node[5] if begin > 0 and text[begin-1] == '`' and text[end] == '`': v = "`%s`" % v.replace("`", "``") l = [] for i, nc in enumerate(c): l.append(trim_ast_fix_bq(text, nc, add_rollup)) if add_rollup and nc[0] == "olap_opt" and nc[1].upper() == "WITH" and (i == len(c)-1 or c[i+1][1].upper() != "ROLLUP"): l.append(("olap_opt", "ROLLUP", [])) return (s, v, l) formatter = formatter_for_statement_ast(ast) if formatter: # workaround a bug in parser where WITH ROLLUP is turned into WITH add_rollup = "WITH ROLLUP" in text.upper() p = formatter(trim_ast_fix_bq(text, ast, add_rollup)) return p.run() else: if return_none_if_unsupported: return None return text @ModuleInfo.export(grt.STRING, grt.STRING) def reformatSQLStatement(text): return doReformatSQLStatement(text, False) @ModuleInfo.plugin("wb.sqlide.enbeautificate", caption = "Reformat SQL Query", input=[wbinputs.currentQueryBuffer()], accessibilityName="Reformat Query") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryBuffer) def enbeautificate(editor): """Reformat the selected SQL statements or the one under the cursor.""" text = editor.selectedText selectionOnly = True if not text: selectionOnly = False text = editor.currentStatement ok_count = 0 bad_count = 0 prev_end = 0 new_text = [] ranges = grt.modules.MysqlSqlFacade.getSqlStatementRanges(text) for begin, end in ranges: end = begin + end if begin > prev_end: new_text.append(text[prev_end:begin]) statement = text[begin:end] # stripped = statement.lstrip(" \t\r\n") leading = statement[:len(statement) - len(stripped)] statement = stripped stripped = statement.rstrip(" \t\r\n") if stripped != statement: trailing = statement[-(len(statement) - len(stripped)):] else: trailing = "" statement = stripped # if there's a comment at the start, then skip the comment until its end while True: if statement.startswith("-- "): comment, _, rest = statement.partition("\n") leading += comment+"\n" statement = rest elif statement.startswith("/*"): pos = statement.find("*/") if pos >= 0: leading += statement[:pos+2] statement = statement[pos+2:] else: break else: break stripped = statement.lstrip(" \t\r\n") leading += statement[:len(statement) - len(stripped)] statement = stripped stripped = statement.rstrip(" \t\r\n") if stripped != statement: trailing += statement[-(len(statement) - len(stripped)):] statement = stripped try: result = doReformatSQLStatement(statement, True) except: import traceback log_error("Error reformating SQL: %s\n%s\n" % (statement, traceback.format_exc())) result = None if result: ok_count += 1 if leading: new_text.append(leading.strip(" ")) new_text.append(result) if trailing: new_text.append(trailing.strip(" ")) else: bad_count += 1 new_text.append(text[begin:end]) prev_end = end new_text.append(text[prev_end:]) new_text = "".join(new_text) if selectionOnly: editor.replaceSelection(new_text) else: editor.replaceCurrentStatement(new_text) if bad_count > 0: mforms.App.get().set_status_text("Formatted %i statements, %i unsupported statement types skipped."%(ok_count, bad_count)) else: mforms.App.get().set_status_text("Formatted %i statements."%ok_count) return 0 def apply_to_keywords(editor, callable): non_keywords = ["ident", "ident_or_text", "TEXT_STRING", "text_string", "TEXT_STRING_filesystem", "TEXT_STRING_literal", "TEXT_STRING_sys", "part_name"] text = editor.selectedText selectionOnly = True if not text: selectionOnly = False text = editor.script new_text = "" ast_list = grt.modules.MysqlSqlFacade.parseAstFromSqlScript(text) bb = 0 for ast in ast_list: if type(ast) is str: # error print(ast) mforms.App.get().set_status_text("Cannot format invalid SQL: %s"%ast) return 1 else: if 0: # debug from sql_reformatter import dump_tree import sys dump_tree(sys.stdout, ast) def get_keyword_offsets(offsets, script, node): s, v, c, base, b, e = node if v: b += base e += base if s not in non_keywords: offsets.append((b, e)) for i in c: get_keyword_offsets(offsets, script, i) offsets = [] get_keyword_offsets(offsets, text, ast) for b, e in offsets: new_text += text[bb:b] + callable(text[b:e]) bb = e new_text += text[bb:] if selectionOnly: editor.replaceSelection(new_text) else: editor.replaceContents(new_text) mforms.App.get().set_status_text("SQL code reformatted.") return 0 @ModuleInfo.plugin("wb.sqlide.upcaseKeywords", caption = "Make keywords in query uppercase", input=[wbinputs.currentQueryEditor()], accessibilityName="Uppercase Query Keywords") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def upcaseKeywords(editor): return apply_to_keywords(editor, lambda s: s.upper()) @ModuleInfo.plugin("wb.sqlide.lowercaseKeywords", caption = "Make keywords in query lowercase", input=[wbinputs.currentQueryEditor()], accessibilityName="Lowercase Query Keywords") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def lowercaseKeywords(editor): return apply_to_keywords(editor, lambda s: s.lower()) def get_lines_in_range(text, range_start, range_end): def intersects_range(start1, end1, start2, end2): if start1 <= start2 <= end1 or start1 <= end2 <= end1 or\ start2 <= start1 <= end2 or start2 <= end1 <= end2: return True return False def split(text): lines = [] while text: p = text.find("\n") if p >= 0: lines.append(text[0:p+1]) text = text[p+1:] else: lines.append(text) break return lines all_lines = split(text) offs = 0 lines = [] first_line_start = None last_line_end = None for line in all_lines: line_start = offs line_end = offs+len(line) if intersects_range(range_start, range_end, line_start, line_end-1): if first_line_start is None: first_line_start = line_start last_line_end = line_end lines.append(line) offs = line_end return (first_line_start, last_line_end, lines) @ModuleInfo.plugin("wb.sqlide.indent", caption = "Indent Selected Lines", input=[wbinputs.currentQueryEditor()], accessibilityName="Indent Selected Lines") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def indent(editor): # indent and unindent handle selection a bit differently: # - if there is no selection, only the line where the cursor is should be indented # - if there is a selection, all selected lines should be indented, even if selected partially indentation = " "*4 start = editor.selectionStart end = editor.selectionEnd full_text = editor.script if end > start: first_line_start, last_line_end, lines = get_lines_in_range(full_text, start, end) new_text = indentation + indentation.join(lines) last_line_end = end while last_line_end < len(full_text) and full_text[last_line_end-1] != "\n": last_line_end += 1 if last_line_end != end: new_text = new_text[:-(last_line_end-end)] delta = len(lines) * len(indentation) editor.selectionStart = first_line_start editor.replaceSelection(new_text) # update cursor position editor.selectionEnd = end + delta editor.selectionStart = start + len(indentation) else: line_start = pos = editor.insertionPoint while line_start > 0 and full_text[line_start-1] != "\n": line_start -= 1 editor.replaceContents(full_text[:line_start] + indentation + full_text[line_start:]) editor.insertionPoint = pos + len(indentation) return 0 @ModuleInfo.plugin("wb.sqlide.unindent", caption = "Unindent Selected Lines", input=[wbinputs.currentQueryEditor()], accessibilityName="Unindent Selected Lines") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def unindent(editor): indentation = " "*4 start = editor.selectionStart end = editor.selectionEnd full_text = editor.script if end > start: first_line_start, last_line_end, lines = get_lines_in_range(full_text, start, end) flag = False for i in range(len(lines)): if lines[i].startswith(indentation): lines[i] = lines[i][len(indentation):] flag = True if not flag: return if lines: last_line_start = last_line_end - len(lines[-1]) else: last_line_start = last_line_end new_text = "".join(lines) last_line_end = end while last_line_end < len(full_text) and full_text[last_line_end-1] != "\n": last_line_end += 1 if last_line_end != end: new_text = new_text[:-(last_line_end-end)] delta = len(lines) * len(indentation) # select to the beginning of the line, so that we can indent the whole block editor.selectionStart = first_line_start editor.replaceSelection(new_text) # update cursor position if start - len(indentation) > first_line_start: start = start - len(indentation) else: start = first_line_start if end - delta > last_line_start: end = end - delta else: end = last_line_start editor.selectionEnd = end editor.selectionStart = start else: line_start = pos = editor.insertionPoint while line_start > 0 and full_text[line_start-1] != "\n": line_start -= 1 if full_text[line_start:].startswith(indentation): editor.replaceContents(full_text[:line_start] + full_text[line_start+len(indentation):]) if pos - len(indentation) >= line_start: editor.insertionPoint = pos - len(indentation) return 0 @ModuleInfo.plugin("wb.sqlide.comment", caption = "Un/Comment Selection", input=[wbinputs.currentQueryEditor()], accessibilityName="Comment or Uncomment Selection") @ModuleInfo.export(grt.INT, grt.classes.db_query_QueryEditor) def commentText(editor): commentType = "%s " % grt.root.wb.options.options["DbSqlEditor:SQLCommentTypeForHotkey"] commentTypeLength = len(grt.root.wb.options.options["DbSqlEditor:SQLCommentTypeForHotkey"]) + 1 text = editor.selectedText if text: lines = text.split("\n") if lines[0].startswith(commentType): new_text = "\n".join((line[commentTypeLength:] if line.startswith(commentType) else line) for line in lines) else: new_text = "\n".join(commentType + line if line != "" else line for line in lines) editor.replaceSelection(new_text) else: pos = editor.insertionPoint full_text = editor.script done = False # if cursor is before or after a comment sequence, then delete that if full_text[pos:pos+commentTypeLength] == commentType: editor.replaceContents(full_text[:pos] + full_text[pos+commentTypeLength:]) done = True else: for i in range(4): if full_text[pos+i:pos+i+commentTypeLength] == commentType: editor.replaceContents(full_text[:pos+i] + full_text[pos+i+commentTypeLength:]) done = True break if pos-i >= 0 and full_text[pos-i:pos-i+commentTypeLength] == commentType: editor.replaceContents(full_text[:pos-i] + full_text[pos-i+commentTypeLength:]) done = True pos -= i break if not done: editor.replaceSelection(commentType) editor.insertionPoint = pos return 0 @ModuleInfo.export(grt.INT, grt.classes.db_query_Editor, grt.LIST) def showInspector(editor, selection): schema_insp = [] table_insp = [] table_insp_idx = [] for s in selection: if s.type == "db.Schema": schema_insp.append(s.schemaName) elif (s.type == "db.Table") or (s.type == "db.View"): table_insp.append((s.schemaName, s.name)) elif s.type == "db.Index": table_insp_idx.append((s.schemaName, s.owner.name)) else: log_error("Unsupported inspector type: %s\n" % s.type) if len(schema_insp): show_schema_manager(editor, schema_insp, False) if len(table_insp): show_table_inspector(editor, table_insp) if len(table_insp_idx): show_table_inspector(editor, table_insp_idx, "indexes") return 0 #@ModuleInfo.plugin("wb.sqlide.refactor.renameSchema", caption = "Rename References to Schema Name", input=[wbinputs.currentQueryBuffer()]) # pluginMenu="SQL/Utilities" #@ModuleInfo.export(grt.INT, grt.classes.db_query_QueryBuffer) #def refactorRenameSchema(editor): # pass @ModuleInfo.plugin("wb.sqlide.runScript", caption = "Run SQL Script", input=[wbinputs.currentSQLEditor()], accessibilityName="Run SQL Script") @ModuleInfo.export(grt.INT, grt.classes.db_query_Editor) def runSQLScript(editor): form = RunScriptForm(editor) return form.run() @ModuleInfo.export(grt.INT, grt.classes.db_query_Editor, grt.STRING) def runSQLScriptFile(editor, path): form = RunScriptForm(editor) return form.run_file(path)