implement CAST function

authorAdrien Di Mascio <Adrien.DiMascio@logilab.fr>
changesetf4b70e749f70
branchdefault
phasepublic
hiddenno
parent revision#0d3439a238b1 added sql_restart_sequence to dbhelper classes (closes #65317)
child revision#6497fc969cff basic support for regexp-based pattern matching using a REGEXP operator
files modified by this revision
ChangeLog
__init__.py
test/unittest_db.py
# HG changeset patch
# User Adrien Di Mascio <Adrien.DiMascio@logilab.fr>
# Date 1303896891 -7200
# Wed Apr 27 11:34:51 2011 +0200
# Node ID f4b70e749f70bf70c4ca19dcc25eb03b50ce9314
# Parent 0d3439a238b1417945a540c4ee8d14c056a573b1
implement CAST function

diff --git a/ChangeLog b/ChangeLog
@@ -1,14 +1,19 @@
1  Changelog for logilab database package
2  ======================================
3 
4 +
5 +	--
6 +
7 +    * new CAST function
8 +
9  2011-04-01  --  1.5.0
10      * fix deprecation warning and depend on common 0.55.2
11 
12      * TZ datetime and time support
13 
14 -2011-04-28  --  1.4.0	
15 +2011-03-28  --  1.4.0
16      * abstract OFFSET LIMIT support (incl. support for MSSQL)
17 
18 
19  2011-01-13  --  1.3.2
20      * fix index dropping on sql server 2005
diff --git a/__init__.py b/__init__.py
@@ -351,11 +351,13 @@
21 
22  # set of hooks that should be called at connection opening time.
23  # mostly for sqlite'stored procedures that have to be registered...
24  SQL_CONNECT_HOOKS = {}
25  ALL_BACKENDS = object()
26 -
27 +# marker for cases where rtype depends on arguments passed to the function
28 +# In that case, functions should implement dynamic_rtype() method
29 +DYNAMIC_RTYPE = object()
30 
31  class FunctionDescr(object):
32      supported_backends = ALL_BACKENDS
33      rtype = None # None <-> returned type should be the same as the first argument
34      aggregat = False
@@ -465,10 +467,25 @@
35      field = 'MINUTE'
36 
37  class SECOND(ExtractDateField):
38      field = 'SECOND'
39 
40 +class CAST(FunctionDescr):
41 +    """usage is CAST(datatype, expression)
42 +
43 +    sql-92 standard says (CAST <expr> as <type>)
44 +    """
45 +    minargs = maxargs = 2
46 +    supported_backends = ('postgres', 'sqlite', 'mysql', 'sqlserver2005')
47 +    rtype = DYNAMIC_RTYPE
48 +
49 +    def as_sql(self, backend, args):
50 +        yamstype, varname = args
51 +        db_helper = get_db_helper(backend)
52 +        sqltype = db_helper.TYPE_MAPPING[yamstype]
53 +        return 'CAST(%s AS %s)' % (varname, sqltype)
54 +
55 
56  class _FunctionRegistry(object):
57      def __init__(self, registry=None):
58          if registry is None:
59              self.functions = {}
@@ -509,10 +526,12 @@
60      # aggregate functions
61      MIN, MAX, SUM, COUNT, AVG,
62      # transformation functions
63      ABS, UPPER, LOWER, LENGTH, DATE, RANDOM,
64      YEAR, MONTH, DAY, HOUR, MINUTE, SECOND, SUBSTRING,
65 +    # cast functions
66 +    CAST,
67      # keyword function
68      IN):
69      SQL_FUNCTIONS_REGISTRY.register_function(func_class())
70 
71  def register_function(funcdef):
diff --git a/test/unittest_db.py b/test/unittest_db.py
@@ -16,10 +16,12 @@
72  # You should have received a copy of the GNU Lesser General Public License along
73  # with logilab-database. If not, see <http://www.gnu.org/licenses/>.
74  """
75  unit tests for module logilab.common.db
76  """
77 +from __future__ import with_statement
78 +
79  import socket
80 
81  from logilab.common.testlib import TestCase, unittest_main
82  from logilab.common.shellutils import getlogin
83  from logilab.database import *
@@ -49,18 +51,18 @@
84      def testNormal(self):
85          set_prefered_driver('pg','bar', self.drivers)
86          self.assertEqual('bar', self.drivers['pg'][0])
87 
88      def testFailuresDb(self):
89 -        ex = self.assertRaises(UnknownDriver,
90 -                               set_prefered_driver, 'oracle','bar', self.drivers)
91 -        self.assertEqual(ex.args[0], 'Unknown driver oracle')
92 +        with self.assertRaises(UnknownDriver) as cm:
93 +            set_prefered_driver('oracle','bar', self.drivers)
94 +        self.assertEqual(str(cm.exception), 'Unknown driver oracle')
95 
96      def testFailuresDriver(self):
97 -        ex = self.assertRaises(UnknownDriver,
98 -                               set_prefered_driver, 'pg','baz', self.drivers)
99 -        self.assertEqual(ex.args[0], 'Unknown module baz for pg')
100 +        with self.assertRaises(UnknownDriver) as cm:
101 +            set_prefered_driver('pg','baz', self.drivers)
102 +        self.assertEqual(str(cm.exception), 'Unknown module baz for pg')
103 
104      def testGlobalVar(self):
105          # XXX: Is this test supposed to be useful ? Is it supposed to test
106          #      set_prefered_driver ?
107          old_drivers = PREFERED_DRIVERS['postgres'][:]