Advanced Models

Klay’s registry already covers the most common GNN, embedding‐ and convolutional blocks used in MLIPs and is actively curated for more layers, but advanced research often calls for arbitrary PyTorch callables and weight-sharing across multiple stages. This page shows how to achieve both with two tiny YAML constructs:

  • ArbitraryModule – wrap any import-path (class or function).

  • Aliases – call an existing layer again, enabling parallel or staged graphs with shared parameters.

End-to-end example

model_inputs:
  x: "(N,16) tensor you feed at runtime"

model_layers:
  dense:                    # one Linear layer
    type: ArbitraryModule
    config:
      target: torch.nn.Linear
      args:  [16, 16]
    inputs: {0: model_inputs.x}

  relu:                     # ReLU module
    type: ArbitraryModule
    config: {target: torch.nn.ReLU}
    inputs: {0: dense}

  second_pass:              #  *alias* -> reuse the same Linear
    alias: dense
    inputs: {0: relu}

  relu2:                    #  functional ReLU
    type: ArbitraryModule
    config: {target: torch.nn.functional.relu}
    inputs: {0: second_pass}

model_outputs:
  preds: relu2

Run it:

from klay.builder import build_model
from klay.io      import load_config
import torch

cfg   = load_config("example/arbitrary_and_alias.yaml")
model = build_model(cfg)
out = model(torch.rand(16))      # torch.Size([16])

1. Arbitrary Layers (ArbitraryModule)

ArbitraryModule turns any importable Python callable into a Klay layer. Its config block mirrors a constructor call:

some_layer:
  type:  ArbitraryModule
  config:
    target: torch.nn.functional.gelu   # dotted import path
    args: []                           # -> *positional* args
    kwargs: {}                         # -> **keyword** args
  inputs:
    0: previous_tensor                 # maps to arg 0
  # output map is optional; omit for single-tensor returns

How it works

  • If target resolves to a class derived from torch.nn.Module. It is instantiated with the given args/kwargs.

  • Otherwise the callable is kept intact; when args or kwargs are non-empty Klay wraps it in :pymod:`functools.partial`.

  • The wrapper still behaves like a standard layer, so you can alias it, trace it with :pyclass:`torch.fx.GraphModule`, and place it in any branch of a larger DAG.

Input block rules

  • Positional ports use integer keys (0, 1…). They are forwarded in sorted order. Keep them consistent with the callable’s signature.

  • Keyword ports use strings; they map 1-to-1 to argument names.

2. Layer Aliases (weight sharing & staged graphs)

An alias creates a second call-site to an already-declared layer, reusing its parameters:

next_stage:
  alias: previous_dense     # must point to an existing layer name
  inputs:
    0: some_tensor          # new data path
  output: {0: stage_out}

Typical use-cases

  • Parallel branches – e.g. self-attention where the same MLP block serves query/key/value heads.

  • Staged graphs – run the same block (e.g. embedding) on different input graphs for a domain decomposition invariant parallel staged graphs to be used in OpenKIM TorchML driver.

  • Recurrent constructs – feed the output of a block back into itself for another iteration.

Key points

  • The alias appears as a normal node in the FX graph, but shares the original parameters (state_dict stores them only once).

  • You may define a fresh inputs / output map to wire the alias into a new context.

  • Validation – the DAG builder ensures alias targets an existing layer and that no cycles are introduced.

Caveats & compliance notes

Warning

  • TorchScript support is *best-effort*. Arbitrary callables that rely on Python features not supported by TorchScript (e.g. string operations, dynamic shapes) will break serialization.

  • OpenKIM validators require scripted models. If you intend to register a potential under the KIM framework, re-implement your arbitrary or functional layers as torch.nn.Module and confirm they script cleanly.

  • Functional callables + constructor args Remember to provide args / kwargs only when the function actually expects them; mismatched signatures surface at run-time.

  • Alias loops are disallowed. The DAG check stops you from creating recursive references, but be mindful when chaining many aliases in complex graphs.

Tip

This should be enough for most complicated MLIPs, but if it is not, you can always use KLay to generate layers and manually build the model yourself.