diff --git a/lua/u/buffer.lua b/lua/u/buffer.lua index 06424f4..41cb645 100644 --- a/lua/u/buffer.lua +++ b/lua/u/buffer.lua @@ -1,16 +1,21 @@ local Range = require 'u.range' +local Renderer = require 'u.renderer'.Renderer ---@class Buffer ---@field buf number +---@field private renderer Renderer local Buffer = {} ---@param buf? number ---@return Buffer function Buffer.from_nr(buf) if buf == nil or buf == 0 then buf = vim.api.nvim_get_current_buf() end - local b = { buf = buf } - setmetatable(b, { __index = Buffer }) - return b + + local renderer = Renderer.new(buf) + return setmetatable({ + buf = buf, + renderer = renderer, + }, { __index = Buffer }) end ---@return Buffer @@ -69,4 +74,13 @@ function Buffer:text_object(txt_obj, opts) return Range.from_text_object(txt_obj, opts) end +--- @param event string|string[] +--- @param opts vim.api.keyset.create_autocmd +function Buffer:autocmd(event, opts) + vim.api.nvim_create_autocmd(event, vim.tbl_extend('force', opts, { buffer = self.buf })) +end + +--- @param tree Tree +function Buffer:render(tree) return self.renderer:render(tree) end + return Buffer diff --git a/lua/u/range.lua b/lua/u/range.lua index b7d4d34..4ac45d3 100644 --- a/lua/u/range.lua +++ b/lua/u/range.lua @@ -1,9 +1,9 @@ local Pos = require 'u.pos' local State = require 'u.state' -local orig_on_yank = vim.highlight.on_yank +local orig_on_yank = (vim.hl or vim.highlight).on_yank local on_yank_enabled = true; -(vim.highlight --[[@as any]]).on_yank = function(opts) +((vim.hl or vim.highlight) --[[@as any]]).on_yank = function(opts) if not on_yank_enabled then return end return orig_on_yank(opts) end @@ -493,7 +493,7 @@ function Range:highlight(group, opts) State.run(self.start.buf, function(s) if not in_macro then s:track_winview() end - vim.highlight.range( + (vim.hl or vim.highlight).range( self.start.buf, ns, group, diff --git a/lua/u/renderer.lua b/lua/u/renderer.lua new file mode 100644 index 0000000..f4030b0 --- /dev/null +++ b/lua/u/renderer.lua @@ -0,0 +1,432 @@ +local utils = require 'u.utils' + +local M = {} + +--- @alias Tag { kind: 'tag'; name: string, attributes: table, children: Tree } +--- @alias Node nil | boolean | string | Tag +--- @alias Tree Node | Node[] +local TagMetaTable = {} + +--- @param name string +--- @param attributes? table +--- @param children? Node | Node[] +--- @return Tag +function M.h(name, attributes, children) + return setmetatable({ + kind = 'tag', + name = name, + attributes = attributes or {}, + children = children, + }, TagMetaTable) +end + +-------------------------------------------------------------------------------- +-- Renderer class +-------------------------------------------------------------------------------- +--- @alias RendererExtmark { id?: number; start: [number, number]; stop: [number, number]; opts: any; tag: any } + +--- @class Renderer +--- @field bufnr number +--- @field ns number +--- @field changedtick number +--- @field old { lines: string[]; extmarks: RendererExtmark[] } +--- @field curr { lines: string[]; extmarks: RendererExtmark[] } +local Renderer = {} +Renderer.__index = Renderer +M.Renderer = Renderer + +--- @param x any +--- @return boolean +function Renderer.is_tag(x) return type(x) == 'table' and getmetatable(x) == TagMetaTable end + +--- @param x any +--- @return boolean +function Renderer.is_tag_arr(x) + if type(x) ~= 'table' then return false end + return #x == 0 or not Renderer.is_tag(x) +end + +--- @param bufnr number|nil +function Renderer.new(bufnr) + if bufnr == nil then bufnr = vim.api.nvim_get_current_buf() end + + if vim.b[bufnr]._renderer_ns == nil then + vim.b[bufnr]._renderer_ns = vim.api.nvim_create_namespace('my.renderer:' .. tostring(bufnr)) + end + + local self = setmetatable({ + bufnr = bufnr, + ns = vim.b[bufnr]._renderer_ns, + changedtick = 0, + old = { lines = {}, extmarks = {} }, + curr = { lines = {}, extmarks = {} }, + }, Renderer) + return self +end + +--- @param opts { +--- tree: Tree; +--- on_tag?: fun(tag: Tag, start0: [number, number], stop0: [number, number]): any; +--- } +function Renderer.markup_to_lines(opts) + --- @type string[] + local lines = {} + + local curr_line1 = 1 + local curr_col1 = 1 -- exclusive: sits one position **beyond** the last inserted text + --- @param s string + local function put(s) + lines[curr_line1] = (lines[curr_line1] or '') .. s + curr_col1 = #lines[curr_line1] + 1 + end + local function put_line() + table.insert(lines, '') + curr_line1 = curr_line1 + 1 + curr_col1 = 1 + end + + --- @param node Node + local function visit(node) + if node == nil or type(node) == 'boolean' then return end + + if type(node) == 'string' then + local node_lines = vim.split(node, '\n') + for lnum, s in ipairs(node_lines) do + if lnum > 1 then put_line() end + put(s) + end + elseif Renderer.is_tag(node) then + local start0 = { curr_line1 - 1, curr_col1 - 1 } + + -- visit the children: + if Renderer.is_tag_arr(node.children) then + for _, child in ipairs(node.children) do + -- newlines are not controlled by array entries, do NOT output a line here: + visit(child) + end + else + visit(node.children) + end + + local stop0 = { curr_line1 - 1, curr_col1 - 1 } + if opts.on_tag then opts.on_tag(node, start0, stop0) end + elseif Renderer.is_tag_arr(node) then + for _, child in ipairs(node) do + -- newlines are not controlled by array entries, do NOT output a line here: + visit(child) + end + end + end + visit(opts.tree) + + return lines +end + +--- @param opts { +--- tree: string; +--- format_tag?: fun(tag: Tag): string; +--- } +function Renderer.markup_to_string(opts) return table.concat(Renderer.markup_to_lines(opts), '\n') end + +--- @param tree Tree +function Renderer:render(tree) + local changedtick = vim.b[self.bufnr].changedtick + if changedtick ~= self.changedtick then + self.curr = { lines = vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) } + self.changedtick = changedtick + end + + --- @type RendererExtmark[] + local extmarks = {} + + --- @type string[] + local lines = Renderer.markup_to_lines { + tree = tree, + + on_tag = function(tag, start0, stop0) + if tag.name == 'text' then + local hl = tag.attributes.hl + if type(hl) == 'string' then + tag.attributes.extmark = tag.attributes.extmark or {} + tag.attributes.extmark.hl_group = tag.attributes.extmark.hl_group or hl + end + + local extmark = tag.attributes.extmark + + -- Force creating an extmark if there are key handlers. To accurately + -- sense the bounds of the text, we need an extmark: + if tag.attributes.on_key or tag.attributes.on_typed then extmark = extmark or {} end + + if extmark then + table.insert(extmarks, { + start = start0, + stop = stop0, + opts = extmark, + tag = tag, + }) + end + end + end, + } + + self.old = self.curr + self.curr = { lines = lines, extmarks = extmarks } + self:_reconcile() +end + +--- @private +--- @param info string +--- @param start integer +--- @param end_ integer +--- @param strict_indexing boolean +--- @param replacement string[] +function Renderer:_set_lines(info, start, end_, strict_indexing, replacement) + self:_log { 'set_lines', self.bufnr, start, end_, strict_indexing, replacement } + vim.api.nvim_buf_set_lines(self.bufnr, start, end_, strict_indexing, replacement) + self:_log { 'after(' .. info .. ')', vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) } +end + +--- @private +--- @param info string +--- @param start_row integer +--- @param start_col integer +--- @param end_row integer +--- @param end_col integer +--- @param replacement string[] +function Renderer:_set_text(info, start_row, start_col, end_row, end_col, replacement) + self:_log { 'set_text', self.bufnr, start_row, start_col, end_row, end_col, replacement } + vim.api.nvim_buf_set_text(self.bufnr, start_row, start_col, end_row, end_col, replacement) + self:_log { 'after(' .. info .. ')', vim.api.nvim_buf_get_lines(self.bufnr, 0, -1, false) } +end + +--- @private +function Renderer:_log(...) + -- + -- vim.print(...) +end + +--- @private +function Renderer:_reconcile() + local line_changes = utils.levenshtein(self.old.lines, self.curr.lines) + self.old = self.curr + + -- + -- Step 1: morph the text to the desired state: + -- + self:_log { line_changes = line_changes } + for _, line_change in ipairs(line_changes) do + local lnum0 = line_change.index - 1 + + if line_change.kind == 'add' then + self:_set_lines('add-line', lnum0, lnum0, true, { line_change.item }) + elseif line_change.kind == 'change' then + -- Compute inter-line diff, and apply: + self:_log '--------------------------------------------------------------------------------' + local col_changes = utils.levenshtein(vim.split(line_change.from, ''), vim.split(line_change.to, '')) + + for _, col_change in ipairs(col_changes) do + local cnum0 = col_change.index - 1 + self:_log { line_change = col_change, cnum = cnum0, lnum = lnum0 } + if col_change.kind == 'add' then + self:_set_text('add-char', lnum0, cnum0, lnum0, cnum0, { col_change.item }) + elseif col_change.kind == 'change' then + self:_set_text('change-char', lnum0, cnum0, lnum0, cnum0 + 1, { col_change.to }) + elseif col_change.kind == 'delete' then + self:_set_text('del-char', lnum0, cnum0, lnum0, cnum0 + 1, {}) + else + -- No change + end + end + elseif line_change.kind == 'delete' then + self:_set_lines('del-line', lnum0, lnum0 + 1, true, {}) + else + -- No change + end + end + self.changedtick = vim.b[self.bufnr].changedtick + + -- + -- Step 2: reconcile extmarks: + -- + -- Clear current extmarks: + vim.api.nvim_buf_clear_namespace(self.bufnr, self.ns, 0, -1) + -- Set current extmarks: + for _, extmark in ipairs(self.curr.extmarks) do + extmark.id = vim.api.nvim_buf_set_extmark( + self.bufnr, + self.ns, + extmark.start[1], + extmark.start[2], + vim.tbl_extend('force', { + id = extmark.id, + end_row = extmark.stop[1], + end_col = extmark.stop[2], + }, extmark.opts) + ) + end + + -- + -- Step 3: setup an updated on_key handler: + -- + + vim.on_key(nil, self.ns) + vim.on_key(function(key, typed) + -- Discard if not in the current buffer: + if + -- do not capture keys in the wrong buffer: + vim.api.nvim_get_current_buf() ~= self.bufnr + -- do not capture keys in COMMAND mode: + or vim.startswith(vim.api.nvim_get_mode().mode, 'c') + then + return + end + + -- find the tag with the smallest intersection that contains the cursor: + local pos0 = vim.api.nvim_win_get_cursor(0) + pos0[1] = pos0[1] - 1 -- make it actually 0-based + local pos_infos = self:get_pos_infos(pos0) + + -- Check the attributes for each matching_tag and fire events if they are + -- listening: + for _, pos_info in ipairs(pos_infos) do + local tag = pos_info.tag + + -- is the tag listening? + if tag.attributes.on_key and type(tag.attributes.on_key[key]) == 'function' then + -- key: + local result = tag.attributes.on_key[key]() + if result == '' then return '' end + elseif tag.attributes.on_typed and type(tag.attributes.on_typed[typed]) == 'function' then + -- typed: + local result = tag.attributes.on_typed[typed]() + if result == '' then return '' end + end + end + end, self.ns) +end + +--- Returns pairs of extmarks and tags associate with said extmarks. The +--- returned tags/extmarks are sorted smallest (innermost) to largest +--- (outermost). +--- +--- @private (private for now) +--- @param pos0 [number; number] +--- @return { extmark: RendererExtmark; tag: Tag; }[] +function Renderer:get_pos_infos(pos0) + local cursor_line0, cursor_col0 = pos0[1], pos0[2] + + -- The cursor (block) occupies **two** extmark spaces: one for it's left + -- edge, and one for it's right. We need to do our own intersection test, + -- because the NeoVim API is over-inclusive in what it returns: + --- @type RendererExtmark[] + local intersecting_extmarks = vim + .iter(vim.api.nvim_buf_get_extmarks(self.bufnr, self.ns, pos0, pos0, { details = true, overlap = true })) + --- @return RendererExtmark + :map(function(ext) + --- @type number, number, number, { end_row?: number; end_col?: number }|nil + local id, line0, col0, details = unpack(ext) + local start = { line0, col0 } + local stop = { line0, col0 } + if details and details.end_row ~= nil and details.end_col ~= nil then + stop = { details.end_row, details.end_col } + end + return { id = id, start = start, stop = stop, opts = details } + end) + --- @param ext RendererExtmark + :filter(function(ext) + if ext.stop[1] ~= nil and ext.stop[2] ~= nil then + return cursor_line0 >= ext.start[1] + and cursor_col0 >= ext.start[2] + and cursor_line0 <= ext.stop[1] + and cursor_col0 < ext.stop[2] + else + return true + end + end) + :totable() + + -- Sort the tags into smallest (inner) to largest (outer): + table.sort( + intersecting_extmarks, + --- @param x1 RendererExtmark + --- @param x2 RendererExtmark + function(x1, x2) + if + x1.start[1] == x2.start[1] + and x1.start[2] == x2.start[2] + and x1.stop[1] == x2.stop[1] + and x1.stop[2] == x2.stop[2] + then + return x1.id < x2.id + end + + return x1.start[1] >= x2.start[1] + and x1.start[2] >= x2.start[2] + and x1.stop[1] <= x2.stop[1] + and x1.stop[2] <= x2.stop[2] + end + ) + + -- When we set the extmarks in the step above, we captured the IDs of the + -- created extmarks in self.curr.extmarks, which also has which tag each + -- extmark is associated with. Cross-reference with that list to get a list + -- of tags that we need to fire events for: + --- @type { extmark: RendererExtmark; tag: Tag }[] + local matching_tags = vim + .iter(intersecting_extmarks) + --- @param ext RendererExtmark + :map(function(ext) + for _, extmark_cache in ipairs(self.curr.extmarks) do + if extmark_cache.id == ext.id then return { extmark = ext, tag = extmark_cache.tag } end + end + end) + :totable() + + return matching_tags +end + +-------------------------------------------------------------------------------- +-- TreeBuilder class +-------------------------------------------------------------------------------- + +--- @class TreeBuilder +--- @field private nodes Node[] +local TreeBuilder = {} +TreeBuilder.__index = TreeBuilder +M.TreeBuilder = TreeBuilder + +function TreeBuilder.new() + local self = setmetatable({ nodes = {} }, TreeBuilder) + return self +end + +--- @param nodes Tree +--- @return TreeBuilder +function TreeBuilder:put(nodes) + table.insert(self.nodes, nodes) + return self +end + +--- @param name string +--- @param attributes? table +--- @param children? Node | Node[] +--- @return TreeBuilder +function TreeBuilder:put_h(name, attributes, children) + local tag = M.h(name, attributes, children) + table.insert(self.nodes, tag) + return self +end + +--- @param fn fun(TreeBuilder): any +--- @return TreeBuilder +function TreeBuilder:nest(fn) + local nested_writer = TreeBuilder.new() + fn(nested_writer) + table.insert(self.nodes, nested_writer.nodes) + return self +end + +--- @return Tree +function TreeBuilder:tree() return self.nodes end + +return M diff --git a/lua/u/tracker.lua b/lua/u/tracker.lua new file mode 100644 index 0000000..5bfe5f0 --- /dev/null +++ b/lua/u/tracker.lua @@ -0,0 +1,294 @@ +local M = {} + +M.debug = false + +-------------------------------------------------------------------------------- +-- class Signal +-------------------------------------------------------------------------------- + +--- @class Signal +--- @field name? string +--- @field private changing boolean +--- @field private value any +--- @field private subscribers table +--- @field private on_dispose_callbacks function[] +local Signal = {} +M.Signal = Signal +Signal.__index = Signal + +--- @param value any +--- @param name? string +--- @return Signal +function Signal:new(value, name) + local obj = setmetatable({ + name = name, + changing = false, + value = value, + subscribers = {}, + on_dispose_callbacks = {}, + }, self) + return obj +end + +--- @param value any +function Signal:set(value) + self.value = value + + -- We don't handle cyclic updates: + if self.changing then + if M.debug then + vim.notify('circular dependency detected' .. (self.name and (' in ' .. self.name) or ''), vim.log.levels.WARN) + end + return + end + + local prev_changing = self.changing + self.changing = true + local ok = true + local err = nil + for _, cb in ipairs(self.subscribers) do + local ok2, err2 = pcall(cb, value) + if not ok2 then + ok = false + err = err or err2 + end + end + self.changing = prev_changing + + if not ok then + vim.notify( + 'error notifying' .. (self.name and (' in ' .. self.name) or '') .. ': ' .. tostring(err), + vim.log.levels.WARN + ) + error(err) + end +end + +--- @return any +function Signal:get() + local ctx = M.ExecutionContext.current() + if ctx then ctx:track(self) end + return self.value +end + +--- @param fn function +function Signal:update(fn) self:set(fn(self.value)) end + +--- @generic U +--- @param fn fun(value: T): U +--- @return Signal -- +function Signal:map(fn) + local mapped_signal = M.create_memo(function() + local value = self:get() + return fn(value) + end, self.name and self.name .. ':mapped' or nil) + return mapped_signal +end + +--- @return Signal +function Signal:clone() + return self:map(function(x) return x end) +end + +--- @param fn fun(value: T): boolean +--- @return Signal -- +function Signal:filter(fn) + local filtered_signal = M.create_signal(nil, self.name and self.name .. ':filtered' or nil) + local unsubscribe_from_self = self:subscribe(function(value) + if fn(value) then filtered_signal:set(value) end + end) + filtered_signal:on_dispose(unsubscribe_from_self) + return filtered_signal +end + +--- @param ms number +--- @return Signal -- +function Signal:debounce(ms) + local function set_timeout(timeout, callback) + local timer = (vim.uv or vim.loop).new_timer() + timer:start(timeout, 0, function() + timer:stop() + timer:close() + callback() + end) + return timer + end + + local filtered = M.create_signal(self.value, self.name and self.name .. ':debounced' or nil) + + --- @type { + -- queued: { value: T, ts: number }[] + -- timer?: uv_timer_t + -- } + local state = { queued = {}, timer = nil } + local function clear_timeout() + if state.timer == nil then return end + pcall(function() + state.timer:stop() + state.timer:close() + end) + state.timer = nil + end + + local unsubscribe_from_self = self:subscribe(function(value) + -- Stop any previously running timer: + if state.timer then clear_timeout() end + local now_ms = (vim.uv or vim.loop).hrtime() / 1e6 + + -- If there is anything older than `ms` in our queue, emit it: + local older_than_ms = vim.iter(state.queued):filter(function(item) return now_ms - item.ts > ms end):totable() + local last_older_than_ms = older_than_ms[#older_than_ms] + if last_older_than_ms then + filtered:set(last_older_than_ms.value) + state.queued = {} + end + + -- overwrite anything young enough + table.insert(state.queued, { value = value, ts = now_ms }) + state.timer = set_timeout(ms, function() + filtered:set(value) + -- If a timer was allowed to run to completion, that means that no other + -- item has been queued, since the timer is reset every time a new item + -- comes in. This means we can reset the queue + clear_timeout() + state.queued = {} + end) + end) + filtered:on_dispose(unsubscribe_from_self) + + return filtered +end + +--- @param callback function +function Signal:subscribe(callback) + table.insert(self.subscribers, callback) + return function() self:unsubscribe(callback) end +end + +--- @param callback function +function Signal:on_dispose(callback) table.insert(self.on_dispose_callbacks, callback) end + +--- @param callback function +function Signal:unsubscribe(callback) + for i, cb in ipairs(self.subscribers) do + if cb == callback then + table.remove(self.subscribers, i) + break + end + end +end + +function Signal:dispose() + self.subscribers = {} + for _, callback in ipairs(self.on_dispose_callbacks) do + callback() + end +end + +-------------------------------------------------------------------------------- +-- class ExecutionContext +-------------------------------------------------------------------------------- + +CURRENT_CONTEXT = nil + +--- @class ExecutionContext +--- @field signals table +local ExecutionContext = {} +M.ExecutionContext = ExecutionContext +ExecutionContext.__index = ExecutionContext + +--- @return ExecutionContext +function ExecutionContext:new() + return setmetatable({ + signals = {}, + subscribers = {}, + }, ExecutionContext) +end + +function ExecutionContext.current() return CURRENT_CONTEXT end + +--- @param fn function +--- @param ctx ExecutionContext +function ExecutionContext:run(fn, ctx) + local oldCtx = CURRENT_CONTEXT + CURRENT_CONTEXT = ctx + local result + local success, err = pcall(function() result = fn() end) + + CURRENT_CONTEXT = oldCtx + + if not success then error(err) end + + return result +end + +function ExecutionContext:track(signal) self.signals[signal] = true end + +--- @param callback function +function ExecutionContext:subscribe(callback) + local wrapped_callback = function() callback() end + for signal in pairs(self.signals) do + signal:subscribe(wrapped_callback) + end + + return function() + for signal in pairs(self.signals) do + signal:unsubscribe(wrapped_callback) + end + end +end + +function ExecutionContext:dispose() + for signal, _ in pairs(self.signals) do + signal:dispose() + end + self.signals = {} +end + +-------------------------------------------------------------------------------- +-- Helpers +-------------------------------------------------------------------------------- + +--- @param value any +--- @param name? string +--- @return Signal +function M.create_signal(value, name) return Signal:new(value, name) end + +--- @param fn function +--- @param name? string +--- @return Signal +function M.create_memo(fn, name) + --- @type Signal + local result + local unsubscribe = M.create_effect(function() + local value = fn() + if name and M.debug then vim.notify(name) end + if result then + result:set(value) + else + result = M.create_signal(value, name and ('m.s:' .. name) or nil) + end + end, name) + result:on_dispose(unsubscribe) + return result +end + +--- @param fn function +--- @param name? string +function M.create_effect(fn, name) + local ctx = M.ExecutionContext:new() + M.ExecutionContext:run(fn, ctx) + return ctx:subscribe(function() + if name and M.debug then + local deps = vim + .iter(vim.tbl_keys(ctx.signals)) + :map(function(s) return s.name end) + :filter(function(nm) return nm ~= nil end) + :join ',' + vim.notify(name .. '(deps=' .. deps .. ')') + end + fn() + end) +end + +return M diff --git a/lua/u/utils.lua b/lua/u/utils.lua index 228a9d2..ffe1d88 100644 --- a/lua/u/utils.lua +++ b/lua/u/utils.lua @@ -8,6 +8,18 @@ local M = {} ---@alias KeyMaps table } ---@alias CmdArgs { args: string; bang: boolean; count: number; fargs: string[]; line1: number; line2: number; mods: string; name: string; range: 0|1|2; reg: string; smods: any; info: Range|nil } +--- @generic T +--- @param x `T` +--- @param message? string +--- @return T +function M.dbg(x, message) + local t = {} + if message ~= nil then table.insert(t, message) end + table.insert(t, x) + vim.print(t) + return x +end + --- A utility for creating user commands that also pre-computes useful information --- and attaches it to the arguments. --- @@ -107,28 +119,97 @@ function M.repeatablemap(mode, lhs, rhs, opts) end, vim.tbl_extend('force', opts or {}, { expr = true })) end -function M.get_editor_dimensions() - local w = 0 - local h = 0 - local tabnr = vim.api.nvim_get_current_tabpage() - for _, winid in ipairs(vim.api.nvim_list_wins()) do - local tabpage = vim.api.nvim_win_get_tabpage(winid) - if tabpage == tabnr then - local pos = vim.api.nvim_win_get_position(winid) - local r, c = pos[1], pos[2] - local win_w = vim.api.nvim_win_get_width(winid) - local win_h = vim.api.nvim_win_get_height(winid) - local right = c + win_w - local bottom = r + win_h - if right > w then w = right end - if bottom > h then h = bottom end +function M.get_editor_dimensions() return { width = vim.go.columns, height = vim.go.lines } end + +--- @alias LevenshteinChange ({ kind: 'add'; item: T; index: number; } | { kind: 'delete'; item: T; index: number; } | { kind: 'change'; from: T; to: T; index: number; }) +--- @private +--- @generic T +--- @param x `T`[] +--- @param y T[] +--- @param cost? { of_delete?: fun(x: T): number; of_add?: fun(x: T): number; of_change?: fun(x: T, y: T): number; } +--- @return LevenshteinChange[] +function M.levenshtein(x, y, cost) + cost = cost or {} + local cost_of_delete_f = cost.of_delete or function() return 1 end + local cost_of_add_f = cost.of_add or function() return 1 end + local cost_of_change_f = cost.of_change or function() return 1 end + + local m, n = #x, #y + -- Initialize the distance matrix + local dp = {} + for i = 0, m do + dp[i] = {} + for j = 0, n do + dp[i][j] = 0 end end - if w == 0 or h == 0 then - w = vim.api.nvim_win_get_width(0) - h = vim.api.nvim_win_get_height(0) + + -- Fill the base cases + for i = 0, m do + dp[i][0] = i end - return { width = w, height = h } + for j = 0, n do + dp[0][j] = j + end + + -- Compute the Levenshtein distance dynamically + for i = 1, m do + for j = 1, n do + if x[i] == y[j] then + dp[i][j] = dp[i - 1][j - 1] -- no cost if items are the same + else + local costDelete = dp[i - 1][j] + cost_of_delete_f(x[i]) + local costAdd = dp[i][j - 1] + cost_of_add_f(y[j]) + local costChange = dp[i - 1][j - 1] + cost_of_change_f(x[i], y[j]) + dp[i][j] = math.min(costDelete, costAdd, costChange) + end + end + end + + -- Backtrack to find the changes + local i = m + local j = n + --- @type LevenshteinChange[] + local changes = {} + + while i > 0 or j > 0 do + local default_cost = dp[i][j] + local cost_of_change = (i > 0 and j > 0) and dp[i - 1][j - 1] or default_cost + local cost_of_add = j > 0 and dp[i][j - 1] or default_cost + local cost_of_delete = i > 0 and dp[i - 1][j] or default_cost + + --- @param u number + --- @param v number + --- @param w number + local function is_first_min(u, v, w) return u <= v and u <= w end + + if is_first_min(cost_of_change, cost_of_add, cost_of_delete) then + -- potential change + if x[i] ~= y[j] then + --- @type LevenshteinChange + local change = { kind = 'change', from = x[i], index = i, to = y[j] } + table.insert(changes, change) + end + i = i - 1 + j = j - 1 + elseif is_first_min(cost_of_add, cost_of_change, cost_of_delete) then + -- addition + --- @type LevenshteinChange + local change = { kind = 'add', item = y[j], index = i + 1 } + table.insert(changes, change) + j = j - 1 + elseif is_first_min(cost_of_delete, cost_of_change, cost_of_add) then + -- deletion + --- @type LevenshteinChange + local change = { kind = 'delete', item = x[i], index = i } + table.insert(changes, change) + i = i - 1 + else + error 'unreachable' + end + end + + return changes end return M diff --git a/spec/tracker_spec.lua b/spec/tracker_spec.lua new file mode 100644 index 0000000..5da1a85 --- /dev/null +++ b/spec/tracker_spec.lua @@ -0,0 +1,206 @@ +local tracker = require 'u.tracker' +local Signal = tracker.Signal +local ExecutionContext = tracker.ExecutionContext + +describe('Signal', function() + local signal + + before_each(function() signal = Signal:new(0, 'testSignal') end) + + it('should initialize with correct parameters', function() + assert.is.equal(signal.value, 0) + assert.is.equal(signal.name, 'testSignal') + assert.is.not_nil(signal.subscribers) + assert.is.equal(#signal.subscribers, 0) + assert.is.equal(signal.changing, false) + end) + + it('should set new value and notify subscribers', function() + local called = false + signal:subscribe(function(value) + called = true + assert.is.equal(value, 42) + end) + + signal:set(42) + assert.is.equal(called, true) + end) + + it('should not notify subscribers during circular dependency', function() + signal.changing = true + local notified = false + + signal:subscribe(function() notified = true end) + + signal:set(42) + assert.is.equal(notified, false) -- No notification should occur + end) + + it('should get current value', function() + signal:set(100) + assert.is.equal(signal:get(), 100) + end) + + it('should update value with function', function() + signal:set(10) + signal:update(function(value) return value * 2 end) + assert.is.equal(signal:get(), 20) + end) + + it('should dispose subscribers', function() + local called = false + local unsubscribe = signal:subscribe(function() called = true end) + + unsubscribe() + signal:set(10) + assert.is.equal(called, false) -- Should not be notified + end) + + describe('Signal:map', function() + it('should transform the signal value', function() + local signal = Signal:new(5) + local mapped_signal = signal:map(function(value) return value * 2 end) + + assert.is.equal(mapped_signal:get(), 10) -- Initial transformation + signal:set(10) + assert.is.equal(mapped_signal:get(), 20) -- Updated transformation + end) + + it('should handle empty transformations', function() + local signal = Signal:new(nil) + local mapped_signal = signal:map(function(value) return value or 'default' end) + + assert.is.equal(mapped_signal:get(), 'default') -- Return default + signal:set 'new value' + assert.is.equal(mapped_signal:get(), 'new value') -- Return new value + end) + end) + + describe('Signal:filter', function() + it('should only emit values that pass the filter', function() + local signal = Signal:new(5) + local filtered_signal = signal:filter(function(value) return value > 10 end) + + assert.is.equal(filtered_signal:get(), nil) -- Initial value should not pass + signal:set(15) + assert.is.equal(filtered_signal:get(), 15) -- Now filtered + signal:set(8) + assert.is.equal(filtered_signal:get(), 15) -- Does not pass the filter + end) + + it('should handle empty initial values', function() + local signal = Signal:new(nil) + local filtered_signal = signal:filter(function(value) return value ~= nil end) + + assert.is.equal(filtered_signal:get(), nil) -- Should be nil + signal:set(10) + assert.is.equal(filtered_signal:get(), 10) -- Should pass now + end) + end) + + describe('create_memo', function() + it('should compute a derived value and update when dependencies change', function() + local signal = Signal:new(2) + local memoized_signal = tracker.create_memo(function() return signal:get() * 2 end) + + assert.is.equal(memoized_signal:get(), 4) -- Initially compute 2 * 2 + + signal:set(3) + assert.is.equal(memoized_signal:get(), 6) -- Update to 3 * 2 = 6 + + signal:set(5) + assert.is.equal(memoized_signal:get(), 10) -- Update to 5 * 2 = 10 + end) + + it('should not recompute if the dependencies do not change', function() + local call_count = 0 + local signal = Signal:new(10) + local memoized_signal = tracker.create_memo(function() + call_count = call_count + 1 + return signal:get() + 1 + end) + + assert.is.equal(memoized_signal:get(), 11) -- Compute first value + assert.is.equal(call_count, 1) -- Should compute once + + memoized_signal:get() -- Call again, should use memoized value + assert.is.equal(call_count, 1) -- Still should only be one call + + signal:set(10) -- Set the same value + assert.is.equal(memoized_signal:get(), 11) + assert.is.equal(call_count, 2) + + signal:set(20) + assert.is.equal(memoized_signal:get(), 21) + assert.is.equal(call_count, 3) + end) + end) + + describe('create_effect', function() + it('should track changes and execute callback', function() + local signal = Signal:new(5) + local call_count = 0 + + tracker.create_effect(function() + signal:get() -- track as a dependency + call_count = call_count + 1 + end) + + assert.is.equal(call_count, 1) + signal:set(10) + assert.is.equal(call_count, 2) + end) + + it('should clean up signals and not call after dispose', function() + local signal = Signal:new(5) + local call_count = 0 + + local unsubscribe = tracker.create_effect(function() + call_count = call_count + 1 + return signal:get() * 2 + end) + + assert.is.equal(call_count, 1) -- Initially calls + unsubscribe() -- Unsubscribe the effect + signal:set(10) -- Update signal value + assert.is.equal(call_count, 1) -- Callback should not be called again + end) + end) +end) + +describe('ExecutionContext', function() + local context + + before_each(function() context = ExecutionContext:new() end) + + it('should initialize a new context', function() + assert.is.table(context.signals) + assert.is.table(context.subscribers) + end) + + it('should track signals', function() + local signal = Signal:new(0) + context:track(signal) + + assert.is.equal(next(context.signals), signal) -- Check if signal is tracked + end) + + it('should subscribe to signals', function() + local signal = Signal:new(0) + local callback_called = false + + context:track(signal) + context:subscribe(function() callback_called = true end) + + signal:set(100) + assert.is.equal(callback_called, true) -- Callback should be called + end) + + it('should dispose tracked signals', function() + local signal = Signal:new(0) + context:track(signal) + + context:dispose() + assert.is.falsy(next(context.signals)) -- Should not have any tracked signals + end) +end) diff --git a/spec/utils_spec.lua b/spec/utils_spec.lua new file mode 100644 index 0000000..75bfdf9 --- /dev/null +++ b/spec/utils_spec.lua @@ -0,0 +1,70 @@ +local utils = require 'u.utils' + +--- @param s string +local function split(s) return vim.split(s, '') end + +--- @param original string +--- @param changes LevenshteinChange[] +local function morph(original, changes) + local t = split(original) + for _, change in ipairs(changes) do + if change.kind == 'add' then + table.insert(t, change.index, change.item) + elseif change.kind == 'delete' then + table.remove(t, change.index) + elseif change.kind == 'change' then + t[change.index] = change.to + end + end + return vim.iter(t):join '' +end + +describe('utils', function() + it('levenshtein', function() + local original = 'abc' + local result = 'absece' + local changes = utils.levenshtein(split(original), split(result)) + assert.are.same(changes, { + { + item = 'e', + kind = 'add', + index = 4, + }, + { + item = 'e', + kind = 'add', + index = 3, + }, + { + item = 's', + kind = 'add', + index = 3, + }, + }) + assert.are.same(morph(original, changes), result) + + original = 'jonathan' + result = 'ajoanthan' + changes = utils.levenshtein(split(original), split(result)) + assert.are.same(changes, { + { + from = 'a', + index = 4, + kind = 'change', + to = 'n', + }, + { + from = 'n', + index = 3, + kind = 'change', + to = 'a', + }, + { + index = 1, + item = 'a', + kind = 'add', + }, + }) + assert.are.same(morph(original, changes), result) + end) +end)