From b1b9382d91e4b2e863225179cde4a61f0300a233 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gerhard=20H=C3=A4ring?= Date: Sat, 29 Mar 2008 00:41:18 +0000 Subject: [PATCH] Added missing files for new iterdump method. --- Lib/sqlite3/dump.py | 63 ++++++++++++++++++++++++++++++++++++++++ Lib/sqlite3/test/dump.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 115 insertions(+) create mode 100644 Lib/sqlite3/dump.py create mode 100644 Lib/sqlite3/test/dump.py diff --git a/Lib/sqlite3/dump.py b/Lib/sqlite3/dump.py new file mode 100644 index 00000000000..409a405cf81 --- /dev/null +++ b/Lib/sqlite3/dump.py @@ -0,0 +1,63 @@ +# Mimic the sqlite3 console shell's .dump command +# Author: Paul Kippes + +def _iterdump(connection): + """ + Returns an iterator to the dump of the database in an SQL text format. + + Used to produce an SQL dump of the database. Useful to save an in-memory + database for later restoration. This function should not be called + directly but instead called from the Connection method, iterdump(). + """ + + cu = connection.cursor() + yield('BEGIN TRANSACTION;') + + # sqlite_master table contains the SQL CREATE statements for the database. + q = """ + SELECT name, type, sql + FROM sqlite_master + WHERE sql NOT NULL AND + type == 'table' + """ + schema_res = cu.execute(q) + for table_name, type, sql in schema_res.fetchall(): + if table_name == 'sqlite_sequence': + yield('DELETE FROM sqlite_sequence;') + elif table_name == 'sqlite_stat1': + yield('ANALYZE sqlite_master;') + elif table_name.startswith('sqlite_'): + continue + # NOTE: Virtual table support not implemented + #elif sql.startswith('CREATE VIRTUAL TABLE'): + # qtable = table_name.replace("'", "''") + # yield("INSERT INTO sqlite_master(type,name,tbl_name,rootpage,sql)"\ + # "VALUES('table','%s','%s',0,'%s');" % + # qtable, + # qtable, + # sql.replace("''")) + else: + yield('%s;' % sql) + + # Build the insert statement for each row of the current table + res = cu.execute("PRAGMA table_info('%s')" % table_name) + column_names = [str(table_info[1]) for table_info in res.fetchall()] + q = "SELECT 'INSERT INTO \"%(tbl_name)s\" VALUES(" + q += ",".join(["'||quote(" + col + ")||'" for col in column_names]) + q += ")' FROM '%(tbl_name)s'" + query_res = cu.execute(q % {'tbl_name': table_name}) + for row in query_res: + yield("%s;" % row[0]) + + # Now when the type is 'index', 'trigger', or 'view' + q = """ + SELECT name, type, sql + FROM sqlite_master + WHERE sql NOT NULL AND + type IN ('index', 'trigger', 'view') + """ + schema_res = cu.execute(q) + for name, type, sql in schema_res.fetchall(): + yield('%s;' % sql) + + yield('COMMIT;') diff --git a/Lib/sqlite3/test/dump.py b/Lib/sqlite3/test/dump.py new file mode 100644 index 00000000000..f40876a8e7e --- /dev/null +++ b/Lib/sqlite3/test/dump.py @@ -0,0 +1,52 @@ +# Author: Paul Kippes + +import unittest +import sqlite3 as sqlite + +class DumpTests(unittest.TestCase): + def setUp(self): + self.cx = sqlite.connect(":memory:") + self.cu = self.cx.cursor() + + def tearDown(self): + self.cx.close() + + def CheckTableDump(self): + expected_sqls = [ + "CREATE TABLE t1(id integer primary key, s1 text, " \ + "t1_i1 integer not null, i2 integer, unique (s1), " \ + "constraint t1_idx1 unique (i2));" + , + "INSERT INTO \"t1\" VALUES(1,'foo',10,20);" + , + "INSERT INTO \"t1\" VALUES(2,'foo2',30,30);" + , + "CREATE TABLE t2(id integer, t2_i1 integer, " \ + "t2_i2 integer, primary key (id)," \ + "foreign key(t2_i1) references t1(t1_i1));" + , + "CREATE TRIGGER trigger_1 update of t1_i1 on t1 " \ + "begin " \ + "update t2 set t2_i1 = new.t1_i1 where t2_i1 = old.t1_i1; " \ + "end;" + , + "CREATE VIEW v1 as select * from t1 left join t2 " \ + "using (id);" + ] + [self.cu.execute(s) for s in expected_sqls] + i = self.cx.iterdump() + actual_sqls = [s for s in i] + expected_sqls = ['BEGIN TRANSACTION;'] + expected_sqls + \ + ['COMMIT;'] + [self.assertEqual(expected_sqls[i], actual_sqls[i]) + for i in range(len(expected_sqls))] + +def suite(): + return unittest.TestSuite(unittest.makeSuite(DumpTests, "Check")) + +def test(): + runner = unittest.TextTestRunner() + runner.run(suite()) + +if __name__ == "__main__": + test()