[postgres] Take schema into account in pg_table and pg_indexes requests

Closes #100459

authorJulien Cristau <julien.cristau@logilab.fr>
changeset82454dcf6a4a
branchdefault
phasedraft
hiddenyes
parent revision#d49278db2104 Add basic support for postgresql database schemas (closes #66133)
child revision<not specified>
files modified by this revision
postgres.py
test/unittest_db.py
# HG changeset patch
# User Julien Cristau <julien.cristau@logilab.fr>
# Date 1400256891 -7200
# Fri May 16 18:14:51 2014 +0200
# Node ID 82454dcf6a4a4b4a8368058687ac2a3691317ec9
# Parent d49278db210433d5f0b4c81db6b36e311a715648
[postgres] Take schema into account in pg_table and pg_indexes requests

Closes #100459

diff --git a/postgres.py b/postgres.py
@@ -357,21 +357,28 @@
1      def list_databases(self, cursor):
2          """return the list of existing databases"""
3          cursor.execute('SELECT datname FROM pg_database')
4          return [r[0] for r in cursor.fetchall()]
5 
6 -    def list_tables(self, cursor):
7 +    def list_tables(self, cursor, schema=None):
8          """return the list of tables of a database"""
9 -        cursor.execute("SELECT tablename FROM pg_tables")
10 +        schema = schema or self.dbschema
11 +        cursor.execute("SELECT tablename FROM pg_tables WHERE schemaname=%(s)s",
12 +                       {'s': schema})
13          return [r[0] for r in cursor.fetchall()]
14 
15      def list_indices(self, cursor, table=None):
16          """return the list of indices of a database, only for the given table if specified"""
17          sql = "SELECT indexname FROM pg_indexes"
18 +        restrictions = []
19          if table:
20 -            sql += " WHERE LOWER(tablename)='%s'" % table.lower()
21 -        cursor.execute(sql)
22 +            restrictions.append('LOWER(tablename)=%(table)s')
23 +        if self.dbschema:
24 +            restrictions.append('schemaname=%(s)s')
25 +        if restrictions:
26 +            sql += ' WHERE %s' % ' AND '.join(restrictions)
27 +        cursor.execute(sql, {'s': self.dbschema, 'table': table})
28          return [r[0] for r in cursor.fetchall()]
29 
30      # full-text search customization ###########################################
31 
32      fti_table = 'appears'
@@ -467,11 +474,11 @@
33          """If necessary, install extensions at database creation time.
34 
35          For postgres, install tsearch2 if not installed by the template.
36          """
37          tstables = []
38 -        for table in self.list_tables(cursor):
39 +        for table in self.list_tables(cursor, schema='pg_catalog'):
40              if table.startswith('pg_ts'):
41                  tstables.append(table)
42          if tstables:
43              self.logger.info('pg_ts_dict already present, do not execute tsearch2.sql')
44              if owner:
diff --git a/test/unittest_db.py b/test/unittest_db.py
@@ -437,8 +437,31 @@
45              self.assertRaises(self.module.Error, cursor.execute, 'SELECT x from public.x')
46          finally:
47              cnx.rollback()
48              cnx.close()
49 
50 +    def test_list_tables(self):
51 +        helper = get_db_helper('postgres')
52 +        cnx = self.get_connection(schema=self.schema)
53 +        cursor = cnx.cursor()
54 +        try:
55 +            cursor.execute('CREATE TABLE x(x integer)')
56 +            self.assertNotIn('x', helper.list_tables(cursor))
57 +            self.assertIn('x', helper.list_tables(cursor, schema=self.schema))
58 +        finally:
59 +            cnx.close()
60 +
61 +    def test_list_indices(self):
62 +        helper = get_db_helper('postgres')
63 +        cnx = self.get_connection(schema=self.schema)
64 +        cursor = cnx.cursor()
65 +        try:
66 +            cursor.execute('CREATE TABLE x(x integer)')
67 +            cursor.execute('CREATE INDEX x_idx ON x(x)')
68 +            self.assertIn('x_idx', helper.list_indices(cursor))
69 +            self.assertIn('x_idx', helper.list_indices(cursor, table='x'))
70 +        finally:
71 +            cnx.close()
72 +
73 
74  if __name__ == '__main__':
75      unittest_main()