[postgres] Take schema into account in pg_table and pg_indexes requests

Closes #100459

authorJulien Cristau <julien.cristau@logilab.fr>
changesetd9a772cdbb19
branchdefault
phasepublic
hiddenno
parent revision#77bcb93ccf50 Add basic support for postgresql database schemas (closes #66133)
child revision#fa3cb6e47777 [test] skip if we can't connect to a local mysql
files modified by this revision
postgres.py
test/unittest_db.py
# HG changeset patch
# User Julien Cristau <julien.cristau@logilab.fr>
# Date 1400256891 -7200
# Fri May 16 18:14:51 2014 +0200
# Node ID d9a772cdbb19f95e156fd45682ecd7f5eef89af8
# Parent 77bcb93ccf50b9633ee0610b0d731093ec6c57b9
[postgres] Take schema into account in pg_table and pg_indexes requests

Closes #100459

diff --git a/postgres.py b/postgres.py
@@ -319,21 +319,30 @@
1      def list_databases(self, cursor):
2          """return the list of existing databases"""
3          cursor.execute('SELECT datname FROM pg_database')
4          return [r[0] for r in cursor.fetchall()]
5 
6 -    def list_tables(self, cursor):
7 +    def list_tables(self, cursor, schema=None):
8          """return the list of tables of a database"""
9 -        cursor.execute("SELECT tablename FROM pg_tables")
10 +        schema = schema or self.dbschema
11 +        sql = "SELECT tablename FROM pg_tables"
12 +        if schema:
13 +            sql += " WHERE schemaname=%(s)s"
14 +        cursor.execute(sql, {'s': schema})
15          return [r[0] for r in cursor.fetchall()]
16 
17      def list_indices(self, cursor, table=None):
18          """return the list of indices of a database, only for the given table if specified"""
19          sql = "SELECT indexname FROM pg_indexes"
20 +        restrictions = []
21          if table:
22 -            sql += " WHERE LOWER(tablename)='%s'" % table.lower()
23 -        cursor.execute(sql)
24 +            restrictions.append('LOWER(tablename)=%(table)s')
25 +        if self.dbschema:
26 +            restrictions.append('schemaname=%(s)s')
27 +        if restrictions:
28 +            sql += ' WHERE %s' % ' AND '.join(restrictions)
29 +        cursor.execute(sql, {'s': self.dbschema, 'table': table})
30          return [r[0] for r in cursor.fetchall()]
31 
32      # full-text search customization ###########################################
33 
34      fti_table = 'appears'
@@ -429,11 +438,11 @@
35          """If necessary, install extensions at database creation time.
36 
37          For postgres, install tsearch2 if not installed by the template.
38          """
39          tstables = []
40 -        for table in self.list_tables(cursor):
41 +        for table in self.list_tables(cursor, schema='pg_catalog'):
42              if table.startswith('pg_ts'):
43                  tstables.append(table)
44          if tstables:
45              self.logger.info('pg_ts_dict already present, do not execute tsearch2.sql')
46              if owner:
diff --git a/test/unittest_db.py b/test/unittest_db.py
@@ -379,8 +379,31 @@
47              self.assertRaises(self.module.Error, cursor.execute, 'SELECT x from public.x')
48          finally:
49              cnx.rollback()
50              cnx.close()
51 
52 +    def test_list_tables(self):
53 +        helper = get_db_helper('postgres')
54 +        cnx = self.get_connection(schema=self.schema)
55 +        cursor = cnx.cursor()
56 +        try:
57 +            cursor.execute('CREATE TABLE x(x integer)')
58 +            self.assertNotIn('x', helper.list_tables(cursor))
59 +            self.assertIn('x', helper.list_tables(cursor, schema=self.schema))
60 +        finally:
61 +            cnx.close()
62 +
63 +    def test_list_indices(self):
64 +        helper = get_db_helper('postgres')
65 +        cnx = self.get_connection(schema=self.schema)
66 +        cursor = cnx.cursor()
67 +        try:
68 +            cursor.execute('CREATE TABLE x(x integer)')
69 +            cursor.execute('CREATE INDEX x_idx ON x(x)')
70 +            self.assertIn('x_idx', helper.list_indices(cursor))
71 +            self.assertIn('x_idx', helper.list_indices(cursor, table='x'))
72 +        finally:
73 +            cnx.close()
74 +
75 
76  if __name__ == '__main__':
77      unittest_main()