2024-07-05 00:13:09 +00:00
|
|
|
local re = require "re"
|
|
|
|
local path = require "path"
|
|
|
|
|
|
|
|
local filenameRegex = re.compile([=[^([0-9]+).*\.sql$]=])
|
|
|
|
|
2024-07-06 17:41:55 +00:00
|
|
|
local function throwSqlError(db, prefix)
|
|
|
|
local code = db:errcode()
|
|
|
|
if code ~= sqlite3.OK and code ~= sqlite3.ROW and code ~= sqlite3.DONE then
|
|
|
|
if prefix then
|
|
|
|
error(string.format("%s: %s", prefix, db:errmsg()))
|
|
|
|
else
|
|
|
|
error(db:errmsg())
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
local function transact(db, func, errorPrefix)
|
|
|
|
local name = string.format("savepoint_%d", math.abs(math.random(0)))
|
|
|
|
db:exec(string.format("savepoint "..name))
|
|
|
|
throwSqlError(db, errorPrefix)
|
|
|
|
local success, result = pcall(func)
|
|
|
|
if not success then
|
|
|
|
db:exec("rollback to "..name)
|
|
|
|
error(result, errorPrefix)
|
|
|
|
end
|
|
|
|
db:exec("release "..name)
|
|
|
|
return result
|
|
|
|
end
|
|
|
|
|
|
|
|
local function getCurrentMigration(db)
|
|
|
|
for seqnum in db:urows("select max(seqnum) from migrations") do return seqnum end
|
|
|
|
throwSqlError(db, "failed to get last migration sequence number", db)
|
|
|
|
return -1
|
|
|
|
end
|
|
|
|
|
|
|
|
local function recordMigration(db, seqnum)
|
|
|
|
local stmt = db:prepare("insert into migrations (seqnum) values (?)")
|
|
|
|
stmt:bind_values(seqnum)
|
|
|
|
throwSqlError(db, "failed to record migration")
|
|
|
|
stmt:step()
|
|
|
|
end
|
|
|
|
|
|
|
|
local function migrate(db)
|
2024-07-05 00:13:09 +00:00
|
|
|
local migrations = {}
|
|
|
|
local seqnums = {}
|
2024-07-06 17:41:55 +00:00
|
|
|
local MIGRATION_PATH = "/migrations/"
|
|
|
|
local paths = GetZipPaths(MIGRATION_PATH)
|
2024-07-05 00:13:09 +00:00
|
|
|
for _, p in ipairs(paths) do
|
|
|
|
-- check that sequence number doesn't already exist
|
2024-07-06 17:41:55 +00:00
|
|
|
if p ~= MIGRATION_PATH then
|
|
|
|
local basename = path.basename(p)
|
|
|
|
local _, seqnum = filenameRegex:search(basename)
|
|
|
|
seqnum = tonumber(seqnum)
|
|
|
|
if seqnum and not seqnums[seqnum] then
|
|
|
|
table.insert(seqnums, seqnum)
|
|
|
|
table.insert(migrations, {
|
|
|
|
filename = p,
|
|
|
|
seqnum = seqnum,
|
|
|
|
})
|
|
|
|
else
|
|
|
|
print(string.format("found weird migration name: %s", p))
|
|
|
|
end
|
2024-07-05 00:13:09 +00:00
|
|
|
end
|
|
|
|
end
|
2024-07-06 17:41:55 +00:00
|
|
|
|
|
|
|
db:exec[[
|
|
|
|
create table if not exists migrations (
|
|
|
|
seqnum integer primary key
|
|
|
|
) without rowid
|
|
|
|
]]
|
|
|
|
throwSqlError(db, "failed to create migrations table")
|
|
|
|
|
2024-07-05 00:13:09 +00:00
|
|
|
table.sort(migrations, function(a, b) return a.seqnum < b.seqnum end)
|
2024-07-06 17:41:55 +00:00
|
|
|
local announced
|
2024-07-05 00:13:09 +00:00
|
|
|
for i,mig in ipairs(migrations) do
|
2024-07-06 17:41:55 +00:00
|
|
|
local seqnum = mig.seqnum
|
|
|
|
local lastSeqnum = getCurrentMigration(db)
|
|
|
|
if seqnum > lastSeqnum then
|
|
|
|
if not announced then
|
|
|
|
print"Applying migrations:"
|
|
|
|
announced = true
|
|
|
|
end
|
|
|
|
print(mig.filename)
|
|
|
|
local migsql = LoadAsset(mig.filename)
|
|
|
|
transact(db, function()
|
|
|
|
db:exec(migsql)
|
|
|
|
throwSqlError(db, "migration failed")
|
|
|
|
recordMigration(db, seqnum)
|
|
|
|
seqnum = seqnum + 1
|
|
|
|
end)
|
|
|
|
end
|
2024-07-05 00:13:09 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
return {
|
|
|
|
migrate = migrate,
|
|
|
|
}
|