How to Build a Privacy-Preserving Federated Pipeline to Fine-Tune Large Language Models with LoRA Using Flower and PEFT

How to Build a Privacy-Preserving Federated Pipeline to Fine-Tune Large Language Models with LoRA Using Flower and PEFT

In this tutorial, we demonstrate how to federate fine-tuning of a large language model using LoRA without ever centralizing private text data. We simulate multiple organizations as virtual clients and show how each client adapts a shared base model locally while exchanging only lightweight LoRA adapter parameters. By combining Flower’s federated learning simulation engine with parameter-efficient fine-tuning, we demonstrate a practical, scalable approach for organizations that want to customize LLMs on sensitive data while preserving privacy and reducing communication and compute costs. Check out the FULL CODES here.

!pip -q install -U "protobuf<5" "flwr[simulation]" transformers peft accelerate datasets sentencepiece
import torch
if torch.cuda.is_available():
   !pip -q install -U bitsandbytes
import os
os.environ["RAY_DISABLE_USAGE_STATS"] = "1"
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import math
import random
import numpy as np
from typing import Dict, List, Tuple, Optional
from torch.utils.data import DataLoader
from datasets import Dataset
import flwr as fl
from flwr.common import Context
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, DataCollatorForLanguageModeling
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
SEED = 7
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", DEVICE)
GPU_MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
CPU_MODEL_ID = "distilgpt2"
MODEL_ID = GPU_MODEL_ID if DEVICE == "cuda" else CPU_MODEL_ID
MAX_LEN = 256 if DEVICE == "cuda" else 192
NUM_CLIENTS = 3
ROUNDS = 3
LOCAL_EPOCHS = 1
BATCH_SIZE = 2
GRAD_ACCUM = 4
LR = 2e-4
WARMUP_STEPS = 5
WEIGHT_DECAY = 0.0
LOG_EVERY = 10
CLIENT_TEXTS: Dict[int, List[str]] = {
   0: [
       "Policy memo: Employees must rotate on-call weekly and document incidents in the internal tracker.",
       "Runbook: If latency spikes, check the database connection pool and recent deploys, then roll back if needed.",
       "Security note: Never paste customer identifiers into public issue trackers. Use redacted tokens.",
       "Engineering guideline: Prefer idempotent retries for event processing; avoid duplicate side-effects.",
       "Postmortem template: impact, timeline, root cause, contributing factors, action items, owners, deadlines."
   ],
   1: [
       "Credit risk review: monitor delinquency curves by cohort and compare against seasonal baselines.",
       "Fraud signals: repeated small authorizations, device changes, and sudden merchant-category shifts require review.",
       "Portfolio strategy: tighten limits on volatile segments while maintaining service levels for stable accounts.",
       "Operational note: reconcile chargebacks weekly and track win-rate by reason code.",
       "Internal SOP: escalation path is analyst -> manager -> compliance for high-risk cases."
   ],
   2: [
       "Fleet ops: preventive maintenance reduces downtime; prioritize vehicles with repeated fault codes.",
       "Dispatch note: optimize routes by time windows and driver hours to reduce empty miles.",
       "Safety policy: enforce rest breaks and log inspections before long-haul trips.",
       "Inventory update: track spare parts usage; reorder thresholds should reflect lead time and seasonality.",
       "Customer SLA: late deliveries require proactive notifications and documented root cause."
   ],
}
for cid in list(CLIENT_TEXTS.keys()):
   base = CLIENT_TEXTS[cid]
   CLIENT_TEXTS[cid] = base + [f"Q: Summarize this for leadership. A: {t}" for t in base]
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
   tokenizer.pad_token = tokenizer.eos_token
bnb_config: Optional[BitsAndBytesConfig] = None
if DEVICE == "cuda":
   compute_dtype = torch.bfloat16 if torch.cuda.get_device_capability(0)[0] >= 8 else torch.float16
   bnb_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=compute_dtype)
if "gpt2" in MODEL_ID.lower():
   TARGET_MODULES = ["c_attn", "c_proj"]
else:
   TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
LORA_R = 16
LORA_ALPHA = 32
LORA_DROPOUT = 0.05
lora_config = LoraConfig(r=LORA_R, lora_alpha=LORA_ALPHA, lora_dropout=LORA_DROPOUT, bias="none", task_type="CAUSAL_LM", target_modules=TARGET_MODULES)
def model_primary_device(model) -> torch.device:
   return next(model.parameters()).device
def build_model_with_lora():
   if DEVICE == "cuda":
       model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto", quantization_config=bnb_config, torch_dtype="auto")
       model = prepare_model_for_kbit_training(model)
   else:
       model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype=torch.float32)
       model.to("cpu")
   model = get_peft_model(model, lora_config)
   model.train()
   return model
def make_dataset(texts: List[str]) -> Dataset:
   ds = Dataset.from_dict({"text": texts})
   def tok(batch):
       return tokenizer(batch["text"], truncation=True, max_length=MAX_LEN, padding="max_length")
   ds = ds.map(tok, batched=True, remove_columns=["text"])
   return ds
collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
def lora_state_keys(model) -> List[str]:
   sd = model.state_dict()
   keys = sorted([k for k in sd.keys() if "lora_" in k])
   if not keys:
       raise RuntimeError("No LoRA keys found. Your model might not have the target_modules specified. " f"Current TARGET_MODULES={TARGET_MODULES}, MODEL_ID={MODEL_ID}")
   return keys
def get_lora_ndarrays(model) -> List[np.ndarray]:
   sd = model.state_dict()
   keys = lora_state_keys(model)
   return [sd[k].detach().float().cpu().numpy() for k in keys]
def set_lora_ndarrays(model, arrays: List[np.ndarray]) -> None:
   keys = lora_state_keys(model)
   if len(keys) != len(arrays):
       raise ValueError(f"Mismatch: got {len(arrays)} arrays but expected {len(keys)}.")
   sd = model.state_dict()
   for k, arr in zip(keys, arrays):
       t = torch.from_numpy(arr).to(sd[k].device).to(sd[k].dtype)
       sd[k].copy_(t)
def cosine_warmup_lr(step: int, total_steps: int, base_lr: float, warmup_steps: int) -> float:
   if step < warmup_steps:
       return base_lr * (step + 1) / max(1, warmup_steps)
   progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
   return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))
@torch.no_grad()
def eval_loss(model, ds: Dataset, max_batches: int = 20) -> float:
   model.eval()
   dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collator)
   losses = []
   dev = model_primary_device(model)
   for i, batch in enumerate(dl):
       if i >= max_batches:
           break
       batch = {k: v.to(dev) for k, v in batch.items()}
       out = model(**batch, labels=batch["input_ids"])
       losses.append(float(out.loss.detach().cpu()))
   model.train()
   return float(np.mean(losses)) if losses else float("nan")
def train_one_client_round(model, ds: Dataset, epochs: int, lr: float, grad_accum: int, warmup_steps: int) -> Tuple[float, int]:
   dl = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
   total_steps = max(1, (len(dl) * epochs) // max(1, grad_accum))
   step = 0
   optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=WEIGHT_DECAY)
   optimizer.zero_grad(set_to_none=True)
   running = []
   examples = 0
   dev = model_primary_device(model)
   for _ in range(epochs):
       for bi, batch in enumerate(dl):
           batch = {k: v.to(dev) for k, v in batch.items()}
           out = model(**batch, labels=batch["input_ids"])
           loss = out.loss / grad_accum
           loss.backward()
           running.append(float(loss.detach().cpu()) * grad_accum)
           examples += batch["input_ids"].shape[0]
           if (bi + 1) % grad_accum == 0:
               lr_t = cosine_warmup_lr(step, total_steps, lr, warmup_steps)
               for pg in optimizer.param_groups:
                   pg["lr"] = lr_t
               optimizer.step()
               optimizer.zero_grad(set_to_none=True)
               step += 1
               if step % LOG_EVERY == 0:
                   print(f"  step={step}/{total_steps} loss={np.mean(running[-LOG_EVERY:]):.4f} lr={lr_t:.2e}")
   return float(np.mean(running)) if running else float("nan"), examples

We set up the full execution environment and define all global configurations required for the experiment. We prepare the private client text silos, tokenizer, LoRA configuration, and model-loading logic so they automatically adapt to CPU or GPU availability. We also establish all helper utilities that enable parameter-efficient fine-tuning and safe device handling across federated clients. Check out the FULL CODES here.

class FedLoRAClient(fl.client.NumPyClient):
   def __init__(self, cid: int):
       self.cid = cid
       self._model = None
       self._ds_train = None
       self._ds_eval = None
   def _ensure(self):
       if self._model is None:
           print(f"[Client {self.cid}] Loading model + LoRA (MODEL_ID={MODEL_ID})...")
           self._model = build_model_with_lora()
           texts = CLIENT_TEXTS[self.cid].copy()
           random.shuffle(texts)
           split = max(1, int(0.8 * len(texts)))
           self._ds_train = make_dataset(texts[:split])
           self._ds_eval = make_dataset(texts[split:])
   def get_parameters(self, config):
       self._ensure()
       return get_lora_ndarrays(self._model)
   def fit(self, parameters, config):
       self._ensure()
       set_lora_ndarrays(self._model, parameters)
       loss_before = eval_loss(self._model, self._ds_eval, max_batches=10)
       print(f"[Client {self.cid}] eval_loss_before={loss_before:.4f}")
       train_loss, n_examples = train_one_client_round(self._model, self._ds_train, epochs=int(config.get("local_epochs", LOCAL_EPOCHS)), lr=float(config.get("lr", LR)), grad_accum=int(config.get("grad_accum", GRAD_ACCUM)), warmup_steps=int(config.get("warmup_steps", WARMUP_STEPS)))
       loss_after = eval_loss(self._model, self._ds_eval, max_batches=10)
       print(f"[Client {self.cid}] train_loss={train_loss:.4f} eval_loss_after={loss_after:.4f}")
       new_params = get_lora_ndarrays(self._model)
       metrics = {"eval_loss_before": loss_before, "eval_loss_after": loss_after, "train_loss": train_loss}
       return new_params, n_examples, metrics
   def evaluate(self, parameters, config):
       self._ensure()
       set_lora_ndarrays(self._model, parameters)
       loss = eval_loss(self._model, self._ds_eval, max_batches=20)
       return float(loss), len(self._ds_eval), {"eval_loss": float(loss)}
def client_fn(context: Context):
   cid = None
   try:
       cid = int(context.node_config.get("partition-id"))
   except Exception:
       try:
           cid = int(context.node_id)
       except Exception:
           cid = 0
   return FedLoRAClient(cid).to_client()

We define the federated client logic that simulates independent organizations participating in training. We initialize a LoRA-augmented language model per client and ensure that local datasets remain isolated. We implement client-side training, evaluation, and parameter exchange while exposing only LoRA adapter weights to the server.

def fit_config(server_round: int):
   return {"local_epochs": LOCAL_EPOCHS, "lr": LR, "grad_accum": GRAD_ACCUM, "warmup_steps": WARMUP_STEPS}
strategy = fl.server.strategy.FedAvg(fraction_fit=1.0, fraction_evaluate=1.0, min_fit_clients=NUM_CLIENTS, min_evaluate_clients=NUM_CLIENTS, min_available_clients=NUM_CLIENTS, on_fit_config_fn=fit_config)
print("nStarting Flower simulation...n")
client_resources = {"num_cpus": 2, "num_gpus": 0.0}
if DEVICE == "cuda":
   client_resources = {"num_cpus": 2, "num_gpus": 0.25}
history = fl.simulation.start_simulation(client_fn=client_fn, num_clients=NUM_CLIENTS, config=fl.server.ServerConfig(num_rounds=ROUNDS), strategy=strategy, client_resources=client_resources, ray_init_args={"include_dashboard": False, "ignore_reinit_error": True})
print("nSimulation done.")

We configure the federated learning strategy and orchestrate the global training process. We specify how many clients participate, how parameters are aggregated, and how training rounds are scheduled. We then launch the Flower simulation to coordinate communication and aggregation across all virtual clients. Check out the FULL CODES here.

demo_model = build_model_with_lora()
demo_model.eval()
prompt = "Summarize this internal note for leadership in 2 bullets:nDispatch note: optimize routes by time windows and driver hours to reduce empty miles.nnAnswer:"
inputs = tokenizer(prompt, return_tensors="pt")
dev = model_primary_device(demo_model)
inputs = {k: v.to(dev) for k, v in inputs.items()}
with torch.no_grad():
   out = demo_model.generate(**inputs, max_new_tokens=80, do_sample=True, temperature=0.8, top_p=0.95, repetition_penalty=1.05, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.pad_token_id)
print("n=== Generation output ===n")
print(tokenizer.decode(out[0], skip_special_tokens=True))

We load a final LoRA-augmented model instance to demonstrate inference after federated training. We prepare a realistic prompt and run text generation using the same architecture employed during training. We verify that the pipeline executes end-to-end by producing coherent, task-aligned outputs.

print(type(history))
print(history.__dict__.keys())

We inspect the training artifacts and simulation outputs produced by the federated run. We examine the returned history object to confirm that rounds, metrics, and aggregation completed successfully. We use this step to validate the overall integrity and reproducibility of the federated fine-tuning workflow.

In conclusion, we showed that federated fine-tuning of LLMs is not only a research concept but something we can run end-to-end in a Colab environment today. We successfully coordinated client-side LoRA training, server-side aggregation, and evaluation without sharing raw text or full model weights. This workflow highlights how federated learning, when paired with modern PEFT techniques, enables privacy-preserving adaptation of generative models and provides a strong foundation for extending the system toward personalization, robustness, and real-world enterprise deployment.


Check out the FULL CODES here. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.

The post How to Build a Privacy-Preserving Federated Pipeline to Fine-Tune Large Language Models with LoRA Using Flower and PEFT appeared first on MarkTechPost.