From 5b10521e13422fd2e1d822a728e32d86675c3af8 Mon Sep 17 00:00:00 2001 From: Thomas Kosiewski Date: Tue, 29 Jul 2025 15:57:13 +0200 Subject: [PATCH] feat: add support for custom terminal providers Change-Id: I2f559e355aa6036ca94b8aca13d53739c6b5e021 Signed-off-by: Thomas Kosiewski --- README.md | 119 +++++++- lua/claudecode/init.lua | 2 +- lua/claudecode/terminal.lua | 96 ++++++- tests/unit/terminal_spec.lua | 523 ++++++++++++++++++++++++++++++++++- 4 files changed, 725 insertions(+), 15 deletions(-) diff --git a/README.md b/README.md index b698beb..d679224 100644 --- a/README.md +++ b/README.md @@ -130,9 +130,6 @@ For deep technical details, see [ARCHITECTURE.md](./ARCHITECTURE.md). ## Advanced Configuration -
-Complete configuration options - ```lua { "coder/claudecode.nvim", @@ -152,7 +149,7 @@ For deep technical details, see [ARCHITECTURE.md](./ARCHITECTURE.md). terminal = { split_side = "right", -- "left" or "right" split_width_percentage = 0.30, - provider = "auto", -- "auto", "snacks", or "native" + provider = "auto", -- "auto", "snacks", "native", or custom provider table auto_close = true, snacks_win_opts = {}, -- Opts to pass to `Snacks.terminal.open()` }, @@ -170,7 +167,119 @@ For deep technical details, see [ARCHITECTURE.md](./ARCHITECTURE.md). } ``` -
+## Custom Terminal Providers + +You can create custom terminal providers by passing a table with the required functions instead of a string provider name: + +```lua +require("claudecode").setup({ + terminal = { + provider = { + -- Required functions + setup = function(config) + -- Initialize your terminal provider + end, + + open = function(cmd_string, env_table, effective_config, focus) + -- Open terminal with command and environment + -- focus parameter controls whether to focus terminal (defaults to true) + end, + + close = function() + -- Close the terminal + end, + + simple_toggle = function(cmd_string, env_table, effective_config) + -- Simple show/hide toggle + end, + + focus_toggle = function(cmd_string, env_table, effective_config) + -- Smart toggle: focus terminal if not focused, hide if focused + end, + + get_active_bufnr = function() + -- Return terminal buffer number or nil + return 123 -- example + end, + + is_available = function() + -- Return true if provider can be used + return true + end, + + -- Optional functions (auto-generated if not provided) + toggle = function(cmd_string, env_table, effective_config) + -- Defaults to calling simple_toggle for backward compatibility + end, + + _get_terminal_for_test = function() + -- For testing only, defaults to return nil + return nil + end, + }, + }, +}) +``` + +### Custom Provider Example + +Here's a complete example using a hypothetical `my_terminal` plugin: + +```lua +local my_terminal_provider = { + setup = function(config) + -- Store config for later use + self.config = config + end, + + open = function(cmd_string, env_table, effective_config, focus) + if focus == nil then focus = true end + + local my_terminal = require("my_terminal") + my_terminal.open({ + cmd = cmd_string, + env = env_table, + width = effective_config.split_width_percentage, + side = effective_config.split_side, + focus = focus, + }) + end, + + close = function() + require("my_terminal").close() + end, + + simple_toggle = function(cmd_string, env_table, effective_config) + require("my_terminal").toggle() + end, + + focus_toggle = function(cmd_string, env_table, effective_config) + local my_terminal = require("my_terminal") + if my_terminal.is_focused() then + my_terminal.hide() + else + my_terminal.focus() + end + end, + + get_active_bufnr = function() + return require("my_terminal").get_bufnr() + end, + + is_available = function() + local ok, _ = pcall(require, "my_terminal") + return ok + end, +} + +require("claudecode").setup({ + terminal = { + provider = my_terminal_provider, + }, +}) +``` + +The custom provider will automatically fall back to the native provider if validation fails or `is_available()` returns false. ## Troubleshooting diff --git a/lua/claudecode/init.lua b/lua/claudecode/init.lua index dcaf16f..875685a 100644 --- a/lua/claudecode/init.lua +++ b/lua/claudecode/init.lua @@ -88,7 +88,7 @@ M.state = { ---@alias ClaudeCode.TerminalOpts { \ --- split_side?: "left"|"right", \ --- split_width_percentage?: number, \ ---- provider?: "auto"|"snacks"|"native", \ +--- provider?: "auto"|"snacks"|"native"|table, \ --- show_native_term_exit_tip?: boolean, \ --- snacks_win_opts?: table } --- diff --git a/lua/claudecode/terminal.lua b/lua/claudecode/terminal.lua index e0db5ac..4d4ee94 100644 --- a/lua/claudecode/terminal.lua +++ b/lua/claudecode/terminal.lua @@ -46,13 +46,88 @@ local function load_provider(provider_name) return providers[provider_name] end +--- Validates and enhances a custom table provider with smart defaults +--- @param provider table The custom provider table to validate +--- @return TerminalProvider|nil provider The enhanced provider, or nil if invalid +--- @return string|nil error Error message if validation failed +local function validate_and_enhance_provider(provider) + if type(provider) ~= "table" then + return nil, "Custom provider must be a table" + end + + -- Required functions that must be implemented + local required_functions = { + "setup", + "open", + "close", + "simple_toggle", + "focus_toggle", + "get_active_bufnr", + "is_available", + } + + -- Validate all required functions exist and are callable + for _, func_name in ipairs(required_functions) do + local func = provider[func_name] + if not func then + return nil, "Custom provider missing required function: " .. func_name + end + -- Check if it's callable (function or table with __call metamethod) + local is_callable = type(func) == "function" + or (type(func) == "table" and getmetatable(func) and getmetatable(func).__call) + if not is_callable then + return nil, "Custom provider field '" .. func_name .. "' must be callable, got: " .. type(func) + end + end + + -- Create enhanced provider with defaults for optional functions + -- Note: Don't deep copy to preserve spy functions in tests + local enhanced_provider = provider + + -- Add default toggle function if not provided (calls simple_toggle for backward compatibility) + if not enhanced_provider.toggle then + enhanced_provider.toggle = function(cmd_string, env_table, effective_config) + return enhanced_provider.simple_toggle(cmd_string, env_table, effective_config) + end + end + + -- Add default test function if not provided + if not enhanced_provider._get_terminal_for_test then + enhanced_provider._get_terminal_for_test = function() + return nil + end + end + + return enhanced_provider, nil +end + --- Gets the effective terminal provider, guaranteed to return a valid provider --- Falls back to native provider if configured provider is unavailable --- @return TerminalProvider provider The terminal provider module (never nil) local function get_provider() local logger = require("claudecode.logger") - if config.provider == "auto" then + -- Handle custom table provider + if type(config.provider) == "table" then + local enhanced_provider, error_msg = validate_and_enhance_provider(config.provider) + if enhanced_provider then + -- Check if custom provider is available + local is_available_ok, is_available = pcall(enhanced_provider.is_available) + if is_available_ok and is_available then + logger.debug("terminal", "Using custom table provider") + return enhanced_provider + else + local availability_msg = is_available_ok and "provider reports not available" or "error checking availability" + logger.warn( + "terminal", + "Custom table provider configured but " .. availability_msg .. ". Falling back to 'native'." + ) + end + else + logger.warn("terminal", "Invalid custom table provider: " .. error_msg .. ". Falling back to 'native'.") + end + -- Fall through to native provider + elseif config.provider == "auto" then -- Try snacks first, then fallback to native silently local snacks_provider = load_provider("snacks") if snacks_provider and snacks_provider.is_available() then @@ -69,8 +144,13 @@ local function get_provider() elseif config.provider == "native" then -- noop, will use native provider as default below logger.debug("terminal", "Using native terminal provider") - else + elseif type(config.provider) == "string" then logger.warn("terminal", "Invalid provider configured: " .. tostring(config.provider) .. ". Defaulting to 'native'.") + else + logger.warn( + "terminal", + "Invalid provider type: " .. type(config.provider) .. ". Must be string or table. Defaulting to 'native'." + ) end local native_provider = load_provider("native") @@ -188,7 +268,7 @@ end -- @param user_term_config table (optional) Configuration options for the terminal. -- @field user_term_config.split_side string 'left' or 'right' (default: 'right'). -- @field user_term_config.split_width_percentage number Percentage of screen width (0.0 to 1.0, default: 0.30). --- @field user_term_config.provider string 'snacks' or 'native' (default: 'snacks'). +-- @field user_term_config.provider string|table 'auto', 'snacks', 'native', or custom provider table (default: 'auto'). -- @field user_term_config.show_native_term_exit_tip boolean Show tip for exiting native terminal (default: true). -- @field user_term_config.snacks_win_opts table Opts to pass to `Snacks.terminal.open()` (default: {}). -- @param p_terminal_cmd string|nil The command to run in the terminal (from main config). @@ -227,7 +307,7 @@ function M.setup(user_term_config, p_terminal_cmd, p_env) config[k] = v elseif k == "split_width_percentage" and type(v) == "number" and v > 0 and v < 1 then config[k] = v - elseif k == "provider" and (v == "snacks" or v == "native") then + elseif k == "provider" and (v == "snacks" or v == "native" or v == "auto" or type(v) == "table") then config[k] = v elseif k == "show_native_term_exit_tip" and type(v) == "boolean" then config[k] = v @@ -314,11 +394,11 @@ end --- Gets the managed terminal instance for testing purposes. -- NOTE: This function is intended for use in tests to inspect internal state. -- The underscore prefix indicates it's not part of the public API for regular use. --- @return snacks.terminal|nil The managed Snacks terminal instance, or nil. +-- @return table|nil The managed terminal instance, or nil. function M._get_managed_terminal_for_test() - local snacks_provider = load_provider("snacks") - if snacks_provider and snacks_provider._get_terminal_for_test then - return snacks_provider._get_terminal_for_test() + local provider = get_provider() + if provider and provider._get_terminal_for_test then + return provider._get_terminal_for_test() end return nil end diff --git a/tests/unit/terminal_spec.lua b/tests/unit/terminal_spec.lua index f0169d1..cd61b70 100644 --- a/tests/unit/terminal_spec.lua +++ b/tests/unit/terminal_spec.lua @@ -228,6 +228,18 @@ describe("claudecode.terminal (wrapper for Snacks.nvim)", function() package.loaded["claudecode.server.init"] = nil package.loaded["snacks"] = nil package.loaded["claudecode.config"] = nil + package.loaded["claudecode.logger"] = nil + + -- Mock logger + package.loaded["claudecode.logger"] = { + debug = function() end, + warn = function(context, message) + vim.notify(message, vim.log.levels.WARN) + end, + error = function(context, message) + vim.notify(message, vim.log.levels.ERROR) + end, + } -- Mock the server module local mock_server_module = { @@ -350,7 +362,7 @@ describe("claudecode.terminal (wrapper for Snacks.nvim)", function() vim.notify = spy.new(function(_msg, _level) end) terminal_wrapper = require("claudecode.terminal") - terminal_wrapper.setup({}) + -- Don't call setup({}) here to allow custom provider tests to work end) after_each(function() @@ -360,6 +372,7 @@ describe("claudecode.terminal (wrapper for Snacks.nvim)", function() package.loaded["claudecode.server.init"] = nil package.loaded["snacks"] = nil package.loaded["claudecode.config"] = nil + package.loaded["claudecode.logger"] = nil if _G.vim and _G.vim._mock and _G.vim._mock.reset then _G.vim._mock.reset() end @@ -700,4 +713,512 @@ describe("claudecode.terminal (wrapper for Snacks.nvim)", function() assert.are.equal("claude", toggle_cmd) end) end) + + describe("custom table provider functionality", function() + describe("valid custom provider", function() + it("should call setup method during terminal wrapper setup", function() + local setup_spy = spy.new(function() end) + local custom_provider = { + setup = setup_spy, + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return true + end), + } + + terminal_wrapper.setup({ provider = custom_provider }) + + setup_spy:was_called(1) + setup_spy:was_called_with(spy.matching.is_type("table")) + end) + + it("should check is_available during open operation", function() + local is_available_spy = spy.new(function() + return true + end) + local open_spy = spy.new(function() end) + local custom_provider = { + setup = spy.new(function() end), + open = open_spy, + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = is_available_spy, + } + + terminal_wrapper.setup({ provider = custom_provider }) + terminal_wrapper.open() + + is_available_spy:was_called() + open_spy:was_called() + end) + + it("should auto-generate toggle function if missing", function() + local simple_toggle_spy = spy.new(function() end) + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = simple_toggle_spy, + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return true + end), + -- Note: toggle function is intentionally missing + } + + terminal_wrapper.setup({ provider = custom_provider }) + + -- Verify that toggle function was auto-generated and calls simple_toggle + assert.is_function(custom_provider.toggle) + local test_env = {} + local test_config = {} + custom_provider.toggle("test_cmd", test_env, test_config) + simple_toggle_spy:was_called(1) + -- Check that the first argument (command string) is correct + local call_args = simple_toggle_spy:get_call(1).refs + assert.are.equal("test_cmd", call_args[1]) + assert.are.equal(3, #call_args) -- Should have 3 arguments + end) + + it("should auto-generate _get_terminal_for_test function if missing", function() + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return true + end), + -- Note: _get_terminal_for_test function is intentionally missing + } + + terminal_wrapper.setup({ provider = custom_provider }) + + -- Verify that _get_terminal_for_test function was auto-generated + assert.is_function(custom_provider._get_terminal_for_test) + assert.is_nil(custom_provider._get_terminal_for_test()) + end) + + it("should pass correct parameters to custom provider functions", function() + local open_spy = spy.new(function() end) + local simple_toggle_spy = spy.new(function() end) + local focus_toggle_spy = spy.new(function() end) + + local custom_provider = { + setup = spy.new(function() end), + open = open_spy, + close = spy.new(function() end), + simple_toggle = simple_toggle_spy, + focus_toggle = focus_toggle_spy, + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return true + end), + } + + terminal_wrapper.setup({ provider = custom_provider }) + + -- Test open with parameters + terminal_wrapper.open({ split_side = "left" }, "test_args") + open_spy:was_called() + local open_call = open_spy:get_call(1) + assert.is_string(open_call.refs[1]) -- cmd_string + assert.is_table(open_call.refs[2]) -- env_table + assert.is_table(open_call.refs[3]) -- effective_config + + -- Test simple_toggle with parameters + terminal_wrapper.simple_toggle({ split_width_percentage = 0.4 }, "toggle_args") + simple_toggle_spy:was_called() + local toggle_call = simple_toggle_spy:get_call(1) + assert.is_string(toggle_call.refs[1]) -- cmd_string + assert.is_table(toggle_call.refs[2]) -- env_table + assert.is_table(toggle_call.refs[3]) -- effective_config + end) + end) + + describe("fallback behavior", function() + it("should fallback to native provider when is_available returns false", function() + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return false + end), -- Returns false + } + + terminal_wrapper.setup({ provider = custom_provider }) + terminal_wrapper.open() + + -- Should use native provider instead + mock_native_provider.open:was_called() + custom_provider.open:was_not_called() + end) + + it("should fallback to native provider when is_available throws error", function() + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + error("Availability check failed") + end), + } + + terminal_wrapper.setup({ provider = custom_provider }) + terminal_wrapper.open() + + -- Should use native provider instead + mock_native_provider.open:was_called() + custom_provider.open:was_not_called() + end) + end) + + describe("invalid provider rejection", function() + it("should reject non-table providers", function() + -- Make snacks provider unavailable to force fallback to native + mock_snacks_provider.is_available = spy.new(function() + return false + end) + mock_native_provider.open:reset() -- Reset the spy before the test + + terminal_wrapper.setup({ provider = "invalid_string" }) + + -- Check that vim.notify was called with the expected warning about invalid value + local notify_calls = vim.notify.calls + local found_warning = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("Invalid value for provider.*invalid_string") then + found_warning = true + break + end + end + assert.is_true(found_warning, "Expected warning about invalid provider value") + + terminal_wrapper.open() + + -- Should fallback to native provider (since snacks is unavailable and invalid string was rejected) + mock_native_provider.open:was_called() + end) + + it("should reject providers missing required functions", function() + local incomplete_provider = { + setup = function() end, + open = function() end, + -- Missing other required functions + } + + terminal_wrapper.setup({ provider = incomplete_provider }) + terminal_wrapper.open() + + -- Should fallback to native provider + mock_native_provider.open:was_called() + end) + + it("should reject providers with non-function required fields", function() + local invalid_provider = { + setup = function() end, + open = "not_a_function", -- Invalid type + close = function() end, + simple_toggle = function() end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + return true + end, + } + + terminal_wrapper.setup({ provider = invalid_provider }) + terminal_wrapper.open() + + -- Should fallback to native provider + mock_native_provider.open:was_called() + end) + end) + + describe("wrapper function invocations", function() + it("should properly invoke all wrapper functions with custom provider", function() + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = spy.new(function() end), + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 456 + end), + is_available = spy.new(function() + return true + end), + } + + terminal_wrapper.setup({ provider = custom_provider }) + + -- Test all wrapper functions + terminal_wrapper.open() + custom_provider.open:was_called() + + terminal_wrapper.close() + custom_provider.close:was_called() + + terminal_wrapper.simple_toggle() + custom_provider.simple_toggle:was_called() + + terminal_wrapper.focus_toggle() + custom_provider.focus_toggle:was_called() + + local bufnr = terminal_wrapper.get_active_terminal_bufnr() + custom_provider.get_active_bufnr:was_called() + assert.are.equal(456, bufnr) + end) + + it("should handle toggle function (legacy) correctly", function() + local simple_toggle_spy = spy.new(function() end) + local custom_provider = { + setup = spy.new(function() end), + open = spy.new(function() end), + close = spy.new(function() end), + simple_toggle = simple_toggle_spy, + focus_toggle = spy.new(function() end), + get_active_bufnr = spy.new(function() + return 123 + end), + is_available = spy.new(function() + return true + end), + } + + terminal_wrapper.setup({ provider = custom_provider }) + + -- Legacy toggle should call simple_toggle + terminal_wrapper.toggle() + simple_toggle_spy:was_called() + end) + end) + end) + + describe("custom provider validation", function() + it("should reject provider missing required functions", function() + local invalid_provider = { setup = function() end } -- missing other functions + terminal_wrapper.setup({ provider = invalid_provider }) + terminal_wrapper.open() + + -- Verify fallback to native provider + mock_native_provider.open:was_called() + -- Check that the warning was logged (vim.notify gets called with logger output) + local notify_calls = vim.notify.calls + local found_warning = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("Invalid custom table provider.*missing required function") then + found_warning = true + break + end + end + assert.is_true(found_warning, "Expected warning about missing required function") + end) + + it("should handle provider availability check failures", function() + local provider_with_error = { + setup = function() end, + open = function() end, + close = function() end, + simple_toggle = function() end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + error("test error") + end, + } + + vim.notify:reset() + terminal_wrapper.setup({ provider = provider_with_error }) + terminal_wrapper.open() + + -- Verify graceful fallback to native provider + mock_native_provider.open:was_called() + + -- Check that the warning was logged about availability error + local notify_calls = vim.notify.calls + local found_warning = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("error checking availability") then + found_warning = true + break + end + end + assert.is_true(found_warning, "Expected warning about availability check error") + end) + + it("should validate provider function types", function() + local invalid_provider = { + setup = function() end, + open = "not_a_function", -- Wrong type + close = function() end, + simple_toggle = function() end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + return true + end, + } + + vim.notify:reset() + terminal_wrapper.setup({ provider = invalid_provider }) + terminal_wrapper.open() + + -- Should fallback to native provider + mock_native_provider.open:was_called() + + -- Check for function type validation error + local notify_calls = vim.notify.calls + local found_error = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("must be callable.*got.*string") then + found_error = true + break + end + end + assert.is_true(found_error, "Expected error about function type validation") + end) + + it("should verify fallback on availability check failure", function() + local provider_unavailable = { + setup = function() end, + open = function() end, + close = function() end, + simple_toggle = function() end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + return false + end, -- Provider says it's not available + } + + vim.notify:reset() + terminal_wrapper.setup({ provider = provider_unavailable }) + terminal_wrapper.open() + + -- Should use native provider + mock_native_provider.open:was_called() + + -- Check for availability warning + local notify_calls = vim.notify.calls + local found_warning = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("provider reports not available") then + found_warning = true + break + end + end + assert.is_true(found_warning, "Expected warning about provider not available") + end) + + it("should test auto-generated optional functions with working provider", function() + local simple_toggle_called = false + local provider_minimal = { + setup = function() end, + open = function() end, + close = function() end, + simple_toggle = function() + simple_toggle_called = true + end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + return true + end, + -- Missing toggle and _get_terminal_for_test functions + } + + terminal_wrapper.setup({ provider = provider_minimal }) + + -- Test auto-generated toggle function + assert.is_function(provider_minimal.toggle) + provider_minimal.toggle("test_cmd", { TEST = "env" }, { split_side = "left" }) + assert.is_true(simple_toggle_called) + + -- Test auto-generated _get_terminal_for_test function + assert.is_function(provider_minimal._get_terminal_for_test) + assert.is_nil(provider_minimal._get_terminal_for_test()) + end) + + it("should handle edge case where provider returns nil for required function", function() + local provider_with_nil_function = { + setup = function() end, + open = function() end, + close = nil, -- Explicitly nil instead of missing + simple_toggle = function() end, + focus_toggle = function() end, + get_active_bufnr = function() + return 123 + end, + is_available = function() + return true + end, + } + + vim.notify:reset() + terminal_wrapper.setup({ provider = provider_with_nil_function }) + terminal_wrapper.open() + + -- Should fallback to native provider + mock_native_provider.open:was_called() + + -- Check for missing function error + local notify_calls = vim.notify.calls + local found_error = false + for _, call in ipairs(notify_calls) do + local message = call.refs[1] + if message and message:match("missing required function.*close") then + found_error = true + break + end + end + assert.is_true(found_error, "Expected error about missing close function") + end) + end) end)