class Architecture(nn.Module):def __init__(self, n_question, n_blocks, d_model, d_feature,d_ff, n_heads, dropout, kq_same, model_type):super().__init__()"""Transformer architecture with two types of blocks."""self.d_model = d_modelself.model_type = model_typeif model_type in {'akt'}:self.blocks_1 = nn.ModuleList([TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same)for _ in range(n_blocks)])self.blocks_2 = nn.ModuleList([TransformerLayer(d_model=d_model, d_feature=d_model // n_heads,d_ff=d_ff, dropout=dropout, n_heads=n_heads, kq_same=kq_same)for _ in range(n_blocks * 2)])self.model1 = Mamba(# This module uses roughly 3 * expand * d_model^2 parametersd_model=256, # Model dimension d_modeld_state=128, # SSM state expansion factord_conv=4, # Local convolution widthexpand=2, # Block expansion factor).to("cuda")self.model2 = Mamba(# This module uses roughly 3 * expand * d_model^2 parametersd_model=256, # Model dimension d_modeld_state=128, # SSM state expansion factord_conv=4, # Local convolution widthexpand=2, # Block expansion factor).to("cuda")# self.model3 = Mamba(# # This module uses roughly 3 * expand * d_model^2 parameters# d_model=256, # Model dimension d_model# d_state=128, # SSM state expansion factor# d_conv=4, # Local convolution width# expand=2, # Block expansion factor# ).to("cuda")# y = model(x)def forward(self, q_embed_data, qa_embed_data):seqlen, batch_size = q_embed_data.size(1), q_embed_data.size(0)qa_pos_embed = qa_embed_dataq_pos_embed = q_embed_datay = qa_pos_embedx = q_pos_embed# Block 1: Time-decay attentionfor block in self.blocks_1:# y = block(mask=1, query=y, key=y, values=y, attention_type=1)y = self.model2(y)y = self.model2(y)# Block 2: Mix of attention typesflag_first = Truefor block in self.blocks_2:if flag_first: # Current question: Time-decay attention# x = block(mask=1, query=x, key=x, values=x, apply_pos=False, attention_type=1)x = self.model1(x)flag_first = Falseelse: # Don't peek response: Time-decay + Sparse attentionx = block(mask=0, query=x, key=x, values=y, apply_pos=True, attention_type=1)flag_first = Truereturn x