## Data Stuff

We take our data set from [The CMU pronouncing dictionary](https://en.wikipedia.org/wiki/CMU_Pronouncing_Dictionary)

In [1]:
%matplotlib inline
import importlib
import utils2; importlib.reload(utils2)
from utils2 import *
np.set_printoptions(4)
PATH = 'data/spellbee/'

Using TensorFlow backend.


In [2]:
limit_mem()

In [3]:
from sklearn.model_selection import train_test_split

In [4]:
lines = [l.strip().split("  ") for l in open(PATH+"cmudict-0.7b", encoding='latin1') 
         if re.match('^[A-Z]', l)]
lines = [(w, ps.split()) for w, ps in lines]
lines[0], lines[-1]

(('A', ['AH0']), ('ZYWICKI', ['Z', 'IH0', 'W', 'IH1', 'K', 'IY0']))

In [5]:
phonemes = ["_"] + sorted(set(p for w, ps in lines for p in ps))
phonemes[:5]

['_', 'AA0', 'AA1', 'AA2', 'AE0']

In [6]:
len(phonemes)

70

In [7]:
p2i = dict((v, k) for k,v in enumerate(phonemes))
letters = "_abcdefghijklmnopqrstuvwxyz*"
l2i = dict((v, k) for k,v in enumerate(letters))

In [8]:
maxlen=15
pronounce_dict = {w.lower(): [p2i[p] for p in ps] for w, ps in lines
                 if (5<=len(w)<=maxlen) and re.match("^[A-Z]+$", w)}
len(pronounce_dict)

108006

In [9]:
a=['xyz','abc']
[o.upper() for o in a if o[0]=='x'], [[p for p in o] for o in a], [p for o in a for p in o]

(['XYZ'], [['x', 'y', 'z'], ['a', 'b', 'c']], ['x', 'y', 'z', 'a', 'b', 'c'])

Split lines into words, phonemes, convert to indexes (with padding), split into training, validation, test sets.

In [10]:
maxlen_p = max([len(v) for k,v in pronounce_dict.items()])

In [11]:
pairs = np.random.permutation(list(pronounce_dict.keys()))
n = len(pairs)
input_ = np.zeros((n, maxlen_p), np.int32)
labels_ = np.zeros((n, maxlen), np.int32)

for i, k in enumerate(pairs):
    for j, p in enumerate(pronounce_dict[k]): input_[i][j] = p
    for j, letter in enumerate(k): labels_[i][j] = l2i[letter]

In [12]:
go_token = l2i["*"]
dec_input_ = np.concatenate([np.ones((n,1)) * go_token, labels_[:,:-1]], axis=1)

In [13]:
(input_train, input_test, labels_train, labels_test, dec_input_train, dec_input_test
    ) = train_test_split(input_, labels_, dec_input_, test_size=0.1)

In [14]:
input_train.shape

(97205, 16)

In [15]:
labels_train.shape

(97205, 15)

In [16]:
input_vocab_size, output_vocab_size = len(phonemes), len(letters)
input_vocab_size, output_vocab_size

(70, 28)

## Keras code

In [17]:
parms = {'verbose': 0, 'callbacks': [TQDMNotebookCallback(leave_inner=True)]}
lstm_params = {}

In [18]:
dim = 240

### Without attention

In [19]:
def get_rnn(return_sequences= True): 
    return LSTM(dim, dropout_U= 0.1, dropout_W= 0.1, 
               consume_less= 'gpu', return_sequences=return_sequences)

In [20]:
inp = Input((maxlen_p,))
x = Embedding(input_vocab_size, 120)(inp)

x = Bidirectional(get_rnn())(x)
x = get_rnn(False)(x)

x = RepeatVector(maxlen)(x)
x = get_rnn()(x)
x = get_rnn()(x)
x = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(x)

In [21]:
model = Model(inp, x)

In [22]:
model.compile(Adam(), 'sparse_categorical_crossentropy', metrics=['acc'])

In [23]:
hist=model.fit(input_train, np.expand_dims(labels_train,-1), 
          validation_data=[input_test, np.expand_dims(labels_test,-1)], 
          batch_size=64, **parms, nb_epoch=3)

KeyboardInterrupt: 

In [None]:
hist.history['val_loss']

In [None]:
hist.history['val_loss']

In [None]:
def eval_keras(input):
    preds = model.predict(input, batch_size=128)
    predict = np.argmax(preds, axis = 2)
    return (np.mean([all(real==p) for real, p in zip(labels_test, predict)]), predict)

In [None]:
acc, preds = eval_keras(input_test); acc

In [51]:
def print_examples(preds):
    print("pronunciation".ljust(40), "real spelling".ljust(17), 
          "model spelling".ljust(17), "is correct")

    for index in range(20):
        ps = "-".join([phonemes[p] for p in input_test[index]]) 
        real = [letters[l] for l in labels_test[index]] 
        predict = [letters[l] for l in preds[index]]
        print (ps.split("-_")[0].ljust(40), "".join(real).split("_")[0].ljust(17),
            "".join(predict).split("_")[0].ljust(17), str(real == predict))

In [52]:
print_examples(preds)

pronunciation                            real spelling     model spelling    is correct
OW1-SH-AA0-F                             oshaf             ossaf             False
R-IH1-NG-K-AH0-L-D                       wrinkled          rinkkld           False
D-IH0-T-EH1-K-SH-AH0-N                   detection         detection         True
Y-AE1-S-IH0-N                            yassin            yasin             False
AA1-P-HH-AY2-M                           opheim            ophime            False
JH-AH1-S-K-OW0                           jusco             jusko             False
B-L-UW1-P-EH2-N-S-AH0-L-IH0-NG           bluepencilling    bluopensiinn      False
L-AE1-R-OW0                              laroe             larro             False
L-AO1-D-AH0-T-AO2-R-IY0                  laudatory         laudttrr          False
P-ER2-T-ER0-B-EY1-SH-AH0-N-Z             perturbations     perterbations     False
EH1-V-AH0-D-AH0-N-S-AH0-Z                evidences         evedences         False


### Attention model

<img src="https://smerity.com/media/images/articles/2016/bahdanau_attn.png" width="600">

In [62]:
import attention_wrapper; importlib.reload(attention_wrapper)
from attention_wrapper import Attention

In [66]:
inp = Input((maxlen_p,))
inp_dec = Input((maxlen,))
emb_dec = Embedding(output_vocab_size, 120)(inp_dec)
emb_dec = Dense(dim)(emb_dec)

x = Embedding(input_vocab_size, 120)(inp)
x = Bidirectional(get_rnn())(x)
x = get_rnn()(x)
x = get_rnn()(x)
x = Attention(get_rnn, 3)([x, emb_dec])
x = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(x)

In [67]:
model = Model([inp, inp_dec], x)
model.compile(Adam(), 'sparse_categorical_crossentropy', metrics=['acc'])

In [68]:
hist=model.fit([input_train, dec_input_train], np.expand_dims(labels_train,-1), 
          validation_data=[[input_test, dec_input_test], np.expand_dims(labels_test,-1)], 
          batch_size=64, **parms, nb_epoch=3)






In [25]:
hist.history['val_loss']

[1.1220142268966762,
 0.8048337527904541,
 0.5265462693931019,
 0.31194811385601234,
 0.23548489567190725,
 0.22613589595439379,
 0.19347934096944011,
 0.19403484622484754,
 0.18044625614216103,
 0.17990882804400168]

In [998]:
K.set_value(model.optimizer.lr, 1e-4)

In [999]:
hist=model.fit([input_train, dec_input_train], np.expand_dims(labels_train,-1), 
          validation_data=[[input_test, dec_input_test], np.expand_dims(labels_test,-1)], 
          batch_size=64, **parms, nb_epoch=5)




In [1000]:
np.array(hist.history['val_loss'])

array([ 0.1591,  0.1563,  0.1532,  0.1517,  0.1499])

In [1001]:
def eval_keras():
    preds = model.predict([input_test, dec_input_test], batch_size=128)
    predict = np.argmax(preds, axis = 2)
    return (np.mean([all(real==p) for real, p in zip(labels_test, predict)]), predict)

In [895]:
acc, preds = eval_keras(); acc

0.51134154244977315

In [896]:
print("pronunciation".ljust(40), "real spelling".ljust(17), 
      "model spelling".ljust(17), "is correct")

for index in range(20):
    ps = "-".join([phonemes[p] for p in input_test[index]]) 
    real = [letters[l] for l in labels_test[index]] 
    predict = [letters[l] for l in preds[index]]
    print (ps.split("-_")[0].ljust(40), "".join(real).split("_")[0].ljust(17),
        "".join(predict).split("_")[0].ljust(17), str(real == predict))

pronunciation                            real spelling     model spelling    is correct
D-AY0-AE1-F-AH0-N-IH0-S                  diaphanous        diaphoneus        False
T-IH1-R-IY0                              teary             tiary             False
M-IH1-T-AH0-N                            mittan            mitton            False
P-AA1-R-N-AH0-S                          parness           parnass           False
SH-UW1-ER0-M-AH0-N                       schuermann        schuerman         False
P-AA1-R-T-AH0-Z-AH0-N-SH-IH2-P           partisanship      partisanship      True
AE1-SH-IH0-Z                             ashes             ashes             True
S-P-IY1-K                                speak             speek             False
K-AO1-R-B-IH0-T                          corbitt           corbit            False
B-IH1-L-D-ER0-Z                          builders          biilders          False
EH0-S-P-R-IY1                            esprit            espree            False
E

## Recurrentshop

In [629]:
import seq2seq
from seq2seq.models import Seq2Seq

In [631]:
s2s_model = Seq2Seq(batch_input_shape=(64, maxlen_p, embedding_dim), hidden_dim=embedding_dim, 
                output_length=maxlen, output_dim=embedding_dim, depth=3, peek=True)

In [646]:
inp = Input((maxlen_p,))
inp_dec = Input((maxlen,))
x = Embedding(input_vocab_size, embedding_dim)(inp)
x = s2s_model(x)
x = TimeDistributed(Dense(output_vocab_size, activation='softmax'))(x)

In [647]:
model = Model(inp, x)

In [648]:
model.output_shape

((None, 256), 15, 256)

In [636]:
model.compile('adam', 'sparse_categorical_crossentropy', metrics=['acc'])

In [639]:
hist=model.fit(input_train, np.expand_dims(labels_train,-1), 
          validation_data=[input_test, np.expand_dims(labels_test,-1)], 
          batch_size=64, **parms, nb_epoch=5)

InvalidArgumentError: Incompatible shapes: [53,1024] vs. [64,1024]
	 [[Node: add_467 = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](MatMul_42, MatMul_43)]]
	 [[Node: Mean_23/_1337 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_2557_Mean_23", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]

Caused by op 'add_467', defined at:
  File "/home/jhoward/anaconda3/lib/python3.5/runpy.py", line 184, in _run_module_as_main
    "__main__", mod_spec)
  File "/home/jhoward/anaconda3/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/traitlets/config/application.py", line 658, in launch_instance
    app.start()
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 474, in start
    ioloop.IOLoop.instance().start()
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 177, in start
    super(ZMQIOLoop, self).start()
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tornado/ioloop.py", line 887, in start
    handler_func(fd_obj, events)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 440, in _handle_events
    self._handle_recv()
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 472, in _handle_recv
    self._run_callback(callback, msg)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 414, in _run_callback
    callback(*args, **kwargs)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 276, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 228, in dispatch_shell
    handler(stream, idents, msg)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 390, in execute_request
    user_expressions, allow_stdin)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 196, in do_execute
    res = shell.run_cell(code, store_history=store_history, silent=silent)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/ipykernel/zmqshell.py", line 501, in run_cell
    return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2717, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2821, in run_ast_nodes
    if self.run_code(code, result):
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2881, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-632-ab81b0c65b90>", line 5, in <module>
    x = s2s_model(x)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 572, in __call__
    self.add_inbound_node(inbound_layers, node_indices, tensor_indices)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 635, in add_inbound_node
    Node.create_node(self, inbound_layers, node_indices, tensor_indices)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 166, in create_node
    output_tensors = to_list(outbound_layer.call(input_tensors[0], mask=input_masks[0]))
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 2247, in call
    output_tensors, output_masks, output_shapes = self.run_internal_graph(inputs, masks)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/keras/engine/topology.py", line 2390, in run_internal_graph
    computed_mask))
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/recurrentshop/engine.py", line 320, in call
    initial_states = self.get_initial_states(x)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/recurrentshop/engine.py", line 388, in get_initial_states
    input = layer._step(input, layer_initial_states)[0]
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/recurrentshop/engine.py", line 116, in _step
    return self.step(*args)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/recurrentshop/cells.py", line 130, in step
    z = K.dot(x, W) + K.dot(h_tm1, U) + b
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/math_ops.py", line 884, in binary_op_wrapper
    return func(x, y, name=name)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tensorflow/python/ops/gen_math_ops.py", line 73, in add
    result = _op_def_lib.apply_op("Add", x=x, y=y, name=name)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
    op_def=op_def)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2395, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/home/jhoward/anaconda3/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1264, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Incompatible shapes: [53,1024] vs. [64,1024]
	 [[Node: add_467 = Add[T=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"](MatMul_42, MatMul_43)]]
	 [[Node: Mean_23/_1337 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/cpu:0", send_device="/job:localhost/replica:0/task:0/gpu:0", send_device_incarnation=1, tensor_name="edge_2557_Mean_23", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"]()]]


          1518/|/[loss: 1.575, acc: 0.542] 100%|| 1518/1519 [02:40<00:00, 10.88it/s]

## Tensorflow code

In [13]:
ops.reset_default_graph()
try: sess.close()
except: pass
sess = tf.InteractiveSession()

In [21]:
import tensorflow.contrib.seq2seq as seq2seq
from tensorflow.contrib.layers import safe_embedding_lookup_sparse as embedding_lookup_unique
from tensorflow.contrib import rnn

In [15]:
data_test  = list(zip(input_test, labels_test))
data_train = list(zip(input_train, labels_train))

In [16]:
input_seq_length = maxlen_p  # Max number of phonemes
output_seq_length = maxlen   # Max number of letters in spelled word
batch_size = 128
embedding_dim = 256

In [17]:
encode_input = [tf.placeholder(tf.int32, shape=(None,), name = "ei_%i" %i)
                                for i in range(input_seq_length)]
labels = [tf.placeholder(tf.int32, shape=(None,), name = "l_%i" %i)
                                for i in range(output_seq_length)]
decode_input = [tf.zeros_like(encode_input[0], dtype=np.int32, name="GO")] + labels[:-1]

In [22]:
keep_prob = tf.placeholder("float")
cells = [rnn.DropoutWrapper(
        rnn.BasicLSTMCell(embedding_dim), output_keep_prob=keep_prob
    ) for i in range(3)]
stacked_lstm = rnn.MultiRNNCell(cells)

In [509]:
with tf.variable_scope("decoders") as scope:
    decode_outputs, decode_state = seq2seq.simple_decoder_fn_inference(
        encode_input, decode_input, stacked_lstm, input_vocab_size, output_vocab_size, 64)
    scope.reuse_variables()
    decode_outputs_test, decode_state_test = seq2seq.embedding_attention_seq2seq(
        encode_input, decode_input, stacked_lstm, input_vocab_size, output_vocab_size, 64,
    feed_previous=True)

In [510]:
loss_weights = [tf.ones_like(l, dtype=tf.float32) for l in labels]
loss = seq2seq.sequence_loss(decode_outputs, labels, loss_weights, True)
train_op = optimizer.minimize(loss)

In [512]:
optimizer = tf.train.AdamOptimizer(1e-3)

In [513]:
sess.run(tf.global_variables_initializer())

## Training model

In [514]:
import batcher; importlib.reload(batcher)
from batcher import Batcher

In [515]:
def trans(x): return x[0].T, x[1].T

In [516]:
(''.join(letters[i] for i in labels_train[0]), 
' '.join(phonemes[i] for i in input_train[0]))

('pfundstein______', 'F AH1 N D S T IY2 N _ _ _ _ _ _ _ _')

In [517]:
train_iter = Batcher(input_train, labels_train, shuffle=False, proc_fn=trans)
test_iter = Batcher(input_test, labels_test, shuffle=True, proc_fn=trans)

In [518]:
def get_feed(X, Y, p):
    feed_dict = {encode_input[t]: X[t] for t in range(input_seq_length)}
    feed_dict.update({labels[t]: Y[t] for t in range(output_seq_length)})
    feed_dict[keep_prob] = p
    return feed_dict

In [519]:
def train_batch(X, Y):
    feed_dict = get_feed(X, Y, 0.5)
    return sess.run([train_op, loss], feed_dict)[1]

In [520]:
def get_eval_batch_data(X, Y):
    feed_dict = get_feed(X, Y, 1.)
    [eval_loss, *d_output] = sess.run([loss] + decode_outputs_test, feed_dict)
    decode_output = np.array(d_output).transpose([1,0,2])
    return eval_loss, decode_output

In [521]:
def eval_batch(X, Y):
    eval_loss, output= get_eval_batch_data(X, Y)
    predict = np.argmax(output, axis = 2)
    acc = [all(real==p) for real, p in zip(Y.T, predict)]
    return eval_loss, np.mean(acc)

In [522]:
def eval_batches(data_iter, num_batches):
    res = [eval_batch(*next(data_iter)) for i in range(num_batches)]
    return list(map(np.mean, zip(*res)))

In [523]:
def eval_test():
    test_loss, test_predict = eval_batches(test_iter, 16)
    train_loss, train_predict = eval_batches(train_iter, 16)
    print (test_loss, test_predict * 100)
    print (train_loss, train_predict * 100)
    print()

In [525]:
fit_gen(train_iter, train_batch, eval_test, 10000)



TypeError: 'bool' object is not iterable

## Examining model outputs

In [277]:
X, Y = next(test_iter)
eval_loss, output = get_eval_batch_data(X, Y)

In [278]:
print("pronunciation".ljust(40), "real spelling".ljust(17), 
      "model spelling".ljust(17), "is correct")

for index in range(len(output)):
    ps = "-".join([phonemes[p] for p in X.T[index]]) 
    real = [letters[l] for l in Y.T[index]] 
    predict = [letters[l] for l in np.argmax(output, axis = 2)[index]]
    
    print (ps.split("-_")[0].ljust(40), "".join(real).split("_")[0].ljust(17),
        "".join(predict).split("_")[0].ljust(17), str(real == predict))

pronunciation                            real spelling     model spelling    is correct
B-AY1-R-AH0-N                            biron             biran             False
R-AY1-NG-G-OW0-L-D                       reingold          rinegold          False
S-AH1-N-D-ER0-L-IH0-N                    sunderlin         sunderlin         True
S-T-AE1-CH                               stach             statch            False
T-R-IY1-P-IY0                            tripi             treepy            False
R-AH1-G-AH0-L-Z                          ruggles           ruggles           True
M-EH0-N-D-OW1-Z-AH0                      mendoza           mendoza           True
P-EH1-R-AH0-S-AY2-T                      parasite          parasite          True
K-AA1-G-L-IY0                            cogley            cogley            True
K-AA1-N-T-AH0-M-P-L-EY2-T-IH0-D          contemplated      contimplated      False
M-IH1-K-IY0                              mickie            micky             False
HH-O

## End

In [301]:
nb_samples, nb_time, input_dim, output_dim = (64, 4, 32, 48)

In [302]:
x = tf.placeholder(np.float32, (nb_samples, nb_time, input_dim))

In [303]:
xr = K.reshape(x,(-1,nb_time,1,input_dim))

In [304]:
W1 = tf.placeholder(np.float32, (input_dim, input_dim)); W1.shape

TensorShape([Dimension(32), Dimension(32)])

In [305]:
W1r = K.reshape(W1, (1, input_dim, input_dim))

In [306]:
W1r2 = K.reshape(W1, (1, 1, input_dim, input_dim))

In [307]:
xW1 = K.conv1d(x,W1r,border_mode='same'); xW1.shape

TensorShape([Dimension(64), Dimension(4), Dimension(32)])

In [308]:
xW12 = K.conv2d(xr,W1r2,border_mode='same'); xW12.shape

TensorShape([Dimension(64), Dimension(4), Dimension(1), Dimension(32)])

In [251]:
xW2 = K.dot(x, W1)

In [245]:
x1 = np.random.normal(size=(nb_samples, nb_time, input_dim))

In [246]:
w1 = np.random.normal(size=(input_dim, input_dim))

In [248]:
res = sess.run(xW1, {x:x1, W1:w1})

In [252]:
res2 = sess.run(xW2, {x:x1, W1:w1})

In [253]:
np.allclose(res, res2)

True

In [283]:
W2 = tf.placeholder(np.float32, (output_dim, input_dim)); W2.shape

TensorShape([Dimension(48), Dimension(32)])

In [295]:
h = tf.placeholder(np.float32, (nb_samples, output_dim))

In [296]:
hW2 = K.dot(h,W2); hW2.shape

TensorShape([Dimension(64), Dimension(32)])

In [297]:
hW2 = K.reshape(hW2,(-1,1,1,input_dim)); hW2.shape

TensorShape([Dimension(64), Dimension(1), Dimension(1), Dimension(32)])