diff --git a/lua/u/range.lua b/lua/u/range.lua index d49787e..fbb0969 100644 --- a/lua/u/range.lua +++ b/lua/u/range.lua @@ -233,39 +233,22 @@ end --- @param opts? { contains_cursor?: boolean } function Range.from_tsquery_caps(bufnr, query, opts) - if bufnr == nil or bufnr == 0 then bufnr = vim.api.nvim_get_current_buf() end - local lang = vim.treesitter.language.get_lang(vim.bo[bufnr].filetype) - if lang == nil then return end - local parser = vim.treesitter.get_parser(bufnr, lang) - if parser == nil then return end - local tree = parser:parse()[1] - if tree == nil then return end - opts = opts or { contains_cursor = true } - local cursor = Pos.from_pos '.' - local root = tree:root() - local root_end = root:end_() - local q = vim.treesitter.query.parse(lang, query) - --- @type table - local ranges = {} - for id, match, _meta in q:iter_captures(root, bufnr, 0, root_end) do - local start_row0, start_col0, stop_row0, stop_col0 = match:range() - local range = Range.new( - Pos.new(bufnr, start_row0 + 1, start_col0 + 1), - Pos.new(bufnr, stop_row0 + 1, stop_col0 + 1), - 'v' - ) - if range.stop.lnum > vim.api.nvim_buf_line_count(bufnr) then - range.stop = range.stop:must_next(-1) - end - if not opts.contains_cursor or opts.contains_cursor and range:contains(cursor) then - local capture_name = q.captures[id] - if not ranges[capture_name] then ranges[capture_name] = {} end - table.insert(ranges[capture_name], range) - end - end - return ranges + local ranges = Range.from_buf_text(bufnr):tsquery(query) + if not ranges then return end + if not opts.contains_cursor then return ranges end + + local cursor = Pos.from_pos '.' + return vim.tbl_map(function(cap_ranges) + return vim + .iter(cap_ranges) + :filter( + --- @param r u.Range + function(r) return r:contains(cursor) end + ) + :totable() + end, ranges) end --- Get range information from the currently selected visual text. @@ -440,14 +423,46 @@ function Range:set_visual_selection() self.stop:save_to_pos '.' end --------------------------------------------------------------------------------- --- Range.from_* functions: --------------------------------------------------------------------------------- - -------------------------------------------------------------------------------- -- Text access/manipulation utilities: -------------------------------------------------------------------------------- +--- @param query string +function Range:tsquery(query) + local bufnr = self.start.bufnr + + local lang = vim.treesitter.language.get_lang(vim.bo[bufnr].filetype) + if lang == nil then return end + local parser = vim.treesitter.get_parser(bufnr, lang) + if parser == nil then return end + local tree = parser:parse()[1] + if tree == nil then return end + + local root = tree:root() + local q = vim.treesitter.query.parse(lang, query) + --- @type table + local ranges = {} + for id, match, _meta in + q:iter_captures(root, bufnr, self.start.lnum - 1, (self.stop or self.start).lnum) + do + local start_row0, start_col0, stop_row0, stop_col0 = match:range() + local range = Range.new( + Pos.new(bufnr, start_row0 + 1, start_col0 + 1), + Pos.new(bufnr, stop_row0 + 1, stop_col0), + 'v' + ) + if range.stop.lnum > vim.api.nvim_buf_line_count(bufnr) then + range.stop = range.stop:must_next(-1) + end + + local capture_name = q.captures[id] + if not ranges[capture_name] then ranges[capture_name] = {} end + if self:contains(range) then table.insert(ranges[capture_name], range) end + end + + return ranges +end + function Range:length() if self:is_empty() then return 0 end diff --git a/spec/range_spec.lua b/spec/range_spec.lua index b0e54b2..7f47ed0 100644 --- a/spec/range_spec.lua +++ b/spec/range_spec.lua @@ -289,6 +289,77 @@ describe('Range', function() end) end) + it('from_tsquery_caps with string array filter', function() + withbuf({ + '{', + ' "sample-key1": "sample-value1",', + ' "sample-key2": "sample-value2"', + '}', + }, function() + vim.cmd.setfiletype 'json' + + -- Place cursor in "sample-value1" + Pos.new(0, 2, 25):save_to_pos '.' + + -- Query that captures both keys and values in pairs + local query = [[ + (pair + key: _ @key + value: _ @value) + ]] + + local ranges = Range.from_line(0, 2):tsquery(query) + + -- Should have both @key and @value captures for the first pair only + -- (since cursor is in sample-value1) + assert(ranges, 'Range should not be nil') + assert(ranges.key, 'Range.key should not be nil') + assert(ranges.value, 'Range.value should not be nil') + + -- Should have exactly one key and one value + assert.are.same(#ranges.key, 1) + assert.are.same(#ranges.value, 1) + + -- Check that we got sample-key1 and sample-value1 + assert.are.same(ranges.key[1]:text(), '"sample-key1"') + assert.are.same(ranges.value[1]:text(), '"sample-value1"') + end) + + -- Make sure this works when the match is on the last line: + withbuf({ + '{"sample-key1": "sample-value1",', + '"sample-key2": "sample-value2"}', + }, function() + vim.cmd.setfiletype 'json' + + -- Place cursor in "sample-value1" + Pos.new(0, 2, 25):save_to_pos '.' + + -- Query that captures both keys and values in pairs + local query = [[ + (pair + key: _ @key + value: _ @value) + ]] + + local ranges = Range.from_line(0, 2):tsquery(query) + + -- Should have both @key and @value captures for the first pair only + -- (since cursor is in sample-value1) + assert(ranges, 'Range should not be nil') + assert(ranges.key, 'Range.key should not be nil') + assert(ranges.value, 'Range.value should not be nil') + + -- Should have exactly one key and one value + assert.are.same(#ranges.key, 1) + assert.are.same(#ranges.value, 1) + + -- Check that we got sample-key2 and sample-value2 + assert.are.same(ranges.key[1]:text(), '"sample-key2"') + assert.are.same(ranges.value[1]:text(), '"sample-value2"') + end) + end) + it('should get nearest block', function() withbuf({ 'this is a {',