-
Notifications
You must be signed in to change notification settings - Fork 88
default pattern search #1259
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
default pattern search #1259
Conversation
jansel
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix lints
| from ..runtime.kernel import BoundKernel | ||
|
|
||
|
|
||
| class DefaultPatternSearch(PatternSearch): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| class DefaultPatternSearch(PatternSearch): | |
| class QuickPatternSearch(PatternSearch): |
I worry people won't understand Default means the default config.
| def _autotune(self) -> Config: | ||
| self.log( | ||
| f"Starting PatternSearch with initial_population={self.initial_population}, copies={self.copies}, max_generations={self.max_generations}" | ||
| ) | ||
| visited = set() | ||
| self.population = [] | ||
| for flat_config in [self.config_gen.default_flat()]: | ||
| member = self.make_unbenchmarked(flat_config) | ||
| if member.config not in visited: | ||
| visited.add(member.config) | ||
| self.population.append(member) | ||
| self.set_generation(0) | ||
| self.parallel_benchmark_population(self.population, desc="Initial population") | ||
| # again with higher accuracy | ||
| self.rebenchmark_population(self.population, desc="Verifying initial results") | ||
| self.population.sort(key=performance) | ||
| starting_points = [] | ||
| for member in self.population[: self.copies]: | ||
| if math.isfinite(member.perf): # filter failed compiles | ||
| starting_points.append(member) | ||
| self.log( | ||
| f"Initial random population of {len(self.population)}, {len(starting_points)} starting points:", | ||
| self.statistics, | ||
| ) | ||
| if not starting_points: | ||
| raise exc.NoConfigFound | ||
|
|
||
| search_copies = [self._pattern_search_from(m, visited) for m in starting_points] | ||
| for generation in range(1, self.max_generations + 1): | ||
| prior_best = self.best | ||
| new_population = {id(prior_best): prior_best} | ||
| num_neighbors = 0 | ||
| num_active = 0 | ||
| for search_copy in search_copies: | ||
| added = next(search_copy, ()) | ||
| if added: | ||
| assert len(added) > 1 | ||
| num_active += 1 | ||
| num_neighbors += len(added) - 1 | ||
| for member in added: | ||
| new_population[id(member)] = member | ||
| if num_active == 0: | ||
| break | ||
|
|
||
| # Log generation header before compiling/benchmarking | ||
| self.log( | ||
| f"Generation {generation} starting: {num_neighbors} neighbors, {num_active} active search path(s)" | ||
| ) | ||
|
|
||
| self.population = [*new_population.values()] | ||
| # compile any unbenchmarked members in parallel | ||
| unbenchmarked = [m for m in self.population if len(m.perfs) == 0] | ||
| if unbenchmarked: | ||
| self.set_generation(generation) | ||
| self.parallel_benchmark_population( | ||
| unbenchmarked, desc=f"Generation {generation}:" | ||
| ) | ||
| # higher-accuracy rebenchmark | ||
| self.rebenchmark_population( | ||
| self.population, desc=f"Generation {generation}: verifying top configs" | ||
| ) | ||
| # Log final statistics for this generation | ||
| self.log(f"Generation {generation} complete:", self.statistics) | ||
| return self.best.config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we refactor things so this shares more code with the base class?
No description provided.