-- -- ultralight webframework for [Redbean web server](https://redbean.dev/) -- Copyright 2021-23 Paul Kulchenko -- -- Permission is hereby granted, free of charge, to any person obtaining a copy -- of this software and associated documentation files (the "Software"), to deal -- in the Software without restriction, including without limitation the rights -- to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -- copies of the Software, and to permit persons to whom the Software is -- furnished to do so, subject to the following conditions: -- The above copyright notice and this permission notice shall be included in -- all copies or substantial portions of the Software. -- THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -- IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -- FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -- AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -- LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN -- THE SOFTWARE. -- local NAME, VERSION = "fullmoon", "0.383" --[[-- support functions --]]-- local unpack = table.unpack or unpack local load = load or loadstring if not setfenv then -- Lua 5.2+; this assumes f is a function -- based on http://lua-users.org/lists/lua-l/2010-06/msg00314.html -- and https://leafo.net/guides/setfenv-in-lua52-and-above.html local function findenv(f) local idx = 1 repeat local name, value = debug.getupvalue(f, idx) if name == '_ENV' then return idx, value end idx = idx + 1 until not name end getfenv = function (f) -- if the function is not provided, use the caller return(select(2, findenv(f or debug.getinfo(2,"f").func)) or _G) end setfenv = function (f, t) local level = findenv(f) if level then debug.upvaluejoin(f, level, function() return t end, 1) end return f end end local function loadsafe(data) local f, err = load(data) if not f then return f, err end local c = -2 local hf, hm, hc = debug.gethook() debug.sethook(function() c=c+1; if c>0 then error("failed safety check") end end, "c") local ok, res = pcall(f) c = -1 debug.sethook(hf, hm, hc) return ok, res end local function argerror(cond, narg, extramsg, name) if cond then return cond end name = name or debug.getinfo(2, "n").name or "?" local msg = ("bad argument #%d to %s%s"):format(narg, name, extramsg and " "..extramsg or "") return error(msg, 3) end local function logFormat(fmt, ...) argerror(type(fmt) == "string", 1, "(string expected)") return "(fm) "..(select('#', ...) == 0 and fmt or (fmt or ""):format(...)) end local function quote(s) return s:gsub('([%(%)%.%%%+%-%*%?%[%^%$%]])','%%%1') end local function getRBVersion() local v = GetRedbeanVersion() local major = math.floor(v / 2^16) local minor = math.floor((v / 2^16 - major) * 2^8) return ("%d.%d.%d"):format(major, minor, v % 2^8) end local LogDebug = function(...) return Log(kLogDebug, logFormat(...)) end local LogVerbose = function(...) return Log(kLogVerbose, logFormat(...)) end local LogInfo = function(...) return Log(kLogInfo, logFormat(...)) end local LogWarn = function(...) return Log(kLogWarn, logFormat(...)) end local istype = function(b) return function(mode) return math.floor((mode % (2*b)) / b) == 1 end end local isregfile = unix and unix.S_ISREG or istype(2^15) local function reg(func, v) local t = {n = 1, x2 = function(t, v) t[v] = t.n; t.n = t.n * 2 end, p1 = function(t, v) t[v] = t.n; t.n = t.n + 1 end, } for _, p in ipairs(v) do t[func](t, p) end return t end local function reg2x(v) return reg("x2", v) end local function reg1p(v) return reg("p1", v) end local getTimeNano = (unix and function() return {unix.clock_gettime()} end or function() return {GetTime(), 0} end) local function getTimeDiff(st, et) if not et then et = getTimeNano() end return et[1] - st[1] + (et[2] - st[2]) * 1e-9 end local function obsolete(obj, old, new, ver) obj[old] = VERSION < ver and function(...) LogWarn(("method %s has been replaced by %s and will be removed in v%s.") :format(old, new, ver)) obj[old] = obj[new] return obj[new](...) end or nil end -- headers that are not allowed to be set, as Redbean may -- also set them, leading to conflicts and improper handling local noHeaderMap = { ["content-length"] = true, ["transfer-encoding"] = true, ["content-encoding"] = true, date = true, -- "close" is the only value that is allowed for "Connection", -- but only in Redbean before v2.2 connection = GetRedbeanVersion() < 0x20200 and "close" or nil, } -- request headers based on https://datatracker.ietf.org/doc/html/rfc7231#section-5 -- response headers based on https://datatracker.ietf.org/doc/html/rfc7231#section-7 -- this allows the user to use `.ContentType` instead of `["Content-Type"]` -- Host is listed to allow retrieving Host header even in the presence of host attribute local headerMap = {} (function(s) for h in s:gmatch("[%w%-]+") do headerMap[h:gsub("-","")] = h end end)([[ Cache-Control Host Max-Forwards Proxy-Authorization User-Agent Accept-Charset Accept-Encoding Accept-Language Content-Disposition If-Match If-None-Match If-Modified-Since If-Unmodified-Since If-Range Content-Type Content-Encoding Content-Language Content-Location Retry-After Last-Modified WWW-Authenticate Proxy-Authenticate Accept-Ranges Content-Length Transfer-Encoding ]]) local htmlVoidTags = {} -- from https://html.spec.whatwg.org/#void-elements (function(s) for h in s:gmatch("%w+") do htmlVoidTags[h] = true end end)([[ area base br col embed hr img input link meta param source track wbr ]]) local default500 = [[{%& status %} {%& reason %}

{%& status %} {%& reason %}

{% if message then %}
{%& message %}
{% end %}]] --[[-- route path generation --]]-- local PARAM = "([:*])([%w_]*)" local routes = {} local function makePath(name, params) argerror(type(name) == "string", 1, "(string expected)") params = params or {} -- name can be the name or the route itself (even not registered) local pos = routes[name] local route = pos and routes[pos].route or name -- replace :foo and *splat with provided parameters route = route:gsub(PARAM.."([^(*:]*)", function(sigil, param, rest) if sigil == "*" and param == "" then param = "splat" end -- ignore everything that doesn't match `:%w` pattern if sigil == ":" and param == "" then return sigil..param..rest end -- if the parameter value is `false`, replace it with an empty string return ((params[param] or (params[param] == false and "" or sigil..param)) ..rest:gsub("^%b[]","")) end) -- remove all optional groups local function findopt(route) return route:gsub("(%b())", function(optroute) optroute = optroute:sub(2, -2) local s = optroute:find("[:*]") if s then local p = optroute:find("%b()") if not p or s < p then return "" end end return findopt(optroute) end) end route = findopt(route) local param = route:match(":(%a[%w_]*)") or route:match("*([%w_]*)") argerror(not param, 2, "(missing required parameter " ..(param and #param > 0 and param or "splat")..")") return route end local function makeUrl(url, opts) if type(url) == "table" and opts == nil then url, opts = nil, url end if not url then url = GetUrl() end if not opts then opts = {} end argerror(type(url) == "string", 1, "(string expected)") argerror(type(opts) == "table", 2, "(table expected)") -- check if params are in the hash table format and -- convert to the array format that Redbean expects if opts.params and not opts.params[1] and next(opts.params) then local tbl = {} for k, v in pairs(opts.params) do table.insert(tbl, v == true and {k} or {k, v}) end table.sort(tbl, function(a, b) return a[1] < b[1] end) opts.params = tbl end local parts = ParseUrl(url) -- copy options, but remove those that have `false` values for k, v in pairs(opts) do parts[k] = v or nil end return EncodeUrl(parts) end local org, ref = {}, {} -- some unique key values to index template parameters -- request functions (`request.write()`) local reqenv = { escapeHtml = EscapeHtml, escapePath = EscapePath, formatIp = FormatIp, formatHttpDateTime = FormatHttpDateTime, makePath = makePath, makeUrl = makeUrl, } -- request properties (`request.authority`) local reqapi = { authority = function() local parts = ParseUrl(GetUrl()) return EncodeUrl({scheme = parts.scheme, host = parts.host, port = parts.port}) end, } local function genEnv(opt) opt = opt or {} return function(t, key) local val = reqenv[key] or rawget(t, ref) and rawget(t, ref)[key] -- can cache the value, since it's not passed as a parameter local cancache = val == nil if not opt.request and val == nil then val = _G[key] end if opt.request and val == nil and type(key) == "string" then local func = reqapi[key] or _G["Get"..key:sub(1,1):upper()..key:sub(2)] -- map a property (like `.host`) to a function call (`GetHost()`) if type(func) == "function" then val = func() else val = func end end -- allow pseudo-tags, but only if used in a template environment; -- provide fallback for `table` to make `table{}` and `table.concat` work local istable = key == "table" if opt.autotag and (val == nil or istable) then -- nothing was resolved; this is either undefined value or -- a pseudo-tag (like `div{}` or `span{}`), so add support for them val = setmetatable({key}, { -- support the case of printing/concatenating undefined values -- tostring handles conversion to a string __tostring = function() return "" end, -- concat handles concatenation with a string __concat = function(a, _) return a end, __index = (istable and table or nil), __call = function(t, v, ...) if type(v) == "table" then table.insert(v, 1, key) return v end return {t[1], v, ...} end}) elseif cancache then t[key] = val -- cache the calculated value for future use end return val end end local tmplTagHandlerEnv = {__index = genEnv({autotag = true}) } local tmplRegHandlerEnv = {__index = genEnv() } local tmplReqHandlerEnv = {__index = genEnv({request = true}) } local req local function getRequest() return req end local function detectType(s) local ch = s:match("^%s*(%S)") return ch and (ch == "<" and "text/html" or ch == "{" and "application/json") or "text/plain" end local function serveResponse(status, headers, body) -- since headers is optional, handle the case when headers are not present if type(headers) == "string" and body == nil then body, headers = headers, nil end if type(status) == "string" and body == nil and headers == nil then body, status = status, 200 end argerror(type(status) == "number", 1, "(number expected)") argerror(not headers or type(headers) == "table", 2, "(table expected)") argerror(not body or type(body) == "string", 3, "(string expected)") return function() SetStatus(status) if headers then -- make sure that the metatable gets transferred as well local r = getRequest() r.headers = setmetatable(headers, getmetatable(r.headers)) end if body then Write(body) end return true end end --[[-- multipart parsing --]]-- local patts = {} local function getParameter(header, name) local function optignorecase(s) if not patts[s] then patts[s] = (";%s*" ..s:gsub("%w", function(s) return ("[%s%s]"):format(s:upper(), s:lower()) end) ..[[=["']?([^;"']*)["']?]]) end return patts[s] end return header:match(optignorecase(name)) end local CRLF, TAIL = "\r\n", "--" local CRLFlen = #CRLF local MULTIVAL = "%[%]$" local function parseMultipart(body, ctype) argerror(type(ctype) == "string", 2, "(string expected)") local parts = { boundary = getParameter(ctype, "boundary"), start = getParameter(ctype, "start"), } local boundary = "--"..argerror(parts.boundary, 2, "(boundary expected in Content-Type)") local bol, eol, eob = 1 while true do repeat eol, eob = string.find(body, boundary, bol, true) if not eol then return nil, "missing expected boundary at position "..bol end until eol == 1 or eol > CRLFlen and body:sub(eol-CRLFlen, eol-1) == CRLF if eol > CRLFlen then eol = eol - CRLFlen end local headers, name, filename = {} if bol > 1 then -- find the header (if any) if string.sub(body, bol, bol+CRLFlen-1) == CRLF then -- no headers bol = bol + CRLFlen else -- headers -- find the end of headers (CRLF+CRLF) local boh, eoh = 1, string.find(body, CRLF..CRLF, bol, true) if not eoh then return nil, "missing expected end of headers at position "..bol end -- join multi-line header values back if present local head = string.sub(body, bol, eoh+1):gsub(CRLF.."%s+", " ") while (string.find(head, CRLF, boh, true) or 0) > boh do local p, e, header, value = head:find("([^:]+)%s*:%s*(.-)%s*\r\n", boh) if p ~= boh then return nil, "invalid header syntax at position "..bol+boh end header = header:lower() if header == "content-disposition" then name = getParameter(value, "name") filename = getParameter(value, "filename") end headers[header] = value boh = e + 1 end bol = eoh + CRLFlen*2 end -- epilogue is processed, but not returned local ct = headers["content-type"] local b, err = string.sub(body, bol, eol-1) if ct and ct:lower():find("^multipart/") then b, err = parseMultipart(b, ct) -- handle multipart/* recursively if not b then return b, err end end local first = parts.start and parts.start == headers["content-id"] and 1 local v = {name = name, headers = headers, filename = filename, data = b} table.insert(parts, first or #parts+1, v) if name then if string.find(name, MULTIVAL) then parts[name] = parts[name] or {} table.insert(parts[name], first or #parts[name]+1, v) else parts[name] = parts[name] or v end end end local tail = body:sub(eob+1, eob+#TAIL) -- check if the encapsulation or regular boundary is present if tail == TAIL then break end if tail ~= CRLF then return nil, "missing closing boundary at position "..eol end bol = eob + #tail + 1 end return parts end --[[-- template engine --]]-- local templates, vars = {}, {} local stack, blocks = {}, {} -- `blocks` is a table behind proxy tables indexed by actual `block` tables in each template -- where the values are defined blocks and their functions -- blocks are retrieved in the order in which templates are rendered -- (first called is first used), as stored in `stack` during rendering. -- the `stack` table also includes mapping from template names to their proxy tables local metablock = { __newindex = function(t, k, v) if not blocks[t] then blocks[t] = {} stack[t[blocks]] = blocks[t] end blocks[t][k] = v end, __index = function(t, k) if not blocks[t] then blocks[t] = {} stack[t[blocks]] = blocks[t] end -- use the "earliest" template after the one signified with `t` for _, name in ipairs(stack) do -- if a template name is requested, then return its table -- this is needed for direct access to template blocks to simulate `super()` local tbl = stack[name] -- don't look beyond the template that made this call if name == k then return tbl end local blk = name and tbl and tbl[k] if blk then return blk end end end, } local function render(name, opt) argerror(type(name) == "string", 1, "(string expected)") argerror(templates[name], 1, "(unknown template name '"..tostring(name).."')") argerror(not opt or type(opt) == "table", 2, "(table expected)") -- assign default parameters, but allow to overwrite local params = {vars = vars, block = setmetatable({[blocks] = name}, metablock)} local env = getfenv(templates[name].handler) -- add "original" template parameters for k, v in pairs(rawget(env, org) or {}) do params[k] = v end -- add "passed" template parameters for k, v in pairs(opt or {}) do params[k] = v end LogDebug("render template '%s'", name) env[ref] = params table.insert(stack, name) local res, more = templates[name].handler(opt) table.remove(stack) -- reset block cache when the render stack becomes empty if #stack == 0 then stack, blocks = {}, {} end -- return template results or an empty string to indicate completion -- this is useful when the template does direct write to the output buffer return res or "", more or {ContentType = templates[name].ContentType} end local function setTemplate(name, code, opt) -- name as a table designates a list of prefixes for assets paths -- to load templates from; -- its hash values provide mapping from extensions to template types if type(name) == "table" then local tmpls = {} for _, prefix in ipairs(name) do local paths = GetZipPaths(prefix) for _, path in ipairs(paths) do local tmplname, ext = path:gsub("^"..prefix.."/?",""):match("(.+)%.(%w+)$") if ext and name[ext] then local asset = LoadAsset(path) or error("Can't load asset: "..path) setTemplate(tmplname, {asset, type = name[ext], path = path}, opt) tmpls[tmplname] = true end end end return tmpls end argerror(type(name) == "string", 1, "(string or table expected)") local params = {} if type(code) == "table" then params, code = code, table.remove(code, 1) end local ctype = type(code) argerror(ctype == "string" or ctype == "function", 2, "(string, table or function expected)") LogVerbose("set template '%s'", name) local tmpl = templates[params.type or "fmt"] local env = setmetatable({render = render, [org] = opt}, -- get the metatable from the template that this one is based on, -- to make sure the correct environment is being served tmpl and getmetatable(getfenv(tmpl.handler)) or params.autotag and tmplTagHandlerEnv or tmplRegHandlerEnv) if ctype == "string" then argerror(tmpl ~= nil, 2, "(unknown template type/name)") argerror(tmpl.parser ~= nil, 2, "(referenced template doesn't have a parser)") local path = "@" .. (params.path or name) -- assign proper environment in case parser needs it if tmpl.autotag then tmpl.parser = setfenv(tmpl.parser, env) end local func = tmpl.parser(code, path) -- if the parser returns function, use it as is -- if it returns some code, then load and use it code = type(func) == "function" and func or assert(load(func, path)) end params.handler = setfenv(code, env) templates[name] = params return {name = true} end local function setTemplateVar(name, value) vars[name] = value end --[[-- routing engine --]]-- local setmap = {} (function(s) for pat, reg in s:gmatch("(%S+)=([^%s,]+),?") do setmap[pat] = reg end end)([[ d=0-9, ]=[.].], -=[.-.], a=[:alpha:], l=[:lower:], u=[:upper:], w=[:alnum:], x=[:xdigit:], ]]) local function findset(s) return setmap[s] or s:match("%p") and s or error("Invalid escape sequence %"..s) end local function route2regex(route) -- foo/bar, foo/*, foo/:bar, foo/:bar[%d], foo(/:bar(/:more))(.:ext) local params = {} local regex = route:gsub("%)", "%1?") -- update optional groups from () to ()? :gsub("%.", "\\.") -- escape dots (.) :gsub(PARAM, function(sigil, param) if sigil == "*" and param == "" then param = "splat" end -- ignore everything that doesn't match `:%w` pattern if sigil == ":" and param == "" then return sigil..param end table.insert(params, param) return sigil == "*" and "(.*)" or "([^/]+)" end) :gsub("%b[](%+%))(%b[])([^/:*%[]*)", function(sep, pat, rest) local leftover, more = rest:match("(.-])(.*)") if leftover then pat = pat..leftover; rest = more end -- replace Lua character classes with regex ones return pat:gsub("%%(.)", findset)..sep..rest end) -- mark optional captures, as they are going to be returned during match local subnum = 1 local s, _, capture = 0 while true do s, _, capture = regex:find("%b()([?]?)", s+1) if not s then break end if capture > "" then table.insert(params, subnum, false) end subnum = subnum + 1 end return "^"..regex.."$", params end local function findRoute(route, opts) for i, r in ipairs(routes) do local ometh = opts.method local rmeth = (r.options or {}).method if route == r.route and (type(ometh) == "table" and table.concat(ometh, ",") or ometh) == (type(rmeth) == "table" and table.concat(rmeth, ",") or rmeth) then return i end end end local function setRoute(opts, ...) local ot = type(opts) if ot == "string" then opts = {opts} elseif ot == "table" then argerror(#opts > 0, 1, "(one or more routes expected)", "setRoute") else argerror(false, 1, "(string or table expected)", "setRoute") end -- as the handler is optional, allow it to be skipped local pnum, handler = select('#', ...), ... local ht = type(handler) -- allow empty, but not `nil` handler (so `setRoute('foo')`, but not `setRoute('foo', nil)`) -- this protects against typos in handler names being silently accepted argerror(ht == "function" or ht == "string" or (ht == "nil" and pnum == 0), 2, "(function or string expected)", "setRoute") if ht == "string" then -- if `handler` is a string, then turn it into a handler that does -- internal redirect (to an existing path), but not a directory. -- This is to avoid failing on a missing directory index. -- If directory index is still desired, then use `serveIndex()`. local newroute = handler handler = function(r) local path = r.makePath(newroute, r.params) local mode = GetAssetMode(path) return mode and isregfile(mode) and RoutePath(path) end end if ot == "table" then -- remap filters to hash if presented as an (array) table for k, v in pairs(opts) do if type(v) == "table" then -- {"POST", "PUT"} => {"POST", "PUT", PUT = true, POST = true} for i = 1, #v do v[v[i]] = true end -- if GET is allowed, then also allow HEAD, unless `HEAD=false` exists if k == "method" and v.GET and v.HEAD == nil then table.insert(v, "HEAD") -- add to the list to generate a proper list of methods v.HEAD = v.GET end if v.regex then v.regex = argerror(re.compile(v.regex), 1, ("(valid regex expected for '%s')"):format(k), "setRoute") end elseif headerMap[k] then opts[k] = {pattern = "%f[%w]"..quote(v).."%f[%W]"} end end end -- process 1+ routes as specified while true do local route = table.remove(opts, 1) if not route then break end argerror(type(route) == "string", 1, "(route string expected)", "setRoute") local pos = findRoute(route, opts) or #routes+1 if opts.routeName then if routes[opts.routeName] then LogWarn("route '%s' already registered", opts.routeName) end routes[opts.routeName], opts.routeName = pos, nil end local regex, params = route2regex(route) local tmethod = type(opts.method) local methods = tmethod == "table" and opts.method or tmethod == "string" and {opts.method} or {'ANY'} LogVerbose("set route '%s' (%s) at index %d", route, table.concat(methods,','), pos) routes[pos] = {route = route, handler = handler, options = opts, comp = re.compile(regex), params = params} routes[route] = pos end end local function matchCondition(value, cond) if type(cond) == "function" then return cond(value) end if type(cond) ~= "table" then return value == cond end -- allow `{function() end, otherwise = ...}` as well if type(cond[1]) == "function" then return cond[1](value) end if value == nil or cond[value] then return true end if cond.regex then return cond.regex:search(value) ~= nil end if cond.pattern then return value:match(cond.pattern) ~= nil end return false end local function getAllowedMethod(matchedRoutes) local methods = {} for _, idx in ipairs(matchedRoutes) do local routeMethod = routes[idx].options and routes[idx].options.method if routeMethod then for _, method in ipairs(type(routeMethod) == "table" and routeMethod or {routeMethod}) do if not methods[method] then methods[method] = true table.insert(methods, method) end end end end table.sort(methods) return (#methods > 0 and table.concat(methods, ", ")..(methods.OPTIONS == nil and ", OPTIONS" or "") or "GET, HEAD, POST, PUT, DELETE, OPTIONS") end local function matchRoute(path, req) assert(type(req) == "table", "bad argument #2 to match (table expected)") LogDebug("match %d route(s) against '%s'", #routes, path) local matchedRoutes = {} for idx, route in ipairs(routes) do -- skip static routes that are only used for path generation local opts = route.options if route.handler or opts and opts.otherwise then local res = {route.comp:search(path)} local matched = table.remove(res, 1) LogDebug("route '%s' %smatched", route.route, matched and "" or "not ") if matched then -- path matched table.insert(matchedRoutes, idx) for ind, val in ipairs(route.params) do if val and res[ind] then req.params[val] = res[ind] > "" and res[ind] or false end end -- check if there are any additional options to filter by local otherwise matched = true if opts and next(opts) then for filter, cond in pairs(opts) do if filter ~= "otherwise" then local header = headerMap[filter] -- check "dashed" headers, params, properties (method, port, host, etc.), and then headers again local value = (filter == "r" and req -- special request value or header and req.headers[header] -- an existing header or req.params[filter] or req[filter] or req.headers[filter]) -- condition can be a value (to compare with) or a table/hash with multiple values local resCond, err = matchCondition(value, cond) if not resCond then otherwise = type(cond) == "table" and cond.otherwise or opts.otherwise LogDebug("route '%s' filter '%s%s' didn't match value '%s'%s", route.route, filter, type(cond) == "string" and "="..cond or "", value, tonumber(otherwise) and " and returned "..otherwise or "") if otherwise then if type(otherwise) == "function" then return otherwise(err, value) else if otherwise == 405 and not req.headers.Allow then req.headers.Allow = getAllowedMethod(matchedRoutes) end return serveResponse(otherwise) end end matched = false break end end end end if matched and route.handler then local res, more = route.handler(req) if res then return res, more end path = rawget(req, "path") or path -- assign path for subsequent checks end end end end end --[[-- storage engine --]]-- local sqlite3 local NONE = {} local dbmt = { -- share one metatable among all DBM objects -- simple __index = db doesn't work, as it gets `dbm` passed instead of `db`, -- so remapping is needed to proxy this to `t.db` instead __index = function(t,k) if sqlite3[k] then return sqlite3[k] end local db = rawget(t, "db") return db and db[k] and function(self,...) return db[k](db,...) end or nil end, __gc = function(t) return t:close() end, __close = function(t) return t:close() end } local function makeStorage(dbname, sqlsetup, opts) sqlite3 = sqlite3 or require "lsqlite3" if type(sqlsetup) == "table" and opts == nil then sqlsetup, opts = nil, sqlsetup end local flags = 0 for flagname, val in pairs(opts or {}) do local flagcode = flagname:find("^OPEN_") and ( sqlite3[flagname] or error("unknown option "..flagname)) flags = flags | (val and flagcode or 0) end argerror(not opts or not opts.trace or type(opts.trace) == "function", 3 , "(function expected as trace option value)") -- check if any of the required flags are set; set defaults if not if flags & (sqlite3.OPEN_READWRITE + sqlite3.OPEN_READONLY) == 0 then flags = flags | (sqlite3.OPEN_READWRITE + sqlite3.OPEN_CREATE) end local dbm = {NONE = NONE, prepcache = {}, pragmas = {}, name = dbname, sql = sqlsetup, opts = opts or {}} local msgdelete = "use delete option to force" local function getPragmas(sql) local pragmas = {} for p in (sql or ""):gmatch("%s*([^;]+)") do if not p:lower():find("pragma ") then break end table.insert(pragmas, p) end return table.concat(pragmas, ";") end function dbm:init(reopen) local db = self.db -- if this is a forked process with a connection opened elsewhere, -- then need to re-open the connection to avoid bad things: -- https://sqlite.org/howtocorrupt.html#_carrying_an_open_database_connection_across_a_fork_ -- (but don't re-open temp/inmemory and read-only DBs, as it's not needed) if db and reopen then -- reset the pid even when re-opening is skipped, so that it's not repeated self.pid = unix.getpid() if db:db_filename("main") ~= "" and (not db.readonly or not db:readonly()) then db = false end end if db then return self end local skipexec = db == false local code, msg db, code, msg = sqlite3.open(self.name, flags) if not db then error(("%s (code: %d)"):format(msg, code)) end -- __gc handler on the DB object will close it, which can happen multiple times -- for forked connections and needs to be prevented, as closing one may affect -- the others due to the way POSIX advisory locks behave on file handlers): -- https://sqlite.org/howtocorrupt.html#_posix_advisory_locks_canceled_by_a_separate_thread_doing_close_ if debug.getmetatable(db) then debug.getmetatable(db).__gc = nil end db:busy_timeout(1000) -- configure wait on busy DB to allow serialized writes -- skipexec indicates that a shortcut can be taken to set up the DB, -- but the pragmas still need to be processed to have the correct configuration local pragmas = skipexec and self.sql and getPragmas(self.sql) if pragmas and db:exec(pragmas) > 0 or not skipexec and self.sql and db:exec(self.sql) > 0 then error("can't setup db: "..db:errmsg()) end self.db = db self.prepcache = {} self.pid = unix.getpid() return setmetatable(self, dbmt) end local function norm(sql) return (sql:gsub("%-%-[^\n]*\n?",""):gsub("^%s+",""):gsub("%s+$",""):gsub("%s+"," ") :gsub("%s*([(),])%s*","%1"):gsub('"(%w+)"',"%1")) end local function prepstmt(dbm, stmt) if not dbm.prepcache[stmt] then assert(dbm.db) local st, tail = dbm.db:prepare(stmt) -- if there is tail, then return as is, don't cache if st and tail and #tail > 0 then return st, tail end dbm.prepcache[stmt] = st end return dbm.prepcache[stmt] end function dbm:close() -- only close the DB when done from the same process that opened it -- to avoid closing the DB from a forked process, -- which is likely to mess up POSIX locks local db = self.db if db and self.pid == unix.getpid() then self.db = false -- mark it as re-openable return db:close() end end local function fetch(self, query, one, ...) -- re-open the connection if this is a forked process; see comment in `dbm:init()` if self.pid ~= unix.getpid() then self:init(true) end if not self.db then self:init() end local trace = self.opts.trace local start = trace and getTimeNano() local rows = {} local stmt, tail = query, nil repeat if type(stmt) == "string" then stmt, tail = prepstmt(self, stmt) end if not stmt then return nil, "can't prepare: "..self.db:errmsg() end -- if the last statement is incomplete if not stmt:isopen() then break end -- if the first parameter is a table, then use it to bind parameters by name local tbl = select(1, ...) if (type(tbl) == "table" and stmt:bind_names(tbl) or stmt:bind_values(...)) ~= sqlite3.OK then return nil, "can't bind values: "..self.db:errmsg() end for row in stmt:nrows() do table.insert(rows, row) if one then break end end stmt:reset() stmt = tail -- get multi-statement ready for processing until (one or not tail) if trace then trace(self, query, {...}, getTimeDiff(start)) end if one == nil then return self.db:changes() end -- return execute results -- return self.NONE instead of an empty table to indicate no rows return not one and (rows[1] and rows or self.NONE) or rows[1] or self.NONE end local function exec(self, stmt, ...) return fetch(self, stmt, nil, ...) end local function dberr(db) return nil, db:errmsg() end function dbm:execute(list, ...) -- if the first parameter is not table, use regular `exec` if type(list) ~= "table" then return exec(self, list, ...) end -- re-open the connection if this is a forked process; see comment in `dbm:init()` if self.pid ~= unix.getpid() then self:init(true) end if not self.db then self:init() end local db = self.db local changes = 0 if db:exec("savepoint execute") ~= sqlite3.OK then return dberr(db) end for _, sql in ipairs(list) do if type(sql) ~= "table" then sql = {sql} end local ok, err = exec(self, unpack(sql)) if not ok then if db:exec("rollback to execute") ~= sqlite3.OK then return dberr(db) end return nil, err end changes = changes + ok end if db:exec("release execute") ~= sqlite3.OK then return dberr(db) end return changes end function dbm:fetchAll(stmt, ...) return fetch(dbm, stmt, false, ...) end function dbm:fetchOne(stmt, ...) return fetch(dbm, stmt, true, ...) end function dbm:pragma(stmt) local pragma = stmt:match("[_%w]+") if not self.pragmas[pragma] then if self:fetchOne("select * from pragma_pragma_list() where name = ?", pragma or "") == self.NONE then return nil, "missing or invalid pragma name" end self.pragmas[pragma] = true end local row = self:fetchOne("PRAGMA "..stmt) if not row then return nil, self.db:errmsg() end return select(2, next(row)) or self.NONE end obsolete(dbm, "fetchone", "fetchOne", "0.40") obsolete(dbm, "fetchall", "fetchAll", "0.40") --[[-- dbm upgrade --]]-- function dbm:upgrade(opts) opts = opts or {} local actual = self.db and self or error("can't upgrade non initialized db") local pristine = makeStorage(":memory:", self.sql) local sqltbl = [[SELECT name, sql FROM sqlite_schema WHERE type = 'table' AND name not like 'sqlite_%']] local ok, err local changes, legacyalter = {}, false local actbl, prtbl = {}, {} for r in pristine:nrows(sqltbl) do prtbl[r.name] = r.sql end for r in actual:nrows(sqltbl) do actbl[r.name] = true if prtbl[r.name] then if norm(r.sql) ~= norm(prtbl[r.name]) then local namepatt = '%f[^%s"]'..r.name:gsub("%p","%%%1")..'%f[%s"(]' local tmpname = r.name.."__new" local createtbl = prtbl[r.name]:gsub(namepatt, tmpname, 1) table.insert(changes, createtbl) local sqlcol = ("PRAGMA table_info(%s)"):format(r.name) local common, prcol = {}, {} for c in pristine:nrows(sqlcol) do prcol[c.name] = true end for c in actual:nrows(sqlcol) do if prcol[c.name] then table.insert(common, c.name) elseif not opts.delete then err = err or ("Not allowed to remove '%s' from '%s'; %s" ):format(c.name, r.name, msgdelete) end end local cols = table.concat(common, ",") table.insert(changes, ("INSERT INTO %s (%s) SELECT %s FROM %s") :format(tmpname, cols, cols, r.name)) table.insert(changes, ("DROP TABLE %s"):format(r.name)) table.insert(changes, ("ALTER TABLE %s RENAME TO %s"):format(tmpname, r.name)) legacyalter = true end else if opts.delete == nil then err = err or ("Not allowed to drop table '%s'; %s" ):format(r.name, msgdelete) end if opts.delete == true then table.insert(changes, ("DROP table %s"):format(r.name)) end end end if err then return nil, err end -- `alter table` may require legacy_alter_table pragma -- if depending triggers/views exist -- see https://sqlite.org/forum/forumpost/0e2390093fbb8fd6 -- and https://www.sqlite.org/pragma.html#pragma_legacy_alter_table if legacyalter then table.insert(changes, 1, "PRAGMA legacy_alter_table=1") table.insert(changes, "PRAGMA legacy_alter_table=0") end for k in pairs(prtbl) do if not actbl[k] then table.insert(changes, prtbl[k]) end end local sqlidx = [[SELECT name, sql, type FROM sqlite_schema WHERE type in ('index', 'trigger', 'view') AND name not like 'sqlite_%']] actbl, prtbl = {}, {} for r in pristine:nrows(sqlidx) do prtbl[r.type..r.name] = r.sql end for r in actual:nrows(sqlidx) do actbl[r.type..r.name] = true if prtbl[r.type..r.name] then if r.sql ~= prtbl[r.type..r.name] then table.insert(changes, ("DROP %s IF EXISTS %s"):format(r.type, r.name)) table.insert(changes, prtbl[r.type..r.name]) end else table.insert(changes, ("DROP %s IF EXISTS %s"):format(r.type, r.name)) end end for k in pairs(prtbl) do if not actbl[k] then table.insert(changes, prtbl[k]) end end -- get the current value of `PRAGMA foreign_keys` to restore if needed local acpfk = assert(actual:pragma"foreign_keys") -- get the pristine value of `PRAGMA foreign_keys` to set later local prpfk = assert(pristine:pragma"foreign_keys") if opts.integritycheck ~= false then local ic = self:pragma"integrity_check(1)" if ic ~= "ok" then return nil, ic end -- check existing foreign key violations if the foreign key setting is enabled local fkc = prpfk ~= "0" and self:pragma"foreign_key_check" if fkc and fkc ~= self.NONE then return nil, "foreign key check failed" end end if opts.dryrun then return changes end if #changes == 0 then return changes end -- disable `pragma foreign_keys`, to avoid triggerring cascading deletes ok, err = self:pragma"foreign_keys=0" if not ok then return ok, err end -- execute the changes (within a savepoint) ok, err = self:execute(changes) -- restore `PRAGMA foreign_keys` value: -- (1) to the original value after failure -- (2) to the "pristine" value after normal execution local pfk = "foreign_keys="..(ok and prpfk or acpfk) if self:pragma(pfk) and ok then table.insert(changes, "PRAGMA "..pfk) end if not ok then return ok, err end -- clean up the database ok, err = self:execute("VACUUM") if not ok then return ok, err end return changes end return dbm:init() end --[[-- hook management --]]-- local hooks = {} local function onHook(hookName, ...) for _, v in ipairs(hooks[hookName]) do local res = v[1](...) if res ~= nil then return res end end end local function findHook(hookName, suffix) for i, v in ipairs(hooks[hookName]) do if v[2] == suffix then return i, v end end end local function setHook(name, func) -- name: OnWorkerStart[.suffix] argerror(type(name) == "string", 1, "(string expected)") local main, suffix = name:match("([^.]+)%.?(.*)") -- register redbean hook even without handler; -- this is needed to set up a handler later, as for some -- hooks redbean only checks before the main loop is started if not hooks[main] then hooks[main] = {} local orig = _G[main] _G[main] = function(...) if orig then orig() end return onHook(main, ...) end end local idx, val = findHook(main, suffix) local res = val and val[1] local isQualified = #suffix > 0 if not func then -- remove the current hook if it's a fully qualified hook if idx and isQualified then table.remove(hooks[main], idx) end else -- set the new function local hook = {func, suffix} if idx and isQualified then -- update existing qualified hook hooks[main][idx] = hook else -- add a new one table.insert(hooks[main], hook) end end return res -- return the old hook value (if any) end --[[-- scheduling engine --]]-- local function expand(min, max, vals) local tbl = {MIN = min, MAX = max, ['*'] = min.."-"..max} for i = min, max do tbl[i] = vals and vals[i] or ("%02d"):format(i) end for k, v in pairs(vals or {}) do tbl[v] = k end return tbl end local expressions = { expand(0,59), expand(0,23), expand(1,31), expand(1,12, {"jan","feb","mar","apr","may","jun","jul","aug","sep","oct","nov","dec"}), expand(0,7, {[0]="sun","mon","tue","wed","thu","fri","sat","sun"}), } local function cron2hash(rec) local cronrec = {rec:lower():match("%s*(%S+)%s+(%S+)%s+(%S+)%s+(%S+)%s+(%S+)%s*")} local tbl = {{},{},{},{},{}} if #cronrec ~= #tbl then return nil, "invalid format" end for exppos, exps in ipairs(cronrec) do local map = expressions[exppos] for e in exps:gmatch("([^,]+)") do local exp = e:gsub("[^%d%-/]+", map) local min, rng, max, step = exp:match("^(%d+)(%-?)(%d*)/?(%d*)$") if not min then max, step = exp:match("^%-(%d+)/?(%d*)$") end if not min and not max then return nil, "invalid expression: "..e end min = math.max(map.MIN, tonumber(min) or map.MIN) max = math.min(map.MAX, tonumber(max) or #rng==0 and min or map.MAX) step = tonumber(step) or 1 for i = min, max, step do tbl[exppos][map[i]] = true end end end return tbl end local schedules, lasttime = {}, 0 local scheduleHook = "OnServerHeartbeat.fm-setSchedule" local function checkSchedule(time) local times = FormatHttpDateTime(time) local dow, dom, mon, h, m = times:lower():match("^(%S+), (%S+) (%S+) %S+ (%S+):(%S+):") for _, v in pairs(schedules) do local cront, func, sameproc = v[1], v[2], v[3] if cront[1][m] and cront[2][h] and cront[3][dom] and cront[4][mon] and cront[5][dow] then if sameproc or assert(unix.fork()) == 0 then local ok, err = pcall(func) if not ok then LogWarn("scheduled task failed: "..err) end if not sameproc then unix.exit(0) end end end end end local function scheduler() local time = math.floor(GetTime()/60)*60 if time == lasttime then return else lasttime = time end checkSchedule(time) end local function setSchedule(exp, func, opts) if type(exp) == "table" then opts, exp, func = exp, unpack(exp) end opts = opts or {} argerror(type(opts) == "table", 3, "(table expected)") local res, err = cron2hash(exp) argerror(res ~= nil, 1, err) schedules[exp] = {res, func, opts.sameProc} if not setHook(scheduleHook, scheduler) then -- first schedule hook if ProgramHeartbeatInterval then local min = 60*1000 if ProgramHeartbeatInterval() > min then ProgramHeartbeatInterval(min) end else LogWarn("OnServerHeartbeat is required for setSchedule to work,".. " but may not be available; you need redbean v2.0.16+.") end end end --[[-- filters --]]-- local function makeLastModified(asset) argerror(type(asset) == "string", 1, "(string expected)") local lastModified = GetLastModifiedTime(asset) return { function(ifModifiedSince) local isModified = (not ifModifiedSince or ParseHttpDateTime(ifModifiedSince) < lastModified) if isModified then getRequest().headers.LastModified = FormatHttpDateTime(lastModified) end return isModified end, otherwise = 304, -- serve 304 if not modified } end local trueval = function() return true end local validators = { msg = trueval, optional = trueval, minlen = function(s, num) return #tostring(s or "") >= num, "%s is shorter than "..num.." chars" end, maxlen = function(s, num) return #tostring(s or "") <= num, "%s is longer than "..num.." chars" end, pattern = function(s, pat) return tostring(s or ""):match(pat), "invalid %s format" end, test = function(s, fun) return fun(s) end, oneof = function(s, list) if type(list) ~= "table" then list = {list} end for _, v in ipairs(list) do if s == v then return true end end return nil, "%s must be one of: "..EncodeLua(list):sub(2, -2) end, } local function makeValidator(rules) argerror(type(rules) == "table", 1, "(table expected)") for i, rule in ipairs(rules) do argerror(type(rule) == "table", 1, "(table expected at position "..i..")") argerror(type(rule[1]) == "string", 1, "(rule with name expected at position "..i..")") argerror(not rule.test or type(rule.test) == "function", 1, "(rule with test as function expected at position "..i..")") end return setmetatable({ function(val) -- validator can be called in three ways: -- (1) directly with a params-like table passed -- (2) as a filter on an existing (scalar) field -- (3) as a filter on an non-existing field (to get request.params table) if val == nil then val = getRequest().params end -- case (3) if type(val) ~= "table" and #rules > 0 then -- case (2) -- convert the passed value into a hash based on the name in the first rule val = {[rules[1][1]] = val} end local errors = {} for _, rule in ipairs(rules) do repeat local param, err = rule[1], rule.msg local value = val[param] if value == nil and rule.optional == true then break end -- continue for checkname, checkval in pairs(rule) do if type(checkname) == "string" then local validator = validators[checkname] if not validator then argerror(false, 1, "unknown validator "..checkname) end local success, msg = validator(value, checkval) if not success then local key = rules.key and param or #errors+1 local errmsg = (err or msg or "%s check failed"):format(param) errors[key] = errors[key] or errmsg if not rules.all then -- report an error as a single message, unless key is asked for return nil, rules.key and errors or errmsg end end end end until true end if #errors > 0 or next(errors) then return nil, errors end return true end, otherwise = rules.otherwise, }, {__call = function(t, r) return t[1](r) end}) end --[[-- security --]]-- local function makeBasicAuth(authtable, opts) argerror(type(authtable) == "table", 1, "(table expected)") argerror(opts == nil or type(opts) == "table", 2, "(table expected)") opts = opts or {} local realm = opts.realm and (" Realm=%q"):format(opts.realm) or "" local hash, key = opts.hash, opts.key return { function(authorization) if not authorization then return false end local pass, user = GetPass(), GetUser() if not pass or not user or not authtable[user] then return false end if hash:upper() == "ARGON2" then return argon2.verify(authtable[user], pass) end return authtable[user] == (hash and GetCryptoHash(hash:upper(), pass, key) or pass) end, -- if authentication is not present or fails, return 401 otherwise = serveResponse(401, {WWWAuthenticate = "Basic" .. realm}), } end local function makeIpMatcher(list) if type(list) == "string" then list = {list} end argerror(type(list) == "table", 1, "(table or string expected)") local subnets = {} for _, ip in ipairs(list) do local v, neg = ip:gsub("^!","") local addr, mask = v:match("^(%d+%.%d+%.%d+%.%d+)/(%d+)$") if not addr then addr, mask = v, 32 end addr = ParseIp(addr) argerror(addr ~= -1, 1, ("(invalid IP address %s)"):format(ip)) mask = tonumber(mask) argerror(mask and mask >= 0 and mask <=32, 1, ("invalid mask in %s"):format(ip)) mask = ~0 << (32 - mask) -- apply mask to addr in case addr/subnet is not properly aligned table.insert(subnets, {addr & mask, mask, neg > 0}) end return function(ip) if ip == -1 then return false end -- fail the check on invalid IP for _, v in ipairs(subnets) do local match = v[1] == (ip & v[2]) if match then return not v[3] end end return false end end --[[-- core engine --]]-- local function error2tmpl(status, reason, message) if not reason then reason = GetHttpReason(status) end SetStatus(status, reason) -- set status, but allow template handlers to overwrite it local ok, res = pcall(render, tostring(status), {status = status, reason = reason, message = message}) if not ok and status ~= 500 and not res:find("unknown template name") then error(res) end return ok and res or ServeError(status, reason) or true end local function checkPath(path) return type(path) == "string" and path or GetPath() end local fm = setmetatable({ _VERSION = VERSION, _NAME = NAME, _COPYRIGHT = "Paul Kulchenko", reg2x = reg2x, reg1p = reg1p, getBrand = function() return ("%s/%s %s/%s"):format("redbean", getRBVersion(), NAME, VERSION) end, setTemplate = setTemplate, setTemplateVar = setTemplateVar, setRoute = setRoute, setSchedule = setSchedule, setHook = setHook, parseMultipart = parseMultipart, makeStorage = makeStorage, makePath = makePath, makeUrl = makeUrl, makeBasicAuth = makeBasicAuth, makeIpMatcher = makeIpMatcher, makeLastModified = makeLastModified, makeValidator = makeValidator, getAsset = LoadAsset, getRequest = getRequest, render = render, -- options cookieOptions = {HttpOnly = true, SameSite = "Strict"}, sessionOptions = {name = "fullmoon_session", hash = "SHA256", secret = true, format = "lua"}, -- serve* methods that take path can be served as a route handler (with request passed) -- or as a method called from a route handler (with the path passed); -- serve index.lua or index.html if available; continue if not serveIndex = function(path) return function() return ServeIndex(checkPath(path)) end end, -- handle and serve existing path, including asset, Lua, folder/index, and pre-configured redirect servePath = function(path) return function() return RoutePath(checkPath(path)) end end, -- return asset (de/compressed) along with checking for asset range and last/not-modified serveAsset = function(path) return function() return ServeAsset(checkPath(path)) end end, serveError = function(status, reason, msg) return function() return error2tmpl(status, reason, msg) end end, serveContent = function(tmpl, params) return function() return render(tmpl, params) end end, serveRedirect = function(status, loc) return function() -- if no status or location is specified, then redirect to the original URL with 303 -- this is useful for switching to GET after POST/PUT to an endpoint -- in all other cases, use the specified status or 307 (temp redirect) local ts, tl = type(status), type(loc) -- swap parameters if needed (as they used to be in a different order before 0.368) -- and also allow both status and location be optional if (ts == "string" or ts == "nil") and (tl == "nil" or tl == "number") then status, loc = loc, status end return ServeRedirect(status or loc and 307 or 303, loc or GetPath()) end end, serveResponse = serveResponse, }, {__index = function(t, key) local function cache(f) t[key] = f return f end local method = key:match("^[A-Z][A-Z][A-Z]+$") if method then return cache(function(route) if type(route) == "string" then return {route, method = method} end argerror(type(route) == "table", 1, "(string or table expected)") route.method = method return route end) end -- handle serve204 and similar calls local serveStatus = key:match("^serve(%d%d%d)$") if serveStatus then return cache(t.serveError(tonumber(serveStatus))) end -- handle logVerbose and other log calls local kVal = not _G[key] and _G[key:gsub("^l(og%w*)$", function(name) return "kL"..name end)] if kVal then return cache(function(...) return Log(kVal, logFormat(...)) end) end -- return upper camel case version if exists return cache(_G[key] or _G[key:sub(1,1):upper()..key:sub(2)]) end}) local isfresh = {} -- some unique key value local function deleteCookie(name, copts) local maxage, MaxAge = copts.maxage, copts.MaxAge copts.maxage, copts.MaxAge = 0, nil SetCookie(name, "", copts) copts.maxage, copts.MaxAge = maxage, MaxAge end local function getSessionOptions() local sopts = fm.sessionOptions or {} if not sopts.name then error("missing session name") end if not sopts.hash then error("missing session hash") end if not sopts.format then error("missing session format") end -- check for session secret and hash if sopts.secret and type(sopts.secret) ~= "string" then error("sessionOptions.secret is expected to be a string") end return sopts end local function setSession(session) -- if the session hasn't been touched (read or updated), do nothing if session and session[isfresh] then return end local sopts = getSessionOptions() local cookie if session and next(session) then local msg = EncodeBase64(EncodeLua(session)) local sig = EncodeBase64( GetCryptoHash(sopts.hash, msg, sopts.secret or "")) cookie = msg.."."..sopts.format.."."..sopts.hash.."."..sig end local copts = fm.cookieOptions or {} if cookie then SetCookie(sopts.name, cookie, copts) else fm.logDebug("delete session cookie") deleteCookie(sopts.name, copts) end end local function getSession() local sopts = getSessionOptions() local session = GetCookie(sopts.name) if not session then return {} end local msg, format, hash, sig = session:match("(.-)%.(.-)%.(.-)%.(.+)") if not msg then return {} end if not pcall(GetCryptoHash, hash, "") then LogWarn("invalid session crypto hash: "..hash) return {} end if DecodeBase64(sig) ~= GetCryptoHash(hash, msg, sopts.secret) then LogWarn("invalid session signature: "..sig) return {} end if format ~= "lua" then LogWarn("invalid session format: "..format) return {} end local ok, val = loadsafe("return "..DecodeBase64(msg)) if not ok then LogWarn("invalid session content: "..val) end return ok and val or {} end local function setHeaders(headers) for name, value in pairs(headers or {}) do local val = tostring(value) if type(value) ~= "string" and type(value) ~= "number" then LogWarn("header '%s' is assigned non-string value '%s'", name, val) end local hname = headerMap[name] or name local noheader = noHeaderMap[hname:lower()] if not noheader or val:lower() == noheader then SetHeader(hname, val) else LogDebug("header '%s' with value '%s' is skipped to avoid conflict", name, val) end end end local function setCookies(cookies) local copts = fm.cookieOptions or {} for cname, cvalue in pairs(cookies or {}) do local value, opts = cvalue, copts if type(cvalue) == "table" then value, opts = cvalue[1], cvalue end if value == false then deleteCookie(cname, opts) else SetCookie(cname, value, opts) end end end -- call the handler and handle any Lua error by returning Server Error local function hcall(func, ...) local co = type(func) == "thread" and func or coroutine.create(func) local ok, res, more = coroutine.resume(co, ...) -- return normally on successful execution or an unsuccessful one -- that returns a function (to allow calling `error(fm.serve404)`) if ok or type(res) == "function" then return coroutine.status(co) == "suspended" and co or false, res, more end local err = debug.traceback(co, res) Log(kLogError, logFormat("Lua error: %s", err)) return false, error2tmpl(500, nil, IsLoopbackIp(GetRemoteAddr()) and err or nil) end local MPKEY = "multipart" local function handleRequest(path) path = path or GetPath() req = setmetatable({ params = setmetatable({}, {__index = function(t, k) local mk = k.."[]" -- if the multi-key exists, then use it instead if not HasParam(k) and HasParam(mk) then k = mk end if not HasParam(k) and not rawget(t, MPKEY) then local ct = GetHeader("Content-Type") if not ct or not string.find(ct, "^multipart/") then return end -- check the multipart body for the requested parameter t[MPKEY] = parseMultipart(GetBody(), ct) -- return pseudo parameter with the parsed multipart message; -- subsequent calls will retrieve it from the table directly if k == MPKEY then return t[MPKEY] end end local mp = rawget(t, MPKEY) if mp then local m = mp[k] or mp[mk] -- if a multipart parameter, then return individual value or list of elements if m then return m.data or m end end -- GetParam may return `nil` for empty parameters (`foo` in `foo&bar=1`), -- but `params` needs to return `false` instead if not string.find(k, MULTIVAL) then return GetParam(k) or false end local array={} for _, v in ipairs(GetParams()) do if v[1] == k then table.insert(array, v[2] or false) end end return array end}), -- check headers table first to allow using `.ContentType` instead of `["Content-Type"]` headers = setmetatable({}, {__index = function(_, k) return GetHeader(headerMap[k] or k) end}), cookies = setmetatable({}, {__index = function(_, k) return GetCookie(k) end}), session = setmetatable({[isfresh] = true}, { __index = function(t, k) if t[isfresh] == true then t[isfresh] = getSession() end return t[isfresh] and t[isfresh][k] end, __newindex = function(t, k, v) if t[isfresh] then -- copy the already processed table if available req.session = type(t[isfresh]) == "table" and t[isfresh] or getSession() end req.session[k] = v end, }), }, tmplReqHandlerEnv) SetStatus(200) -- set default status; can be reset later -- find a match and handle any Lua errors in handlers local co, res, headers = hcall(matchRoute, path, req) -- execute the (deferred) function and handle any errors while type(res) == "function" do co, res, headers = hcall(res) end local tres = type(res) if res == true then -- do nothing, as this request was already handled elseif not res and not co then -- this request wasn't handled, so report 404 return error2tmpl(404) -- use 404 template if available elseif tres == "string" then if #res > 0 then if not headers then headers = {ContentType = detectType(res)} end Write(res) -- output content as is end elseif not co then LogWarn("unexpected result from action handler: '%s' (%s)", tostring(res), tres) end -- set the headers as returned by the render if type(headers) == "table" then if not req.headers then req.headers = {} end for name, value in pairs(headers) do req.headers[name] = value end elseif headers then LogWarn("non-table headers returned from action handler (%s)", tostring(headers)) end setHeaders(req.headers) -- output specified headers setCookies(req.cookies) -- output specified cookies setSession(req.session) -- add a session cookie if needed while co do coroutine.yield() co, res = hcall(co) -- if the function is returned, which may happen if serve* is used -- as the last call, then process it to get its result while type(res) == "function" do co, res = hcall(res) end if type(res) == "string" then Write(res) end end end local function streamWrap(func) return function(...) return coroutine.yield(func(...)()) or true end end fm.streamResponse = streamWrap(fm.serveResponse) fm.streamContent = streamWrap(fm.serveContent) -- add internal functions for test support fm.test = { reqenv = reqenv, route2regex = route2regex, routes = routes, matchRoute = matchRoute, handleRequest = handleRequest, getRequest = getRequest, headerMap = headerMap, detectType = detectType, matchCondition = matchCondition, setSession = setSession, } function fm.run(opts) opts = opts or {} ProgramBrand(fm.getBrand()) -- configure logPath first to capture all subsequent messages -- in the log file, as the order is randomized otherwise local logpath = opts.logPath or opts.LogPath if logpath then ProgramLogPath(logpath) end for key, v in pairs(opts) do if key == "headers" and type(v) == "table" then for h, val in pairs(v) do ProgramHeader(headerMap[h] or h, val) end elseif key:find("Options$") and type(v) == "table" then -- if *Options is assigned, then overwrite the provided default if fm[key] then fm[key] = opts[key] else -- if there is no default, it's some wrong option argerror(false, 1, ("(unknown option '%s')"):format(key)) end else local name = "Program"..key:sub(1,1):upper()..key:sub(2) if name ~= "ProgramLogPath" then -- this is already handled earlier local func = _G[name] argerror(type(func) == "function", 1, ("(unknown option '%s' with value '%s')"):format(key, v)) for _, val in pairs(type(v) == "table" and v or {v}) do -- accept a table to pass multiple values and unpack it func(unpack(type(val) ~= "table" and {val} or val)) end end end end if GetLogLevel then local level, none = GetLogLevel(), function() end if level < kLogWarn then LogWarn = none end if level < kLogInfo then LogInfo = none end if level < kLogVerbose then LogVerbose = none end if level < kLogDebug then LogDebug = none end end LogInfo("started "..fm.getBrand()) local sopts = fm.sessionOptions if sopts.secret == true then sopts.secret = GetRandomBytes(32) LogVerbose("applied random session secret; set `fm.sessionOptions.secret`" ..(" to `fm.decodeBase64('%s')` to continue using this value") :format(EncodeBase64(sopts.secret)) .." or to `false` to disable") end -- assign Redbean handler to execute on each request OnHttpRequest = function() handleRequest(GetPath()) end collectgarbage() -- clean up no longer used memory to reduce image size end -- setTemplate will do some logging, so provide the log function -- if this is executed outside of redbean Log = Log or function() end fm.setTemplate("fmt", { parser = function (tmpl) local EOT = "\0" local function writer(s) return #s > 0 and ("Write(%q)"):format(s) or "" end local function decomment(s) return s:match("%[=*%[") and "--"..s or "" end local tupd = (tmpl.."{%"..EOT.."%}"):gsub("(.-){%%([=&]*)%s*(.-)%s*%%}", function(htm, pref, val) return writer(htm) ..(val ~= EOT -- this is not the suffix and (pref == "" -- this is a code fragment and val:gsub("%-%-([^\r\n]*)", decomment).." " or ("Write(%s(tostring(%s or '')))") :format(pref == "&" and "escapeHtml" or "", val:gsub("%-%-(.*)", decomment))) or "") end) return tupd end, function() end, }) fm.setTemplate("500", default500) -- register default 500 status template fm.setTemplate("json", {ContentType = "application/json", function(val) return EncodeJson(val, {useoutput = true}) end}) fm.setTemplate("sse", function(val) argerror(type(val) == "table", 1, "(table expected)") if #val == 0 then val = {val} end for _, event in ipairs(val) do for etype, eval in pairs(event) do Write(("%s: %s\n"):format( etype == "comment" and "" or etype, etype == "data" and eval:gsub("\n", "\ndata: ") or eval )) end Write("\n") end return "", { ContentType = "text/event-stream", CacheControl = "no-store", ["X-Accel-Buffering"] = "no", } end) local function fmgRender(val) argerror(type(val) == "table", 1, "(table expected)") local function writeAttrs(opt) local doneattr = false for attrname, attrval in pairs(opt) do if type(attrname) == "string" then local valtype = type(attrval) local escape = not(valtype == "table" and attrval[1] == "raw") if valtype == "table" then -- this handles `_=raw"some"` if #attrval > 1 then attrval = attrval[2] else -- the following turns `tx={post="x", get="y"}` -- into `["tx-post"]="x", ["tx-get"]="y"` for k, v in pairs(attrval) do if type(k) == "string" then if escape then v = EscapeHtml(v) end Write((' %s="%s"'):format(attrname.."-"..k, v)) end end end elseif attrval == true then -- this turns `checked=true` into `checked="checked"` attrval = attrname elseif attrval == false then -- write nothing here end if type(attrval) == "string" or type(attrval) == "number" then if escape then attrval = EscapeHtml(attrval) end Write((' %s="%s"'):format(attrname, attrval)) end doneattr = true end end return doneattr end local function writeVal(opt, escape) if type(opt) == "function" then opt = opt() end if type(opt) == "table" then local tag = opt[1] argerror(tag ~= nil, 1, "(tag name expected)") if tag == "include" then return(fm.render(opt[2], opt[3])) end if tag == "raw" then for i = 2, #opt do writeVal(opt[i], false) end return end if tag == "each" then -- rewrite messages to point to `each` function argerror(type(opt[2]) == "function", 1, "(function expected)", "each") argerror(type(opt[3]) == "table", 2, "(table expected)", "each") for _, v in ipairs(opt[3]) do writeVal(opt[2](v), false) end return end if tag:lower() == "doctype" then Write("") return end if getmetatable(opt) and not htmlVoidTags[tag:lower()] then LogWarn("rendering '%s' with `nil` value", tag) return end Write("<"..tag) -- the following allows both `{"elem", attr = 1, {"sub-elem"}}` -- and `{"elem", {attr = 1}, {"sub-elem"}}` to be handled; -- It allows Lua-base languages that don't mix arrays and hashes -- (like fennel) to use the latter syntax and -- requires no attributes to be present and the table in the second -- index to only have attributes and no array values local validx = (not writeAttrs(opt) -- if there are no attributes and type(opt[2]) == "table" -- and the second element is a table and #opt[2] == 0 -- with no values and writeAttrs(opt[2]) -- that has attributes and 3 or 2) -- then shift sub-element processing by one if htmlVoidTags[tag:lower()] then Write("/>") return end Write(">") local escape = tag ~= "script" for i = validx, #opt do writeVal(opt[i], escape) end Write("") else local val = tostring(opt or "") -- escape by default if not requested not to if escape ~= false then val = EscapeHtml(val) end Write(val) end end for _, v in pairs(val) do writeVal(v) end end fm.setTemplate("fmg", { autotag = true, parser = function(s, path) local code = setfenv(assert(load("return "..s, path, "t")), getfenv()) return function() fmgRender(assert(code())) end end, fmgRender, }) fm.setTemplate("cgi", function(cmd) if not cmd or not cmd[1] then error('missing command') end local nph = cmd.nph local maxtime = cmd.maxtime or 5 local env = {} local envu = cmd.env or {} local envd = #envu > 0 and envu or { SERVER_ADDR = FormatIp(GetServerAddr()), SERVER_PORT = select(2, GetServerAddr()), SERVER_SOFTWARE = fm.getBrand(), REMOTE_HOST = GetHost(), REMOTE_ADDR = FormatIp(GetRemoteAddr()), REMOTE_INDENT = ParseUrl(GetUrl()).user, REQUEST_SCHEME = GetScheme(), REQUEST_METHOD = GetMethod(), GATEWAY_INTERFACE = "CGI/1.1", SERVER_PROTOCOL = ("HTTP/%01.1f"):format(GetHttpVersion()/10), QUERY_STRING = EncodeUrl({params = ParseUrl(GetUrl()).params}):sub(2), REQUEST_URI = ( function(u) return EncodeUrl({params = u.params, path = u.path})end)( ParseUrl(GetUrl())), HTTPS = GetSslIdentity() and "on" or nil, PATH_INFO = GetPath(), PATH_TRANSLATED = GetEffectivePath(), CONTENT_TYPE = GetHeader("Content-Type"), CONTENT_LENGTH = GetHeader("Content-Length"), } if #envu == 0 then -- provide all the headers for k, v in pairs(GetHeaders()) do envd["HTTP_"..k:upper()] = v end -- overwrite all defaults for k, v in pairs(envu) do envd[k] = v end -- convert to strings unless the value is false for k, v in pairs(envd) do if v then table.insert(env, k.."="..v) end end end local rfd1, wfd1 = unix.pipe(unix.O_CLOEXEC) local rfd2, wfd2 = unix.pipe(unix.O_CLOEXEC) local pid = unix.fork() if pid == 0 then -- forked child assert(unix.close(GetClientFd())) -- close client fd to not send anything assert(unix.dup(rfd2, 0)) -- redirect stdin assert(unix.dup(wfd1, 1)) -- redirect stdout assert(unix.dup(wfd1, 2)) -- redirect stderr assert(unix.close(wfd1)) assert(unix.close(rfd1)) assert(unix.close(wfd2)) assert(unix.close(rfd2)) assert(unix.sigaction(unix.SIGQUIT, unix.exit)) unix.execve(cmd[1], cmd, env) unix.exit(127) else unix.write(wfd2, GetBody()) local isactive = true local header = true local done = false assert(unix.sigaction(unix.SIGALRM, function() isactive = false end)) assert(unix.setitimer(unix.ITIMER_REAL, 0, 0, maxtime, 0)) local cfd = GetClientFd() while true do -- block for a bit until there is output from the launched process local se = (unix.poll({[rfd1] = unix.POLLIN}, maxtime*1000/10) or {})[rfd1] or 0 if se & (unix.POLLHUP + unix.POLLERR) > 0 then break end -- check if the client is (still) writable local re = (unix.poll({[cfd] = unix.POLLOUT}) or {})[cfd] or 0 if re & (unix.POLLHUP + unix.POLLERR) > 0 then break end -- check if the process took too long to respond if not isactive then break end local data, errno if se & unix.POLLIN > 0 then data, errno = unix.read(rfd1) end -- check for end of file or any descriptor errors if data == "" or errno and errno:errno() == unix.EBADF then break end if data then if header then local pos = 1 while true do -- allow both CRLF and LF to be handled local spos, epos = data:find("\r?\n", pos, false) if not spos then break end -- found an empty string, which signals the end of headers if spos == pos then header = false; pos = epos + 1; break end local status, reason if pos == 1 then status, reason = data:sub(pos, spos-1):match("^HTTP/%d%.%d%s+(%d%d%d)%s+(.+)") end if status then -- found the status line SetStatus(status, reason) else -- found something that may be a header local name, value = data:sub(pos, spos-1):match("^(%a[%w_-]+):%s*(.+)") -- check if the header has a valid syntax -- if not, skip it and handle it as part of the data; -- this may happen if an error is thrown instead of a valid context local ok, err = pcall(SetHeader, name, value) if not name or not value or not ok then header = false break end end pos = epos + 1 end -- finished parsing headers, so remove them from the data if not header then -- if nph is requested, then close the connection when done if nph then SetHeader("Connection", "close") end if pos > 1 then data = data:sub(pos) end end end -- if the headers are set, write the response if not header then Write(data) if nph then coroutine.yield() end end -- check if reading was interrupted or wasn't done elseif not data and (not errno or errno:errno() == unix.EINTR) then -- closing connection will fail on send -- if it doesn't fail, then redo the interrupted read else -- report an error and stop -- TODO: log error break end -- stop reading if the process is already done done = not unix.wait(pid, unix.WNOHANG) if done then break end end unix.close(rfd1) unix.close(wfd1) unix.close(rfd2) unix.close(wfd2) -- kill launched process if it's still running, as it's been running for too long if not done then assert(unix.kill(pid, unix.SIGTERM)) end end end) return fm