Jump to content

Module:Fun: Difference between revisions

From Wikipedia, the free encyclopedia
Content deleted Content added
"range"
allow map to accept iterator like mapIter (redirect it to mapIter if second argument is function)
Line 59: Line 59:
-- map(function (char) return string.char(string.byte(char) - 0x20) end,
-- map(function (char) return string.char(string.byte(char) - 0x20) end,
-- "abc") --> { "A", "B", "C" }
-- "abc") --> { "A", "B", "C" }
function p.map(func, iterable)
function p.map(func, ...)
-- If second parameter is function, redirect to mapIter.
if type(...) == "function" then
return p.mapIter(func, ...)
end
local iterable = ...
local check = _check 'map'
local check = _check 'map'
check(1, func, "function")
check(1, func, "function")

Revision as of 00:09, 2 July 2018

local p = {}

local ustring = mw.ustring
local libraryUtil = require "libraryUtil"
local checkType = libraryUtil.checkType
local checkTypeMulti = libraryUtil.checkTypeMulti

local iterableTypes = { "table", "string" }

local _checkCache = {}
local function _check(funcName, expectType)
	if type(expectType) == "string" then
		return function(argIndex, arg, nilOk)
			return checkType(funcName, argIndex, arg, expectType, nilOk)
		end
	else
		local checkFunc = _checkCache[funcName] -- Lua 5.1 doesn't cache functions as Lua 5.3 does.
			or function(argIndex, arg, expectType, nilOk)
				if type(expectType) == "table" then
					if not (nilOk and arg == nil) then
						return checkTypeMulti(funcName, argIndex, arg, expectType)
					end
				else
					return checkType(funcName, argIndex, arg, expectType, nilOk)
				end
			end
		_checkCache[funcName] = checkFunc
		return checkFunc
	end
end

-- Iterate over UTF-8-encoded codepoints in string.
local function iterString(str)
	local iter = string.gmatch(str, "[%z\1-\127\194-\244][\128-\191]*")
	local i = 0
	local function iterator()
		i = i + 1
		local char = iter()
		if char then
			return i, char
		end
	end
	
	return iterator
end

local function getIterator(iterable)
	return type(iterable) == "string" and iterString
		or iterable[1] ~= nil and ipairs
		or pairs
end

function p.chain(func1, func2, ...)
	return func1(func2(...))
end

--	map(function(number) return number ^ 2 end,
--		{ 1, 2, 3 })									--> { 1, 4, 9 }
--	map(function (char) return string.char(string.byte(char) - 0x20) end,
--		"abc")											--> { "A", "B", "C" }
function p.map(func, ...)
	-- If second parameter is function, redirect to mapIter.
	if type(...) == "function" then
		return p.mapIter(func, ...)
	end
	
	local iterable = ...
	local check = _check 'map'
	check(1, func, "function")
	check(2, iterable, iterableTypes)
	
	local array = {}
	local iterator = getIterator(iterable)
	for i_or_k, val in iterator(iterable) do
		array[i_or_k] = func(val, i_or_k, iterable)
	end
	return array
end

function p.mapIter(func, iter, iterable, initVal)
	local check = _check 'mapIter'
	check(1, func, "function")
	check(2, iter, "function")
	check(3, iterable, iterableTypes, true)
	
	-- initVal could be anything
	
	local array = {}
	local i = 0
	for x, y in iter, iterable, initVal do
		i = i + 1
		array[i] = func(y, x, iterable)
	end
	return array
end

local function fold(func, iterable, result, checked)
	if not checked then
		local check = _check 'fold'
		check(1, func, "function")
		check(2, iterable, iterableTypes)
		-- Result can be anything.
	end
	
	local iterator = getIterator(iterable)
	for i_or_k, val in iterator(iterable) do
		result = func(result, val, i_or_k)
	end
	
	return result
end
p.fold = fold

function p.count(func, iterable)
	local check = _check 'count'
	check(1, func, "function")
	check(2, iterable, iterableTypes)
	
	return fold(
		function (count, val)
			if func(val) then
				return count + 1
			end
			return count
		end,
		iterable,
		0,
		true)
end

function p.forEach(func, iterable)
	local check = _check 'forEach'
	check(1, func, "function")
	check(2, iterable, iterableTypes)
	
	local iterator = getIterator(iterable)
	for i_or_k, val in iterator(iterable) do
		func(val, i_or_k, iterable)
	end
	return nil
end

-------------------------------------------------
-- From [[http://lua-users.org/wiki/CurriedLua]].
-- reverse(...) : take some tuple and return a tuple of elements in reverse order
--
-- e.g. "reverse(1,2,3)" returns 3,2,1
local function reverse(...)
	-- reverse args by building a function to do it, similar to the unpack() example
	local function reverseHelper(acc, v, ...)
		if select('#', ...) == 0 then
			return v, acc()
		else
			return reverseHelper(function() return v, acc() end, ...)
		end
	end
	
	-- initial acc is the end of the list
	return reverseHelper(function() return end, ...)
end

function p.curry(func, numArgs)
	-- currying 2-argument functions seems to be the most popular application
	numArgs = numArgs or 2
	
	-- no sense currying for 1 arg or less
	if numArgs <= 1 then return func end
	
	-- helper takes an argTrace function, and number of arguments remaining to be applied
	local function curryHelper(argTrace, n)
		if n == 0 then
			-- kick off argTrace, reverse argument list, and call the original function
			return func(reverse(argTrace()))
		else
			-- "push" argument (by building a wrapper function) and decrement n
			return function(onearg)
				return curryHelper(function() return onearg, argTrace() end, n - 1)
			end
		end
	end
	
	-- push the terminal case of argTrace into the function first
	return curryHelper(function() return end, numArgs)
end

-------------------------------------------------

--	some(function(val) return val % 2 == 0 end,
--		{ 2, 3, 5, 7, 11 })						--> true
function p.some(func, t)
	if t[1] ~= nil then -- array
		for i, v in ipairs(t) do
			if func(v, i, t) then
				return true
			end
		end
	else
		for k, v in pairs(t) do
			if func(v, k, t) then
				return true
			end
		end
	end
	return false
end

--	all(function(val) return val % 2 == 0 end,
--		{ 2, 4, 8, 10, 12 })					--> true
function p.all(func, t)
	if t[1] ~= nil then -- array
		for i, v in ipairs(t) do
			if not func(v, i, t) then
				return false
			end
		end
	else
		for k, v in pairs(t) do
			if not func(v, k, t) then
				return false
			end
		end
	end
	return true
end

function p.filter(func, iterable)
	local check = _check 'filter'
	check(1, func, "function")
	check(2, iterable, iterableTypes)
	
	local new_t = {}
	local new_i = 0
	local iterator = getIterator(iterable)
	for v1, v2 in iterator(iterable) do
		if func(v2, v1, t) then
			new_i = new_i + 1
			new_t[new_i] = v
		end
	end
	
	return new_t
end

function p.range(low, high)
	low = low - 1
	return function ()
		low = low + 1
		if low <= high then
			return low
		end
	end
end


-------------------------------
-- Fancy stuff
local function capture(...)
	local vals = { ... }
	return function()
		return unpack(vals)
	end
end

-- Log input and output of function.
-- Receives a function and returns a modified form of that function.
function p.logReturnValues(func, prefix)
	return function(...)
		local inputValues = capture(...)
		local returnValues = capture(func(...))
		if prefix then
			mw.log(prefix, inputValues())
			mw.log(returnValues())
		else
			mw.log(inputValues())
			mw.log(returnValues())
		end
		return returnValues()
	end
end

p.log = p.logReturnValues

-- Convenience function to make all functions in a table log their input and output.
function p.logAll(t)
	for k, v in pairs(t) do
		if type(v) == "function" then
			t[k] = p.logReturnValues(v, tostring(k))
		end
	end
	return t
end

----- M E M O I Z A T I O N-----
-- metamethod that does the work
-- Currently supports one argument and one return value.
local func_key = {}
local function callMethod(self, x)
	local output = self[x]
	if not output then
		output = self[func_key](x)
		self[x] = output
	end
	return output
end

-- shared metatable
local mt = { __call = callMethod }

-- Create callable table.
function p.memoize(func)
	return setmetatable({ [func_key] = func }, mt)
end

-------------------------------

return p