sandbox.lua 6.1 KB
Newer Older
1 2
-- Units
kB = 1024
3 4
MB = 1024*kB
GB = 1024*MB
5 6
-- Time
sec = 1000
7
second = sec
8
minute = 60 * sec
9
min = minute
10
hour = 60 * minute
11
day = 24 * hour
12

13 14 15 16 17 18 19 20 21 22 23
-- Logging
function panic(fmt, ...)
        error(string.format('error: '..fmt, ...))
end
function warn(fmt, ...)
        io.stderr:write(string.format(fmt..'\n', ...))
end
function log(fmt, ...)
        print(string.format(fmt, ...))
end

24 25
-- Resolver bindings
kres = require('kres')
26
trust_anchors = require('trust_anchors')
27
resolve = worker.resolve
28 29 30
if rawget(kres, 'str2dname') ~= nil then
	todname = kres.str2dname
end
31

32 33 34 35 36 37 38
-- Shorthand for aggregated per-worker information
worker.info = function ()
	local t = worker.stats()
	t.pid = worker.pid
	return t
end

39 40 41 42 43 44 45 46 47 48 49 50 51
-- Resolver mode of operation
local current_mode = 'normal'
local mode_table = { normal=0, strict=1, permissive=2 }
function mode(m)
	if not m then return current_mode end
	if not mode_table[m] then error('unsupported mode: '..m) end
	-- Update current operation mode
	current_mode = m
	option('STRICT', current_mode == 'strict')
	option('PERMISSIVE', current_mode == 'permissive')
	return true
end

52 53 54 55 56
-- Trivial option alias
function reorder_RR(val)
	return option('REORDER_RR', val)
end

57 58 59 60 61 62 63 64 65
-- Function aliases
-- `env.VAR returns os.getenv(VAR)`
env = {}
setmetatable(env, {
	__index = function (t, k) return os.getenv(k) end
})

-- Quick access to interfaces
-- `net.<iface>` => `net.interfaces()[iface]`
66
-- `net = {addr1, ..}` => `net.listen(name, addr1)`
67
-- `net.ipv{4,6} = {true, false}` => enable/disable IPv{4,6}
68 69 70 71
setmetatable(net, {
	__index = function (t, k)
		local v = rawget(t, k)
		if v then return v
72 73
		elseif k == 'ipv6' then return not option('NO_IPV6')
		elseif k == 'ipv4' then return not option('NO_IPV4')
74 75
		else return net.interfaces()[k]
		end
76 77
	end,
	__newindex = function (t,k,v)
78 79 80 81 82 83 84
		if     k == 'ipv6' then return option('NO_IPV6', not v)
		elseif k == 'ipv4' then return option('NO_IPV4', not v)
		else
			local iname = rawget(net.interfaces(), v)
			if iname then t.listen(iname)
			else t.listen(v)
			end
85
		end
86 87 88
	end
})

89 90
-- Syntactic sugar for module loading
-- `modules.<name> = <config>`
91
setmetatable(modules, {
92
	__newindex = function (t,k,v)
93
		if type(k) == 'number' then k = v v = nil end
94 95
		if not rawget(_G, k) then
			modules.load(k)
96
			k = string.match(k, '%w+')
97
			local mod = _G[k]
98
			local config = mod and rawget(mod, 'config')
99
			if mod ~= nil and config ~= nil then
Marek Vavruša's avatar
Marek Vavruša committed
100 101
				if k ~= v then config(v)
				else           config()
102
				end
103 104
			end
		end
105 106 107 108
	end
})

-- Syntactic sugar for cache
109 110
-- `#cache -> cache.count()`
-- `cache[x] -> cache.get(x)`
111 112
-- `cache.{size|storage} = value`
setmetatable(cache, {
113 114 115 116
	__len = function (t)
		return t.count()
	end,
	__index = function (t, k)
117 118 119
		if type(k) == 'number' then
			return rawget(t, k) or (rawget(t, 'current_size') and t.get(k))
		end
120
	end,
121
	__newindex = function (t,k,v)
122
		-- Defaults
123 124 125 126 127 128
		if type(k) == number then
			local storage = rawget(t, 'current_storage')
			if not storage then storage = 'lmdb://' end
			local size = rawget(t, 'current_size')
			if not size then size = 10*MB end
		end
129 130 131
		-- Declarative interface for cache
		if     k == 'size'    then t.open(v, storage)
		elseif k == 'storage' then t.open(size, v)
132
		else   rawset(t, k, v) end
133
	end
134 135
})

136 137 138 139
-- Syntactic sugar for TA store
setmetatable(trust_anchors, {
	__newindex = function (t,k,v)
	if     k == 'file' then t.config(v)
140
	elseif k == 'negative' then t.set_insecure(v)
141 142 143 144
	else   rawset(t, k, v) end
	end,
})

145 146 147 148 149 150 151
-- Register module in Lua environment
function modules_register(module)
	-- Syntactic sugar for get() and set() properties
	setmetatable(module, {
		__index = function (t, k)
			local  v = rawget(t, k)
			if     v     then return v
152
			elseif rawget(t, 'get') then return t.get(k)
153 154 155 156
			end
		end,
		__newindex = function (t, k, v)
			local  old_v = rawget(t, k)
157
			if not old_v and rawget(t, 'set') then
158 159 160 161 162 163
				t.set(k..' '..v)
			end
		end
	})
end

164
-- Make sandboxed environment
165
local function make_sandbox(defined)
166
	local __protected = { modules = true, cache = true, net = true, trust_anchors = true }
167 168 169 170 171 172 173 174

	-- Compute and export the list of top-level names (hidden otherwise)
	local nl = ""
	for n in pairs(defined) do
		nl = nl .. n .. "\n"
	end

	return setmetatable({ __orig_name_list = nl }, {
175 176 177 178 179 180 181 182 183 184 185 186
		__index = defined,
		__newindex = function (t, k, v)
			if __protected[k] then
				for k2,v2 in pairs(v) do
					defined[k][k2] = v2
				end
			else
				defined[k] = v
			end
		end
	})
end
187

188
-- Compatibility sandbox
189 190 191 192 193 194
if setfenv then -- Lua 5.1 and less
	_G = make_sandbox(getfenv(0))
	setfenv(0, _G)
else -- Lua 5.2+
	_SANDBOX = make_sandbox(_ENV)
end
195

196
-- Interactive command evaluation
197
function eval_cmd(line, raw)
198 199 200 201 202 203 204 205 206
	-- Compatibility sandbox code loading
	local function load_code(code)
	    if getfenv then -- Lua 5.1
	        return loadstring(code)
	    else            -- Lua 5.2+
	        return load(code, nil, 't', _ENV)
	    end
	end
	local status, err, chunk
207
	chunk, err = load_code(raw and 'return '..line or 'return table_print('..line..')')
208 209 210 211
	if err then
		chunk, err = load_code(line)
	end
	if not err then
212 213 214
		return chunk()
	else
		error(err)
215 216 217
	end
end

218 219 220 221
-- Pretty printing
function table_print (tt, indent, done)
	done = done or {}
	indent = indent or 0
222
	result = ""
223 224 225 226 227 228 229 230 231 232 233 234 235
	-- Convert to printable string (escape unprintable)
	local function printable(value)
		value = tostring(value)
		local bytes = {}
		for i = 1, #value do
			local c = string.byte(value, i)
			if c >= 0x20 and c < 0x7f then table.insert(bytes, string.char(c))
			else                           table.insert(bytes, '\\'..tostring(c))
			end
			if i > 50 then table.insert(bytes, '...') break end
		end
		return table.concat(bytes)
	end
236 237
	if type(tt) == "table" then
		for key, value in pairs (tt) do
238
			result = result .. string.rep (" ", indent)
239 240
			if type (value) == "table" and not done [value] then
				done [value] = true
241
				result = result .. string.format("[%s] => {\n", printable (key))
242
				result = result .. table_print (value, indent + 4, done)
243 244
				result = result .. string.rep (" ", indent)
				result = result .. "}\n"
245
			else
246
				result = result .. string.format("[%s] => %s\n",
247
				         tostring (key), printable(value))
248 249 250
			end
		end
	else
251
		result = result .. tostring(tt) .. "\n"
252
	end
253
	return result
254
end