Top-level dialect for interfacing PyTorch and MLIR.
This dialect maintains a fairly isomorphic representation with TorchScript.
This dialect also provides transforms that lower it to the “Torch backend contract”, which is an IR form that we present to later conversions. The Torch backend contract significantly simplifies the IR representation and puts it in a form easier for later lowering to work on. Specifically:
[TOC]
Represent any torch type. All the other types are sub types of Any type.
An immutable boolean taking values 0 or 1.
Torch Dict type with key and value type.
The float type is used to model the Python float
type in TorchScript.
Python and TorchScript use 64-bit floating point for this type at runtime.
Note: This type is not used for modeling tensor dtypes.
The integer type used to model the Python int
type in TorchScript.
TorchScript itself models this type as a 64-bit signed integer.
Note: This type is not used for modeling tensor dtypes.
A weight and optional bias, packed into one value.
This is used to model the
__torch__.torch.classes.quantized.LinearPackedParamsBase
custom C++ class
type which is the input to some Torch quantized::
ops.
We may want to eventually have a full set of ops that model the
LinearPackedParamsBase interface, such as apply
, apply_relu
, etc.
But we instead are likely to just expand the quantized::
ops directly
and fold away the instances of this type.
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
very library-call-y, runtime-y thing that embodies a number of assumptions
about the structure of how the program will be executed, which need not hold
for backends.
Represents an instance of a torch.nn.Module
with the given className
.
Syntax:
tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers?
tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>`
sizes-spec ::= `*` | `[` size-list `]`
size-list ::= /*empty*/ | size-list-nonempty
size-list-nonempty = size (`,` size)*
size ::= `?` | decimal-literal
dtype-spec ::= `unk` | type
Represents a multi-dimensional array to model Torch’s torch.Tensor
type.
If the type is !torch.tensor
, it represents a general unrestricted
torch.Tensor
, including potential mutability, aliasing, etc.
If the type is !torch.vtensor
then the tensor is restricted to operations
that have value semantics (“v” = “value semantics”). This helps to maintain
a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. many backend contracts
mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin tensor
type which has value semantics.
Some notes about value semantics:
!torch.tensor
is a subtype of
!torch.vtensor
. Specifically, both types have the same set of values,
but !torch.tensor
additionally allows aliasing or mutating
operations.!torch.tensor
carries less static
information than a corresponding !torch.vtensor
. In particular,
!torch.vtensor
carries the static information “not used in aliasing
or mutating operations”.!torch.vtensor
can be trivially converted to the builtin tensor
type when the dtype is known (the builtin tensor
type does not allow
an unknown dtype).In the absence of the tensor-modifiers
, the type contains the minimal
amount of static information. That is, !torch.tensor
is equivalent to
!torch.tensor<*,unk>
(and similarly for !torch.vtensor
).
If sizes-spec
is not *
, it indicates additional static information
about the sizes of the tensor. It will consist of a list of elements,
with length equal to the “rank” (in MLIR parlance) or “ndim”
(in Torch parlance). Each element represents a size, with the typical
MLIR representation of a number for a statically known size and ?
for a
size that is unknown. Thus, the lattice consists of *
as the least static
information, followed by lists containing unknown sizes such as [?,?,?]
which contribute rank information, followed by statically specified sizes
for some dimensions such as [?,3,?]
, followed by fully statically
specified sizes such as [2,3,4]
.
If dtype-spec
is not unk
(“unknown”), it contains an MLIR type
which contributes static information about the dtype of the tensor.
Only types allowed by Torch are permitted.
|-------------------|--------------------|
| Torch Type | MLIR Type |
|-------------------|--------------------|
| torch.float16 | f16 |
| torch.bfloat16 | bf16 |
| torch.float32 | f32 |
| torch.float64 | f64 |
| torch.uint8 | ui8 |
| torch.int8 | si8 |
| torch.int16 | si16 |
| torch.int32 | si32 |
| torch.int64 | si64 |
| torch.bool | i1 |
| torch.qint8 | !torch.qint8 |
|-------------------|--------------------|
TODO: Support the full set of Torch dtypes. TODO: Use si1?
Note: We avoid the C++ identifier TensorType
to avoid C++ name ambiguities
with mlir::TensorType
, since most code is transitively nested in
both ::mlir
and ::mlir::torch::Torch
namespaces.
Note: We use the Torch-aligned terminology “sizes” and “dtype” instead of the MLIR-aligned terminology “rank/shape” and “element type”. The cheat sheet is:
hasRank()
-> hasSizes()
getShape()
-> getSizes()
getElementType()
-> getDtype()
(but be sure that hasDtype()
though).The singleton “None” type.
The Int, Float and Complex type are sub types of Number type.
ScalarType::QInt8
This is intended to be a 1:1 match for the Torch ScalarType
types.
Looking at the variety / ad-hocness (e.g. QUInt4x2
) of that set of
types, it is deemed preferable to import them as one-off ad-hoc types
instead of a single parameterized type.
An immutable string representing a sequence of characters.
TODO: Figure out the exact TorchScript/PyTorch string semantics. E.g. is it always unicode-encoded, etc.
Tuple type with 0-N ordered contained types.
Union type with 0-N alternative types.
NOTE: We use the terminology “contained types” for consistency with PyTorch. Strictly speaking, the types aren’t “contained” though.
TODO: Canonicalize unions based on subtype relations, to allow
using pointer equality to compare two unions for being the same.
For now, !torch.union<T1, T2>
is different from !torch.union<T2, T1>
,
and same for !torch.union<T1, SubtypeOfT1>
vs !torch.union<T1>
.
Syntax:
tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers?
tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>`
sizes-spec ::= `*` | `[` size-list `]`
size-list ::= /*empty*/ | size-list-nonempty
size-list-nonempty = size (`,` size)*
size ::= `?` | decimal-literal
dtype-spec ::= `unk` | type
Represents a multi-dimensional array to model Torch’s torch.Tensor
type.
If the type is !torch.tensor
, it represents a general unrestricted
torch.Tensor
, including potential mutability, aliasing, etc.
If the type is !torch.vtensor
then the tensor is restricted to operations
that have value semantics (“v” = “value semantics”). This helps to maintain
a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. many backend contracts
mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin tensor
type which has value semantics.
Some notes about value semantics:
!torch.tensor
is a subtype of
!torch.vtensor
. Specifically, both types have the same set of values,
but !torch.tensor
additionally allows aliasing or mutating
operations.!torch.tensor
carries less static
information than a corresponding !torch.vtensor
. In particular,
!torch.vtensor
carries the static information “not used in aliasing
or mutating operations”.!torch.vtensor
can be trivially converted to the builtin tensor
type when the dtype is known (the builtin tensor
type does not allow
an unknown dtype).In the absence of the tensor-modifiers
, the type contains the minimal
amount of static information. That is, !torch.tensor
is equivalent to
!torch.tensor<*,unk>
(and similarly for !torch.vtensor
).
If sizes-spec
is not *
, it indicates additional static information
about the sizes of the tensor. It will consist of a list of elements,
with length equal to the “rank” (in MLIR parlance) or “ndim”
(in Torch parlance). Each element represents a size, with the typical
MLIR representation of a number for a statically known size and ?
for a
size that is unknown. Thus, the lattice consists of *
as the least static
information, followed by lists containing unknown sizes such as [?,?,?]
which contribute rank information, followed by statically specified sizes
for some dimensions such as [?,3,?]
, followed by fully statically
specified sizes such as [2,3,4]
.
If dtype-spec
is not unk
(“unknown”), it contains an MLIR type
which contributes static information about the dtype of the tensor.
Only types allowed by Torch are permitted.
|-------------------|--------------------|
| Torch Type | MLIR Type |
|-------------------|--------------------|
| torch.float16 | f16 |
| torch.bfloat16 | bf16 |
| torch.float32 | f32 |
| torch.float64 | f64 |
| torch.uint8 | ui8 |
| torch.int8 | si8 |
| torch.int16 | si16 |
| torch.int32 | si32 |
| torch.int64 | si64 |
| torch.bool | i1 |
| torch.qint8 | !torch.qint8 |
|-------------------|--------------------|
TODO: Support the full set of Torch dtypes. TODO: Use si1?
Note: We avoid the C++ identifier TensorType
to avoid C++ name ambiguities
with mlir::TensorType
, since most code is transitively nested in
both ::mlir
and ::mlir::torch::Torch
namespaces.
Note: We use the Torch-aligned terminology “sizes” and “dtype” instead of the MLIR-aligned terminology “rank/shape” and “element type”. The cheat sheet is:
hasRank()
-> hasSizes()
getShape()
-> getSizes()
getElementType()
-> getDtype()
(but be sure that hasDtype()
though).Torch any type
Syntax: !torch.any
Represent any torch type. All the other types are sub types of Any type.
Torch BoolType
Syntax: !torch.bool
An immutable boolean taking values 0 or 1.
Torch device
Syntax: !torch.Device
!torch.dict[KT, VT]
Torch Dict type with key and value type.
Parameter | C++ type | Description |
---|---|---|
keyType | ::mlir::Type |
|
valueType | ::mlir::Type |
Torch FloatType
Syntax: !torch.float
The float type is used to model the Python float
type in TorchScript.
Python and TorchScript use 64-bit floating point for this type at runtime.
Note: This type is not used for modeling tensor dtypes.
Torch Generator for producing valsem-random numbers
Syntax: !torch.Generator
Torch IntType
Syntax: !torch.int
The integer type used to model the Python int
type in TorchScript.
TorchScript itself models this type as a 64-bit signed integer.
Note: This type is not used for modeling tensor dtypes.
Torch packed linear params type
Syntax: !torch.LinearParams
A weight and optional bias, packed into one value.
This is used to model the
__torch__.torch.classes.quantized.LinearPackedParamsBase
custom C++ class
type which is the input to some Torch quantized::
ops.
We may want to eventually have a full set of ops that model the
LinearPackedParamsBase interface, such as apply
, apply_relu
, etc.
But we instead are likely to just expand the quantized::
ops directly
and fold away the instances of this type.
The whole LinearPackedParamsBase abstraction as it stands in PyTorch is a
very library-call-y, runtime-y thing that embodies a number of assumptions
about the structure of how the program will be executed, which need not hold
for backends.
!torch.list
Parameter | C++ type | Description |
---|---|---|
containedType | ::mlir::Type |
torch.nn.Module
Represents an instance of a torch.nn.Module
with the given className
.
Parameter | C++ type | Description |
---|---|---|
className | ::llvm::StringRef |
class name |
Multi-dimensional array modeling Torch’s Tensor type
Syntax:
tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers?
tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>`
sizes-spec ::= `*` | `[` size-list `]`
size-list ::= /*empty*/ | size-list-nonempty
size-list-nonempty = size (`,` size)*
size ::= `?` | decimal-literal
dtype-spec ::= `unk` | type
Represents a multi-dimensional array to model Torch’s torch.Tensor
type.
If the type is !torch.tensor
, it represents a general unrestricted
torch.Tensor
, including potential mutability, aliasing, etc.
If the type is !torch.vtensor
then the tensor is restricted to operations
that have value semantics (“v” = “value semantics”). This helps to maintain
a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. many backend contracts
mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin tensor
type which has value semantics.
Some notes about value semantics:
!torch.tensor
is a subtype of
!torch.vtensor
. Specifically, both types have the same set of values,
but !torch.tensor
additionally allows aliasing or mutating
operations.!torch.tensor
carries less static
information than a corresponding !torch.vtensor
. In particular,
!torch.vtensor
carries the static information “not used in aliasing
or mutating operations”.!torch.vtensor
can be trivially converted to the builtin tensor
type when the dtype is known (the builtin tensor
type does not allow
an unknown dtype).In the absence of the tensor-modifiers
, the type contains the minimal
amount of static information. That is, !torch.tensor
is equivalent to
!torch.tensor<*,unk>
(and similarly for !torch.vtensor
).
If sizes-spec
is not *
, it indicates additional static information
about the sizes of the tensor. It will consist of a list of elements,
with length equal to the “rank” (in MLIR parlance) or “ndim”
(in Torch parlance). Each element represents a size, with the typical
MLIR representation of a number for a statically known size and ?
for a
size that is unknown. Thus, the lattice consists of *
as the least static
information, followed by lists containing unknown sizes such as [?,?,?]
which contribute rank information, followed by statically specified sizes
for some dimensions such as [?,3,?]
, followed by fully statically
specified sizes such as [2,3,4]
.
If dtype-spec
is not unk
(“unknown”), it contains an MLIR type
which contributes static information about the dtype of the tensor.
Only types allowed by Torch are permitted.
|-------------------|--------------------|
| Torch Type | MLIR Type |
|-------------------|--------------------|
| torch.float16 | f16 |
| torch.bfloat16 | bf16 |
| torch.float32 | f32 |
| torch.float64 | f64 |
| torch.uint8 | ui8 |
| torch.int8 | si8 |
| torch.int16 | si16 |
| torch.int32 | si32 |
| torch.int64 | si64 |
| torch.bool | i1 |
| torch.qint8 | !torch.qint8 |
|-------------------|--------------------|
TODO: Support the full set of Torch dtypes. TODO: Use si1?
Note: We avoid the C++ identifier TensorType
to avoid C++ name ambiguities
with mlir::TensorType
, since most code is transitively nested in
both ::mlir
and ::mlir::torch::Torch
namespaces.
Note: We use the Torch-aligned terminology “sizes” and “dtype” instead of the MLIR-aligned terminology “rank/shape” and “element type”. The cheat sheet is:
hasRank()
-> hasSizes()
getShape()
-> getSizes()
getElementType()
-> getDtype()
(but be sure that hasDtype()
though).Parameter | C++ type | Description |
---|---|---|
optionalSizes | ::llvm::Optional<::llvm::ArrayRef<int64_t>> |
sizes of dimensions |
optionalDtype | ::mlir::Type |
Torch NoneType
Syntax: !torch.none
The singleton “None” type.
Torch number type
Syntax: !torch.number
The Int, Float and Complex type are sub types of Number type.
!torch.optional
Parameter | C++ type | Description |
---|---|---|
containedType | ::mlir::Type |
Type modeling ScalarType::QInt8
Syntax: !torch.qint8
This is intended to be a 1:1 match for the Torch ScalarType
types.
Looking at the variety / ad-hocness (e.g. QUInt4x2
) of that set of
types, it is deemed preferable to import them as one-off ad-hoc types
instead of a single parameterized type.
Torch StringType
Syntax: !torch.str
An immutable string representing a sequence of characters.
TODO: Figure out the exact TorchScript/PyTorch string semantics. E.g. is it always unicode-encoded, etc.
!torch.tuple<T1, T2, T3>
Tuple type with 0-N ordered contained types.
Parameter | C++ type | Description |
---|---|---|
containedTypes | ::llvm::ArrayRef<::mlir::Type> |
contained types |
!torch.union<T1, T2, T3>
Union type with 0-N alternative types.
NOTE: We use the terminology “contained types” for consistency with PyTorch. Strictly speaking, the types aren’t “contained” though.
TODO: Canonicalize unions based on subtype relations, to allow
using pointer equality to compare two unions for being the same.
For now, !torch.union<T1, T2>
is different from !torch.union<T2, T1>
,
and same for !torch.union<T1, SubtypeOfT1>
vs !torch.union<T1>
.
Parameter | C++ type | Description |
---|---|---|
containedTypes | ::llvm::ArrayRef<::mlir::Type> |
contained types |
Multi-dimensional array modeling Torch’s Tensor type
Syntax:
tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers?
tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>`
sizes-spec ::= `*` | `[` size-list `]`
size-list ::= /*empty*/ | size-list-nonempty
size-list-nonempty = size (`,` size)*
size ::= `?` | decimal-literal
dtype-spec ::= `unk` | type
Represents a multi-dimensional array to model Torch’s torch.Tensor
type.
If the type is !torch.tensor
, it represents a general unrestricted
torch.Tensor
, including potential mutability, aliasing, etc.
If the type is !torch.vtensor
then the tensor is restricted to operations
that have value semantics (“v” = “value semantics”). This helps to maintain
a strict separation between the value-semantic and potentially-mutating
worlds, as one of our main jobs in the compiler is to isolate the mutating
parts as much as possible because most lower levels of the compiler stack
are expected to require value semantics. E.g. many backend contracts
mostly use linalg-on-tensor for compute-heavy ops, which require
a conversion to the builtin tensor
type which has value semantics.
Some notes about value semantics:
!torch.tensor
is a subtype of
!torch.vtensor
. Specifically, both types have the same set of values,
but !torch.tensor
additionally allows aliasing or mutating
operations.!torch.tensor
carries less static
information than a corresponding !torch.vtensor
. In particular,
!torch.vtensor
carries the static information “not used in aliasing
or mutating operations”.!torch.vtensor
can be trivially converted to the builtin tensor
type when the dtype is known (the builtin tensor
type does not allow
an unknown dtype).In the absence of the tensor-modifiers
, the type contains the minimal
amount of static information. That is, !torch.tensor
is equivalent to
!torch.tensor<*,unk>
(and similarly for !torch.vtensor
).
If sizes-spec
is not *
, it indicates additional static information
about the sizes of the tensor. It will consist of a list of elements,
with length equal to the “rank” (in MLIR parlance) or “ndim”
(in Torch parlance). Each element represents a size, with the typical
MLIR representation of a number for a statically known size and ?
for a
size that is unknown. Thus, the lattice consists of *
as the least static
information, followed by lists containing unknown sizes such as [?,?,?]
which contribute rank information, followed by statically specified sizes
for some dimensions such as [?,3,?]
, followed by fully statically
specified sizes such as [2,3,4]
.
If dtype-spec
is not unk
(“unknown”), it contains an MLIR type
which contributes static information about the dtype of the tensor.
Only types allowed by Torch are permitted.
|-------------------|--------------------|
| Torch Type | MLIR Type |
|-------------------|--------------------|
| torch.float16 | f16 |
| torch.bfloat16 | bf16 |
| torch.float32 | f32 |
| torch.float64 | f64 |
| torch.uint8 | ui8 |
| torch.int8 | si8 |
| torch.int16 | si16 |
| torch.int32 | si32 |
| torch.int64 | si64 |
| torch.bool | i1 |
| torch.qint8 | !torch.qint8 |
|-------------------|--------------------|
TODO: Support the full set of Torch dtypes. TODO: Use si1?
Note: We avoid the C++ identifier TensorType
to avoid C++ name ambiguities
with mlir::TensorType
, since most code is transitively nested in
both ::mlir
and ::mlir::torch::Torch
namespaces.
Note: We use the Torch-aligned terminology “sizes” and “dtype” instead of the MLIR-aligned terminology “rank/shape” and “element type”. The cheat sheet is:
hasRank()
-> hasSizes()
getShape()
-> getSizes()
getElementType()
-> getDtype()
(but be sure that hasDtype()
though).Parameter | C++ type | Description |
---|---|---|
optionalSizes | ::llvm::Optional<::llvm::ArrayRef<int64_t>> |
sizes of dimensions |
optionalDtype | ::mlir::Type |
The xten
dialect is an IR…
[TOC]
xten.add_constant
(xilinx::xten::AddConstantOp)add one operator
add one operator
Operand | Description |
---|---|
src |
Any Torch tensor type |
c |
any type |
Result | Description |
---|---|
output |
Any Torch tensor type |
xten.add
(xilinx::xten::AddOp)add operator
add operator
Operand | Description |
---|---|
input0 |
Any Torch tensor type |
input1 |
Any Torch tensor type |
Result | Description |
---|---|
output |
Any Torch tensor type |
xten.concat
(xilinx::xten::ConcatOp)Concat operator
Concat operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
inputs |
Any Torch tensor type |
dim |
scalar |
Result | Description |
---|---|
«unnamed» | Any Torch tensor type |
xten.conv2d_bn_relu
(xilinx::xten::Conv2dBatchNormReLUOp)Convolution BatchNorm ReLU operator
Fused Convolution BatchNorm ReLU operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
bn_weight |
Any Torch tensor type |
bn_bias |
Any Torch tensor type |
running_mean |
Any Torch tensor type |
running_var |
Any Torch tensor type |
training |
Torch BoolType |
momentum |
Torch FloatType |
eps |
Torch FloatType |
Result | Description |
---|---|
«unnamed» | Any Torch tensor type |
xten.conv2d_lrelu_maxpool
(xilinx::xten::Conv2dLReLUMaxPoolOp)Convolution with Leaky ReLU plus MaxPool operator
Fused Convolution followed by Leaky ReLU activation followed by compatible MaxPool operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
alpha |
Torch FloatType |
mp_kernel_size |
!torch.list |
mp_stride |
!torch.list |
mp_padding |
!torch.list |
mp_dilation |
!torch.list |
mp_ceil_mode |
Torch BoolType |
Result | Description |
---|---|
output |
Any Torch tensor type |
xten.conv2d_lrelu
(xilinx::xten::Conv2dLReLUOp)Convolution with Leaky ReLU operator
Fused Convolution followed by Leaky ReLU activation operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
alpha |
Torch FloatType |
Result | Description |
---|---|
«unnamed» | Any Torch tensor type |
xten.conv2d
(xilinx::xten::Conv2dOp)Convolution operator
Convolution operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
Result | Description |
---|---|
result |
Any Torch tensor type |
xten.conv2d_relu
(xilinx::xten::Conv2dReLUOp)Convolution ReLU operator
Fused Convolution ReLU operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
Result | Description |
---|---|
«unnamed» | Any Torch tensor type |
xten.mm
(xilinx::xten::MMOp)matrix multiply operator
matrix multiply operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
x |
Any Torch tensor type |
y |
Any Torch tensor type |
Result | Description |
---|---|
«unnamed» | Any Torch tensor type |
xten.mul
(xilinx::xten::MulOp)mul operator
mul operator
Operand | Description |
---|---|
input0 |
Any Torch tensor type |
input1 |
Any Torch tensor type |
Result | Description |
---|---|
output |
Any Torch tensor type |
xten.noop
(xilinx::xten::NoOp)noop returns its input
noop returns its input or a copy of its input
Operand | Description |
---|---|
x |
any type |
Result | Description |
---|---|
«unnamed» | any type |
xten.partialconv2d_bn_relu
(xilinx::xten::PartialConv2dBatchNormReLUOp)Partial Convolution BatchNorm ReLU operator
Fused Convolution BatchNorm ReLU operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
PartialIn |
Optional torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
bn_weight |
Any Torch tensor type |
bn_bias |
Any Torch tensor type |
running_mean |
Any Torch tensor type |
running_var |
Any Torch tensor type |
training |
Torch BoolType |
momentum |
Torch FloatType |
eps |
Torch FloatType |
Result | Description |
---|---|
output |
Any Torch tensor type |
forward |
Optional torch tensor type |
xten.partialconv2d
(xilinx::xten::PartialConv2dOp)Partial convolution operator
Partial convolution operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
PartialIn |
Optional torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
Result | Description |
---|---|
output |
Any Torch tensor type |
forward |
Optional torch tensor type |
xten.partialconv2d_relu
(xilinx::xten::PartialConv2dReLUOp)Partial convolution ReLU operator
Quantized convolution operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
PartialIn |
Optional torch tensor type |
weight |
Any Torch tensor type |
bias |
Optional torch tensor type |
stride |
!torch.list |
padding |
!torch.list |
dilation |
!torch.list |
groups |
Torch IntType |
Result | Description |
---|---|
output |
Any Torch tensor type |
forward |
Optional torch tensor type |
xten.split
(xilinx::xten::SplitOp)split operator
split operator
Interfaces: NoSideEffect (MemoryEffectOpInterface)
Effects: MemoryEffects::Effect{}
Operand | Description |
---|---|
input |
Any Torch tensor type |
dim |
scalar |
Result | Description |
---|---|
outputs |
Any Torch tensor type |