load=ndlMacroDefine
run=ndlCreateNetwork_LSTMP_c1024_p256_x3

ndlMacroDefine=[
    # Macro definitions
    MeanVarNorm(x)=[
	xMean = Mean(x);
    	xStdDev = InvStdDev(x)
    	xNorm=PerDimMeanVarNormalization(x,xMean,xStdDev)
    ]
 
    LogPrior(labels)
    {
	Prior=Mean(labels)
	LogPrior=Log(Prior)
    }   


    LSTMPComponent(inputDim, outputDim, cellDim, inputx)
    {
        Wxo = Parameter(cellDim, inputDim, init=uniform, initValueScale=1);
        Wxi = Parameter(cellDim, inputDim, init=uniform, initValueScale=1);
        Wxf = Parameter(cellDim, inputDim, init=uniform, initValueScale=1);
        Wxc = Parameter(cellDim, inputDim, init=uniform, initValueScale=1);

        bo = Parameter(cellDim, init=fixedValue, value=0.0);
        bc = Parameter(cellDim, init=fixedValue, value=0.0);
        bi = Parameter(cellDim, init=fixedValue, value=0.0);
        bf = Parameter(cellDim, init=fixedValue, value=0.0);

        Whi = Parameter(cellDim, outputDim, init=uniform, initValueScale=1);

        Wci = Parameter(cellDim, init=uniform, initValueScale=1);


        Whf = Parameter(cellDim, outputDim, init=uniform, initValueScale=1);
        Wcf = Parameter(cellDim, init=uniform, initValueScale=1);
        Who = Parameter(cellDim, outputDim, init=uniform, initValueScale=1);
        Wco = Parameter(cellDim, init=uniform, initValueScale=1);
        Whc = Parameter(cellDim, outputDim, init=uniform, initValueScale=1);

        dh = Delay(outputDim, output, delayTime=1);
        dc = Delay(cellDim, ct, delayTime=1);


        Wxix = Times(Wxi, inputx);
        Whidh = Times(Whi, dh);
        Wcidc = DiagTimes(Wci, dc);

        it = Sigmoid (Plus ( Plus (Plus (Wxix, bi), Whidh), Wcidc));

        Wxcx = Times(Wxc, inputx);
        Whcdh = Times(Whc, dh);
        bit = ElementTimes(it, Tanh( Plus(Wxcx, Plus(Whcdh, bc))));

        Wxfx = Times(Wxf, inputx);
        Whfdh = Times(Whf, dh);
        Wcfdc = DiagTimes(Wcf, dc);

        ft = Sigmoid( Plus (Plus (Plus(Wxfx, bf), Whfdh), Wcfdc));

        bft = ElementTimes(ft, dc);

        ct = Plus(bft, bit);

        Wxox  = Times(Wxo, inputx);
        Whodh = Times(Who, dh);
        Wcoct = DiagTimes(Wco, ct);

        ot = Sigmoid( Plus( Plus( Plus(Wxox, bo), Whodh), Wcoct));

        mt = ElementTimes(ot, Tanh(ct));

        Wmr = Parameter(outputDim, cellDim, init=uniform, initValueScale=1);
        output = Times(Wmr, mt); 

    }

]

ndlCreateNetwork_LSTMP_c1024_p256_x3=[

	#define basic i/o
	baseFeatDim=40
    RowSliceStart=400 
    FeatDim=$featDim$
	labelDim=$labelDim$
    cellDim=1024
	hiddenDim=512

	features=Input(FeatDim, tag=feature)
	labels=Input(labelDim, tag=label)
    feashift=RowSlice(RowSliceStart, baseFeatDim, features);      # shift 5 frames right (x_{t+5} -> x_{t} )


	featNorm = MeanVarNorm(feashift)


	# layer 1
    LSTMoutput1 = LSTMPComponent(baseFeatDim, hiddenDim, cellDim, featNorm);
    # layer 2 
    LSTMoutput2 = LSTMPComponent(hiddenDim, hiddenDim, cellDim, LSTMoutput1);
    # layer 3 
    LSTMoutput3 = LSTMPComponent(hiddenDim, hiddenDim, cellDim, LSTMoutput2);




	W = Parameter(labelDim, hiddenDim, init=uniform, initValueScale=1);
	b = Parameter(labelDim, init=fixedvalue, value=0);
	LSTMoutputW = Plus(Times(W, LSTMoutput3), b);

	
    cr = CrossEntropyWithSoftmax(labels, LSTMoutputW,tag=Criteria);
    Err = ClassificationError(labels,LSTMoutputW,tag=Eval);
    
    logPrior = LogPrior(labels)	 
    ScaledLogLikelihood=Minus(LSTMoutputW,logPrior,tag=Output)

]

