-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathencoder_decoder.py
More file actions
43 lines (27 loc) · 1003 Bytes
/
encoder_decoder.py
File metadata and controls
43 lines (27 loc) · 1003 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torch import nn
# class Classifier():
# pass
class Encoder(nn.Module):
"""The encoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
def foward(self, X, *args):
raise NotImplementedError
class Decoder(nn.Module):
"""The decoder interface for the encoder-decoder architecture."""
def __init__(self):
super().__init__()
def init_sate(self, enc_all_outputs, *args):
raise NotImplementedError
def forward(self, X, state):
raise NotImplementedError
class EncoderDecoder(nn.Module):
def __init__(self, encoder, decoder):
super().__init__()
self.encoder = encoder
self.decoder = decoder
def forward(self, enc_X, dec_X, *args):
enc_all_outputs = self.encoder(enc_X, *args)
dec_state = self.decoder.init_state(enc_all_outputs, *args)
# Return decoder output only
return self.decoder(dec_X, dec_state)[0]