Skip to content

Update to new lua 5.1 module syntax and make luacheck happy #10

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .luacheckrc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
std = "ngx_lua"
81 changes: 38 additions & 43 deletions lib/resty/postgres.lua
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ local bit = bit
local ngx = ngx
local tonumber = tonumber
local setmetatable = setmetatable
local error = error

module(...)
local _M = {}

_VERSION = '0.2'
_M.VERSION = '0.2'

local STATE_CONNECTED = 1
local STATE_COMMAND_SENT = 2
Expand All @@ -34,7 +33,7 @@ converters[701] = tonumber
-- NUMERICOID
converters[1700] = tonumber

function new(self)
function _M.new()
local sock, err = ngx.socket.tcp()
if not sock then
return nil, err
Expand All @@ -43,7 +42,7 @@ function new(self)
return setmetatable({ sock = sock, env = {}}, mt)
end

function set_timeout(self, timeout)
function _M.set_timeout(self, timeout)
local sock = self.sock
if not sock then
return nil, "not initialized"
Expand All @@ -66,9 +65,9 @@ local function _get_data_n(data, len, i)
return d, i+len
end

local function _set_byte2(n)
return string.char(bit.band(bit.rshift(n, 8), 0xff), bit.band(n, 0xff))
end
-- local function _set_byte2(n)
-- return string.char(bit.band(bit.rshift(n, 8), 0xff), bit.band(n, 0xff))
-- end

local function _set_byte4(n)
return string.char(bit.band(bit.rshift(n, 24), 0xff), bit.band(bit.rshift(n, 16), 0xff),
Expand All @@ -87,7 +86,7 @@ local function _to_cstring(data)
return {data, "\0"}
end

function _send_packet(self, data, len, typ)
local function _send_packet(self, data, len, typ)
local sock = self.sock
local packet
if typ then
Expand All @@ -105,7 +104,7 @@ function _send_packet(self, data, len, typ)
return sock:send(packet)
end

function _parse_error_packet(packet)
local function _parse_error_packet(packet)
local pos = 1
local flg, value, msg
msg = {}
Expand All @@ -128,15 +127,16 @@ function _parse_error_packet(packet)
return msg
end

function _recv_packet(self)
local function _recv_packet(self)
-- receive type
local sock = self.sock
local typ, err = sock:receive(1)
if not typ then
return nil, nil, "failed to receive packet type: " .. err
end
-- receive length
local data, err = sock:receive(4)
local data
data, err = sock:receive(4)
if not data then
return nil, nil , "failed to read packet length: " .. err
end
Expand All @@ -152,26 +152,25 @@ function _recv_packet(self)
return data, typ
end

function _compute_token(self, user, password, salt)
local function _compute_token(user, password, salt)
local token1 = ngx.md5(password .. user)
local token2 = ngx.md5(token1 .. salt)
return "md5" .. token2
end

function connect(self, opts)
function _M.connect(self, opts)
local sock = self.sock
if not sock then
return nil, "not initialized"
end

local ok, err

self.compact = opts.compact

local host = opts.host
local database = opts.database or ""
local user = opts.user or ""
local host = opts.host
local pool = opts.pool
local password = opts.password

Expand Down Expand Up @@ -217,10 +216,11 @@ function connect(self, opts)
-- packet_len + PG_PROTOCOL + user + database + end
-- req_len = 4 + 4 + string.len(user) + 6 + string.len(database) + 10 + 1
req_len = string.len(user) + string.len(database) + 25
local bytes, err = _send_packet(self, req, req_len)
local bytes
bytes, err = _send_packet(self, req, req_len)
if not bytes then
return nil, "failed to send client authentication packet1: " .. err
end
end
-- receive salt packet (len + data) no type
local packet, typ
packet, typ, err = _recv_packet(self)
Expand All @@ -230,12 +230,12 @@ function connect(self, opts)
if typ ~= 'R' then
return nil, "handshake error, got packet type:" .. typ
end
local auth_type = string.sub(packet, 1, 4)
-- local auth_type = string.sub(packet, 1, 4)
local salt = string.sub(packet, 5, 8)
-- send passsowrd
req = {_to_cstring(_compute_token(self, user, password, salt))}
req = {_to_cstring(_compute_token(user, password, salt))}
req_len = 40
local bytes, err = _send_packet(self, req, req_len, 'p')
bytes, err = _send_packet(self, req, req_len, 'p')
if not bytes then
return nil, "failed to send client authentication packet2: " .. err
end
Expand All @@ -244,7 +244,9 @@ function connect(self, opts)
if typ ~= 'R' then
return nil, "auth return type not support"
end
if packet ~= AUTH_REQ_OK then
if not packet then
return nil, "read packet error:" .. err
elseif packet ~= AUTH_REQ_OK then
return nil, "authentication failed"
end
while true do
Expand All @@ -254,9 +256,10 @@ function connect(self, opts)
end
-- env
if typ == 'S' then
local k, v
local pos = 1
local k, pos = _from_cstring(packet, pos)
local v, pos = _from_cstring(packet, pos)
k, pos = _from_cstring(packet, pos)
v = _from_cstring(packet, pos)
self.env[k] = v
end
-- secret key
Expand All @@ -279,7 +282,7 @@ function connect(self, opts)
end
end

function set_keepalive(self, ...)
function _M.set_keepalive(self, ...)
local sock = self.sock
if not sock then
return nil, "not initialized"
Expand All @@ -292,15 +295,15 @@ function set_keepalive(self, ...)
return sock:setkeepalive(...)
end

function get_reused_times(self)
function _M.get_reused_times(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
end
return sock:getreusedtimes()
end

function close(self)
function _M.close(self)
local sock = self.sock
if not sock then
return nil, "not initialized"
Expand All @@ -309,7 +312,7 @@ function close(self)
return sock:close()
end

function send_query(self, query)
local function send_query(self, query)
if self.state ~= STATE_CONNECTED then
return nil, "cannot send query in the current context: "
.. (self.state or "nil")
Expand All @@ -327,7 +330,7 @@ function send_query(self, query)
return bytes, err
end

function read_result(self)
local function read_result(self)
if self.state ~= STATE_COMMAND_SENT then
return nil, "cannot read result in the current context: " .. self.state
end
Expand All @@ -348,7 +351,7 @@ function read_result(self)
-- packet of fields
if typ == 'T' then
local field_num, pos = _get_byte2(packet, 1)
for i=1, field_num do
for _=1, field_num do
local field = {}
field.name, pos = _from_cstring(packet, pos)
field.table_id, pos = _get_byte4(packet, pos)
Expand Down Expand Up @@ -388,7 +391,7 @@ function read_result(self)
local name = field.name
row[name] = data
end
end
end
table.insert(res, row)
end
if typ == 'E' then
Expand All @@ -408,11 +411,11 @@ function read_result(self)
self.state = STATE_CONNECTED
break
end
end
end
return res, err
end

function query(self, query)
function _M.query(self, query)
local bytes, err = send_query(self, query)
if not bytes then
return nil, "failed to send query: " .. err
Expand All @@ -421,17 +424,9 @@ function query(self, query)
return read_result(self)
end

function escape_string(str)
function _M.escape_string(str)
local new = string.gsub(str, "['\\]", "%0%0")
return new
end

local class_mt = {
-- to prevent use of casual module global variables
__newindex = function (table, key, val)
error('attempt to write to undeclared variable "' .. key .. '"')
end
}

setmetatable(_M, class_mt)

return _M