@@ -745,8 +745,8 @@ class SparseTensorAllocConverter
745
745
const auto resType = getSparseTensorType (op);
746
746
if (!resType.hasEncoding ())
747
747
return failure ();
748
- Location loc = op.getLoc ();
749
748
749
+ Location loc = op.getLoc ();
750
750
// Deal with copy.
751
751
if (op.getCopy ()) {
752
752
auto desc = getDescriptorFromTensorTuple (adaptor.getCopy ());
@@ -768,16 +768,14 @@ class SparseTensorAllocConverter
768
768
return success ();
769
769
}
770
770
771
- // Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
772
- SmallVector<Value> dimSizesValues;
771
+ if (!resType.isIdentity ()) {
772
+ return rewriter.notifyMatchFailure (
773
+ op, " try run --sparse-reinterpret-map before codegen" );
774
+ }
775
+ // Level size equals to dimension size since lvl2dim map is an identity map.
773
776
SmallVector<Value> lvlSizesValues;
774
- Value dimSizesBuffer;
775
- Value dim2lvlBuffer;
776
- Value lvl2dimBuffer;
777
777
createDimSizes (rewriter, loc, resType, adaptor.getDynamicSizes (),
778
- dimSizesValues);
779
- genMapBuffers (rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
780
- lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
778
+ /* dimSizesValues=*/ lvlSizesValues);
781
779
782
780
// Construct allocation for each field.
783
781
Value sizeHint = op.getSizeHint ();
@@ -809,19 +807,17 @@ class SparseTensorEmptyConverter : public OpConversionPattern<tensor::EmptyOp> {
809
807
const auto resType = getSparseTensorType (op);
810
808
if (!resType.hasEncoding ())
811
809
return failure ();
812
- Location loc = op.getLoc ();
813
810
814
- // Construct the dim/lvl sizes and the (unused) dim2lvl/lvl2dim buffers.
815
- SmallVector<Value> dimSizesValues;
811
+ if (!resType.isIdentity ()) {
812
+ return rewriter.notifyMatchFailure (
813
+ op, " try run --sparse-reinterpret-map before codegen" );
814
+ }
815
+
816
+ Location loc = op.getLoc ();
817
+ // Level size equals to dimension size since lvl2dim map is an identity map.
816
818
SmallVector<Value> lvlSizesValues;
817
- Value dimSizesBuffer;
818
- Value dim2lvlBuffer;
819
- Value lvl2dimBuffer;
820
819
createDimSizes (rewriter, loc, resType, adaptor.getDynamicSizes (),
821
- dimSizesValues);
822
- genMapBuffers (rewriter, loc, resType, dimSizesValues, dimSizesBuffer,
823
- lvlSizesValues, dim2lvlBuffer, lvl2dimBuffer);
824
-
820
+ /* dimSizesValues=*/ lvlSizesValues);
825
821
// Construct allocation for each field.
826
822
Value sizeHint; // none
827
823
SmallVector<Value> fields;
0 commit comments