Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 41 additions & 29 deletions src/vectorcode/cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ async def import_from(cls, config_dict: dict[str, Any]) -> "Config":
default_config = Config()
db_path = config_dict.get("db_path")

expand_envs_in_dict(config_dict)
if db_path is None:
db_path = os.path.expanduser("~/.local/share/vectorcode/chromadb/")
elif not os.path.isdir(db_path):
Expand Down Expand Up @@ -470,7 +471,7 @@ async def parse_cli_args(args: Optional[Sequence[str]] = None):


def expand_envs_in_dict(d: dict):
if not isinstance(d, dict):
if not isinstance(d, dict): # pragma: nocover
return
stack = [d]
while stack:
Expand All @@ -482,31 +483,43 @@ def expand_envs_in_dict(d: dict):
stack.append(curr[k])


async def load_config_file(path: Optional[Union[str, Path]] = None):
"""Load config file from ~/.config/vectorcode/config.json(5)"""
if path is None:
for name in ("config.json5", "config.json"):
p = os.path.join(GLOBAL_CONFIG_DIR, name)
if os.path.isfile(p):
path = str(p)
break
if path and os.path.isfile(path):
logger.debug(f"Loading config from {path}")
with open(path) as fin:
content = fin.read()
if content:
config = json5.loads(content)
if isinstance(config, dict):
expand_envs_in_dict(config)
return await Config.import_from(config)
else:
logger.error("Invalid configuration format!")
raise ValueError("Invalid configuration format!")
else:
logger.debug("Skipping empty json file.")
else:
logger.warning("Loading default config.")
return Config()
async def load_config_file(path: str | Path | None = None) -> Config:
"""
Load config object by merging the project-local and the global config files.
`path` can be a _file path_ or a _project-root_ path.

Raises `ValueError` if the config file is not a valid json dictionary.
"""
valid_config_paths = []
# default to load from the global config
for name in ("config.json5", "config.json"):
p = os.path.join(GLOBAL_CONFIG_DIR, name)
if os.path.isfile(p):
valid_config_paths.append(str(p))
break

if path:
if os.path.isfile((path)):
valid_config_paths.append(path)
elif os.path.isdir(path):
for name in ("config.json5", "config.json"):
p = os.path.join(path, ".vectorcode", name)
if os.path.isfile(p):
valid_config_paths.append(str(p))
break

final_config = Config()

for p in valid_config_paths:
with open(p) as fin:
content = json5.load(fin)
logger.info(f"Loaded config from {p}")
if not isinstance(content, dict):
raise ValueError("Invalid configuration format!")
final_config = await final_config.merge_from(await Config.import_from(content))
logger.debug(f"Merged config: {final_config}")

return final_config


async def find_project_config_dir(start_from: Union[str, Path] = "."):
Expand Down Expand Up @@ -543,13 +556,12 @@ def find_project_root(
start_from = start_from.parent


async def get_project_config(project_root: Union[str, Path]) -> Config:
async def get_project_config(project_root: str | Path) -> Config:
"""
Load config file for `project_root`.
Fallback to global config, and then default config.
"""
if not os.path.isabs(project_root):
project_root = os.path.abspath(project_root)
project_root = os.path.abspath(os.path.expanduser(project_root))
exts = ("json5", "json")
config = None
for ext in exts:
Expand Down
4 changes: 2 additions & 2 deletions src/vectorcode/lsp_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
expand_globs,
expand_path,
find_project_root,
get_project_config,
load_config_file,
parse_cli_args,
)
from vectorcode.common import ClientManager, get_collection, list_collection_files
Expand Down Expand Up @@ -113,7 +113,7 @@ async def execute_command(ls: LanguageServer, args: list[str]):
parsed_args.project_root = os.path.abspath(str(parsed_args.project_root))

final_configs = await (
await get_project_config(parsed_args.project_root)
await load_config_file(parsed_args.project_root)
).merge_from(parsed_args)
final_configs.pipe = True
else:
Expand Down
6 changes: 3 additions & 3 deletions src/vectorcode/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
CliAction,
config_logging,
find_project_root,
get_project_config,
load_config_file,
parse_cli_args,
)
from vectorcode.common import ClientManager
Expand All @@ -24,7 +24,7 @@ async def async_main():
if cli_args.no_stderr:
sys.stderr = open(os.devnull, "w")

if cli_args.debug:
if cli_args.debug: # pragma: nocover
from vectorcode import debugging

debugging.enable()
Expand All @@ -43,7 +43,7 @@ async def async_main():

try:
final_configs = await (
await get_project_config(cli_args.project_root)
await load_config_file(cli_args.project_root)
).merge_from(cli_args)
except IOError as e:
traceback.print_exception(e, file=sys.stderr)
Expand Down
2 changes: 0 additions & 2 deletions src/vectorcode/subcommands/init.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,6 @@ async def init(configs: Config) -> int:
else:
os.makedirs(project_config_dir, exist_ok=True)
for item in (
"config.json5",
"config.json",
"vectorcode.include",
"vectorcode.exclude",
):
Expand Down
6 changes: 5 additions & 1 deletion tests/subcommands/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,11 @@ async def test_init_copies_global_config(capsys):

# Assert files were copied
assert return_code == 0
assert copyfile_mock.call_count == len(config_items)
assert copyfile_mock.call_count == sum(
# not copying `json`s.
"json" not in i
for i in config_items.keys()
)

# Check output messages
captured = capsys.readouterr()
Expand Down
49 changes: 48 additions & 1 deletion tests/test_cli_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,52 @@ async def test_load_config_file_invalid_json():
await load_config_file(config_path)


@pytest.mark.asyncio
async def test_load_config_file_merging():
with tempfile.TemporaryDirectory() as dummy_home:
global_config_dir = os.path.join(dummy_home, ".config", "vectorcode")
os.makedirs(global_config_dir, exist_ok=True)
with open(os.path.join(global_config_dir, "config.json"), mode="w") as fin:
fin.writelines(['{"embedding_function": "DummyEmbeddingFunction"}'])

with tempfile.TemporaryDirectory(dir=dummy_home) as proj_root:
os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True)
with open(
os.path.join(proj_root, ".vectorcode", "config.json"), mode="w"
) as fin:
fin.writelines(
['{"embedding_function": "AnotherDummyEmbeddingFunction"}']
)

with patch(
"vectorcode.cli_utils.GLOBAL_CONFIG_DIR", new=str(global_config_dir)
):
assert (
await load_config_file()
).embedding_function == "DummyEmbeddingFunction"
assert (
await load_config_file(proj_root)
).embedding_function == "AnotherDummyEmbeddingFunction"


@pytest.mark.asyncio
async def test_load_config_file_with_envs():
with tempfile.TemporaryDirectory() as proj_root:
os.makedirs(os.path.join(proj_root, ".vectorcode"), exist_ok=True)
with (
open(
os.path.join(proj_root, ".vectorcode", "config.json"), mode="w"
) as fin,
):
fin.writelines(['{"embedding_function": "$DUMMY_EMBEDDING_FUNCTION"}'])
with patch.dict(
os.environ, {"DUMMY_EMBEDDING_FUNCTION": "DummyEmbeddingFunction"}
):
assert (
await load_config_file(proj_root)
).embedding_function == "DummyEmbeddingFunction"


@pytest.mark.asyncio
async def test_load_from_default_config():
for name in ("config.json5", "config.json"):
Expand Down Expand Up @@ -261,7 +307,8 @@ async def test_load_config_file_empty_file():
with open(config_path, "w") as f:
f.write("")

assert await load_config_file(config_path) == Config()
with pytest.raises(ValueError):
await load_config_file(config_path)


@pytest.mark.asyncio
Expand Down
24 changes: 12 additions & 12 deletions tests/test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ async def test_async_main_ioerror(monkeypatch):
"vectorcode.main.parse_cli_args", AsyncMock(return_value=mock_cli_args)
)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(side_effect=IOError("Test Error")),
)

Expand All @@ -62,7 +62,7 @@ async def test_async_main_cli_action_check(monkeypatch):
mock_check = AsyncMock(return_value=0)
monkeypatch.setattr("vectorcode.subcommands.check", mock_check)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(return_value=MagicMock(merge_from=AsyncMock())),
)

Expand All @@ -79,7 +79,7 @@ async def test_async_main_cli_action_init(monkeypatch):
)
mock_init = AsyncMock(return_value=0)
monkeypatch.setattr("vectorcode.subcommands.init", mock_init)
monkeypatch.setattr("vectorcode.main.get_project_config", AsyncMock())
monkeypatch.setattr("vectorcode.main.load_config_file", AsyncMock())

return_code = await async_main()
assert return_code == 0
Expand All @@ -95,7 +95,7 @@ async def test_async_main_cli_action_chunks(monkeypatch):
mock_chunks = AsyncMock(return_value=0)
monkeypatch.setattr("vectorcode.subcommands.chunks", mock_chunks)
monkeypatch.setattr(
"vectorcode.main.get_project_config", AsyncMock(return_value=Config())
"vectorcode.main.load_config_file", AsyncMock(return_value=Config())
)
monkeypatch.setattr("vectorcode.common.try_server", AsyncMock(return_value=True))

Expand Down Expand Up @@ -126,7 +126,7 @@ async def test_async_main_cli_action_prompts(monkeypatch):
mock_prompts = MagicMock(return_value=0)
monkeypatch.setattr("vectorcode.subcommands.prompts", mock_prompts)
monkeypatch.setattr(
"vectorcode.main.get_project_config", AsyncMock(return_value=Config())
"vectorcode.main.load_config_file", AsyncMock(return_value=Config())
)

return_code = await async_main()
Expand All @@ -144,7 +144,7 @@ async def test_async_main_cli_action_query(monkeypatch):
db_url="http://test_host:1234", action=CliAction.query, pipe=False
)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand Down Expand Up @@ -175,7 +175,7 @@ async def test_async_main_cli_action_vectorise(monkeypatch):
db_url="http://test_host:1234", action=CliAction.vectorise, include_hidden=True
)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand All @@ -199,7 +199,7 @@ async def test_async_main_cli_action_drop(monkeypatch):
)
mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.drop)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand All @@ -223,7 +223,7 @@ async def test_async_main_cli_action_ls(monkeypatch):
)
mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.ls)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand Down Expand Up @@ -259,7 +259,7 @@ async def test_async_main_cli_action_update(monkeypatch):
)
mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.update)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand All @@ -283,7 +283,7 @@ async def test_async_main_cli_action_clean(monkeypatch):
)
mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.clean)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand All @@ -307,7 +307,7 @@ async def test_async_main_exception_handling(monkeypatch):
)
mock_final_configs = Config(db_url="http://test_host:1234", action=CliAction.query)
monkeypatch.setattr(
"vectorcode.main.get_project_config",
"vectorcode.main.load_config_file",
AsyncMock(
return_value=AsyncMock(
merge_from=AsyncMock(return_value=mock_final_configs)
Expand Down
Loading