public class AttentionVertex extends SameDiffVertex
| Modifier and Type | Class and Description |
|---|---|
static class |
AttentionVertex.Builder |
| Modifier and Type | Field and Description |
|---|---|
protected WeightInit |
weightInit |
biasUpdater, dataType, gradientNormalization, gradientNormalizationThreshold, regularization, regularizationBias, updater| Modifier | Constructor and Description |
|---|---|
protected |
AttentionVertex(AttentionVertex.Builder builder) |
| Modifier and Type | Method and Description |
|---|---|
AttentionVertex |
clone() |
void |
defineParametersAndInputs(SDVertexParams params)
Define the parameters - and inputs - for the network.
|
SDVariable |
defineVertex(SameDiff sameDiff,
Map<String,SDVariable> layerInput,
Map<String,SDVariable> paramTable,
Map<String,SDVariable> maskVars)
Define the vertex
|
Pair<INDArray,MaskState> |
feedForwardMaskArrays(INDArray[] maskArrays,
MaskState currentMaskState,
int minibatchSize) |
InputType |
getOutputType(int layerIndex,
InputType... vertexInputs)
Determine the type of output for this GraphVertex, given the specified inputs.
|
void |
initializeParameters(Map<String,INDArray> params)
Set the initial parameter values for this layer, if required
|
applyGlobalConfig, applyGlobalConfigToLayer, getGradientNormalization, getGradientNormalizationThreshold, getLayerName, getMemoryReport, getRegularizationByParam, getUpdaterByParam, getVertexParams, instantiate, isPretrainParam, maxVertexInputs, minVertexInputs, numParams, paramReshapeOrder, setDataType, validateInputequals, hashCodeprotected WeightInit weightInit
protected AttentionVertex(AttentionVertex.Builder builder)
public AttentionVertex clone()
clone in class GraphVertexpublic InputType getOutputType(int layerIndex, InputType... vertexInputs) throws InvalidInputTypeException
GraphVertexgetOutputType in class SameDiffVertexlayerIndex - The index of the layer (if appropriate/necessary).vertexInputs - The inputs to this vertexInvalidInputTypeException - If the input type is invalid for this type of GraphVertexpublic void defineParametersAndInputs(SDVertexParams params)
SameDiffVertexSDLayerParams.addWeightParam(String, long...) and
SDLayerParams.addBiasParam(String, long...).
Note also you must define (and optionally name) the inputs to the vertex. This is required so that
DL4J knows how many inputs exists for the vertex.defineParametersAndInputs in class SameDiffVertexparams - Object used to set parameters for this layerpublic void initializeParameters(Map<String,INDArray> params)
SameDiffVertexinitializeParameters in class SameDiffVertexparams - Parameter arrays that may be initializedpublic Pair<INDArray,MaskState> feedForwardMaskArrays(INDArray[] maskArrays, MaskState currentMaskState, int minibatchSize)
feedForwardMaskArrays in class SameDiffVertexpublic SDVariable defineVertex(SameDiff sameDiff, Map<String,SDVariable> layerInput, Map<String,SDVariable> paramTable, Map<String,SDVariable> maskVars)
SameDiffVertexdefineVertex in class SameDiffVertexsameDiff - SameDiff instancelayerInput - Input to the layer - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)paramTable - Parameter table - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)maskVars - Masks of input, if available - keys as defined by SameDiffVertex.defineParametersAndInputs(SDVertexParams)Copyright © 2021. All rights reserved.