diff --git a/Lib/sqlite3/test/dbapi.py b/Lib/sqlite3/test/dbapi.py index 39c9bf5b611..ab331353394 100644 --- a/Lib/sqlite3/test/dbapi.py +++ b/Lib/sqlite3/test/dbapi.py @@ -135,6 +135,26 @@ class ConnectionTests(unittest.TestCase): def test_close(self): self.cx.close() + def test_use_after_close(self): + sql = "select 1" + cu = self.cx.cursor() + res = cu.execute(sql) + self.cx.close() + self.assertRaises(sqlite.ProgrammingError, res.fetchall) + self.assertRaises(sqlite.ProgrammingError, cu.execute, sql) + self.assertRaises(sqlite.ProgrammingError, cu.executemany, sql, []) + self.assertRaises(sqlite.ProgrammingError, cu.executescript, sql) + self.assertRaises(sqlite.ProgrammingError, self.cx.execute, sql) + self.assertRaises(sqlite.ProgrammingError, + self.cx.executemany, sql, []) + self.assertRaises(sqlite.ProgrammingError, self.cx.executescript, sql) + self.assertRaises(sqlite.ProgrammingError, + self.cx.create_function, "t", 1, lambda x: x) + self.assertRaises(sqlite.ProgrammingError, self.cx.cursor) + with self.assertRaises(sqlite.ProgrammingError): + with self.cx: + pass + def test_exceptions(self): # Optional DB-API extension. self.assertEqual(self.cx.Warning, sqlite.Warning) diff --git a/Modules/_sqlite/connection.c b/Modules/_sqlite/connection.c index e15629c8aba..62c4dc3bbb3 100644 --- a/Modules/_sqlite/connection.c +++ b/Modules/_sqlite/connection.c @@ -258,6 +258,16 @@ connection_clear(pysqlite_Connection *self) return 0; } +static void +connection_close(pysqlite_Connection *self) +{ + if (self->db) { + int rc = sqlite3_close_v2(self->db); + assert(rc == SQLITE_OK); + self->db = NULL; + } +} + static void connection_dealloc(pysqlite_Connection *self) { @@ -266,9 +276,7 @@ connection_dealloc(pysqlite_Connection *self) tp->tp_clear((PyObject *)self); /* Clean up if user has not called .close() explicitly. */ - if (self->db) { - sqlite3_close_v2(self->db); - } + connection_close(self); tp->tp_free(self); Py_DECREF(tp); @@ -353,24 +361,12 @@ static PyObject * pysqlite_connection_close_impl(pysqlite_Connection *self) /*[clinic end generated code: output=a546a0da212c9b97 input=3d58064bbffaa3d3]*/ { - int rc; - if (!pysqlite_check_thread(self)) { return NULL; } pysqlite_do_all_statements(self, ACTION_FINALIZE, 1); - - if (self->db) { - rc = sqlite3_close_v2(self->db); - - if (rc != SQLITE_OK) { - _pysqlite_seterror(self->db); - return NULL; - } else { - self->db = NULL; - } - } + connection_close(self); Py_RETURN_NONE; } @@ -1820,6 +1816,9 @@ static PyObject * pysqlite_connection_enter_impl(pysqlite_Connection *self) /*[clinic end generated code: output=457b09726d3e9dcd input=127d7a4f17e86d8f]*/ { + if (!pysqlite_check_connection(self)) { + return NULL; + } return Py_NewRef((PyObject *)self); }