Add basic support for postgresql database schemas (closes #66133)

Add a 'schema' keyword argument to get_connection, which is ignored by all but the postgresql backend. Use it to set the search path on connect, and pass it down to backup/restore commands.

authorJulien Cristau <julien.cristau@logilab.fr>
changeset77bcb93ccf50
branchdefault
phasepublic
hiddenno
parent revision#900d91d54f45 [mysql] Fix SQL command in sql_change_col_type()
child revision#d9a772cdbb19 [postgres] Take schema into account in pg_table and pg_indexes requests
files modified by this revision
__init__.py
mysql.py
postgres.py
sqlite.py
sqlserver.py
sqlserver2005.py
test/unittest_db.py
# HG changeset patch
# User Julien Cristau <julien.cristau@logilab.fr>
# Date 1400255594 -7200
# Fri May 16 17:53:14 2014 +0200
# Node ID 77bcb93ccf50b9633ee0610b0d731093ec6c57b9
# Parent 900d91d54f4563beb0ed12285f006f5a376d54f9
Add basic support for postgresql database schemas (closes #66133)

Add a 'schema' keyword argument to get_connection, which is ignored by
all but the postgresql backend. Use it to set the search path on
connect, and pass it down to backup/restore commands.

diff --git a/__init__.py b/__init__.py
@@ -83,11 +83,11 @@
1          mod = err.adapted_obj
2      return mod
3 
4  def get_connection(driver='postgres', host='', database='', user='',
5                    password='', port='', quiet=False, drivers=_PREFERED_DRIVERS,
6 -                  pywrap=False, extra_args=None):
7 +                  pywrap=False, schema=None, extra_args=None):
8      """return a db connection according to given arguments
9 
10      extra_args is an optional string that is appended to the DSN"""
11      _ensure_module_loaded(driver)
12      module, modname = _import_driver_module(driver, drivers)
@@ -106,11 +106,12 @@
13          except ValueError:
14              pass
15      if port:
16          port = int(port)
17      return adapted_module.connect(host, database, user, password,
18 -                                  port=port, extra_args=extra_args)
19 +                                  port=port, schema=schema,
20 +                                  extra_args=extra_args)
21 
22  def set_prefered_driver(driver, module, _drivers=_PREFERED_DRIVERS):
23      """sets the preferred driver module for driver
24      driver is the name of the db engine (postgresql, mysql...)
25      module is the name of the module providing the connect function
@@ -337,11 +338,11 @@
26              except AttributeError:
27                  self.logger.warning('%s adapter has no %s type code',
28                                      self, typecode)
29 
30      def connect(self, host='', database='', user='', password='', port='',
31 -                extra_args=None):
32 +                schema=None, extra_args=None):
33          """Wraps the native module connect method"""
34          kwargs = {'host' : host, 'port' : port, 'database' : database,
35                    'user' : user, 'password' : password}
36          return self._wrap_if_needed(self._native_module.connect(**kwargs))
37 
@@ -753,11 +754,11 @@
38      alter_column_support = True
39      case_sensitive = False
40 
41      # allow call to [backup|restore]_commands without previous call to
42      # record_connection_information but by specifying argument explicitly
43 -    dbname = dbhost = dbport = dbuser = dbpassword = dbextraargs = dbencoding = None
44 +    dbname = dbhost = dbport = dbuser = dbpassword = dbextraargs = dbencoding = dbschema = None
45 
46      def __init__(self, encoding='utf-8', _cnx=None):
47          self.dbencoding = encoding
48          self._cnx = _cnx
49          self.dbapi_module = get_dbapi_compliant_module(self.backend_name)
@@ -769,19 +770,20 @@
50                                                       self.backend_name, id(self))
51          return super(_GenericAdvFuncHelper, self).__repr__()
52 
53      def record_connection_info(self, dbname, dbhost=None, dbport=None,
54                                 dbuser=None, dbpassword=None, dbextraargs=None,
55 -                               dbencoding=None):
56 +                               dbencoding=None, dbschema=None):
57          self.dbname = dbname
58          self.dbhost = dbhost
59          self.dbport = dbport
60          self.dbuser = dbuser
61          self.dbpasswd = dbpassword
62          self.dbextraargs = dbextraargs
63          if dbencoding:
64              self.dbencoding = dbencoding
65 +        self.dbschema = dbschema
66 
67      def get_connection(self, initcnx=True):
68          """open and return a connection to the database
69 
70          you should first call record_connection_info to set connection
@@ -794,10 +796,11 @@
71              self.logger.info('connecting to %s@%s', self.dbname,
72                           self.dbhost or 'localhost')
73          cnx = self.dbapi_module.connect(self.dbhost, self.dbname,
74                                          self.dbuser,self.dbpasswd,
75                                          port=self.dbport,
76 +                                        schema=self.dbschema,
77                                          extra_args=self.dbextraargs)
78          if initcnx:
79              for hook in SQL_CONNECT_HOOKS.get(self.backend_name, ()):
80                  hook(cnx)
81          return cnx
@@ -819,11 +822,11 @@
82      def system_database(self):
83          """return the system database for the given driver"""
84          raise NotImplementedError('not supported by this DBMS')
85 
86      def backup_commands(self, backupfile, keepownership=True,
87 -                        dbname=None, dbhost=None, dbport=None, dbuser=None):
88 +                        dbname=None, dbhost=None, dbport=None, dbuser=None, dbschema=None):
89          """Return a list of commands to backup the given database.
90 
91          Each command may be given as a list or as a string. In the latter case,
92          expected to be used with a subshell (for instance using `os.system(cmd)`
93          or `subprocess.call(cmd, shell=True)`
@@ -1056,10 +1059,18 @@
94 
95      def create_database(self, cursor, dbname, owner=None, dbencoding=None):
96          """create a new database"""
97          raise NotImplementedError('not supported by this DBMS')
98 
99 +    def create_schema(self, cursor, schema, granted_user=None):
100 +        """create a new database schema"""
101 +        raise NotImplementedError('not supported by this DBMS')
102 +
103 +    def drop_schema(self, cursor, schema):
104 +        """drop a database schema"""
105 +        raise NotImplementedError('not supported by this DBMS')
106 +
107      def list_databases(self):
108          """return the list of existing databases"""
109          raise NotImplementedError('not supported by this DBMS')
110 
111      def list_users(self, cursor):
diff --git a/mysql.py b/mysql.py
@@ -19,10 +19,12 @@
112 
113  Full-text search based on MyISAM full text search capabilities.
114  """
115  __docformat__ = "restructuredtext en"
116 
117 +from warnings import warn
118 +
119  from logilab import database as db
120  from logilab.database.fti import normalize_words, tokenize
121 
122  class _MySqlDBAdapter(db.DBAPIAdapter):
123      """Simple mysql Adapter to DBAPI
@@ -49,15 +51,18 @@
124              times.TimeDelta = times.timedelta = mxdt.TimeDelta
125              times.DateTimeType = mxdt.DateTimeType
126              times.DateTimeDeltaType = mxdt.DateTimeDeltaType
127 
128      def connect(self, host='', database='', user='', password='', port=None,
129 -                unicode=True, charset='utf8', extra_args=None):
130 +                unicode=True, charset='utf8', schema=None, extra_args=None):
131          """Handles mysqldb connection format
132          the unicode named argument asks to use Unicode objects for strings
133          in result sets and query parameters
134          """
135 +        if schema is not None:
136 +            warn('schema support is not implemented on mysql backends, ignoring schema %s'
137 +                 % schema)
138          kwargs = {'host' : host or '', 'db' : database,
139                    'user' : user, 'passwd' : password,
140                    'use_unicode' : unicode}
141          # MySQLdb doesn't support None port
142          if port:
@@ -160,18 +165,18 @@
143      def system_database(self):
144          """return the system database for the given driver"""
145          return ''
146 
147      def backup_commands(self, backupfile, keepownership=True,
148 -                        dbname=None, dbhost=None, dbport=None, dbuser=None):
149 +                        dbname=None, dbhost=None, dbport=None, dbuser=None, dbschema=None):
150          cmd = self.mycmd('mysqldump', dbhost, dbport, dbuser)
151          cmd += ('-p', '-r', backupfile, dbname or self.dbname)
152          return [cmd]
153 
154      def restore_commands(self, backupfile, keepownership=True, drop=True,
155                           dbname=None, dbhost=None, dbport=None, dbuser=None,
156 -                         dbencoding=None):
157 +                         dbencoding=None, dbschema=None):
158          dbname = dbname or self.dbname
159          cmds = []
160          mysqlcmd = ' '.join(self.mycmd('mysql', dbhost, dbport, dbuser))
161          if drop:
162              cmd = 'echo "DROP DATABASE %s;" | %s -p' % (
diff --git a/postgres.py b/postgres.py
@@ -50,10 +50,13 @@
163                         'tsearch2.sql')
164 
165 
166  class _Psycopg2Adapter(db.DBAPIAdapter):
167      """Simple Psycopg2 Adapter to DBAPI (cnx_string differs from classical ones)
168 +
169 +    It provides basic support for postgresql schemas :
170 +    cf. http://www.postgresql.org/docs/current/static/ddl-schemas.html
171      """
172      # not defined in psycopg2.extensions
173      # "select typname from pg_type where oid=705";
174      UNKNOWN = 705
175      returns_unicode = True
@@ -71,11 +74,11 @@
176          extensions.register_type(unicodearray)
177          self.BOOLEAN = extensions.BOOLEAN
178          db.DBAPIAdapter.__init__(self, native_module, pywrap)
179          self._init_psycopg2()
180 
181 -    def connect(self, host='', database='', user='', password='', port='', extra_args=None):
182 +    def connect(self, host='', database='', user='', password='', port='', schema=None, extra_args=None):
183          """Handles psycopg connection format"""
184          args = {}
185          if host:
186              args.setdefault('host', host)
187          if database:
@@ -89,12 +92,35 @@
188          cnx_string = ' '.join('%s=%s' % item for item in args.iteritems())
189          if extra_args is not None:
190              cnx_string += ' ' + extra_args
191          cnx = self._native_module.connect(cnx_string)
192          cnx.set_isolation_level(1)
193 +        self.set_search_path(cnx, schema)
194          return self._wrap_if_needed(cnx)
195 
196 +    def _schema_exists(self, cursor, schema):
197 +        cursor.execute('SELECT nspname FROM pg_namespace WHERE nspname=%(s)s',
198 +                       {'s': schema})
199 +        return cursor.fetchone() is not None
200 +
201 +    def set_search_path(self, cnx, schema):
202 +        if schema:
203 +            cursor = cnx.cursor()
204 +            if not self._schema_exists(cursor, schema):
205 +                warn("%s schema doesn't exist, search path can't be set" % schema,
206 +                     UserWarning)
207 +                return
208 +            cursor.execute('SHOW search_path')
209 +            schemas = cursor.fetchone()[0].split(',')
210 +            if schema not in schemas:
211 +                schemas.insert(0, schema)
212 +            else:
213 +                schemas.pop(schemas.index(schema))
214 +                schemas.insert(0, schema)
215 +            cursor.execute('SET search_path TO %s;' % ','.join(schemas))
216 +            cursor.close()
217 +
218      def _init_psycopg2(self):
219          """initialize psycopg2 to use mx.DateTime for date and timestamps
220          instead for datetime.datetime"""
221          psycopg2 = self._native_module
222          if hasattr(psycopg2, '_lc_initialized'):
@@ -162,38 +188,41 @@
223      TYPE_MAPPING.update({
224          'TZTime' :   'time with time zone',
225          'TZDatetime':'timestamp with time zone'})
226      TYPE_CONVERTERS = db._GenericAdvFuncHelper.TYPE_CONVERTERS.copy()
227 
228 -    def pgdbcmd(self, cmd, dbhost, dbport, dbuser, *args):
229 +    def pgdbcmd(self, cmd, dbhost, dbport, dbuser, dbschema, *args):
230          cmd = [cmd]
231          cmd += args
232          if dbhost or self.dbhost:
233              cmd.append('--host=%s' % (dbhost or self.dbhost))
234          if dbport or self.dbport:
235              cmd.append('--port=%s' % (dbport or self.dbport))
236          if dbuser or self.dbuser:
237              cmd.append('--username=%s' % (dbuser or self.dbuser))
238 +        if dbschema or self.dbschema:
239 +            cmd.append('--schema=%s' % (dbschema or self.dbschema))
240          return cmd
241 
242      def system_database(self):
243          """return the system database for the given driver"""
244          return 'template1'
245 
246      def backup_commands(self, backupfile, keepownership=True,
247 -                        dbname=None, dbhost=None, dbport=None, dbuser=None):
248 -        cmd = self.pgdbcmd('pg_dump', dbhost, dbport, dbuser, '-Fc')
249 +                        dbname=None, dbhost=None, dbport=None, dbuser=None, dbschema=None):
250 +        cmd = self.pgdbcmd('pg_dump', dbhost, dbport, dbuser, dbschema, '-Fc')
251          if not keepownership:
252              cmd.append('--no-owner')
253          cmd.append('--file')
254          cmd.append(backupfile)
255          cmd.append(dbname or self.dbname)
256          return [cmd]
257 
258      def restore_commands(self, backupfile, keepownership=True, drop=True,
259                           dbname=None, dbhost=None, dbport=None, dbuser=None,
260 -                         dbencoding=None):
261 +                         dbencoding=None, dbschema=None):
262 +        # XXX what about dbschema ?
263          dbname = dbname or self.dbname
264          cmds = []
265          if drop:
266              cmd = self.pgdbcmd('dropdb', dbhost, dbport, dbuser)
267              cmd.append(dbname)
@@ -259,10 +288,21 @@
268          dbencoding = dbencoding or self.dbencoding
269          if dbencoding:
270              sql += " ENCODING='%(dbencoding)s'"
271          cursor.execute(sql % locals())
272 
273 +    def create_schema(self, cursor, schema, granted_user=None):
274 +        """create a new database schema"""
275 +        sql = 'CREATE SCHEMA %s' % schema
276 +        if granted_user is not None:
277 +            sql += ' AUTHORIZATION %s' % granted_user
278 +        cursor.execute(sql)
279 +
280 +    def drop_schema(self, cursor, schema):
281 +        """drop a database schema"""
282 +        cursor.execute('DROP SCHEMA %s CASCADE' % schema)
283 +
284      def create_language(self, cursor, extlang):
285          """postgres specific method to install a procedural language on a database"""
286          # make sure plpythonu is not directly in template1
287          cursor.execute("SELECT * FROM pg_language WHERE lanname='%s';" % extlang)
288          if cursor.fetchall():
diff --git a/sqlite.py b/sqlite.py
@@ -146,13 +146,17 @@
289                                   microseconds)
290 
291              sqlite.register_converter('interval', convert_timedelta)
292 
293 
294 -    def connect(self, host='', database='', user='', password='', port=None, extra_args=None):
295 +    def connect(self, host='', database='', user='', password='', port=None,
296 +                schema=None, extra_args=None):
297          """Handles sqlite connection format"""
298          sqlite = self._native_module
299 +        if schema is not None:
300 +            warn('schema support is not implemented on sqlite backends, ignoring schema %s'
301 +                 % schema)
302 
303          class Sqlite3Cursor(sqlite.Cursor):
304              """cursor adapting usual dict format to pysqlite named format
305              in SQL queries
306              """
@@ -224,17 +228,17 @@
307      alter_column_support = False
308 
309      TYPE_CONVERTERS = db._GenericAdvFuncHelper.TYPE_CONVERTERS.copy()
310 
311      def backup_commands(self, backupfile, keepownership=True,
312 -                        dbname=None, dbhost=None, dbport=None, dbuser=None):
313 +                        dbname=None, dbhost=None, dbport=None, dbuser=None, dbschema=None):
314          dbname = dbname or self.dbname
315          return ['gzip -c %s > %s' % (dbname, backupfile)]
316 
317      def restore_commands(self, backupfile, keepownership=True, drop=True,
318                           dbname=None, dbhost=None, dbport=None, dbuser=None,
319 -                         dbencoding=None):
320 +                         dbencoding=None, dbschema=None):
321          return ['zcat %s > %s' % (backupfile, dbname or self.dbname)]
322 
323      def sql_create_index(self, table, column, unique=False):
324          idx = self._index_name(table, column, unique)
325          if unique:
diff --git a/sqlserver.py b/sqlserver.py
@@ -25,10 +25,11 @@
326  """
327  __docformat__ = "restructuredtext en"
328 
329  import datetime
330  import re
331 +from warnings import warn
332 
333  from logilab import database as db
334 
335  class _BaseSqlServerAdapter(db.DBAPIAdapter):
336      driver = 'Override in subclass'
@@ -57,21 +58,27 @@
337          if 'trusted_connection' in arguments:
338              cls.use_trusted_connection(True)
339          if 'autocommit' in arguments:
340              cls.use_autocommit(True)
341 
342 -    def connect(self, host='', database='', user='', password='', port=None, extra_args=None):
343 +    def connect(self, host='', database='', user='', password='', port=None,
344 +                schema=None, extra_args=None):
345          """Handles pyodbc connection format
346 
347          If extra_args is not None, it is expected to be a string
348          containing a list of semicolon separated keywords. The only
349          keyword currently supported is Trusted_Connection : if found
350          the connection string to the database will include
351          Trusted_Connection=yes (which for SqlServer will trigger using
352          Windows Authentication, and therefore no login/password is
353          required.
354          """
355 +        if schema is not None:
356 +            # NOTE: SQLServer supports schemas
357 +            # cf. http://msdn.microsoft.com/en-us/library/ms189462%28v=SQL.90%29.aspx
358 +            warn('schema support is not implemented on sqlserver backends, ignoring schema %s'
359 +                 % schema)
360          class SqlServerCursor(object):
361              """cursor adapting usual dict format to pyodbc/adobdapi format
362              in SQL queries
363              """
364              def __init__(self, cursor):
diff --git a/sqlserver2005.py b/sqlserver2005.py
@@ -80,18 +80,18 @@
365                     % table)
366          cursor.execute(sql)
367          return [r[0] for r in cursor.fetchall()]
368 
369      def backup_commands(self, backupfile, keepownership=True,
370 -                        dbname=None, dbhost=None, dbport=None, dbuser=None):
371 +                        dbname=None, dbhost=None, dbport=None, dbuser=None, dbschema=None):
372          return [[sys.executable, os.path.normpath(__file__),
373                   "_SqlServer2005FuncHelper._do_backup", dbhost or self.dbhost,
374                   dbname or self.dbname, backupfile]]
375 
376      def restore_commands(self, backupfile, keepownership=True, drop=True,
377                           dbname=None, dbhost=None, dbport=None, dbuser=None,
378 -                         dbencoding=None):
379 +                         dbencoding=None, dbschema=None):
380          return [[sys.executable, os.path.normpath(__file__),
381                  "_SqlServer2005FuncHelper._do_restore", dbhost or self.dbhost,
382                   dbname or self.dbname, backupfile],
383                  ]
384 
diff --git a/test/unittest_db.py b/test/unittest_db.py
@@ -78,15 +78,15 @@
385          self.assertEqual(PREFERED_DRIVERS['postgres'], expected)
386 
387 
388  class GetCnxTC(TestCase):
389      def setUp(self):
390 +        self.host = 'localhost'
391          try:
392 -            socket.gethostbyname('centaurus')
393 +            socket.gethostbyname(self.host)
394          except:
395              self.skipTest("those tests require specific DB configuration")
396 -        self.host = 'centaurus'
397          self.db = 'template1'
398          self.user = getlogin()
399          self.passwd = getlogin()
400          self.old_drivers = PREFERED_DRIVERS['postgres'][:]
401 
@@ -320,7 +320,67 @@
402 
403          def test_varbinary_none(self):
404              self.varbinary_none()
405 
406 
407 +class PostgresqlDatabaseSchemaTC(TestCase):
408 +    host = 'localhost'
409 +    database = 'template1'
410 +    user = password = getlogin()
411 +    schema = 'tests'
412 +
413 +    def setUp(self):
414 +        try:
415 +            self.module = get_dbapi_compliant_module('postgres')
416 +        except ImportError:
417 +            self.skipTest('postgresql dbapi module not installed')
418 +        try:
419 +            cnx = self.get_connection()
420 +        except Exception:
421 +            self.skipTest('could not connect to %s:%s@%s/%s'
422 +                          % (self.user, self.password, self.host, self.database))
423 +        self._execute(cnx, 'CREATE SCHEMA %s' % self.schema)
424 +        cnx.close()
425 +
426 +    def tearDown(self):
427 +        cnx = self.get_connection()
428 +        self._execute(cnx, 'DROP SCHEMA %s' % self.schema)
429 +        cnx.close()
430 +
431 +    def _execute(self, cnx, sql):
432 +        cursor = cnx.cursor()
433 +        cursor.execute(sql)
434 +        cnx.commit()
435 +        cursor.close()
436 +        cnx.close()
437 +
438 +    def get_connection(self, schema=None):
439 +        return self.module.connect(host=self.host, database=self.database,
440 +                                   user=self.user, password=self.password,
441 +                                   schema=schema)
442 +
443 +    def assertRsetEqual(self, rset, expected_rset):
444 +        # NOTE: different drivers will use different result structures
445 +        #       (list of lists, list of tuples, etc.)
446 +        self.assertEqual(len(rset), len(expected_rset))
447 +        for line, expected_line in zip(rset, expected_rset):
448 +            self.assertSequenceEqual(line, expected_line)
449 +
450 +    def test_database_schema(self):
451 +        """Tests database schema support"""
452 +        cnx = self.get_connection(schema=self.schema)
453 +        cursor = cnx.cursor()
454 +        try:
455 +            cursor.execute('CREATE TABLE x(x integer)')
456 +            cursor.execute('INSERT INTO x VALUES(12)')
457 +            cursor.execute('SELECT x from x')
458 +            self.assertRsetEqual(cursor.fetchall(), [[12]])
459 +            cursor.execute('SELECT x from tests.x')
460 +            self.assertRsetEqual(cursor.fetchall(), [[12]])
461 +            self.assertRaises(self.module.Error, cursor.execute, 'SELECT x from public.x')
462 +        finally:
463 +            cnx.rollback()
464 +            cnx.close()
465 +
466 +
467  if __name__ == '__main__':
468      unittest_main()