diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b8a7c8e..79518b2 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -59,12 +59,9 @@ jobs: - name: Generate coverage report run: | - # Check if stats file exists (created by busted --coverage) if [ -f "luacov.stats.out" ]; then - # Generate the regular luacov report nix develop .#ci -c luacov - # Create simple lcov.info from luacov.report.out echo "Creating lcov.info from luacov.report.out" { echo "TN:" @@ -81,12 +78,10 @@ jobs: done } > lcov.info - # Create markdown coverage summary for GitHub Actions { echo "## 📊 Test Coverage Report" echo "" - # Extract overall coverage percentage if [ -f "luacov.report.out" ]; then overall_coverage=$(grep -E "Total.*%" luacov.report.out | grep -oE "[0-9]+\.[0-9]+%" | head -1) if [ -n "$overall_coverage" ]; then @@ -94,11 +89,9 @@ jobs: echo "" fi - # Create table header echo "| File | Coverage |" echo "|------|----------|" - # Extract file-by-file coverage grep -E "^[^ ].*:" luacov.report.out | while read -r line; do file=$(echo "$line" | cut -d':' -f1) percent=$(echo "$line" | grep -oE "[0-9]+\.[0-9]+%" | head -1) diff --git a/Makefile b/Makefile index bf93e9f..071b76b 100644 --- a/Makefile +++ b/Makefile @@ -1,26 +1,18 @@ .PHONY: check format test clean # Default target -all: check format +all: format check test # Check for syntax errors check: @echo "Checking Lua files for syntax errors..." - @find lua -name "*.lua" -type f -exec lua -e "assert(loadfile('{}'))" \; + nix develop .#ci -c find lua -name "*.lua" -type f -exec lua -e "assert(loadfile('{}'))" \; @echo "Running luacheck..." - @luacheck lua/ tests/ --no-unused-args --no-max-line-length + nix develop .#ci -c luacheck lua/ tests/ --no-unused-args --no-max-line-length # Format all files format: - @echo "Formatting files..." - @if command -v nix >/dev/null 2>&1; then \ - nix fmt; \ - elif command -v stylua >/dev/null 2>&1; then \ - stylua lua/; \ - else \ - echo "Neither nix nor stylua found. Please install one of them."; \ - exit 1; \ - fi + nix fmt # Run tests test: diff --git a/README.md b/README.md index d398883..ad28941 100644 --- a/README.md +++ b/README.md @@ -49,8 +49,15 @@ Using [lazy.nvim](https://github.com/folke/lazy.nvim): "coder/claudecode.nvim", config = true, keys = { + { "a", nil, desc = "AI/Claude Code" }, { "ac", "ClaudeCode", desc = "Toggle Claude" }, { "as", "ClaudeCodeSend", mode = "v", desc = "Send to Claude" }, + { + "as", + "ClaudeCodeTreeAdd", + desc = "Add file", + ft = { "NvimTree", "neo-tree" }, + }, }, } ``` @@ -60,13 +67,80 @@ That's it! For more configuration options, see [Advanced Setup](#advanced-setup) ## Usage 1. **Launch Claude**: Run `:ClaudeCode` to open Claude in a split terminal -2. **Send context**: Select text and run `:'<,'>ClaudeCodeSend` to send it to Claude +2. **Send context**: + - Select text in visual mode and use `as` to send it to Claude + - In `nvim-tree` or `neo-tree`, press `as` on a file to add it to Claude's context 3. **Let Claude work**: Claude can now: - See your current file and selections in real-time - Open files in your editor - Show diffs with proposed changes - Access diagnostics and workspace info +## Commands + +- `:ClaudeCode` - Toggle the Claude Code terminal window +- `:ClaudeCodeSend` - Send current visual selection to Claude, or add files from tree explorer +- `:ClaudeCodeTreeAdd` - Add selected file(s) from tree explorer to Claude context (also available via ClaudeCodeSend) +- `:ClaudeCodeAdd [start-line] [end-line]` - Add a specific file or directory to Claude context by path with optional line range + +### Tree Integration + +The `as` keybinding has context-aware behavior: + +- **In normal buffers (visual mode)**: Sends selected text to Claude +- **In nvim-tree/neo-tree buffers**: Adds the file under cursor (or selected files) to Claude's context + +This allows you to quickly add entire files to Claude's context for review, refactoring, or discussion. + +#### Features + +- **Single file**: Place cursor on any file and press `as` +- **Multiple files**: Select multiple files (using tree plugin's selection features) and press `as` +- **Smart detection**: Automatically detects whether you're in nvim-tree or neo-tree +- **Error handling**: Clear feedback if no files are selected or if tree plugins aren't available + +### Direct File Addition + +The `:ClaudeCodeAdd` command allows you to add files or directories directly by path, with optional line range specification: + +```vim +:ClaudeCodeAdd src/main.lua +:ClaudeCodeAdd ~/projects/myproject/ +:ClaudeCodeAdd ./README.md +:ClaudeCodeAdd src/main.lua 50 100 " Lines 50-100 only +:ClaudeCodeAdd config.lua 25 " From line 25 to end of file +``` + +#### Features + +- **Path completion**: Tab completion for file and directory paths +- **Path expansion**: Supports `~` for home directory and relative paths +- **Line range support**: Optionally specify start and end lines for files (ignored for directories) +- **Validation**: Checks that files and directories exist before adding, validates line numbers +- **Flexible**: Works with both individual files and entire directories + +#### Examples + +```vim +" Add entire files +:ClaudeCodeAdd src/components/Header.tsx +:ClaudeCodeAdd ~/.config/nvim/init.lua + +" Add entire directories (line numbers ignored) +:ClaudeCodeAdd tests/ +:ClaudeCodeAdd ../other-project/ + +" Add specific line ranges +:ClaudeCodeAdd src/main.lua 50 100 " Lines 50 through 100 +:ClaudeCodeAdd config.lua 25 " From line 25 to end of file +:ClaudeCodeAdd utils.py 1 50 " First 50 lines +:ClaudeCodeAdd README.md 10 20 " Just lines 10-20 + +" Path expansion works with line ranges +:ClaudeCodeAdd ~/project/src/app.js 100 200 +:ClaudeCodeAdd ./relative/path.lua 30 +``` + ## How It Works This plugin creates a WebSocket server that Claude Code CLI connects to, implementing the same protocol as the official VS Code extension. When you launch Claude, it automatically detects Neovim and gains full access to your editor. @@ -132,8 +206,15 @@ See [DEVELOPMENT.md](./DEVELOPMENT.md) for build instructions and development gu }, config = true, keys = { + { "a", nil, desc = "AI/Claude Code" }, { "ac", "ClaudeCode", desc = "Toggle Claude" }, { "as", "ClaudeCodeSend", mode = "v", desc = "Send to Claude" }, + { + "as", + "ClaudeCodeTreeAdd", + desc = "Add file", + ft = { "NvimTree", "neo-tree" }, + }, { "ao", "ClaudeCodeOpen", desc = "Open Claude" }, { "ax", "ClaudeCodeClose", desc = "Close Claude" }, }, diff --git a/flake.lock b/flake.lock index a7e03fe..5509b09 100644 --- a/flake.lock +++ b/flake.lock @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1748190013, - "narHash": "sha256-R5HJFflOfsP5FBtk+zE8FpL8uqE7n62jqOsADvVshhE=", + "lastModified": 1749143949, + "narHash": "sha256-QuUtALJpVrPnPeozlUG/y+oIMSLdptHxb3GK6cpSVhA=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "62b852f6c6742134ade1abdd2a21685fd617a291", + "rev": "d3d2d80a2191a73d1e86456a751b83aa13085d7d", "type": "github" }, "original": { @@ -77,11 +77,11 @@ "nixpkgs": "nixpkgs_2" }, "locked": { - "lastModified": 1748243702, - "narHash": "sha256-9YzfeN8CB6SzNPyPm2XjRRqSixDopTapaRsnTpXUEY8=", + "lastModified": 1749194973, + "narHash": "sha256-eEy8cuS0mZ2j/r/FE0/LYBSBcIs/MKOIVakwHVuqTfk=", "owner": "numtide", "repo": "treefmt-nix", - "rev": "1f3f7b784643d488ba4bf315638b2b0a4c5fb007", + "rev": "a05be418a1af1198ca0f63facb13c985db4cb3c5", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 0b0d8aa..736e21e 100644 --- a/flake.nix +++ b/flake.nix @@ -40,6 +40,7 @@ luajitPackages.luacov neovim treefmt.config.build.wrapper + findutils ]; # Development packages (additional tools for development) @@ -49,7 +50,7 @@ gnumake websocat jq - claude-code + # claude-code ]; in { diff --git a/lua/claudecode/config.lua b/lua/claudecode/config.lua index 83ff74a..aa5f123 100644 --- a/lua/claudecode/config.lua +++ b/lua/claudecode/config.lua @@ -17,11 +17,7 @@ M.defaults = { }, } ---- Validates the provided configuration table. --- Ensures that all configuration options are of the correct type and within valid ranges. --- @param config table The configuration table to validate. --- @return boolean true if the configuration is valid. --- @error string if any configuration option is invalid. +--- @param config table The configuration table to validate. function M.validate(config) assert( type(config.port_range) == "table" @@ -64,11 +60,7 @@ function M.validate(config) return true end ---- Applies user configuration on top of default settings and validates the result. --- Merges the user-provided configuration with the default configuration, --- then validates the merged configuration. --- @param user_config table|nil The user-provided configuration table. --- @return table The final, validated configuration table. +--- @param user_config table|nil The user-provided configuration table. function M.apply(user_config) local config = vim.deepcopy(M.defaults) diff --git a/lua/claudecode/diff.lua b/lua/claudecode/diff.lua index dae03fe..7a1b437 100644 --- a/lua/claudecode/diff.lua +++ b/lua/claudecode/diff.lua @@ -399,22 +399,41 @@ end --- Apply accepted changes to the original file and reload open buffers -- @param diff_data table The diff state data -- @param final_content string The final content to write +-- @return boolean success Whether the operation succeeded +-- @return string|nil error Error message if operation failed function M._apply_accepted_changes(diff_data, final_content) local old_file_path = diff_data.old_file_path if not old_file_path then - require("claudecode.logger").error("diff", "No old_file_path found in diff_data") - return + local error_msg = "No old_file_path found in diff_data" + require("claudecode.logger").error("diff", error_msg) + return false, error_msg end require("claudecode.logger").debug("diff", "Writing accepted changes to file:", old_file_path) + -- Ensure parent directories exist for new files + if diff_data.is_new_file then + local parent_dir = vim.fn.fnamemodify(old_file_path, ":h") + if parent_dir and parent_dir ~= "" and parent_dir ~= "." then + require("claudecode.logger").debug("diff", "Creating parent directories for new file:", parent_dir) + local mkdir_success, mkdir_err = pcall(vim.fn.mkdir, parent_dir, "p") + if not mkdir_success then + local error_msg = "Failed to create parent directories: " .. parent_dir .. " - " .. tostring(mkdir_err) + require("claudecode.logger").error("diff", error_msg) + return false, error_msg + end + require("claudecode.logger").debug("diff", "Successfully created parent directories:", parent_dir) + end + end + -- Write the content to the actual file local lines = vim.split(final_content, "\n") local success, err = pcall(vim.fn.writefile, lines, old_file_path) if not success then - require("claudecode.logger").error("diff", "Failed to write file:", old_file_path, "error:", err) - return + local error_msg = "Failed to write file: " .. old_file_path .. " - " .. tostring(err) + require("claudecode.logger").error("diff", error_msg) + return false, error_msg end require("claudecode.logger").debug("diff", "Successfully wrote changes to", old_file_path) @@ -434,6 +453,8 @@ function M._apply_accepted_changes(diff_data, final_content) end end end + + return true, nil end --- Resolve diff as accepted with final content @@ -581,8 +602,9 @@ end -- @param old_file_path string Path to the original file -- @param new_buffer number New file buffer ID -- @param tab_name string The diff identifier +-- @param is_new_file boolean Whether this is a new file (doesn't exist yet) -- @return table Info about the created diff layout -function M._create_diff_view_from_window(target_window, old_file_path, new_buffer, tab_name) +function M._create_diff_view_from_window(target_window, old_file_path, new_buffer, tab_name, is_new_file) require("claudecode.logger").debug("diff", "Creating diff view from window", target_window) -- If no target window provided, create a new window in suitable location @@ -608,16 +630,35 @@ function M._create_diff_view_from_window(target_window, old_file_path, new_buffe vim.api.nvim_set_current_win(target_window) end - -- Make sure the window shows the file we want to diff - -- This handles the case where the buffer exists but isn't in the current window - vim.cmd("edit " .. vim.fn.fnameescape(old_file_path)) - - -- Store the original buffer for later - local original_buffer = vim.api.nvim_win_get_buf(target_window) + -- Handle the left side of the diff (original file or empty for new files) + local original_buffer + if is_new_file then + -- Create an empty buffer for new file comparison + require("claudecode.logger").debug("diff", "Creating empty buffer for new file diff") + local empty_buffer = vim.api.nvim_create_buf(false, true) -- unlisted, scratch + vim.api.nvim_buf_set_name(empty_buffer, old_file_path .. " (NEW FILE)") + vim.api.nvim_buf_set_lines(empty_buffer, 0, -1, false, {}) -- Empty content + vim.api.nvim_buf_set_option(empty_buffer, "buftype", "nofile") + vim.api.nvim_buf_set_option(empty_buffer, "modifiable", false) + vim.api.nvim_buf_set_option(empty_buffer, "readonly", true) + + vim.api.nvim_win_set_buf(target_window, empty_buffer) + original_buffer = empty_buffer + else + -- Make sure the window shows the existing file we want to diff + vim.cmd("edit " .. vim.fn.fnameescape(old_file_path)) + original_buffer = vim.api.nvim_win_get_buf(target_window) + end - -- Enable diff mode on the original file + -- Enable diff mode on the original/empty file vim.cmd("diffthis") - require("claudecode.logger").debug("diff", "Enabled diff mode on original file in window", target_window) + require("claudecode.logger").debug( + "diff", + "Enabled diff mode on", + is_new_file and "empty buffer" or "original file", + "in window", + target_window + ) -- Create vertical split for new buffer (proposed changes) vim.cmd("vsplit") @@ -647,6 +688,14 @@ function M._create_diff_view_from_window(target_window, old_file_path, new_buffe -- Accept all changes local new_content = vim.api.nvim_buf_get_lines(new_buffer, 0, -1, false) + -- Ensure parent directories exist for new files + if is_new_file then + local parent_dir = vim.fn.fnamemodify(old_file_path, ":h") + if parent_dir and parent_dir ~= "" and parent_dir ~= "." then + vim.fn.mkdir(parent_dir, "p") + end + end + -- Write to file vim.fn.writefile(new_content, old_file_path) @@ -747,41 +796,49 @@ function M._setup_blocking_diff(params, resolution_callback) params.old_file_path ) - -- Step 1: Check if the file exists + -- Step 1: Check if the file exists (allow new files) local old_file_exists = vim.fn.filereadable(params.old_file_path) == 1 - if not old_file_exists then - error({ - code = -32000, - message = "File access error", - data = "Cannot open file: " .. params.old_file_path .. " (file does not exist)", - }) - end + local is_new_file = not old_file_exists + + require("claudecode.logger").debug( + "diff", + "File existence check - old_file_exists:", + old_file_exists, + "is_new_file:", + is_new_file, + "path:", + params.old_file_path + ) - -- Step 2: Find if the file is already open in a buffer + -- Step 2: Find if the file is already open in a buffer (only for existing files) local existing_buffer = nil local target_window = nil - -- Look for existing buffer with this file - for _, buf in ipairs(vim.api.nvim_list_bufs()) do - if vim.api.nvim_buf_is_valid(buf) and vim.api.nvim_buf_is_loaded(buf) then - local buf_name = vim.api.nvim_buf_get_name(buf) - if buf_name == params.old_file_path then - existing_buffer = buf - require("claudecode.logger").debug("diff", "Found existing buffer", buf, "for file", params.old_file_path) - break + if old_file_exists then + -- Look for existing buffer with this file + for _, buf in ipairs(vim.api.nvim_list_bufs()) do + if vim.api.nvim_buf_is_valid(buf) and vim.api.nvim_buf_is_loaded(buf) then + local buf_name = vim.api.nvim_buf_get_name(buf) + if buf_name == params.old_file_path then + existing_buffer = buf + require("claudecode.logger").debug("diff", "Found existing buffer", buf, "for file", params.old_file_path) + break + end end end - end - -- Find window containing this buffer (if any) - if existing_buffer then - for _, win in ipairs(vim.api.nvim_list_wins()) do - if vim.api.nvim_win_get_buf(win) == existing_buffer then - target_window = win - require("claudecode.logger").debug("diff", "Found window", win, "containing buffer", existing_buffer) - break + -- Find window containing this buffer (if any) + if existing_buffer then + for _, win in ipairs(vim.api.nvim_list_wins()) do + if vim.api.nvim_win_get_buf(win) == existing_buffer then + target_window = win + require("claudecode.logger").debug("diff", "Found window", win, "containing buffer", existing_buffer) + break + end end end + else + require("claudecode.logger").debug("diff", "Skipping buffer search for new file:", params.old_file_path) end -- If no existing buffer/window, find a suitable main editor window @@ -811,17 +868,23 @@ function M._setup_blocking_diff(params, resolution_callback) }) end - local new_unique_name = tab_name .. " (proposed)" + local new_unique_name = is_new_file and (tab_name .. " (NEW FILE - proposed)") or (tab_name .. " (proposed)") vim.api.nvim_buf_set_name(new_buffer, new_unique_name) vim.api.nvim_buf_set_lines(new_buffer, 0, -1, false, vim.split(params.new_file_contents, "\n")) - -- Set buffer options for the new content buffer vim.api.nvim_buf_set_option(new_buffer, "buftype", "acwrite") -- Allows saving but stays as scratch-like vim.api.nvim_buf_set_option(new_buffer, "modifiable", true) -- Step 4: Set up diff view using the target window - require("claudecode.logger").debug("diff", "Creating diff view from window", target_window) - local diff_info = M._create_diff_view_from_window(target_window, params.old_file_path, new_buffer, tab_name) + require("claudecode.logger").debug( + "diff", + "Creating diff view from window", + target_window, + "is_new_file:", + is_new_file + ) + local diff_info = + M._create_diff_view_from_window(target_window, params.old_file_path, new_buffer, tab_name, is_new_file) -- Step 5: Register autocmds for user interaction monitoring require("claudecode.logger").debug("diff", "Registering autocmds") @@ -842,6 +905,7 @@ function M._setup_blocking_diff(params, resolution_callback) status = "pending", resolution_callback = resolution_callback, result_content = nil, + is_new_file = is_new_file, }) require("claudecode.logger").debug("diff", "Setup completed successfully for", tab_name) end diff --git a/lua/claudecode/init.lua b/lua/claudecode/init.lua index 489785c..9ae13e9 100644 --- a/lua/claudecode/init.lua +++ b/lua/claudecode/init.lua @@ -46,7 +46,7 @@ local default_config = { terminal_cmd = nil, log_level = "info", track_selection = true, - visual_demotion_delay_ms = 200, + visual_demotion_delay_ms = 50, -- Reduced from 200ms for better responsiveness in tree navigation diff_opts = { auto_close_on_accept = true, show_diff_stats = true, @@ -245,25 +245,188 @@ function M._create_commands() desc = "Show Claude Code integration status", }) - vim.api.nvim_create_user_command("ClaudeCodeSend", function(opts) + -- Helper function to format file paths for at mentions + local function format_path_for_at_mention(file_path) + local is_directory = vim.fn.isdirectory(file_path) == 1 + local formatted_path = file_path + + -- For directories, convert to relative path and add trailing slash + if is_directory then + -- Get current working directory + local cwd = vim.fn.getcwd() + -- Convert absolute path to relative if it's under the current working directory + if string.find(file_path, cwd, 1, true) == 1 then + local relative_path = string.sub(file_path, #cwd + 2) -- +2 to skip the trailing slash + if relative_path ~= "" then + formatted_path = relative_path + end + end + -- Always add trailing slash for directories + if not string.match(formatted_path, "/$") then + formatted_path = formatted_path .. "/" + end + end + + return formatted_path, is_directory + end + + ---@param file_path string The file path to broadcast + ---@return boolean success Whether the broadcast was successful + ---@return string|nil error Error message if broadcast failed + local function broadcast_at_mention(file_path, start_line, end_line) + if not M.state.server then + return false, "Claude Code integration is not running" + end + + local formatted_path, is_directory = format_path_for_at_mention(file_path) + + if is_directory and (start_line or end_line) then + logger.debug("command", "Line numbers ignored for directory: " .. formatted_path) + start_line = nil + end_line = nil + end + + local params = { + filePath = formatted_path, + lineStart = start_line, + lineEnd = end_line, + } + + local broadcast_success = M.state.server.broadcast("at_mentioned", params) + if broadcast_success then + if logger.is_level_enabled and logger.is_level_enabled("debug") then + local message = "Broadcast success: Added " .. (is_directory and "directory" or "file") .. " " .. formatted_path + if not is_directory and (start_line or end_line) then + local range_info = "" + if start_line and end_line then + range_info = " (lines " .. start_line .. "-" .. end_line .. ")" + elseif start_line then + range_info = " (from line " .. start_line .. ")" + end + message = message .. range_info + end + logger.debug("command", message) + elseif not logger.is_level_enabled then + -- Fallback for tests or environments where logger isn't fully initialized + logger.debug( + "command", + "Broadcast success: Added " .. (is_directory and "directory" or "file") .. " " .. formatted_path + ) + end + return true, nil + else + local error_msg = "Failed to broadcast " .. (is_directory and "directory" or "file") .. " " .. formatted_path + logger.error("command", error_msg) + return false, error_msg + end + end + + ---@param file_paths table List of file paths to add + ---@param options table|nil Optional settings: { delay?: number, show_summary?: boolean, context?: string } + ---@return number success_count Number of successfully added files + ---@return number total_count Total number of files attempted + local function add_paths_to_claude(file_paths, options) + options = options or {} + local delay = options.delay or 0 + local show_summary = options.show_summary ~= false + local context = options.context or "command" + + if not file_paths or #file_paths == 0 then + return 0, 0 + end + + local success_count = 0 + local total_count = #file_paths + + if delay > 0 then + local function send_files_sequentially(index) + if index > total_count then + if show_summary and success_count > 0 then + local message = success_count == 1 and "Added 1 file to Claude context" + or string.format("Added %d files to Claude context", success_count) + if total_count > success_count then + message = message .. string.format(" (%d failed)", total_count - success_count) + end + logger.debug(context, message) + end + return + end + + local file_path = file_paths[index] + local success = broadcast_at_mention(file_path) + if success then + success_count = success_count + 1 + end + + if index < total_count then + vim.defer_fn(function() + send_files_sequentially(index + 1) + end, delay) + else + if show_summary and success_count > 0 then + local message = success_count == 1 and "Added 1 file to Claude context" + or string.format("Added %d files to Claude context", success_count) + if total_count > success_count then + message = message .. string.format(" (%d failed)", total_count - success_count) + end + logger.debug(context, message) + end + end + end + + send_files_sequentially(1) + else + for _, file_path in ipairs(file_paths) do + local success = broadcast_at_mention(file_path) + if success then + success_count = success_count + 1 + end + end + + if show_summary and success_count > 0 then + local message = success_count == 1 and "Added 1 file to Claude context" + or string.format("Added %d files to Claude context", success_count) + if total_count > success_count then + message = message .. string.format(" (%d failed)", total_count - success_count) + end + logger.debug(context, message) + end + end + + return success_count, total_count + end + + local function handle_send_normal(opts) if not M.state.server then logger.error("command", "ClaudeCodeSend: Claude Code integration is not running.") vim.notify("Claude Code integration is not running", vim.log.levels.ERROR) return end - logger.debug( - "command", - "ClaudeCodeSend (new logic) invoked. Mode: " - .. vim.fn.mode(true) - .. ", Neovim's reported range: " - .. tostring(opts and opts.range) - ) - -- We now ignore opts.range and rely on the selection module's state, - -- as opts.range was found to be 0 even when in visual mode for mappings. - if not M.state.server then - logger.error("command", "ClaudeCodeSend: Claude Code integration is not running.") - vim.notify("Claude Code integration is not running", vim.log.levels.ERROR, { title = "ClaudeCode Error" }) + local current_ft = vim.bo.filetype + local current_bufname = vim.api.nvim_buf_get_name(0) + + local is_tree_buffer = current_ft == "NvimTree" + or current_ft == "neo-tree" + or string.match(current_bufname, "neo%-tree") + or string.match(current_bufname, "NvimTree") + + if is_tree_buffer then + local integrations = require("claudecode.integrations") + local files, error = integrations.get_selected_files_from_tree() + + if error then + logger.warn("command", "ClaudeCodeSend->TreeAdd: " .. error) + return + end + + if not files or #files == 0 then + logger.warn("command", "ClaudeCodeSend->TreeAdd: No files selected") + return + end + + add_paths_to_claude(files, { context = "ClaudeCodeSend->TreeAdd" }) + return end @@ -271,13 +434,9 @@ function M._create_commands() if selection_module_ok then local sent_successfully = selection_module.send_at_mention_for_visual_selection() if sent_successfully then - vim.api.nvim_feedkeys(vim.api.nvim_replace_termcodes("", true, false, true), "n", false) - logger.debug("command", "ClaudeCodeSend: Exited visual mode after successful send.") - - -- Focus the Claude Code terminal after sending selection local terminal_ok, terminal = pcall(require, "claudecode.terminal") if terminal_ok then - terminal.open({}) -- Open/focus the terminal + terminal.open({}) logger.debug("command", "ClaudeCodeSend: Focused Claude Code terminal after selection send.") else logger.warn("command", "ClaudeCodeSend: Failed to load terminal module for focusing.") @@ -287,9 +446,203 @@ function M._create_commands() logger.error("command", "ClaudeCodeSend: Failed to load selection module.") vim.notify("Failed to send selection: selection module not loaded.", vim.log.levels.ERROR) end + end + + local function handle_send_visual(visual_data, opts) + if not M.state.server then + logger.error("command", "ClaudeCodeSend_visual: Claude Code integration is not running.") + return + end + + if visual_data then + local visual_commands = require("claudecode.visual_commands") + local files, error = visual_commands.get_files_from_visual_selection(visual_data) + + if not error and files and #files > 0 then + local success_count = add_paths_to_claude(files, { + delay = 10, + context = "ClaudeCodeSend_visual", + show_summary = false, + }) + if success_count > 0 then + local message = success_count == 1 and "Added 1 file to Claude context from visual selection" + or string.format("Added %d files to Claude context from visual selection", success_count) + logger.debug("command", message) + + local terminal_ok, terminal = pcall(require, "claudecode.terminal") + if terminal_ok then + terminal.open({}) + end + end + return + end + end + local selection_module_ok, selection_module = pcall(require, "claudecode.selection") + if selection_module_ok then + local sent_successfully = selection_module.send_at_mention_for_visual_selection() + if sent_successfully then + local terminal_ok, terminal = pcall(require, "claudecode.terminal") + if terminal_ok then + terminal.open({}) + end + end + end + end + + local visual_commands = require("claudecode.visual_commands") + local unified_send_handler = visual_commands.create_visual_command_wrapper(handle_send_normal, handle_send_visual) + + vim.api.nvim_create_user_command("ClaudeCodeSend", unified_send_handler, { + desc = "Send current visual selection as an at_mention to Claude Code (supports tree visual selection)", + range = true, + }) + + local function handle_tree_add_normal() + if not M.state.server then + logger.error("command", "ClaudeCodeTreeAdd: Claude Code integration is not running.") + return + end + + local integrations = require("claudecode.integrations") + local files, error = integrations.get_selected_files_from_tree() + + if error then + logger.warn("command", "ClaudeCodeTreeAdd: " .. error) + return + end + + if not files or #files == 0 then + logger.warn("command", "ClaudeCodeTreeAdd: No files selected") + return + end + + local success_count = add_paths_to_claude(files, { context = "ClaudeCodeTreeAdd" }) + + if success_count == 0 then + logger.error("command", "ClaudeCodeTreeAdd: Failed to add any files") + end + end + + local function handle_tree_add_visual(visual_data) + if not M.state.server then + logger.error("command", "ClaudeCodeTreeAdd_visual: Claude Code integration is not running.") + return + end + + local visual_cmd_module = require("claudecode.visual_commands") + local files, error = visual_cmd_module.get_files_from_visual_selection(visual_data) + + if error then + logger.warn("command", "ClaudeCodeTreeAdd_visual: " .. error) + return + end + + if not files or #files == 0 then + logger.warn("command", "ClaudeCodeTreeAdd_visual: No files selected in visual range") + return + end + + local success_count = add_paths_to_claude(files, { + delay = 10, + context = "ClaudeCodeTreeAdd_visual", + show_summary = false, + }) + if success_count > 0 then + local message = success_count == 1 and "Added 1 file to Claude context from visual selection" + or string.format("Added %d files to Claude context from visual selection", success_count) + logger.debug("command", message) + else + logger.error("command", "ClaudeCodeTreeAdd_visual: Failed to add any files from visual selection") + end + end + + local unified_tree_add_handler = + visual_commands.create_visual_command_wrapper(handle_tree_add_normal, handle_tree_add_visual) + + vim.api.nvim_create_user_command("ClaudeCodeTreeAdd", unified_tree_add_handler, { + desc = "Add selected file(s) from tree explorer to Claude Code context (supports visual selection)", + }) + + vim.api.nvim_create_user_command("ClaudeCodeAdd", function(opts) + if not M.state.server then + logger.error("command", "ClaudeCodeAdd: Claude Code integration is not running.") + return + end + + if not opts.args or opts.args == "" then + logger.error("command", "ClaudeCodeAdd: No file path provided") + return + end + + local args = vim.split(opts.args, "%s+") + local file_path = args[1] + local start_line = args[2] and tonumber(args[2]) or nil + local end_line = args[3] and tonumber(args[3]) or nil + + if #args > 3 then + logger.error( + "command", + "ClaudeCodeAdd: Too many arguments. Usage: ClaudeCodeAdd [start-line] [end-line]" + ) + return + end + + if args[2] and not start_line then + logger.error("command", "ClaudeCodeAdd: Invalid start line number: " .. args[2]) + return + end + + if args[3] and not end_line then + logger.error("command", "ClaudeCodeAdd: Invalid end line number: " .. args[3]) + return + end + + if start_line and start_line < 1 then + logger.error("command", "ClaudeCodeAdd: Start line must be positive: " .. start_line) + return + end + + if end_line and end_line < 1 then + logger.error("command", "ClaudeCodeAdd: End line must be positive: " .. end_line) + return + end + + if start_line and end_line and start_line > end_line then + logger.error( + "command", + "ClaudeCodeAdd: Start line (" .. start_line .. ") must be <= end line (" .. end_line .. ")" + ) + return + end + + file_path = vim.fn.expand(file_path) + if vim.fn.filereadable(file_path) == 0 and vim.fn.isdirectory(file_path) == 0 then + logger.error("command", "ClaudeCodeAdd: File or directory does not exist: " .. file_path) + return + end + + -- Convert 1-indexed user input to 0-indexed for Claude + local claude_start_line = start_line and (start_line - 1) or nil + local claude_end_line = end_line and (end_line - 1) or nil + + local success, error_msg = broadcast_at_mention(file_path, claude_start_line, claude_end_line) + if not success then + logger.error("command", "ClaudeCodeAdd: " .. (error_msg or "Failed to add file")) + else + local message = "ClaudeCodeAdd: Successfully added " .. file_path + if start_line or end_line then + if start_line and end_line then + message = message .. " (lines " .. start_line .. "-" .. end_line .. ")" + elseif start_line then + message = message .. " (from line " .. start_line .. ")" + end + end + logger.debug("command", message) + end end, { - desc = "Send current visual selection as an at_mention to Claude Code", - range = true, -- Important: This makes the command expect a range (visual selection) + nargs = "+", + complete = "file", + desc = "Add specified file or directory to Claude Code context with optional line range", }) local terminal_ok, terminal = pcall(require, "claudecode.terminal") @@ -337,4 +690,38 @@ function M.get_version() } end +--- Format file path for at mention (exposed for testing) +---@param file_path string The file path to format +---@return string formatted_path The formatted path +---@return boolean is_directory Whether the path is a directory +function M._format_path_for_at_mention(file_path) + local is_directory = vim.fn.isdirectory(file_path) == 1 + local formatted_path = file_path + + if is_directory then + local cwd = vim.fn.getcwd() + if string.find(file_path, cwd, 1, true) == 1 then + local relative_path = string.sub(file_path, #cwd + 2) + if relative_path ~= "" then + formatted_path = relative_path + else + formatted_path = "./" + end + end + if not string.match(formatted_path, "/$") then + formatted_path = formatted_path .. "/" + end + else + local cwd = vim.fn.getcwd() + if string.find(file_path, cwd, 1, true) == 1 then + local relative_path = string.sub(file_path, #cwd + 2) + if relative_path ~= "" then + formatted_path = relative_path + end + end + end + + return formatted_path, is_directory +end + return M diff --git a/lua/claudecode/integrations.lua b/lua/claudecode/integrations.lua new file mode 100644 index 0000000..f5adeff --- /dev/null +++ b/lua/claudecode/integrations.lua @@ -0,0 +1,181 @@ +--- +-- Tree integration module for ClaudeCode.nvim +-- Handles detection and selection of files from nvim-tree and neo-tree +-- @module claudecode.integrations +local M = {} + +--- Get selected files from the current tree explorer +--- @return table|nil files List of file paths, or nil if error +--- @return string|nil error Error message if operation failed +function M.get_selected_files_from_tree() + local current_ft = vim.bo.filetype + + if current_ft == "NvimTree" then + return M._get_nvim_tree_selection() + elseif current_ft == "neo-tree" then + return M._get_neotree_selection() + else + return nil, "Not in a supported tree buffer (current filetype: " .. current_ft .. ")" + end +end + +--- Get selected files from nvim-tree +--- Supports both multi-selection (marks) and single file under cursor +--- @return table files List of file paths +--- @return string|nil error Error message if operation failed +function M._get_nvim_tree_selection() + local success, nvim_tree_api = pcall(require, "nvim-tree.api") + if not success then + return {}, "nvim-tree not available" + end + + local files = {} + + local marks = nvim_tree_api.marks.list() + + if marks and #marks > 0 then + for i, mark in ipairs(marks) do + if mark.type == "file" and mark.absolute_path and mark.absolute_path ~= "" then + -- Check if it's not a root-level file (basic protection) + if not string.match(mark.absolute_path, "^/[^/]*$") then + table.insert(files, mark.absolute_path) + end + end + end + + if #files > 0 then + return files, nil + end + end + + local node = nvim_tree_api.tree.get_node_under_cursor() + if node then + if node.type == "file" and node.absolute_path and node.absolute_path ~= "" then + -- Check if it's not a root-level file (basic protection) + if not string.match(node.absolute_path, "^/[^/]*$") then + return { node.absolute_path }, nil + else + return {}, "Cannot add root-level file. Please select a file in a subdirectory." + end + elseif node.type == "directory" and node.absolute_path and node.absolute_path ~= "" then + return { node.absolute_path }, nil + end + end + + return {}, "No file found under cursor" +end + +--- Get selected files from neo-tree +--- Uses neo-tree's own visual selection method when in visual mode +--- @return table files List of file paths +--- @return string|nil error Error message if operation failed +function M._get_neotree_selection() + local success, manager = pcall(require, "neo-tree.sources.manager") + if not success then + return {}, "neo-tree not available" + end + + local state = manager.get_state("filesystem") + if not state then + return {}, "neo-tree filesystem state not available" + end + + local files = {} + + -- Use neo-tree's own visual selection method (like their copy/paste feature) + local mode = vim.fn.mode() + + if mode == "V" or mode == "v" or mode == "\22" then + local current_win = vim.api.nvim_get_current_win() + + if state.winid and state.winid == current_win then + -- Use neo-tree's exact method to get visual range (from their get_selected_nodes implementation) + local start_pos = vim.fn.getpos("'<")[2] + local end_pos = vim.fn.getpos("'>")[2] + + -- Fallback to current cursor and anchor if marks are not valid + if start_pos == 0 or end_pos == 0 then + local cursor_pos = vim.api.nvim_win_get_cursor(0)[1] + local anchor_pos = vim.fn.getpos("v")[2] + if anchor_pos > 0 then + start_pos = math.min(cursor_pos, anchor_pos) + end_pos = math.max(cursor_pos, anchor_pos) + else + start_pos = cursor_pos + end_pos = cursor_pos + end + end + + if end_pos < start_pos then + start_pos, end_pos = end_pos, start_pos + end + + local selected_nodes = {} + + for line = start_pos, end_pos do + local node = state.tree:get_node(line) + if node then + -- Add validation for node types before adding to selection + if node.type and node.type ~= "message" then + table.insert(selected_nodes, node) + end + end + end + + for i, node in ipairs(selected_nodes) do + -- Enhanced validation: check for file type and valid path + if node.type == "file" and node.path and node.path ~= "" then + -- Additional check: ensure it's not a root node (depth protection) + local depth = (node.get_depth and node:get_depth()) and node:get_depth() or 0 + if depth > 1 then + table.insert(files, node.path) + end + end + end + + if #files > 0 then + return files, nil + end + end + end + + if state.tree then + local selection = nil + + if state.tree.get_selection then + selection = state.tree:get_selection() + end + + if (not selection or #selection == 0) and state.selected_nodes then + selection = state.selected_nodes + end + + if selection and #selection > 0 then + for i, node in ipairs(selection) do + if node.type == "file" and node.path then + table.insert(files, node.path) + end + end + + if #files > 0 then + return files, nil + end + end + end + + if state.tree then + local node = state.tree:get_node() + + if node then + if node.type == "file" and node.path then + return { node.path }, nil + elseif node.type == "directory" and node.path then + return { node.path }, nil + end + end + end + + return {}, "No file found under cursor" +end + +return M diff --git a/lua/claudecode/lockfile.lua b/lua/claudecode/lockfile.lua index 4d1ebf4..12792a9 100644 --- a/lua/claudecode/lockfile.lua +++ b/lua/claudecode/lockfile.lua @@ -18,7 +18,6 @@ function M.create(port) return false, "Invalid port number" end - -- Ensure lock directory exists local ok, err = pcall(function() return vim.fn.mkdir(M.lock_dir, "p") end) @@ -27,10 +26,8 @@ function M.create(port) return false, "Failed to create lock directory: " .. (err or "unknown error") end - -- Generate lock file path local lock_path = M.lock_dir .. "/" .. port .. ".lock" - -- Get workspace folders local workspace_folders = M.get_workspace_folders() -- Prepare lock file content @@ -41,7 +38,6 @@ function M.create(port) transport = "ws", } - -- Convert to JSON with error handling local json local ok_json, json_err = pcall(function() json = vim.json.encode(lock_content) @@ -52,7 +48,6 @@ function M.create(port) return false, "Failed to encode lock file content: " .. (json_err or "unknown error") end - -- Write to file local file = io.open(lock_path, "w") if not file then return false, "Failed to create lock file: " .. lock_path @@ -64,7 +59,6 @@ function M.create(port) end) if not write_ok then - -- Try to close file if still open pcall(function() file:close() end) @@ -85,12 +79,10 @@ function M.remove(port) local lock_path = M.lock_dir .. "/" .. port .. ".lock" - -- Check if file exists if vim.fn.filereadable(lock_path) == 0 then return false, "Lock file does not exist: " .. lock_path end - -- Remove the file with error handling local ok, err = pcall(function() return os.remove(lock_path) end) @@ -111,7 +103,6 @@ function M.update(port) return false, "Invalid port number" end - -- First remove existing lock file if it exists local exists = vim.fn.filereadable(M.lock_dir .. "/" .. port .. ".lock") == 1 if exists then local remove_ok, remove_err = M.remove(port) @@ -120,7 +111,6 @@ function M.update(port) end end - -- Then create a new one return M.create(port) end diff --git a/lua/claudecode/logger.lua b/lua/claudecode/logger.lua index 710437c..44418a3 100644 --- a/lua/claudecode/logger.lua +++ b/lua/claudecode/logger.lua @@ -20,8 +20,7 @@ local level_values = { local current_log_level_value = M.levels.INFO ---- Initializes the logger with the provided configuration. --- @param plugin_config table The configuration table (e.g., from claudecode.init.state.config). +--- @param plugin_config table The configuration table (e.g., from claudecode.init.state.config). function M.setup(plugin_config) local conf = plugin_config @@ -83,8 +82,7 @@ local function log(level, component, message_parts) end end ---- Logs a message at the ERROR level. --- @param component string|nil Optional component/module name. +--- @param component string|nil Optional component/module name. -- @param ... any Varargs representing parts of the message. function M.error(component, ...) if type(component) ~= "string" then @@ -94,8 +92,7 @@ function M.error(component, ...) end end ---- Logs a message at the WARN level. --- @param component string|nil Optional component/module name. +--- @param component string|nil Optional component/module name. -- @param ... any Varargs representing parts of the message. function M.warn(component, ...) if type(component) ~= "string" then @@ -105,8 +102,7 @@ function M.warn(component, ...) end end ---- Logs a message at the INFO level. --- @param component string|nil Optional component/module name. +--- @param component string|nil Optional component/module name. -- @param ... any Varargs representing parts of the message. function M.info(component, ...) if type(component) ~= "string" then @@ -116,8 +112,18 @@ function M.info(component, ...) end end ---- Logs a message at the DEBUG level. --- @param component string|nil Optional component/module name. +--- Check if a specific log level is enabled +-- @param level_name string The level name ("error", "warn", "info", "debug", "trace") +-- @return boolean Whether the level is enabled +function M.is_level_enabled(level_name) + local level_value = level_values[level_name] + if not level_value then + return false + end + return level_value <= current_log_level_value +end + +--- @param component string|nil Optional component/module name. -- @param ... any Varargs representing parts of the message. function M.debug(component, ...) if type(component) ~= "string" then @@ -127,8 +133,7 @@ function M.debug(component, ...) end end ---- Logs a message at the TRACE level. --- @param component string|nil Optional component/module name. +--- @param component string|nil Optional component/module name. -- @param ... any Varargs representing parts of the message. function M.trace(component, ...) if type(component) ~= "string" then diff --git a/lua/claudecode/selection.lua b/lua/claudecode/selection.lua index 9b585c9..a2ff7db 100644 --- a/lua/claudecode/selection.lua +++ b/lua/claudecode/selection.lua @@ -1,8 +1,5 @@ --- -- Manages selection tracking and communication with the Claude server. --- This module handles enabling/disabling selection tracking, debouncing updates, --- determining the current selection (visual or cursor position), and sending --- updates to the Claude server. -- @module claudecode.selection local M = {} @@ -13,16 +10,14 @@ M.state = { latest_selection = nil, tracking_enabled = false, debounce_timer = nil, - debounce_ms = 300, + debounce_ms = 100, - -- New state for delayed visual demotion - last_active_visual_selection = nil, -- Stores { bufnr, selection_data, timestamp } + last_active_visual_selection = nil, demotion_timer = nil, - visual_demotion_delay_ms = 50, -- Default, will be overridden by config in M.enable + visual_demotion_delay_ms = 50, } --- Enables selection tracking. --- Sets up autocommands to monitor cursor movements, mode changes, and text changes. -- @param server table The server object to use for communication. -- @param visual_demotion_delay_ms number The delay for visual selection demotion. function M.enable(server, visual_demotion_delay_ms) @@ -209,6 +204,7 @@ function M.update_selection() M.state.demotion_timer:stop() M.state.demotion_timer:close() end + M.state.demotion_timer = vim.loop.new_timer() M.state.demotion_timer:start( M.state.visual_demotion_delay_ms, @@ -271,6 +267,7 @@ function M.handle_selection_demotion(original_bufnr_when_scheduled) end local current_mode_info = vim.api.nvim_get_mode() + -- Condition 2: Back in Visual Mode in the Original Buffer if current_buf == original_bufnr_when_scheduled @@ -296,7 +293,9 @@ function M.handle_selection_demotion(original_bufnr_when_scheduled) M.send_selection_update(M.state.latest_selection) end end + -- No change detected in selection end + -- User switched to different buffer -- Always clear last_active_visual_selection for the original buffer as its pending demotion is resolved. if diff --git a/lua/claudecode/server/frame.lua b/lua/claudecode/server/frame.lua index d8d57bf..2c1d90e 100644 --- a/lua/claudecode/server/frame.lua +++ b/lua/claudecode/server/frame.lua @@ -26,7 +26,6 @@ M.OPCODE = { ---@return WebSocketFrame|nil frame The parsed frame, or nil if incomplete/invalid ---@return number bytes_consumed Number of bytes consumed from input function M.parse_frame(data) - -- Input validation if type(data) ~= "string" then return nil, 0 end @@ -46,14 +45,12 @@ function M.parse_frame(data) pos = pos + 2 - -- Parse first byte local fin = math.floor(byte1 / 128) == 1 local rsv1 = math.floor((byte1 % 128) / 64) == 1 local rsv2 = math.floor((byte1 % 64) / 32) == 1 local rsv3 = math.floor((byte1 % 32) / 16) == 1 local opcode = byte1 % 16 - -- Parse second byte local masked = math.floor(byte2 / 128) == 1 local payload_len = byte2 % 128 diff --git a/lua/claudecode/server/tcp.lua b/lua/claudecode/server/tcp.lua index 21859f8..ef3f30a 100644 --- a/lua/claudecode/server/tcp.lua +++ b/lua/claudecode/server/tcp.lua @@ -22,7 +22,6 @@ function M.find_available_port(min_port, max_port) return nil -- Or handle error appropriately end - -- Create a list of ports in the range local ports = {} for i = min_port, max_port do table.insert(ports, i) @@ -51,13 +50,11 @@ end ---@return TCPServer|nil server The server object, or nil on error ---@return string|nil error Error message if failed function M.create_server(config, callbacks) - -- Find available port local port = M.find_available_port(config.port_range.min, config.port_range.max) if not port then return nil, "No available ports in range " .. config.port_range.min .. "-" .. config.port_range.max end - -- Create TCP server local tcp_server = vim.loop.new_tcp() if not tcp_server then return nil, "Failed to create TCP server" @@ -74,7 +71,6 @@ function M.create_server(config, callbacks) on_error = callbacks.on_error or function() end, } - -- Bind to port local bind_success, bind_err = tcp_server:bind("127.0.0.1", port) if not bind_success then tcp_server:close() diff --git a/lua/claudecode/terminal.lua b/lua/claudecode/terminal.lua index 4880388..77be1f1 100644 --- a/lua/claudecode/terminal.lua +++ b/lua/claudecode/terminal.lua @@ -33,9 +33,7 @@ local managed_fallback_terminal_winid = nil local managed_fallback_terminal_jobid = nil local native_term_tip_shown = false --- Determines the command to run in the terminal. -- Uses the `terminal_cmd` from the module's configuration, or defaults to "claude". --- @local -- @return string The command to execute. local function get_claude_command() local cmd_from_config = term_module_config.terminal_cmd @@ -91,7 +89,6 @@ function M.setup(user_term_config, p_terminal_cmd) end --- Determines the effective terminal provider based on configuration and availability. --- @local -- @return string "snacks" or "native" local function get_effective_terminal_provider() if term_module_config.provider == "snacks" then @@ -117,8 +114,6 @@ local function get_effective_terminal_provider() end end ---- Cleans up state variables for the fallback terminal. --- @local local function cleanup_fallback_terminal_state() managed_fallback_terminal_bufnr = nil managed_fallback_terminal_winid = nil @@ -127,7 +122,6 @@ end --- Checks if the managed fallback terminal is currently valid (window and buffer exist). -- Cleans up state if invalid. --- @local -- @return boolean True if valid, false otherwise. local function is_fallback_terminal_valid() -- First check if we have a valid buffer @@ -158,7 +152,6 @@ local function is_fallback_terminal_valid() end --- Opens a new terminal using native Neovim functions. --- @local -- @param cmd_string string The command string to run. -- @param env_table table Environment variables for the command. -- @param effective_term_config table Configuration for split_side and split_width_percentage. @@ -252,7 +245,6 @@ local function open_fallback_terminal(cmd_string, env_table, effective_term_conf end --- Closes the managed fallback terminal if it's open and valid. --- @local local function close_fallback_terminal() if is_fallback_terminal_valid() then -- Closing the window should trigger on_exit of the job if the process is still running, @@ -265,7 +257,6 @@ local function close_fallback_terminal() end --- Focuses the managed fallback terminal if it's open and valid. --- @local local function focus_fallback_terminal() if is_fallback_terminal_valid() then vim.api.nvim_set_current_win(managed_fallback_terminal_winid) @@ -275,7 +266,6 @@ end --- Builds the effective terminal configuration by merging module defaults with runtime overrides. -- Used by the native fallback. --- @local -- @param opts_override table (optional) Overrides for terminal appearance (split_side, split_width_percentage). -- @return table The effective terminal configuration. local function build_effective_term_config(opts_override) @@ -304,7 +294,6 @@ end --- Builds the options table for Snacks.terminal. -- This function merges the module's current terminal configuration -- with any runtime overrides provided specifically for an open/toggle action. --- @local -- @param effective_term_config_for_snacks table Pre-calculated effective config for split_side, width. -- @param env_table table Environment variables for the command. -- @return table The options table for Snacks. @@ -329,7 +318,6 @@ local function build_snacks_opts(effective_term_config_for_snacks, env_table) end --- Gets the base claude command string and necessary environment variables. --- @local -- @return string|nil cmd_string The command string, or nil on failure. -- @return table|nil env_table The environment variables table, or nil on failure. local function get_claude_command_and_env() @@ -355,7 +343,6 @@ local function get_claude_command_and_env() end --- Find any existing Claude Code terminal buffer by checking terminal job command --- @local -- @return number|nil Buffer number if found, nil otherwise local function find_existing_claude_terminal() local buffers = vim.api.nvim_list_bufs() diff --git a/lua/claudecode/tools/init.lua b/lua/claudecode/tools/init.lua index fd52967..23fb537 100644 --- a/lua/claudecode/tools/init.lua +++ b/lua/claudecode/tools/init.lua @@ -172,23 +172,4 @@ function M.handle_invoke(client, params) -- client needed for blocking tools return { result = handler_return_val1 } end --- Removed M.open_file function, its logic is now in lua/claudecode/tools/impl/open_file.lua - --- Removed M.get_diagnostics function, its logic is now in lua/claudecode/tools/impl/get_diagnostics.lua - --- Removed M.get_open_editors function, its logic is now in lua/claudecode/tools/impl/get_open_editors.lua - --- Removed M.get_workspace_folders function, its logic is now in lua/claudecode/tools/impl/get_workspace_folders.lua - --- Removed M.get_current_selection function, its logic is now in lua/claudecode/tools/impl/get_current_selection.lua --- Removed M.get_latest_selection function as it was redundant with get_current_selection's new implementation - --- Removed M.check_document_dirty function, its logic is now in lua/claudecode/tools/impl/check_document_dirty.lua - --- Removed M.save_document function, its logic is now in lua/claudecode/tools/impl/save_document.lua - --- Removed M.open_diff function, its logic is now in lua/claudecode/tools/impl/open_diff.lua - --- Removed M.close_buffer_by_name function, its logic is now in lua/claudecode/tools/impl/close_buffer_by_name.lua - return M diff --git a/lua/claudecode/visual_commands.lua b/lua/claudecode/visual_commands.lua new file mode 100644 index 0000000..4e76c41 --- /dev/null +++ b/lua/claudecode/visual_commands.lua @@ -0,0 +1,346 @@ +--- +-- Visual command handling module for ClaudeCode.nvim +-- Implements neo-tree-style visual mode exit and command processing +-- @module claudecode.visual_commands +local M = {} + +-- ESC key constant matching neo-tree's implementation +local ESC_KEY +local success = pcall(function() + ESC_KEY = vim.api.nvim_replace_termcodes("", true, false, true) +end) +if not success then + ESC_KEY = "\27" +end + +--- Exit visual mode properly and schedule command execution +--- @param callback function The function to call after exiting visual mode +--- @param ... any Arguments to pass to the callback +function M.exit_visual_and_schedule(callback, ...) + local args = { ... } + + -- Capture visual selection data BEFORE exiting visual mode + local visual_data = M.capture_visual_selection_data() + + pcall(function() + vim.api.nvim_feedkeys(ESC_KEY, "i", true) + end) + + -- Schedule execution until after mode change (neo-tree pattern) + local schedule_fn = vim.schedule or function(fn) + fn() + end -- Fallback for test environments + schedule_fn(function() + -- Pass the captured visual data as the first argument + callback(visual_data, unpack(args)) + end) +end + +--- Validate that we're currently in a visual mode +--- @return boolean true if in visual mode, false otherwise +--- @return string|nil error message if not in visual mode +function M.validate_visual_mode() + local current_mode = "n" -- Default fallback + + -- Use pcall to handle test environments + local mode_success = pcall(function() + current_mode = vim.api.nvim_get_mode().mode + end) + + if not mode_success then + return false, "Cannot determine current mode (test environment)" + end + + local is_visual = current_mode == "v" or current_mode == "V" or current_mode == "\022" + + -- Additional debugging: check visual marks and cursor position + if is_visual then + pcall(function() + vim.api.nvim_win_get_cursor(0) + vim.fn.getpos("'<") + vim.fn.getpos("'>") + vim.fn.getpos("v") + end) + end + + if not is_visual then + return false, "Not in visual mode (current mode: " .. current_mode .. ")" + end + + return true, nil +end + +--- Get visual selection range using vim marks or current cursor position +--- @return number, number start_line, end_line (1-indexed) +function M.get_visual_range() + local start_pos, end_pos = 1, 1 -- Default fallback + + -- Use pcall to handle test environments + local range_success = pcall(function() + -- Check if we're currently in visual mode + local current_mode = vim.api.nvim_get_mode().mode + local is_visual = current_mode == "v" or current_mode == "V" or current_mode == "\022" + + if is_visual then + -- In visual mode, ALWAYS use cursor + anchor (marks are stale until exit) + local cursor_pos = vim.api.nvim_win_get_cursor(0)[1] + local anchor_pos = vim.fn.getpos("v")[2] + + if anchor_pos > 0 then + start_pos = math.min(cursor_pos, anchor_pos) + end_pos = math.max(cursor_pos, anchor_pos) + else + -- Fallback: just use current cursor position + start_pos = cursor_pos + end_pos = cursor_pos + end + else + -- Not in visual mode, try to use the marks (they should be valid now) + local mark_start = vim.fn.getpos("'<")[2] + local mark_end = vim.fn.getpos("'>")[2] + + if mark_start > 0 and mark_end > 0 then + start_pos = mark_start + end_pos = mark_end + else + -- No valid marks, use cursor position + local cursor_pos = vim.api.nvim_win_get_cursor(0)[1] + start_pos = cursor_pos + end_pos = cursor_pos + end + end + end) + + if not range_success then + return 1, 1 + end + + if end_pos < start_pos then + start_pos, end_pos = end_pos, start_pos + end + + -- Ensure we have valid line numbers (at least 1) + start_pos = math.max(1, start_pos) + end_pos = math.max(1, end_pos) + + return start_pos, end_pos +end + +--- Check if we're in a tree buffer and get the tree state +--- @return table|nil, string|nil tree_state, tree_type ("neo-tree" or "nvim-tree") +function M.get_tree_state() + local current_ft = "" -- Default fallback + local current_win = 0 -- Default fallback + + -- Use pcall to handle test environments + local state_success = pcall(function() + current_ft = vim.bo.filetype or "" + current_win = vim.api.nvim_get_current_win() + end) + + if not state_success then + return nil, nil + end + + if current_ft == "neo-tree" then + local manager_success, manager = pcall(require, "neo-tree.sources.manager") + if not manager_success then + return nil, nil + end + + local state = manager.get_state("filesystem") + if not state then + return nil, nil + end + + -- Validate we're in the correct neo-tree window + if state.winid and state.winid == current_win then + return state, "neo-tree" + else + return nil, nil + end + elseif current_ft == "NvimTree" then + local api_success, nvim_tree_api = pcall(require, "nvim-tree.api") + if not api_success then + return nil, nil + end + + return nvim_tree_api, "nvim-tree" + else + return nil, nil + end +end + +--- Create a visual command wrapper that follows neo-tree patterns +--- @param normal_handler function The normal command handler +--- @param visual_handler function The visual command handler +--- @return function The wrapped command function +function M.create_visual_command_wrapper(normal_handler, visual_handler) + return function(...) + local current_mode = vim.api.nvim_get_mode().mode + + if current_mode == "v" or current_mode == "V" or current_mode == "\022" then + -- Use the neo-tree pattern: exit visual mode, then schedule execution + M.exit_visual_and_schedule(visual_handler, ...) + else + normal_handler(...) + end + end +end + +--- Capture visual selection data while still in visual mode +--- @return table|nil visual_data Captured data or nil if not in visual mode +function M.capture_visual_selection_data() + local valid = M.validate_visual_mode() + if not valid then + return nil + end + + local tree_state, tree_type = M.get_tree_state() + if not tree_state then + return nil + end + + local start_pos, end_pos = M.get_visual_range() + + -- Validate that we have a meaningful range + if start_pos == 0 or end_pos == 0 then + return nil + end + + return { + tree_state = tree_state, + tree_type = tree_type, + start_pos = start_pos, + end_pos = end_pos, + } +end + +--- Extract files from visual selection in tree buffers +--- @param visual_data table|nil Pre-captured visual selection data +--- @return table files List of file paths +--- @return string|nil error Error message if failed +function M.get_files_from_visual_selection(visual_data) + -- If we have pre-captured data, use it; otherwise try to get current data + local tree_state, tree_type, start_pos, end_pos + + if visual_data then + tree_state = visual_data.tree_state + tree_type = visual_data.tree_type + start_pos = visual_data.start_pos + end_pos = visual_data.end_pos + else + local valid, err = M.validate_visual_mode() + if not valid then + return {}, err + end + + tree_state, tree_type = M.get_tree_state() + if not tree_state then + return {}, "Not in a supported tree buffer" + end + + start_pos, end_pos = M.get_visual_range() + end + + if not tree_state then + return {}, "Not in a supported tree buffer" + end + + local files = {} + + if tree_type == "neo-tree" then + local selected_nodes = {} + for line = start_pos, end_pos do + -- Neo-tree's tree:get_node() uses the line number directly (1-based) + local node = tree_state.tree:get_node(line) + if node then + if node.type and node.type ~= "message" then + table.insert(selected_nodes, node) + end + end + end + + for _, node in ipairs(selected_nodes) do + if node.type == "file" and node.path and node.path ~= "" then + local depth = (node.get_depth and node:get_depth()) or 0 + if depth > 1 then + table.insert(files, node.path) + end + elseif node.type == "directory" and node.path and node.path ~= "" then + local depth = (node.get_depth and node:get_depth()) or 0 + if depth > 1 then + table.insert(files, node.path) + end + end + end + elseif tree_type == "nvim-tree" then + -- For nvim-tree, we need to manually map visual lines to tree nodes + -- since nvim-tree doesn't have direct line-to-node mapping like neo-tree + require("claudecode.logger").debug( + "visual_commands", + "Processing nvim-tree visual selection from line", + start_pos, + "to", + end_pos + ) + + local nvim_tree_api = tree_state + local current_buf = vim.api.nvim_get_current_buf() + + -- Get all lines in the visual selection + local lines = vim.api.nvim_buf_get_lines(current_buf, start_pos - 1, end_pos, false) + + require("claudecode.logger").debug("visual_commands", "Found", #lines, "lines in visual selection") + + -- For each line in the visual selection, try to get the corresponding node + for i, line_content in ipairs(lines) do + local line_num = start_pos + i - 1 + + -- Set cursor to this line to get the node + pcall(vim.api.nvim_win_set_cursor, 0, { line_num, 0 }) + + -- Get node under cursor for this line + local node_success, node = pcall(nvim_tree_api.tree.get_node_under_cursor) + if node_success and node then + require("claudecode.logger").debug( + "visual_commands", + "Line", + line_num, + "node type:", + node.type, + "path:", + node.absolute_path + ) + + if node.type == "file" and node.absolute_path and node.absolute_path ~= "" then + -- Check if it's not a root-level file (basic protection) + if not string.match(node.absolute_path, "^/[^/]*$") then + table.insert(files, node.absolute_path) + end + elseif node.type == "directory" and node.absolute_path and node.absolute_path ~= "" then + table.insert(files, node.absolute_path) + end + else + require("claudecode.logger").debug("visual_commands", "No valid node found for line", line_num) + end + end + + require("claudecode.logger").debug("visual_commands", "Extracted", #files, "files from nvim-tree visual selection") + + -- Remove duplicates while preserving order + local seen = {} + local unique_files = {} + for _, file_path in ipairs(files) do + if not seen[file_path] then + seen[file_path] = true + table.insert(unique_files, file_path) + end + end + files = unique_files + end + + return files, nil +end + +return M diff --git a/tests/unit/at_mention_spec.lua b/tests/unit/at_mention_spec.lua new file mode 100644 index 0000000..18a5fef --- /dev/null +++ b/tests/unit/at_mention_spec.lua @@ -0,0 +1,361 @@ +-- luacheck: globals expect +require("tests.busted_setup") + +describe("At Mention Functionality", function() + local init_module + local integrations + local mock_vim + + local function setup_mocks() + package.loaded["claudecode.init"] = nil + package.loaded["claudecode.integrations"] = nil + package.loaded["claudecode.logger"] = nil + package.loaded["claudecode.config"] = nil + + -- Mock logger + package.loaded["claudecode.logger"] = { + debug = function() end, + warn = function() end, + error = function() end, + } + + -- Mock config + package.loaded["claudecode.config"] = { + get = function() + return { + debounce_ms = 100, + visual_demotion_delay_ms = 50, + } + end, + } + + -- Extend the existing vim mock instead of replacing it + mock_vim = _G.vim or {} + + -- Add or override specific functions for this test + mock_vim.fn = mock_vim.fn or {} + mock_vim.fn.isdirectory = function(path) + if string.match(path, "/lua$") or string.match(path, "/tests$") or path == "/Users/test/project" then + return 1 + end + return 0 + end + mock_vim.fn.getcwd = function() + return "/Users/test/project" + end + mock_vim.fn.mode = function() + return "n" + end + + mock_vim.api = mock_vim.api or {} + mock_vim.api.nvim_get_current_win = function() + return 1002 + end + mock_vim.api.nvim_get_mode = function() + return { mode = "n" } + end + mock_vim.api.nvim_get_current_buf = function() + return 1 + end + + mock_vim.bo = { filetype = "neo-tree" } + mock_vim.schedule = function(fn) + fn() + end + + _G.vim = mock_vim + end + + before_each(function() + setup_mocks() + end) + + describe("file at mention from neo-tree", function() + before_each(function() + integrations = require("claudecode.integrations") + init_module = require("claudecode.init") + end) + + it("should format single file path correctly", function() + local mock_state = { + tree = { + get_node = function() + return { + type = "file", + path = "/Users/test/project/lua/init.lua", + } + end, + }, + } + + package.loaded["neo-tree.sources.manager"] = { + get_state = function() + return mock_state + end, + } + + local files, err = integrations._get_neotree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/lua/init.lua") + end) + + it("should format directory path with trailing slash", function() + local mock_state = { + tree = { + get_node = function() + return { + type = "directory", + path = "/Users/test/project/lua", + } + end, + }, + } + + package.loaded["neo-tree.sources.manager"] = { + get_state = function() + return mock_state + end, + } + + local files, err = integrations._get_neotree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/lua") + + local formatted_path = init_module._format_path_for_at_mention(files[1]) + expect(formatted_path).to_be("lua/") + end) + + it("should handle relative path conversion", function() + local file_path = "/Users/test/project/lua/config.lua" + local formatted_path = init_module._format_path_for_at_mention(file_path) + + expect(formatted_path).to_be("lua/config.lua") + end) + + it("should handle root project directory", function() + local dir_path = "/Users/test/project" + local formatted_path = init_module._format_path_for_at_mention(dir_path) + + expect(formatted_path).to_be("./") + end) + end) + + describe("file at mention from nvim-tree", function() + before_each(function() + integrations = require("claudecode.integrations") + init_module = require("claudecode.init") + end) + + it("should get selected file from nvim-tree", function() + package.loaded["nvim-tree.api"] = { + tree = { + get_node_under_cursor = function() + return { + type = "file", + absolute_path = "/Users/test/project/tests/test_spec.lua", + } + end, + }, + marks = { + list = function() + return {} + end, + }, + } + + mock_vim.bo.filetype = "NvimTree" + + local files, err = integrations._get_nvim_tree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/tests/test_spec.lua") + end) + + it("should get selected directory from nvim-tree", function() + package.loaded["nvim-tree.api"] = { + tree = { + get_node_under_cursor = function() + return { + type = "directory", + absolute_path = "/Users/test/project/tests", + } + end, + }, + marks = { + list = function() + return {} + end, + }, + } + + mock_vim.bo.filetype = "NvimTree" + + local files, err = integrations._get_nvim_tree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/tests") + + local formatted_path = init_module._format_path_for_at_mention(files[1]) + expect(formatted_path).to_be("tests/") + end) + + it("should handle multiple marked files in nvim-tree", function() + package.loaded["nvim-tree.api"] = { + tree = { + get_node_under_cursor = function() + return { + type = "file", + absolute_path = "/Users/test/project/init.lua", + } + end, + }, + marks = { + list = function() + return { + { type = "file", absolute_path = "/Users/test/project/config.lua" }, + { type = "file", absolute_path = "/Users/test/project/utils.lua" }, + } + end, + }, + } + + mock_vim.bo.filetype = "NvimTree" + + local files, err = integrations._get_nvim_tree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(2) + expect(files[1]).to_be("/Users/test/project/config.lua") + expect(files[2]).to_be("/Users/test/project/utils.lua") + end) + end) + + describe("at mention error handling", function() + before_each(function() + integrations = require("claudecode.integrations") + end) + + it("should handle unsupported buffer types", function() + mock_vim.bo.filetype = "text" + + local files, err = integrations.get_selected_files_from_tree() + + expect(files).to_be_nil() + expect(err).to_be_string() + assert_contains(err, "supported") + end) + + it("should handle neo-tree errors gracefully", function() + mock_vim.bo.filetype = "neo-tree" + + package.loaded["neo-tree.sources.manager"] = { + get_state = function() + error("Neo-tree not initialized") + end, + } + + local success, result_or_error = pcall(function() + return integrations._get_neotree_selection() + end) + expect(success).to_be_false() + expect(result_or_error).to_be_string() + assert_contains(result_or_error, "Neo-tree not initialized") + end) + + it("should handle nvim-tree errors gracefully", function() + mock_vim.bo.filetype = "NvimTree" + + package.loaded["nvim-tree.api"] = { + tree = { + get_node_under_cursor = function() + error("NvimTree not available") + end, + }, + marks = { + list = function() + return {} + end, + }, + } + + local success, result_or_error = pcall(function() + return integrations._get_nvim_tree_selection() + end) + expect(success).to_be_false() + expect(result_or_error).to_be_string() + assert_contains(result_or_error, "NvimTree not available") + end) + end) + + describe("integration with main module", function() + before_each(function() + integrations = require("claudecode.integrations") + init_module = require("claudecode.init") + end) + + it("should send files to Claude via at mention", function() + local sent_files = {} + + init_module._test_send_at_mention = function(files) + sent_files = files + end + local mock_state = { + tree = { + get_node = function() + return { + type = "file", + path = "/Users/test/project/src/main.lua", + } + end, + }, + } + + package.loaded["neo-tree.sources.manager"] = { + get_state = function() + return mock_state + end, + } + + local files, err = integrations.get_selected_files_from_tree() + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + if init_module._test_send_at_mention then + init_module._test_send_at_mention(files) + end + + expect(#sent_files).to_be(1) + expect(sent_files[1]).to_be("/Users/test/project/src/main.lua") + end) + + it("should handle mixed file and directory selection", function() + local mixed_files = { + "/Users/test/project/init.lua", + "/Users/test/project/lua", + "/Users/test/project/config.lua", + } + + local formatted_files = {} + for _, file_path in ipairs(mixed_files) do + local formatted_path = init_module._format_path_for_at_mention(file_path) + table.insert(formatted_files, formatted_path) + end + + expect(#formatted_files).to_be(3) + expect(formatted_files[1]).to_be("init.lua") + expect(formatted_files[2]).to_be("lua/") + expect(formatted_files[3]).to_be("config.lua") + end) + end) +end) diff --git a/tests/unit/claudecode_add_command_spec.lua b/tests/unit/claudecode_add_command_spec.lua new file mode 100644 index 0000000..c8d2dce --- /dev/null +++ b/tests/unit/claudecode_add_command_spec.lua @@ -0,0 +1,448 @@ +require("tests.busted_setup") +require("tests.mocks.vim") + +describe("ClaudeCodeAdd command", function() + local claudecode + local mock_server + local mock_logger + local saved_require = _G.require + + local function setup_mocks() + mock_server = { + broadcast = spy.new(function() + return true + end), + } + + mock_logger = { + setup = function() end, + debug = spy.new(function() end), + error = spy.new(function() end), + warn = spy.new(function() end), + } + + -- Override vim.fn functions for our specific tests + vim.fn.expand = spy.new(function(path) + if path == "~/test.lua" then + return "/home/user/test.lua" + elseif path == "./relative.lua" then + return "/current/dir/relative.lua" + end + return path + end) + + vim.fn.filereadable = spy.new(function(path) + if path == "/existing/file.lua" or path == "/home/user/test.lua" or path == "/current/dir/relative.lua" then + return 1 + end + return 0 + end) + + vim.fn.isdirectory = spy.new(function(path) + if path == "/existing/dir" then + return 1 + end + return 0 + end) + + vim.fn.getcwd = function() + return "/current/dir" + end + + vim.api.nvim_create_user_command = spy.new(function() end) + vim.api.nvim_buf_get_name = function() + return "test.lua" + end + + vim.bo = { filetype = "lua" } + vim.notify = spy.new(function() end) + + _G.require = function(mod) + if mod == "claudecode.logger" then + return mock_logger + elseif mod == "claudecode.config" then + return { + apply = function(opts) + return opts or {} + end, + } + elseif mod == "claudecode.diff" then + return { + setup = function() end, + } + elseif mod == "claudecode.terminal" then + return { + setup = function() end, + } + elseif mod == "claudecode.visual_commands" then + return { + create_visual_command_wrapper = function(normal_handler, visual_handler) + return normal_handler + end, + } + else + return saved_require(mod) + end + end + end + + before_each(function() + setup_mocks() + + -- Clear package cache to ensure fresh require + package.loaded["claudecode"] = nil + package.loaded["claudecode.config"] = nil + package.loaded["claudecode.logger"] = nil + package.loaded["claudecode.diff"] = nil + package.loaded["claudecode.visual_commands"] = nil + package.loaded["claudecode.terminal"] = nil + + claudecode = require("claudecode") + + -- Set up the server state manually for testing + claudecode.state.server = mock_server + claudecode.state.port = 12345 + end) + + after_each(function() + _G.require = saved_require + package.loaded["claudecode"] = nil + end) + + describe("command registration", function() + it("should register ClaudeCodeAdd command during setup", function() + claudecode.setup({ auto_start = false }) + + -- Find the ClaudeCodeAdd command registration + local add_command_found = false + for _, call in ipairs(vim.api.nvim_create_user_command.calls) do + if call.vals[1] == "ClaudeCodeAdd" then + add_command_found = true + + local config = call.vals[3] + assert.is_equal("+", config.nargs) + assert.is_equal("file", config.complete) + assert.is_string(config.desc) + assert.is_true(string.find(config.desc, "line range") ~= nil, "Description should mention line range support") + break + end + end + + assert.is_true(add_command_found, "ClaudeCodeAdd command was not registered") + end) + end) + + describe("command execution", function() + local command_handler + + before_each(function() + claudecode.setup({ auto_start = false }) + + for _, call in ipairs(vim.api.nvim_create_user_command.calls) do + if call.vals[1] == "ClaudeCodeAdd" then + command_handler = call.vals[2] + break + end + end + + assert.is_function(command_handler, "Command handler should be a function") + end) + + describe("validation", function() + it("should error when server is not running", function() + claudecode.state.server = nil + + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_logger.error).was_called() + end) + + it("should error when no file path is provided", function() + command_handler({ args = "" }) + + assert.spy(mock_logger.error).was_called() + end) + + it("should error when file does not exist", function() + command_handler({ args = "/nonexistent/file.lua" }) + + assert.spy(mock_logger.error).was_called() + end) + end) + + describe("path handling", function() + it("should expand tilde paths", function() + command_handler({ args = "~/test.lua" }) + + assert.spy(vim.fn.expand).was_called_with("~/test.lua") + assert.spy(mock_server.broadcast).was_called() + end) + + it("should expand relative paths", function() + command_handler({ args = "./relative.lua" }) + + assert.spy(vim.fn.expand).was_called_with("./relative.lua") + assert.spy(mock_server.broadcast).was_called() + end) + + it("should handle absolute paths", function() + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_server.broadcast).was_called() + end) + end) + + describe("broadcasting", function() + it("should broadcast existing file successfully", function() + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = nil, + lineEnd = nil, + }) + assert.spy(mock_logger.debug).was_called() + end) + + it("should broadcast existing directory successfully", function() + command_handler({ args = "/existing/dir" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/dir/", + lineStart = nil, + lineEnd = nil, + }) + assert.spy(mock_logger.debug).was_called() + end) + + it("should handle broadcast failure", function() + mock_server.broadcast = spy.new(function() + return false + end) + + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_logger.error).was_called() + end) + end) + + describe("path formatting", function() + it("should handle file broadcasting correctly", function() + -- Set up a file that exists + vim.fn.filereadable = spy.new(function(path) + return path == "/current/dir/src/test.lua" and 1 or 0 + end) + + command_handler({ args = "/current/dir/src/test.lua" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", match.is_table()) + assert.spy(mock_logger.debug).was_called() + end) + + it("should add trailing slash for directories", function() + command_handler({ args = "/existing/dir" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/dir/", + lineStart = nil, + lineEnd = nil, + }) + end) + end) + + describe("line number conversion", function() + it("should convert 1-indexed user input to 0-indexed for Claude", function() + command_handler({ args = "/existing/file.lua 1 3" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 0, + lineEnd = 2, + }) + end) + end) + + describe("line range functionality", function() + describe("argument parsing", function() + it("should parse single file path correctly", function() + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = nil, + lineEnd = nil, + }) + end) + + it("should parse file path with start line", function() + command_handler({ args = "/existing/file.lua 50" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 49, + lineEnd = nil, + }) + end) + + it("should parse file path with start and end lines", function() + command_handler({ args = "/existing/file.lua 50 100" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 49, + lineEnd = 99, + }) + end) + end) + + describe("line number validation", function() + it("should error on invalid start line number", function() + command_handler({ args = "/existing/file.lua abc" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error on invalid end line number", function() + command_handler({ args = "/existing/file.lua 50 xyz" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error on negative start line", function() + command_handler({ args = "/existing/file.lua -5" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error on negative end line", function() + command_handler({ args = "/existing/file.lua 10 -20" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error on zero line numbers", function() + command_handler({ args = "/existing/file.lua 0 10" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error when start line > end line", function() + command_handler({ args = "/existing/file.lua 100 50" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + + it("should error on too many arguments", function() + command_handler({ args = "/existing/file.lua 10 20 30" }) + + assert.spy(mock_logger.error).was_called() + assert.spy(mock_server.broadcast).was_not_called() + end) + end) + + describe("directory handling with line numbers", function() + it("should ignore line numbers for directories and warn", function() + command_handler({ args = "/existing/dir 50 100" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/dir/", + lineStart = nil, + lineEnd = nil, + }) + assert.spy(mock_logger.debug).was_called() + end) + end) + + describe("valid line range scenarios", function() + it("should handle start line equal to end line", function() + command_handler({ args = "/existing/file.lua 50 50" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 49, + lineEnd = 49, -- 50 - 1 (converted to 0-indexed) + }) + end) + + it("should handle large line numbers", function() + command_handler({ args = "/existing/file.lua 1000 2000" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 999, + lineEnd = 1999, + }) + end) + + it("should handle single line specification", function() + command_handler({ args = "/existing/file.lua 42" }) + + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/existing/file.lua", + lineStart = 41, + lineEnd = nil, + }) + end) + end) + + describe("path expansion with line ranges", function() + it("should expand tilde paths with line numbers", function() + command_handler({ args = "~/test.lua 10 20" }) + + assert.spy(vim.fn.expand).was_called_with("~/test.lua") + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/home/user/test.lua", + lineStart = 9, + lineEnd = 19, + }) + end) + + it("should expand relative paths with line numbers", function() + command_handler({ args = "./relative.lua 5" }) + + assert.spy(vim.fn.expand).was_called_with("./relative.lua") + assert.spy(mock_server.broadcast).was_called_with("at_mentioned", { + filePath = "/current/dir/relative.lua", + lineStart = 4, + lineEnd = nil, + }) + end) + end) + end) + end) + + describe("integration with broadcast functions", function() + it("should use the extracted broadcast_at_mention function", function() + -- This test ensures that the command uses the centralized function + -- rather than duplicating broadcast logic + claudecode.setup({ auto_start = false }) + + local command_handler + for _, call in ipairs(vim.api.nvim_create_user_command.calls) do + if call.vals[1] == "ClaudeCodeAdd" then + command_handler = call.vals[2] + break + end + end + + -- Mock the _format_path_for_at_mention function to verify it's called + local original_format = claudecode._format_path_for_at_mention + claudecode._format_path_for_at_mention = spy.new(function(path) + return path, false + end) + + command_handler({ args = "/existing/file.lua" }) + + assert.spy(mock_server.broadcast).was_called() + + -- Restore original function + claudecode._format_path_for_at_mention = original_format + end) + end) +end) diff --git a/tests/unit/diff_mcp_spec.lua b/tests/unit/diff_mcp_spec.lua index 463fc42..93468b2 100644 --- a/tests/unit/diff_mcp_spec.lua +++ b/tests/unit/diff_mcp_spec.lua @@ -90,17 +90,30 @@ describe("MCP-compliant diff operations", function() assert.equal("text", result.content[2].type) end) - it("should error on non-existent old file", function() + it("should handle non-existent old file as new file", function() local non_existent_file = "/tmp/non_existent_file.txt" + + -- Set up mock resolution + _G.claude_deferred_responses = { + [tostring(coroutine.running())] = function() + -- Mock resolution + end, + } + local co = coroutine.create(function() diff.open_diff_blocking(non_existent_file, test_new_file, test_content_new, test_tab_name) end) - local success, err = coroutine.resume(co) - assert.is_false(success, "Should fail with non-existent file") - assert.is_table(err) - assert.equal(-32000, err.code) - assert_contains(err.message, "File access error") + local success = coroutine.resume(co) + assert.is_true(success, "Should handle new file scenario successfully") + + -- The coroutine should yield (waiting for user action) + assert.equal("suspended", coroutine.status(co)) + + -- Verify diff state was created for new file + local active_diffs = diff._get_active_diffs() + assert.is_table(active_diffs[test_tab_name]) + assert.is_true(active_diffs[test_tab_name].is_new_file) end) it("should replace existing diff with same tab_name", function() diff --git a/tests/unit/directory_at_mention_spec.lua b/tests/unit/directory_at_mention_spec.lua new file mode 100644 index 0000000..b2e7dd3 --- /dev/null +++ b/tests/unit/directory_at_mention_spec.lua @@ -0,0 +1,188 @@ +-- luacheck: globals expect +require("tests.busted_setup") + +describe("Directory At Mention Functionality", function() + local integrations + local visual_commands + local mock_vim + + local function setup_mocks() + package.loaded["claudecode.integrations"] = nil + package.loaded["claudecode.visual_commands"] = nil + package.loaded["claudecode.logger"] = nil + + -- Mock logger + package.loaded["claudecode.logger"] = { + debug = function() end, + warn = function() end, + error = function() end, + } + + mock_vim = { + fn = { + isdirectory = function(path) + if string.match(path, "/lua$") or string.match(path, "/tests$") or string.match(path, "src") then + return 1 + end + return 0 + end, + getcwd = function() + return "/Users/test/project" + end, + mode = function() + return "n" + end, + }, + api = { + nvim_get_current_win = function() + return 1002 + end, + nvim_get_mode = function() + return { mode = "n" } + end, + }, + bo = { filetype = "neo-tree" }, + } + + _G.vim = mock_vim + end + + before_each(function() + setup_mocks() + end) + + describe("directory handling in integrations", function() + before_each(function() + integrations = require("claudecode.integrations") + end) + + it("should return directory paths from neo-tree", function() + local mock_state = { + tree = { + get_node = function() + return { + type = "directory", + path = "/Users/test/project/lua", + } + end, + }, + } + + package.loaded["neo-tree.sources.manager"] = { + get_state = function() + return mock_state + end, + } + + local files, err = integrations._get_neotree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/lua") + end) + + it("should return directory paths from nvim-tree", function() + package.loaded["nvim-tree.api"] = { + tree = { + get_node_under_cursor = function() + return { + type = "directory", + absolute_path = "/Users/test/project/tests", + } + end, + }, + marks = { + list = function() + return {} + end, + }, + } + + mock_vim.bo.filetype = "NvimTree" + + local files, err = integrations._get_nvim_tree_selection() + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) + expect(files[1]).to_be("/Users/test/project/tests") + end) + end) + + describe("visual commands directory handling", function() + before_each(function() + visual_commands = require("claudecode.visual_commands") + end) + + it("should include directories in visual selections", function() + local visual_data = { + tree_state = { + tree = { + get_node = function(self, line) + if line == 1 then + return { + type = "file", + path = "/Users/test/project/init.lua", + get_depth = function() + return 2 + end, + } + elseif line == 2 then + return { + type = "directory", + path = "/Users/test/project/lua", + get_depth = function() + return 2 + end, + } + end + return nil + end, + }, + }, + tree_type = "neo-tree", + start_pos = 1, + end_pos = 2, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(2) + expect(files[1]).to_be("/Users/test/project/init.lua") + expect(files[2]).to_be("/Users/test/project/lua") + end) + + it("should respect depth protection for directories", function() + local visual_data = { + tree_state = { + tree = { + get_node = function(line) + if line == 1 then + return { + type = "directory", + path = "/Users/test/project", + get_depth = function() + return 1 + end, + } + end + return nil + end, + }, + }, + tree_type = "neo-tree", + start_pos = 1, + end_pos = 1, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(0) -- Root-level directory should be skipped + end) + end) +end) diff --git a/tests/unit/nvim_tree_visual_selection_spec.lua b/tests/unit/nvim_tree_visual_selection_spec.lua new file mode 100644 index 0000000..46a8db4 --- /dev/null +++ b/tests/unit/nvim_tree_visual_selection_spec.lua @@ -0,0 +1,237 @@ +-- luacheck: globals expect +require("tests.busted_setup") + +describe("NvimTree Visual Selection", function() + local visual_commands + local mock_vim + + local function setup_mocks() + package.loaded["claudecode.visual_commands"] = nil + package.loaded["claudecode.logger"] = nil + + -- Mock logger + package.loaded["claudecode.logger"] = { + debug = function() end, + warn = function() end, + error = function() end, + } + + mock_vim = { + fn = { + mode = function() + return "V" -- Visual line mode + end, + getpos = function(mark) + if mark == "'<" then + return { 0, 2, 0, 0 } -- Start at line 2 + elseif mark == "'>" then + return { 0, 4, 0, 0 } -- End at line 4 + elseif mark == "v" then + return { 0, 2, 0, 0 } -- Anchor at line 2 + end + return { 0, 0, 0, 0 } + end, + }, + api = { + nvim_get_current_win = function() + return 1002 + end, + nvim_get_mode = function() + return { mode = "V" } + end, + nvim_get_current_buf = function() + return 1 + end, + nvim_win_get_cursor = function() + return { 4, 0 } -- Cursor at line 4 + end, + nvim_buf_get_lines = function(buf, start, end_line, strict) + -- Return mock buffer lines for the visual selection + return { + " 📁 src/", + " 📄 init.lua", + " 📄 config.lua", + } + end, + nvim_win_set_cursor = function(win, pos) + -- Mock cursor setting + end, + nvim_replace_termcodes = function(keys, from_part, do_lt, special) + return keys + end, + }, + bo = { filetype = "NvimTree" }, + schedule = function(fn) + fn() + end, + } + + _G.vim = mock_vim + end + + before_each(function() + setup_mocks() + end) + + describe("nvim-tree visual selection handling", function() + before_each(function() + visual_commands = require("claudecode.visual_commands") + end) + + it("should extract files from visual selection in nvim-tree", function() + -- Create a stateful mock that tracks cursor position + local cursor_positions = {} + local expected_nodes = { + [2] = { type = "directory", absolute_path = "/Users/test/project/src" }, + [3] = { type = "file", absolute_path = "/Users/test/project/init.lua" }, + [4] = { type = "file", absolute_path = "/Users/test/project/config.lua" }, + } + + mock_vim.api.nvim_win_set_cursor = function(win, pos) + cursor_positions[#cursor_positions + 1] = pos[1] + end + + local mock_nvim_tree_api = { + tree = { + get_node_under_cursor = function() + local current_line = cursor_positions[#cursor_positions] or 2 + return expected_nodes[current_line] + end, + }, + } + + local visual_data = { + tree_state = mock_nvim_tree_api, + tree_type = "nvim-tree", + start_pos = 2, + end_pos = 4, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(3) + expect(files[1]).to_be("/Users/test/project/src") + expect(files[2]).to_be("/Users/test/project/init.lua") + expect(files[3]).to_be("/Users/test/project/config.lua") + end) + + it("should handle empty visual selection in nvim-tree", function() + local mock_nvim_tree_api = { + tree = { + get_node_under_cursor = function() + return nil -- No node found + end, + }, + } + + local visual_data = { + tree_state = mock_nvim_tree_api, + tree_type = "nvim-tree", + start_pos = 2, + end_pos = 2, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(0) + end) + + it("should filter out root-level files in nvim-tree", function() + local mock_nvim_tree_api = { + tree = { + get_node_under_cursor = function() + return { + type = "file", + absolute_path = "/root_file.txt", -- Root-level file should be filtered + } + end, + }, + } + + local visual_data = { + tree_state = mock_nvim_tree_api, + tree_type = "nvim-tree", + start_pos = 1, + end_pos = 1, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(0) -- Root-level file should be filtered out + end) + + it("should remove duplicate files in visual selection", function() + local call_count = 0 + local mock_nvim_tree_api = { + tree = { + get_node_under_cursor = function() + call_count = call_count + 1 + -- Return the same file path twice to test deduplication + return { + type = "file", + absolute_path = "/Users/test/project/duplicate.lua", + } + end, + }, + } + + local visual_data = { + tree_state = mock_nvim_tree_api, + tree_type = "nvim-tree", + start_pos = 1, + end_pos = 2, -- Two lines, same file + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(1) -- Should have only one instance + expect(files[1]).to_be("/Users/test/project/duplicate.lua") + end) + + it("should handle mixed file and directory selection", function() + local cursor_positions = {} + local expected_nodes = { + [1] = { type = "directory", absolute_path = "/Users/test/project/lib" }, + [2] = { type = "file", absolute_path = "/Users/test/project/main.lua" }, + [3] = { type = "directory", absolute_path = "/Users/test/project/tests" }, + } + + mock_vim.api.nvim_win_set_cursor = function(win, pos) + cursor_positions[#cursor_positions + 1] = pos[1] + end + + local mock_nvim_tree_api = { + tree = { + get_node_under_cursor = function() + local current_line = cursor_positions[#cursor_positions] or 1 + return expected_nodes[current_line] + end, + }, + } + + local visual_data = { + tree_state = mock_nvim_tree_api, + tree_type = "nvim-tree", + start_pos = 1, + end_pos = 3, + } + + local files, err = visual_commands.get_files_from_visual_selection(visual_data) + + expect(err).to_be_nil() + expect(files).to_be_table() + expect(#files).to_be(3) + expect(files[1]).to_be("/Users/test/project/lib") + expect(files[2]).to_be("/Users/test/project/main.lua") + expect(files[3]).to_be("/Users/test/project/tests") + end) + end) +end) diff --git a/tests/unit/tools/open_diff_mcp_spec.lua b/tests/unit/tools/open_diff_mcp_spec.lua index 4073b0c..048e2a6 100644 --- a/tests/unit/tools/open_diff_mcp_spec.lua +++ b/tests/unit/tools/open_diff_mcp_spec.lua @@ -202,7 +202,7 @@ describe("openDiff tool MCP compliance", function() end) describe("error handling", function() - it("should handle file access errors", function() + it("should handle new files successfully", function() local params = { old_file_path = "/tmp/non_existent_file.txt", new_file_path = test_new_file, @@ -210,15 +210,22 @@ describe("openDiff tool MCP compliance", function() tab_name = test_tab_name, } + -- Set up mock resolution to avoid hanging + _G.claude_deferred_responses = { + [tostring(coroutine.running())] = function(result) + -- Mock resolution + end, + } + local co = coroutine.create(function() open_diff_tool.handler(params) end) - local success, err = coroutine.resume(co) - assert.is_false(success) - assert.is_table(err) - assert.equal(-32000, err.code) - assert_contains(err.data, "Cannot open file") + local success = coroutine.resume(co) + assert.is_true(success, "Should handle new file scenario successfully") + + -- The coroutine should yield (waiting for user action) + assert.equal("suspended", coroutine.status(co)) end) it("should handle diff module loading errors", function()