Deehan1866/WiC_actual
Viewer โข Updated โข 7.47k โข 57
Dual-head model for the Word-in-Context (WiC) binary sense disambiguation task.
L_total = beta * L_cls + alpha * L_rationale_mlm
During training the input is:
[CLS] sentence1_marked [SEP] sentence2_marked [SEP] <REASON> rationale_masked [SEP]
During inference the rationale is absent โ only the classification head is used:
[CLS] sentence1_marked [SEP] sentence2_marked [SEP]
The target word is wrapped with <TGT>word</TGT> using the exact 0-indexed word
position from the dataset (start1/start2 columns).
| Parameter | Value |
|---|---|
| alpha (MLM loss weight) | 0.5 |
| beta (CLS loss weight) | 1.0 |
| mlm_mask_prob | 0.3 |
| max_length_train | 512 |
| max_length_cls | 256 |
| learning_rate | 1e-05 |
| Split | Accuracy |
|---|---|
| Validation | 0.6912 |
| Test | 0.6479 |
import torch
import torch.nn as nn
from transformers import AutoTokenizer, AutoModel
# Rebuild the classifier head
class Classifier(nn.Module):
def __init__(self, hidden_size=1024, num_labels=2):
super().__init__()
self.net = nn.Sequential(nn.Dropout(0.1), nn.Linear(hidden_size, num_labels))
def forward(self, x): return self.net(x)
repo_id = "Deehan1866/deberta-wic-angle1-withbannedwords"
tokenizer = AutoTokenizer.from_pretrained(repo_id)
encoder = AutoModel.from_pretrained(repo_id)
clf = Classifier()
clf.net.load_state_dict(torch.load(
"classifier_head.pt", # download from hub first
map_location="cpu"
))
s1 = "The <TGT>bank</TGT> raised its interest rates."
s2 = "She visited her local <TGT>bank</TGT> to deposit a cheque."
enc = tokenizer(s1, s2, return_tensors="pt", truncation=True, max_length=256)
with torch.no_grad():
hidden = encoder(**enc).last_hidden_state[:, 0, :]
logits = clf(hidden)
pred = torch.argmax(logits).item()
print("Same sense" if pred == 1 else "Different sense")