"""
@file program_synthesis_search.py
@brief Staged Genetic Programming Search for Register Swap Algorithm Synthesis
@version 2.4
@author Grok (built by xAI)
@date November 07, 2025
@details
This program implements a staged genetic programming (GP) search to evolve a small
program (as an Abstract Syntax Tree, AST) that swaps the values in two registers
(R0 and R1) using a third scratchpad register (R2). The search uses a beam search
with mutation and crossover operators, guided by a curriculum of fitness stages.
Key Features:
- Lisp-style prefix notation for program representation (S-expressions).
- VM interpreter for side-effectful execution (memory modifications via 'set').
- Staged heuristics: First stage rewards partial correctness (one register right),
second requires full swap.
- Post-evolution simplification using conservative higher-order logic pruning: recursive
rewrite rules for math identities (+/- with 0) on pure subtrees only (no side-effects or
data dependencies), progn flattening and single-child removal (preserving side-effects).
No intra-search simplification to avoid disrupting evolution; only post-search.
The evolved program is output in full Lisp syntax, suitable for direct use in a
Lisp environment (with equivalents for 'set', 'get', 'progn', etc.).
Usage:
python program_synthesis_search.py
@license MIT (or equivalent open-source)
"""
import math
import heapq
import copy
import random
from dataclasses import dataclass, field
from typing import Callable, Optional, List, Tuple, Dict, Any
# Optional: For reproducibility
# random.seed(42)
# ---
# ## 1. The "World" (Memory)
# ---
MEMORY_SIZE = 3 # R0, R1, R2 (In_0, In_1, Scratchpad)
# ---
# ## 2. The "Gene Pool" and "Chromosome" (AST Node)
# ---
@dataclass(frozen=True)
class Gene:
"""Gene definition for GP primitives."""
name: str
arity: int
complexity_cost: float
function: Callable[[List[float], Any, Any], float]
@dataclass
class Node:
"""AST node for program representation."""
gene: Gene
left: Optional['Node'] = None
right: Optional['Node'] = None
def __str__(self):
if self.gene.arity == 0:
return self.gene.name
parts = [self.gene.name]
if self.left:
parts.append(str(self.left))
if self.right:
parts.append(str(self.right))
return "(" + " ".join(parts) + ")"
# ---
# ## 3. The "Gene Pool" Definition (Instruction Set)
# ---
def op_progn(mem: List[float], left_res: float, right_res: float) -> float:
"""Sequential execution stub."""
return right_res
def op_set(mem: List[float], idx: float, val: float) -> float:
"""Memory write operation."""
int_idx = int(round(idx))
if 0 <= int_idx < MEMORY_SIZE:
mem[int_idx] = val
return val
def op_get(mem: List[float], idx: float, _) -> float:
"""Memory read operation."""
int_idx = int(round(idx))
if 0 <= int_idx < MEMORY_SIZE:
return mem[int_idx]
return 0.0
def op_add(mem: List[float], a: float, b: float) -> float:
return a + b
def op_sub(mem: List[float], a: float, b: float) -> float:
return a - b
def term_idx_0(mem, _, __): return 0.0
def term_idx_1(mem, _, __): return 1.0
def term_idx_2(mem, _, __): return 2.0
def term_input_0(mem, _, __): return mem[0]
def term_input_1(mem, _, __): return mem[1]
def term_input_2(mem, _, __): return mem[2]
def term_const_0(mem, _, __): return 0.0
GENE_POOL = [
# Control Flow
Gene("progn", 2, 0.05, op_progn),
# Memory Operations
Gene("set", 2, 1.0, op_set),
Gene("get", 1, 0.5, op_get),
# Math
Gene("+", 2, 1.0, op_add),
Gene("-", 2, 1.0, op_sub),
# Terminals (Indices)
Gene("R0", 0, 0.1, term_idx_0),
Gene("R1", 0, 0.1, term_idx_1),
Gene("R2", 0, 0.1, term_idx_2),
# Terminals (Values)
Gene("Val(R0)", 0, 0.2, term_input_0),
Gene("Val(R1)", 0, 0.2, term_input_1),
Gene("Val(R2)", 0, 0.2, term_input_2),
Gene("0.0", 0, 0.1, term_const_0),
]
TERMINAL_GENES = [g for g in GENE_POOL if g.arity == 0]
NON_TERMINAL_GENES = [g for g in GENE_POOL if g.arity > 0]
# ---
# ## 4. The "Interpreter" (VM) and Cost Functions
# ---
def eval_tree(node: Node, memory: List[float]) -> float:
"""Recursive VM for AST execution."""
try:
if node.gene.name == 'progn':
if node.left:
eval_tree(node.left, memory)
if node.right:
return eval_tree(node.right, memory)
return 0.0
if node.gene.arity == 0:
return node.gene.function(memory, 0, 0)
left_val = eval_tree(node.left, memory) if node.left else 0.0
right_val = eval_tree(node.right, memory) if node.right else 0.0
return node.gene.function(memory, left_val, right_val)
except (ValueError, OverflowError, IndexError, ZeroDivisionError):
return float('inf')
def get_complexity_score(node: Node) -> float:
"""Recursive complexity calculation."""
if not node:
return 0.0
cost = node.gene.complexity_cost
if node.left:
cost += get_complexity_score(node.left)
if node.right:
cost += get_complexity_score(node.right)
return cost
# ---
# ## 5. Tree Simplification (Higher-Order Logic Pruning)
# ---
def has_side_effect_or_dependency(node: Node) -> bool:
"""Check for side-effects or data dependencies in subtree."""
if not node:
return False
name = node.gene.name
if name in ['set', 'Val(R0)', 'Val(R1)', 'Val(R2)', 'get']:
return True
if name == 'progn':
return has_side_effect_or_dependency(node.left) or has_side_effect_or_dependency(node.right)
return has_side_effect_or_dependency(node.left) or has_side_effect_or_dependency(node.right)
def get_const_value(node: Node) -> Optional[float]:
"""Extract constant value if subtree is pure constant, else None."""
if not node:
return None
name = node.gene.name
if name == '0.0':
return 0.0
if name in ['R0', 'R1', 'R2']:
return float(name[1])
if name in ['Val(R0)', 'Val(R1)', 'Val(R2)', 'get']:
return None
if name in ['+', '-']:
l = get_const_value(node.left)
r = get_const_value(node.right)
if l is not None and r is not None:
if name == '+':
return l + r
return l - r
if name == 'progn':
return get_const_value(node.right)
return None
def simplify_tree(node: Node) -> Node:
"""Apply higher-order logic pruning: flatten progn, fold constants, remove identities."""
if not node:
return node
# Recurse on children
left = simplify_tree(node.left) if node.left else None
right = simplify_tree(node.right) if node.right else None
new_node = Node(gene=node.gene, left=left, right=right)
name = new_node.gene.name
# Flatten and reduce progn
if name == 'progn':
children = []
def flatten_progn(n: Optional[Node]):
if n and n.gene.name == 'progn':
flatten_progn(n.left)
flatten_progn(n.right)
elif n:
children.append(n)
flatten_progn(new_node)
if not children:
return Node(gene=GENE_POOL[-1])
if len(children) == 1:
return children[0]
else:
# Binary chain rebuild
def build_chain(cs: List[Node]) -> Optional[Node]:
if len(cs) == 1:
return cs[0]
mid = len(cs) // 2
l = build_chain(cs[:mid])
r = build_chain(cs[mid:])
return Node(gene=GENE_POOL[0], left=l, right=r)
return build_chain(children)
# Math identities (pure subtrees only)
if name in ['+', '-'] and not has_side_effect_or_dependency(new_node):
l_const = get_const_value(left)
r_const = get_const_value(right)
if name == '+':
if l_const == 0:
return right
if r_const == 0:
return left
if l_const is not None and r_const is not None:
if l_const + r_const == 0:
return Node(gene=GENE_POOL[-1])
elif name == '-':
if r_const == 0:
return left
if l_const is not None and r_const is not None:
if l_const - r_const == 0:
return Node(gene=GENE_POOL[-1])
# Constant folding (pure subtrees)
if not has_side_effect_or_dependency(new_node):
const_val = get_const_value(new_node)
if const_val is not None and abs(const_val) < 1e-6:
return Node(gene=GENE_POOL[-1])
return new_node
# ---
# ## 6. Staged Fitness Heuristic
# ---
@dataclass
class FitnessStage:
"""Curriculum stage definition."""
name: str
target_error: float
heuristic_func: Callable[[Node, List[Tuple[List[float], List[float]]]], float]
def h_full_swap(program: Node, test_data: List[Tuple[List[float], List[float]]]) -> float:
"""Full error on all registers."""
total_squared_error = 0
for input_list, target_list in test_data:
memory = [0.0] * MEMORY_SIZE
for i, val in enumerate(input_list):
if i < MEMORY_SIZE:
memory[i] = val
eval_tree(program, memory)
for i, target_val in enumerate(target_list):
if i < MEMORY_SIZE:
error = memory[i] - target_val
total_squared_error += error * error
return total_squared_error
def h_one_reg_correct(program: Node, test_data: List[Tuple[List[float], List[float]]]) -> float:
"""Bonus if at least one register is correct."""
total_squared_error = 0
for input_list, target_list in test_data:
memory = [0.0] * MEMORY_SIZE
for i, val in enumerate(input_list):
if i < MEMORY_SIZE:
memory[i] = val
eval_tree(program, memory)
errors = [abs(memory[i] - target_val) for i, target_val in enumerate(target_list) if i < MEMORY_SIZE]
if any(e < 0.01 for e in errors):
total_squared_error += sum(e * e for e in errors if e >= 0.01)
else:
total_squared_error += sum(e * e for e in errors)
return total_squared_error
# ---
# ## 7. Genetic Operators
# ---
def _find_nodes_by_type(node: Node, type_: str) -> List[Node]:
"""Find nodes by type ('all', 'leaf', 'internal')."""
nodes = []
if not node:
return []
is_leaf = (node.gene.arity == 0)
if type_ == 'all':
nodes.append(node)
elif type_ == 'leaf' and is_leaf:
nodes.append(node)
elif type_ == 'internal' and not is_leaf:
nodes.append(node)
nodes.extend(_find_nodes_by_type(node.left, type_))
nodes.extend(_find_nodes_by_type(node.right, type_))
return nodes
def _generate_random_subtree(max_depth: int = 4) -> Node:
"""Generate random subtree with depth limit."""
if max_depth <= 0:
return Node(gene=random.choice(TERMINAL_GENES))
weighted_non_term = NON_TERMINAL_GENES * 3
weighted_all = weighted_non_term + TERMINAL_GENES
gene = random.choice(weighted_all)
if gene.arity == 0 or max_depth <= 0:
return Node(gene=gene)
new_node = Node(gene=gene)
if gene.arity >= 1:
new_node.left = _generate_random_subtree(max_depth - 1)
if gene.arity == 2:
new_node.right = _generate_random_subtree(max_depth - 1)
return new_node
def _functional_replace(current: Node, target: Node, replacement: Node) -> Node:
"""Replace subtree by identity."""
if current is target:
return copy.deepcopy(replacement)
if current.gene.arity == 0:
return copy.deepcopy(current)
new_node = Node(gene=current.gene)
if current.left:
new_node.left = _functional_replace(current.left, target, replacement)
if current.right:
new_node.right = _functional_replace(current.right, target, replacement)
return new_node
def op_mutate_subtree(tree: Node) -> Node:
"""Random subtree mutation."""
nodes = _find_nodes_by_type(tree, 'all')
if not nodes:
return tree
node_to_replace = random.choice(nodes)
new_branch = _generate_random_subtree(max_depth=4)
return _functional_replace(tree, node_to_replace, new_branch)
def op_crossover(tree1: Node, tree2: Node) -> Node:
"""Subtree crossover."""
all_nodes1 = _find_nodes_by_type(tree1, 'all')
all_nodes2 = _find_nodes_by_type(tree2, 'all')
if not all_nodes1 or not all_nodes2:
return tree1
node_from_1 = random.choice(all_nodes1)
node_from_2 = random.choice(all_nodes2)
return _functional_replace(tree1, node_from_1, node_from_2)
# ---
# ## 8. The Search Engine (Staged Curriculum)
# ---
@dataclass(order=True)
class SearchState:
"""State for beam search."""
f_cost: float
h_cost: float
g_cost: float
program: Node = field(compare=False)
def search_program(test_data: List[Tuple[List[float], List[float]]],
stages: List[FitnessStage],
beam_width: int,
max_iterations_per_stage: int,
lambda_complexity: float,
p_mutate: float):
pq: List[SearchState] = []
print("--- Initializing Frontier with Terminals ---")
visited: Dict[str, float] = {}
h_func = stages[0].heuristic_func
for gene in TERMINAL_GENES:
program = Node(gene=gene)
prog_str = str(program)
g = get_complexity_score(program)
h = h_func(program, test_data)
f = (lambda_complexity * g) + h
if h == float('inf'):
continue
state = SearchState(f_cost=f, h_cost=h, g_cost=g, program=program)
heapq.heappush(pq, state)
visited[prog_str] = g
print(f" Added: {prog_str:<8} (f={f:.2f}) [g={g:.1f}, h={h:.1f}]")
for stage in stages:
print(f"\n\n--- Starting Stage: '{stage.name}' (Goal: h <= {stage.target_error}) ---")
visited.clear()
current_population = pq
pq = []
for state in current_population:
state.h_cost = stage.heuristic_func(state.program, test_data)
state.f_cost = (lambda_complexity * state.g_cost) + state.h_cost
heapq.heappush(pq, state)
visited[str(state.program)] = state.g_cost
stage_solved = False
for i in range(max_iterations_per_stage):
current_beam = heapq.nsmallest(beam_width, pq)
if not current_beam:
print("Search failed: Frontier is empty.")
return None
best_in_beam = current_beam[0]
print(f"\n --- Stage '{stage.name}', Iter {i} ---")
print(f" Best F={best_in_beam.f_cost:.2f} [g={best_in_beam.g_cost:.1f}, h={best_in_beam.h_cost:.1f}] Prog: {best_in_beam.program}")
if best_in_beam.h_cost <= stage.target_error:
print(f" --- Stage '{stage.name}' SOLVED! ---")
stage_solved = True
pq = current_beam
break
next_gen_candidates: Dict[str, SearchState] = {}
for state in current_beam:
next_gen_candidates[str(state.program)] = state
for parent_state in current_beam:
new_child_tree = None
if random.random() < p_mutate:
new_child_tree = op_mutate_subtree(parent_state.program)
else:
other_parent_state = random.choice(current_beam)
new_child_tree = op_crossover(parent_state.program, other_parent_state.program)
prog_str = str(new_child_tree)
g_new = get_complexity_score(new_child_tree)
if prog_str in visited and visited[prog_str] <= g_new:
continue
h_new = stage.heuristic_func(new_child_tree, test_data)
if h_new == float('inf'):
continue
f_new = (lambda_complexity * g_new) + h_new
visited[prog_str] = g_new
if prog_str not in next_gen_candidates or f_new < next_gen_candidates[prog_str].f_cost:
next_gen_candidates[prog_str] = SearchState(f_new, h_new, g_new, new_child_tree)
pq = list(next_gen_candidates.values())
heapq.heapify(pq)
if not stage_solved:
print(f"Search failed: Max iterations reached for stage '{stage.name}'.")
return None
print("\n--- All Stages Solved! ---")
final_solution = heapq.nsmallest(1, pq)[0]
return final_solution.program
# ---
# ## 9. Main Execution
# ---
if __name__ == "__main__":
print("--- Starting Staged Algorithmic Search ---")
print(f"Goal: Find an algorithm to swap R0 and R1 (Memory Size={MEMORY_SIZE})")
TEST_DATA = [
([5.0, 10.0], [10.0, 5.0]),
([1.0, 2.0], [2.0, 1.0]),
([-5.0, 3.0], [3.0, -5.0]),
]
STAGES = [
FitnessStage(
name="Get One Register Right",
target_error=90.01,
heuristic_func=h_one_reg_correct
),
FitnessStage(
name="Get Both Registers Right",
target_error=0.01,
heuristic_func=h_full_swap
)
]
LAMBDA_COMPLEXITY = 1.0 # Penalize bloat more
BEAM_WIDTH = 200
MAX_ITERATIONS_PER_STAGE = 200
P_MUTATE = 1.0
solution = search_program(
test_data=TEST_DATA,
stages=STAGES,
beam_width=BEAM_WIDTH,
max_iterations_per_stage=MAX_ITERATIONS_PER_STAGE,
lambda_complexity=LAMBDA_COMPLEXITY,
p_mutate=P_MUTATE
)
if solution:
# Post-search simplification
simplified_solution = simplify_tree(solution)
print(f"\nRaw Final Program: {solution}")
print(f"Raw Complexity (g): {get_complexity_score(solution):.1f}")
print(f"\nSimplified Final Program (Lisp Syntax): {simplified_solution}")
print(f"Simplified Complexity (g): {get_complexity_score(simplified_solution):.1f}")
print(f"Final Fitness Error (h): {h_full_swap(simplified_solution, TEST_DATA):.4f}")
print("\n--- Test Results ---")
for in_list, target_list in TEST_DATA:
mem = [0.0] * MEMORY_SIZE
for i, val in enumerate(in_list):
if i < MEMORY_SIZE:
mem[i] = val
print(f"Input: {in_list}")
eval_tree(simplified_solution, mem)
print(f"Output: {[round(m, 4) for m in mem[:len(target_list)]]}")
print(f"Target: {target_list}\n")# your code goes here