|
8 | 8 |
|
9 | 9 | class LLaDAModel(nn.Module): |
10 | 10 | """ |
11 | | - LLaDA: Large Language Diffusion Model |
12 | | - Uses a transformer backbone with timestep conditioning for diffusion-based language modeling |
| 11 | + A torch-based language model that incorporates diffusion-based generation through time step conditioning. |
| 12 | + The model allows for various text generation strategies including random sampling, confidence-based sampling, |
| 13 | + semi-autoregressive generation, and beam search. |
| 14 | + Attributes: |
| 15 | + config: Configuration object containing model hyperparameters |
| 16 | + wte (nn.Embedding): Token embeddings |
| 17 | + wpe (nn.Embedding): Position embeddings |
| 18 | + dropout (nn.Dropout): Dropout layer |
| 19 | + h (nn.ModuleList): List of transformer blocks |
| 20 | + ln_f (nn.LayerNorm): Final layer normalization |
| 21 | + time_embed (TimeEmbedding): Time step embedding module |
| 22 | + time_proj (nn.ModuleList): Time projection layers for each transformer block |
| 23 | + lm_head (nn.Linear): Output projection to vocabulary |
| 24 | + Methods: |
| 25 | + forward(input_ids, attention_mask, timesteps, labels): |
| 26 | + Forward pass through the model for training and inference |
| 27 | + generate(prompt, max_length, num_inference_steps, temperature, strategy, top_p, top_k, num_beams, return_scores): |
| 28 | + Generate text using various sampling strategies and the reverse diffusion process |
| 29 | + generate_stream(prompt, max_length, num_inference_steps, temperature, strategy, top_p, top_k, num_beams, callback_fn): |
| 30 | + Example: |
| 31 | + >>> config = ModelConfig(vocab_size=50257, hidden_size=768) |
| 32 | + >>> model = LLaDAModel(config) |
| 33 | + >>> output = model.generate(prompt="Hello", max_length=50, temperature=0.7) |
13 | 34 | """ |
| 35 | + |
14 | 36 | def __init__(self, config): |
15 | 37 | super().__init__() |
16 | 38 | self.config = config |
@@ -61,13 +83,11 @@ def forward( |
61 | 83 | ) -> Dict[str, torch.Tensor]: |
62 | 84 | """ |
63 | 85 | Forward pass through the model |
64 | | - |
65 | 86 | Args: |
66 | 87 | input_ids: Tensor of token ids [batch_size, seq_len] |
67 | 88 | attention_mask: Mask tensor [batch_size, seq_len] |
68 | 89 | timesteps: Current diffusion timesteps [batch_size] |
69 | 90 | labels: Target token ids for masked positions [batch_size, seq_len] |
70 | | - |
71 | 91 | Returns: |
72 | 92 | dict with loss and logits |
73 | 93 | """ |
|
0 commit comments