Understanding krach's IR¶
krach uses a JAX-inspired tracing model: Python functions become frozen IR graphs that lower to backend code. This page walks through the system from trace to execution.
The big picture¶
Python DSP function → trace → DspGraph (Signal IR) → emit_faust → FAUST → LLVM JIT → audio
Python patterns → build → PatternNode tree → serialize → JSON → Rust engine
Python session → capture → ModuleIr → to_dict → JSON → persistence
Three IRs, one for each domain:
| IR | Shape | Produced by | Consumed by |
|---|---|---|---|
| DspGraph | Flat DAG of equations | Signal tracing (TraceContext) |
Faust codegen |
| PatternNode | Tree | Direct construction (operators) | Engine protocol serialization |
| ModuleIr | Flat record of definitions | capture() or ModuleProxy |
instantiate(), JSON persistence |
Signal tracing: Python → DspGraph¶
This is the core idea. A Python function executes once against Signal proxy objects. Each operation records an Equation into a DspGraph. The function never runs again — the graph IS the computation.
A concrete example¶
import krach.dsp as krs
def bass():
freq = krs.control("freq", 55.0, 20.0, 800.0)
gate = krs.control("gate", 0.0, 0.0, 1.0)
env = krs.adsr(0.005, 0.15, 0.3, 0.08, gate)
return krs.lowpass(krs.saw(freq), 800.0) * env
When krach traces this function, it produces:
{ lambda ; . let
s0 = control [ControlParams(name='freq', init=55.0, lo=20.0, hi=800.0, step=0.001)]
s1 = control [ControlParams(name='gate', init=0.0, lo=0.0, hi=1.0, step=0.001)]
s2 = const [ConstParams(value=0.005)]
s3 = const [ConstParams(value=0.15)]
s4 = const [ConstParams(value=0.3)]
s5 = const [ConstParams(value=0.08)]
s6 = faust_expr s2 s3 s4 s5 s1 [FaustExprParams(template='en.adsr(...)')]
s18 = feedback s0 [FeedbackParams(body_graph=<saw phasor>)]
s20 = mul s18 2.0
s22 = sub s20 1.0
s24 = faust_expr 800.0 s22 [FaustExprParams(template='{1} : fi.lowpass(2, {0})')]
s25 = mul s24 s6
in (s25) }
This is the DspGraph — krach's equivalent of a jaxpr. Each line is an Equation:
@dataclass(frozen=True, slots=True)
class Equation:
primitive: Primitive # operation name (e.g., "mul", "control", "feedback")
inputs: tuple[Signal, ...] # input signals (by ID)
outputs: tuple[Signal, ...] # output signals (by ID)
params: PrimitiveParams # typed parameters (ControlParams, ConstParams, etc.)
How tracing works¶
make_graph(bass)creates aTraceContextand callsbass()krs.control("freq", ...)callsbind(control_p, ...)which:- Runs the
abstract_evalrule to compute the output type - Creates a fresh
Signalfor the output - Records an
Equationinto the TraceContext - Returns the output Signal (a proxy, not a value)
krs.saw(freq)— same:bind(feedback_p, freq)records another equationkrs.lowpass(sig, 800.0)—800.0is coerced to aconstSignal viacoerce_to_signal()* env—Signal.__mul__callsbind(mul_p, sig, env)(operator overload)- The function returns. TraceContext collects all equations into a
DspGraph
The key insight: the function runs once, against abstract values, to produce a graph. No audio is generated during tracing. The graph is then lowered to FAUST, compiled via LLVM, and runs at 44.1kHz.
The types¶
# ir/signal.py — pure frozen data
@dataclass(frozen=True, slots=True, eq=False)
class Signal:
aval: SignalType # abstract value (channels, precision)
id: int # unique identifier
owner_id: int # which TraceContext created this
# eq=False: identity comparison by id only (custom __eq__/__hash__)
@dataclass(frozen=True, slots=True)
class DspGraph:
inputs: tuple[Signal, ...] # function parameters (audio inputs)
outputs: tuple[Signal, ...] # function return values
equations: tuple[Equation, ...]
precision: Precision = Precision.FLOAT32
Canonicalization and caching¶
Signal IDs are assigned during tracing and vary between runs. Two traces of the same function produce different IDs but the same computation.
canonicalize(graph) alpha-renames all Signal IDs to sequential integers (0, 1, 2, ...). graph_key(graph) returns a structural hash of the canonicalized graph.
Two DspGraphs with the same graph_key are semantically identical — they share compiled FAUST binaries:
g1 = make_graph(bass)
g2 = make_graph(bass)
g1.inputs[0].id != g2.inputs[0].id # different raw IDs
graph_key(g1) == graph_key(g2) # same structural hash → cache hit
Pattern building: operators → PatternNode¶
Patterns don't trace — they build trees directly. The Python expression IS the tree:
Produces:
PatternNode(cat_p, children=(
PatternNode(freeze_p, children=(
PatternNode(cat_p, children=(
PatternNode(stack_p, children=(
PatternNode(atom_p, AtomParams(Control("freq", 261.63))),
PatternNode(atom_p, AtomParams(Control("gate", 1.0))),
)),
PatternNode(atom_p, AtomParams(Control("gate", 0.0))),
)),
)),
PatternNode(freeze_p, children=(...)), # E4
PatternNode(silence_p), # rest
))
Each PatternNode has:
@dataclass(frozen=True, slots=True)
class PatternNode:
primitive: Primitive # "cat", "freeze", "atom", etc.
children: tuple[PatternNode, ...] # sub-trees
params: PatternParams # typed per-primitive params
Why no tracing? Because patterns are trees (no sharing), and Python's expression syntax naturally builds the right shape. Signals need tracing because they're graphs (with sharing — the same signal can feed multiple equations).
ModuleIr: the top-level jaxpr¶
ModuleIr is the session specification — it contains DspGraphs and PatternNodes:
@dataclass(frozen=True, slots=True)
class ModuleIr:
nodes: tuple[NodeDef, ...] # each has source: DspGraph | str
routing: tuple[RouteDef, ...] # connections between nodes
patterns: tuple[PatternDef, ...] # each has pattern: PatternNode
controls: tuple[ControlDef, ...]
muted: tuple[MutedDef, ...]
automations: tuple[AutomationDef, ...] = ()
tempo: float | None
meter: float | None
master: float | None
sub_modules: tuple[tuple[str, ModuleIr], ...] # recursion
The NodeDef.source field holds a DspGraph (the signal computation for that node) or a str (reference to a pre-compiled FAUST type like "faust:kick").
inputs and outputs¶
ModuleIr has optional port declarations:
inputs: tuple[str, ...] | None = None # declared input ports
outputs: tuple[str, ...] | None = None # declared output ports
None= undeclared (scene/session capture — no explicit ports)()= explicitly no ports("osc", "bus")= named ports referencing nodes in the module
Port names are validated by flatten() — all declared inputs/outputs must reference existing nodes in the flattened IR.
prefix_ir(ir, prefix) — namespace prefixing¶
Rewrites all name fields with the prefix:
- NodeDef.name: "kick" → "drums/kick"
- RouteDef.source/target: prefixed
- PatternDef.target: prefixed
- MutedDef.name: prefixed
- ControlDef.path/AutomationDef.path: node portion prefixed (param portion preserved)
- inputs/outputs: prefixed
- sub_modules prefixes: recursively prefixed
- NOT prefixed: RouteDef.port (DSP input name, not a node name)
flatten(ir) — recursive sub_module resolution¶
Recursively resolves sub_modules:
1. For each (prefix, child_ir) in sub_modules, calls prefix_ir(child_ir, prefix)
2. Recursively flattens the prefixed child
3. Merges child nodes, routing, patterns, controls, automations into the parent
4. Parent-wins for transport: child tempo/meter/master are ignored
5. Validates that all declared inputs/outputs reference existing nodes
6. Returns a flat ModuleIr with empty sub_modules
Module composition pattern¶
Modules compose through sub_modules + inputs/outputs:
@kr.module "full_band"
├── node: "bass"
├── sub_modules:
│ └── ("drums", drums_ir)
│ ├── node: "kick" → flattened to "drums/kick"
│ └── node: "hat" → flattened to "drums/hat"
├── inputs: None
└── outputs: ("bass",)
After flatten(), all sub_module nodes are merged into the parent with prefixed names. The hierarchical / separator works with krach's resolve_path — kr.remove("drums") removes all drums/* nodes.
Shared infrastructure¶
Primitive¶
One frozen type for both domains:
# ir/primitive.py
@dataclass(frozen=True, slots=True)
class Primitive:
name: str
stateful: bool = False
Signal primitives: add_p = Primitive("add"), sin_p = Primitive("sin"), feedback_p = Primitive("feedback", stateful=True)
Pattern primitives: cat_p = Primitive("cat"), fast_p = Primitive("fast"), freeze_p = Primitive("freeze")
RuleRegistry¶
Per-primitive rules registered externally (not on the Primitive — it's just data):
# ir/registry.py
class RuleRegistry(Generic[P, R]):
def register(self, prim: P, rule: R) -> R
def lookup(self, prim: P) -> R
def check_complete(self, expected: frozenset[P]) -> None
Two RuleRegistry instances (defined in signal/trace.py, rules registered externally):
| Registry | Rules registered in | Purpose |
|---|---|---|
abstract_eval |
signal/primitives.py |
Type inference during tracing |
lowering |
backends/faust_lowering.py |
Signal IR → FAUST expressions |
check_complete() runs at import time — adding a primitive without a rule fails immediately, not at runtime.
Pattern rules use a simpler mechanism: a dict[str, Rule] in pattern/primitives.py with def_serialize / def_summary wrappers. Same import-time completeness guarantee, different implementation.
Dependency layering¶
ir/ → stdlib only (pure frozen data)
signal/ → ir/ + backends/ (tracing runtime + DSL; transpile imports codegen)
pattern/ → ir/ (building + DSL)
backends/ → ir/ + signal/ (lowering)
top-level → everything (Mixer, REPL)
tests/test_dependency_invariant.py enforces this: no module-level imports from ir/ to signal/, pattern/, or backends/.
Adding a new signal primitive¶
- Define it:
my_p = Primitive("my_op")insignal/primitives.py - Register abstract_eval:
abstract_eval.register(my_p, my_eval_fn) - Write the user-facing function in
signal/core.py: callsbind(my_p, ...) - Register lowering:
lowering.register(my_p, my_lower_fn)inbackends/faust_lowering.py check_complete()at import time verifies nothing is missing
Adding a new pattern primitive¶
- Add a
*Paramsdataclass toir/pattern.py - Define it:
my_p = Primitive("my_op")inpattern/primitives.py - Register serialize rule in
pattern/serialize.py - Register summary handler in
pattern/summary.py - Add an operator or method on
Patterninpattern/pattern.py - Import-time completeness checks catch missing rules