@@ -2478,7 +2478,7 @@ metadata: !LinalgOpMetadata
2478
2478
The partial multiplication results are reduced into a 2D output.
2479
2479
2480
2480
Numeric casting is performed on the operands to the inner multiply, promoting
2481
- them to the same data type as the accumulator/output."
2481
+ them to the same data type as the accumulator/output.
2482
2482
implements :
2483
2483
- LinalgContractionOpInterface
2484
2484
structured_op : !LinalgStructuredOpConfig
@@ -4096,38 +4096,39 @@ structured_op: !LinalgStructuredOpConfig
4096
4096
name : I
4097
4097
kind : input_tensor
4098
4098
type_var : T1
4099
- shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1 *
4100
- s2 + s3 * s4, s5 * s6 + s7 * s8 )>
4099
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2
4100
+ * s3 + s4 * s5, s6 * s7 + s8 * s9 )>
4101
4101
- !LinalgOperandDefConfig
4102
4102
name : K
4103
4103
kind : input_tensor
4104
4104
type_var : T2
4105
- shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s9, s3, s7 )>
4105
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s1, s4, s8 )>
4106
4106
- !LinalgOperandDefConfig
4107
4107
name : O
4108
4108
kind : output_tensor
4109
4109
type_var : U
4110
- shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s9, s1, s5)>
4110
+ shape_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s0, s1, s2,
4111
+ s6)>
4111
4112
- !LinalgOperandDefConfig
4112
4113
name : strides
4113
4114
kind : index_attr
4114
- index_attr_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s2 ,
4115
- s6 )>
4115
+ index_attr_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s3 ,
4116
+ s7 )>
4116
4117
default_indices :
4117
4118
- 1
4118
4119
- 1
4119
4120
- !LinalgOperandDefConfig
4120
4121
name : dilations
4121
4122
kind : index_attr
4122
- index_attr_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s4 ,
4123
- s8 )>
4123
+ index_attr_map : affine_map<()[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9] -> (s5 ,
4124
+ s9 )>
4124
4125
default_indices :
4125
4126
- 1
4126
4127
- 1
4127
4128
indexing_maps : !LinalgIndexingMapsConfig
4128
4129
static_indexing_maps :
4129
4130
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
4130
- -> (d0, d3, d1 * s2 + d4 * s4 , d2 * s6 + d5 * s8 )>
4131
+ -> (d0, d3, d1 * s3 + d4 * s5 , d2 * s7 + d5 * s9 )>
4131
4132
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
4132
4133
-> (d3, d4, d5)>
4133
4134
- affine_map<(d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6, s7, s8, s9]
@@ -5766,3 +5767,74 @@ structured_op: !LinalgStructuredOpConfig
5766
5767
scalar_const : ' 2.3283063999999999E-10 : f64'
5767
5768
- !ScalarExpression
5768
5769
scalar_arg : min
5770
+ --- !LinalgOpConfig
5771
+ metadata : !LinalgOpMetadata
5772
+ name : linear_relu
5773
+ cpp_class_name : LinearReluOp
5774
+ doc : |-
5775
+ Performs a linear/fully-connected + relu operation
5776
+
5777
+ This is a long description that I'll fill later
5778
+
5779
+ Layout:
5780
+ * I: WH (Input)
5781
+ * W: WH (Weights)
5782
+ * B: H (Bias)
5783
+ structured_op : !LinalgStructuredOpConfig
5784
+ args :
5785
+ - !LinalgOperandDefConfig
5786
+ name : I
5787
+ kind : input_tensor
5788
+ type_var : T1
5789
+ shape_map : affine_map<()[s0, s1, s2] -> (s0, s1)>
5790
+ - !LinalgOperandDefConfig
5791
+ name : W
5792
+ kind : input_tensor
5793
+ type_var : T1
5794
+ shape_map : affine_map<()[s0, s1, s2] -> (s2, s1)>
5795
+ - !LinalgOperandDefConfig
5796
+ name : B
5797
+ kind : input_tensor
5798
+ type_var : T1
5799
+ shape_map : affine_map<()[s0, s1, s2] -> (s2)>
5800
+ - !LinalgOperandDefConfig
5801
+ name : O
5802
+ kind : output_tensor
5803
+ type_var : T1
5804
+ shape_map : affine_map<()[s0, s1, s2] -> (s0, s2)>
5805
+ indexing_maps : !LinalgIndexingMapsConfig
5806
+ static_indexing_maps :
5807
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d1)>
5808
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2, d1)>
5809
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2)>
5810
+ - affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0, d2)>
5811
+ iterator_types :
5812
+ - parallel
5813
+ - reduction
5814
+ - parallel
5815
+ assignments :
5816
+ - !ScalarAssign
5817
+ arg : O
5818
+ value : !ScalarExpression
5819
+ scalar_fn :
5820
+ kind : binary
5821
+ fn_name : add
5822
+ operands :
5823
+ - !ScalarExpression
5824
+ scalar_arg : O
5825
+ - !ScalarExpression
5826
+ scalar_fn :
5827
+ kind : binary
5828
+ fn_name : add
5829
+ operands :
5830
+ - !ScalarExpression
5831
+ scalar_fn :
5832
+ kind : binary
5833
+ fn_name : mul
5834
+ operands :
5835
+ - !ScalarExpression
5836
+ scalar_arg : I
5837
+ - !ScalarExpression
5838
+ scalar_arg : W
5839
+ - !ScalarExpression
5840
+ scalar_arg : B
0 commit comments