peachy/assets/.lua/migrate.lua

97 lines
2.5 KiB
Lua

local re = require "re"
local path = require "path"
local filenameRegex = re.compile([=[^([0-9]+).*\.sql$]=])
local function throwSqlError(db, prefix)
local code = db:errcode()
if code ~= sqlite3.OK and code ~= sqlite3.ROW and code ~= sqlite3.DONE then
if prefix then
error(string.format("%s: %s", prefix, db:errmsg()))
else
error(db:errmsg())
end
end
end
local function transact(db, func, errorPrefix)
local name = string.format("savepoint_%d", math.abs(math.random(0)))
db:exec(string.format("savepoint "..name))
throwSqlError(db, errorPrefix)
local success, result = pcall(func)
if not success then
db:exec("rollback to "..name)
error(result, errorPrefix)
end
db:exec("release "..name)
return result
end
local function getCurrentMigration(db)
for seqnum in db:urows("select max(seqnum) from migrations") do return seqnum end
throwSqlError(db, "failed to get last migration sequence number", db)
return -1
end
local function recordMigration(db, seqnum)
local stmt = db:prepare("insert into migrations (seqnum) values (?)")
stmt:bind_values(seqnum)
throwSqlError(db, "failed to record migration")
stmt:step()
end
local function migrate(db)
local migrations = {}
local seqnums = {}
local MIGRATION_PATH = "/.migrations/"
local paths = GetZipPaths(MIGRATION_PATH)
for _, p in ipairs(paths) do
-- check that sequence number doesn't already exist
if p ~= MIGRATION_PATH then
local basename = path.basename(p)
local _, seqnum = filenameRegex:search(basename)
seqnum = tonumber(seqnum)
if seqnum and not seqnums[seqnum] then
table.insert(seqnums, seqnum)
table.insert(migrations, {
filename = p,
seqnum = seqnum,
})
else
print(string.format("found weird migration name: %s", p))
end
end
end
db:exec[[
create table if not exists migrations (
seqnum integer primary key
) without rowid
]]
throwSqlError(db, "failed to create migrations table")
table.sort(migrations, function(a, b) return a.seqnum < b.seqnum end)
local announced
for i,mig in ipairs(migrations) do
local seqnum = mig.seqnum
local lastSeqnum = getCurrentMigration(db)
if seqnum > lastSeqnum then
if not announced then
print"Applying migrations:"
announced = true
end
print(mig.filename)
local migsql = LoadAsset(mig.filename)
transact(db, function()
db:exec(migsql)
throwSqlError(db, "migration failed")
recordMigration(db, seqnum)
seqnum = seqnum + 1
end)
end
end
end
return {
migrate = migrate,
}