Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 39 additions & 29 deletions padiff/abstracts/hooks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,42 +60,52 @@ def _traversal(node, bucket):

class _CallsContext:
"""
A global context for managing forward call counts across multiple PaDiffGuard invocations.
This ensures that max_calls is respected even when PaDiffGuard is re-entered.
A context for managing forward call counts for PaDiffGuard invocations.
Each model instance will have its own independent call count state.
Different files always use independent counters.
"""

_state = contextvars.ContextVar("_calls_context_state", default=None)

def __init__(self):
self._state.set({"count": 0, "limit": 0, "active": False})

@property
def state(self) -> Dict:
s = self._state.get()
if s is None:
s = {"count": 0, "limit": 0, "active": False}
self._state.set(s)
return s

def set_limit(self, limit: int):
self.state["limit"] = limit
self.state["active"] = True

def increment(self) -> int:
if not self.state["active"]:
self._model_states = {} # model_id -> state dict

def _get_model_id(self, model):
"""Get a unique identifier for the model in current context"""
# Simple object identity is sufficient since we don't need cross-file compatibility
return str(id(model))

def get_state(self, model) -> Dict:
"""Get the state for a specific model"""
model_id = self._get_model_id(model)
if model_id not in self._model_states:
self._model_states[model_id] = {"count": 0, "limit": 0, "active": False}
return self._model_states[model_id]

def set_limit(self, model, limit: int):
"""Set the call limit for a specific model"""
state = self.get_state(model)
state["limit"] = limit
state["active"] = True

def increment(self, model) -> int:
"""Increment the call count for a specific model"""
state = self.get_state(model)
if not state["active"]:
return 0
self.state["count"] += 1
return self.state["count"]
state["count"] += 1
return state["count"]

def is_exceeded(self) -> bool:
if not self.state["active"]:
def is_exceeded(self, model) -> bool:
"""Check if the call limit is exceeded for a specific model"""
state = self.get_state(model)
if not state["active"]:
return False
return self.state["count"] >= self.state["limit"]
return state["count"] >= state["limit"]

def reset(self):
self.state["count"] = 0
self.state["limit"] = 0
self.state["active"] = False
def reset(self, model):
"""Reset the state for a specific model"""
model_id = self._get_model_id(model)
if model_id in self._model_states:
del self._model_states[model_id]

@classmethod
def get_current(cls) -> "_CallsContext":
Expand Down
13 changes: 7 additions & 6 deletions padiff/abstracts/hooks/guard.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,10 @@ def MaxCallsGuard(max_calls: int, model):
calls_context = get_calls_context()

def pre_hook(m, input):
if calls_context.is_exceeded():
if calls_context.is_exceeded(model):
logger.warning(f"PaDiffGuard: max_calls={max_calls} reached, raising _CallsComplete")
raise _CallsComplete()
count = calls_context.increment()
count = calls_context.increment(model)
logger.info(f"MaxCallsGuard: forward start calling #{count}")

handle = model.register_forward_pre_hook(pre_hook)
Expand All @@ -240,13 +240,14 @@ def PaDiffGuard(
black_list=None,
keys_mapping=None,
):
# moniter number of calls
# get the global calls context
calls_context = get_calls_context()
reset_flag = calls_context.state["count"] == 0
# check if this is the first call for this specific model
reset_flag = calls_context.get_state(model)["count"] == 0

if reset_flag:
# set max calls
calls_context.set_limit(max_calls)
# set max calls for this model
calls_context.set_limit(model, max_calls)

proxy_model = create_model(model, name=name, reset_dir=reset_flag)
model._padiff_proxy = proxy_model
Expand Down
Loading