Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __init__(
use_mmap: bool = True,
use_mlock: bool = False,
kv_overrides: Optional[Dict[str, Union[bool, int, float, str]]] = None,
tensor_buft_overrides: Optional[List[str]] = None,
# Context Params
seed: int = llama_cpp.LLAMA_DEFAULT_SEED,
n_ctx: int = 512,
Expand Down Expand Up @@ -299,6 +300,45 @@ def __init__(
].key = b"\0" # ensure sentinel element is zeroed
self.model_params.kv_overrides = self._kv_overrides_array

self.tensor_buft_overrides = tensor_buft_overrides
if tensor_buft_overrides is not None:
# _tensor_buft_overrides_array is a ctypes.Array of llama_model_tensor_buft_override Structs
tbo_array_len = len(tensor_buft_overrides) + 1 # for sentinel element
self._tensor_buft_overrides_array = (
llama_cpp.llama_model_tensor_buft_override * tbo_array_len
)()
self._tensor_buft_overrides_patterns: List[bytes] = []

for i, override in enumerate(tensor_buft_overrides):
pattern, buft_name = override.split("=", 1)
pattern_bytes = pattern.encode("utf-8")
self._tensor_buft_overrides_patterns.append(pattern_bytes)
self._tensor_buft_overrides_array[i].pattern = pattern_bytes

if buft_name == "CUDA":
try:
self._tensor_buft_overrides_array[i].buft = llama_cpp.ggml_backend_cuda_buffer_type(0)
except AttributeError:
raise ValueError(f"CUDA backend not supported")
elif buft_name == "CPU":
try:
self._tensor_buft_overrides_array[i].buft = llama_cpp.ggml_backend_cpu_buffer_type()
except AttributeError:
raise ValueError(f"CPU backend not supported")
elif buft_name == "Metal":
try:
self._tensor_buft_overrides_array[i].buft = llama_cpp.ggml_backend_metal_buffer_type()
except AttributeError:
raise ValueError(f"Metal backend not supported")
else:
# TODO: Support other backends
raise ValueError(f"Unknown buffer type: {buft_name}")

self._tensor_buft_overrides_array[
-1
].pattern = None
self.model_params.tensor_buft_overrides = self._tensor_buft_overrides_array

self.n_batch = min(n_ctx, n_batch) # ???
self.n_threads = n_threads or max(multiprocessing.cpu_count() // 2, 1)
self.n_threads_batch = n_threads_batch or multiprocessing.cpu_count()
Expand Down
63 changes: 57 additions & 6 deletions llama_cpp/llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,9 @@
llama_memory_t = NewType("llama_memory_t", int)
llama_memory_t_ctypes = ctypes.c_void_p

# typedef struct ggml_backend_buffer_type_i * ggml_backend_buffer_type_t;
ggml_backend_buffer_type_t = ctypes.c_void_p

# struct llama_kv_cache; (DEPRECATED)
llama_kv_cache_p = NewType("llama_kv_cache_p", int)
llama_kv_cache_p_ctypes = ctypes.c_void_p
Expand Down Expand Up @@ -655,10 +658,15 @@ class llama_model_kv_override(ctypes.Structure):
value: Union[int, float, bool, bytes]


# struct llama_model_tensor_buft_override {
# const char * pattern;
# ggml_backend_buffer_type_t buft;
# };
class llama_model_tensor_buft_override(ctypes.Structure):
_fields_ = [
("pattern", ctypes.c_char_p),
("buft", ggml_backend_buffer_type_t),
]

if TYPE_CHECKING:
pattern: bytes
buft: ggml_backend_buffer_type_t


# struct llama_model_params {
Expand Down Expand Up @@ -716,7 +724,7 @@ class llama_model_params(ctypes.Structure):

if TYPE_CHECKING:
devices: CtypesArray[ctypes.c_void_p] # NOTE: unused
tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override] # NOTE: unused
tensor_buft_overrides: CtypesArray[llama_model_tensor_buft_override]
n_gpu_layers: int
split_mode: int
main_gpu: int
Expand All @@ -732,7 +740,7 @@ class llama_model_params(ctypes.Structure):

_fields_ = [
("devices", ctypes.c_void_p), # NOTE: unnused
("tensor_buft_overrides", ctypes.c_void_p), # NOTE: unused
("tensor_buft_overrides", ctypes.POINTER(llama_model_tensor_buft_override)),
("n_gpu_layers", ctypes.c_int32),
("split_mode", ctypes.c_int),
("main_gpu", ctypes.c_int32),
Expand Down Expand Up @@ -4311,6 +4319,49 @@ def llama_opt_param_filter_all(tensor: ctypes.c_void_p, userdata: ctypes.c_void_
...


# //
# // Backend utils
# //

# // ggml_backend_buffer_type_t ggml_backend_cpu_buffer_type(void);
@ctypes_function(
"ggml_backend_cpu_buffer_type",
[],
ggml_backend_buffer_type_t,
)
def ggml_backend_cpu_buffer_type() -> ggml_backend_buffer_type_t:
"""Get the CPU buffer type"""
...


# // ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device);
try:
@ctypes_function(
"ggml_backend_cuda_buffer_type",
[ctypes.c_int],
ggml_backend_buffer_type_t,
)
def ggml_backend_cuda_buffer_type(device: int, /) -> ggml_backend_buffer_type_t:
"""Get the CUDA buffer type"""
...
except AttributeError:
pass


# // ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
try:
@ctypes_function(
"ggml_backend_metal_buffer_type",
[],
ggml_backend_buffer_type_t,
)
def ggml_backend_metal_buffer_type() -> ggml_backend_buffer_type_t:
"""Get the Metal buffer type"""
...
except AttributeError:
pass


# struct llama_opt_params {
# uint32_t n_ctx_train; // assumed context size post training, use context size specified in llama_context if 0

Expand Down