-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add vector bitwidth target to xfer op flattening #81966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,6 @@ | |
#include "mlir/Dialect/Vector/Transforms/LoweringPatterns.h" | ||
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h" | ||
#include "mlir/Dialect/Vector/Utils/VectorUtils.h" | ||
#include "mlir/IR/BuiltinOps.h" | ||
#include "mlir/IR/Dominance.h" | ||
#include "mlir/Interfaces/SideEffectInterfaces.h" | ||
#include "llvm/ADT/STLExtras.h" | ||
|
@@ -535,9 +534,17 @@ namespace { | |
/// memref.collapse_shape on the source so that the resulting | ||
/// vector.transfer_read has a 1D source. Requires the source shape to be | ||
/// already reduced i.e. without unit dims. | ||
/// If `targetVectorBitwidth` is provided, the flattening will only happen if | ||
/// the trailing dimension of the vector read is smaller than the provided | ||
/// bitwidth. | ||
class FlattenContiguousRowMajorTransferReadPattern | ||
: public OpRewritePattern<vector::TransferReadOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
public: | ||
FlattenContiguousRowMajorTransferReadPattern(MLIRContext *context, | ||
unsigned vectorBitwidth, | ||
PatternBenefit benefit) | ||
: OpRewritePattern<vector::TransferReadOp>(context, benefit), | ||
targetVectorBitwidth(vectorBitwidth) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferReadOp transferReadOp, | ||
PatternRewriter &rewriter) const override { | ||
|
@@ -554,6 +561,12 @@ class FlattenContiguousRowMajorTransferReadPattern | |
// If this is already 0D/1D, there's nothing to do. | ||
if (vectorType.getRank() <= 1) | ||
return failure(); | ||
if (!vectorType.getElementType().isSignlessIntOrFloat()) | ||
return failure(); | ||
unsigned trailingVectorDimBitwidth = | ||
vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); | ||
if (trailingVectorDimBitwidth >= targetVectorBitwidth) | ||
return failure(); | ||
if (!vector::isContiguousSlice(sourceType, vectorType)) | ||
return failure(); | ||
// TODO: generalize this pattern, relax the requirements here. | ||
|
@@ -642,6 +655,11 @@ class FlattenContiguousRowMajorTransferReadPattern | |
transferReadOp, cast<VectorType>(vector.getType()), flatRead); | ||
return success(); | ||
} | ||
|
||
private: | ||
// Minimum bitwidth that the trailing vector dimension should have after | ||
// flattening. | ||
unsigned targetVectorBitwidth; | ||
}; | ||
|
||
/// Rewrites contiguous row-major vector.transfer_write ops by inserting | ||
|
@@ -650,7 +668,12 @@ class FlattenContiguousRowMajorTransferReadPattern | |
/// already reduced i.e. without unit dims. | ||
class FlattenContiguousRowMajorTransferWritePattern | ||
: public OpRewritePattern<vector::TransferWriteOp> { | ||
using OpRewritePattern::OpRewritePattern; | ||
public: | ||
FlattenContiguousRowMajorTransferWritePattern(MLIRContext *context, | ||
unsigned vectorBitwidth, | ||
PatternBenefit benefit) | ||
: OpRewritePattern<vector::TransferWriteOp>(context, benefit), | ||
targetVectorBitwidth(vectorBitwidth) {} | ||
|
||
LogicalResult matchAndRewrite(vector::TransferWriteOp transferWriteOp, | ||
PatternRewriter &rewriter) const override { | ||
|
@@ -665,6 +688,12 @@ class FlattenContiguousRowMajorTransferWritePattern | |
if (vectorType.getRank() <= 1) | ||
// Already 0D/1D, nothing to do. | ||
return failure(); | ||
if (!vectorType.getElementType().isSignlessIntOrFloat()) | ||
return failure(); | ||
unsigned trailingVectorDimBitwidth = | ||
vectorType.getShape().back() * vectorType.getElementTypeBitWidth(); | ||
if (trailingVectorDimBitwidth >= targetVectorBitwidth) | ||
return failure(); | ||
if (!vector::isContiguousSlice(sourceType, vectorType)) | ||
return failure(); | ||
int64_t firstContiguousInnerDim = | ||
|
@@ -702,6 +731,11 @@ class FlattenContiguousRowMajorTransferWritePattern | |
rewriter.eraseOp(transferWriteOp); | ||
return success(); | ||
} | ||
|
||
private: | ||
// Minimum bitwidth that the trailing vector dimension should have after | ||
// flattening. | ||
unsigned targetVectorBitwidth; | ||
Comment on lines
+736
to
+738
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
How about update it to |
||
}; | ||
|
||
/// Base class for `vector.extract/vector.extract_element(vector.transfer_read)` | ||
|
@@ -917,10 +951,11 @@ void mlir::vector::populateVectorTransferDropUnitDimsPatterns( | |
} | ||
|
||
void mlir::vector::populateFlattenVectorTransferPatterns( | ||
RewritePatternSet &patterns, PatternBenefit benefit) { | ||
RewritePatternSet &patterns, unsigned targetVectorBitwidth, | ||
PatternBenefit benefit) { | ||
patterns.add<FlattenContiguousRowMajorTransferReadPattern, | ||
FlattenContiguousRowMajorTransferWritePattern>( | ||
patterns.getContext(), benefit); | ||
patterns.getContext(), targetVectorBitwidth, benefit); | ||
populateShapeCastFoldingPatterns(patterns, benefit); | ||
populateDropUnitDimWithShapeCastPatterns(patterns, benefit); | ||
} |
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,4 +1,5 @@ | ||||||||||||
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns -split-input-file | FileCheck %s | ||||||||||||
// RUN: mlir-opt %s -test-vector-transfer-flatten-patterns=target-vector-bitwidth=128 -split-input-file | FileCheck %s --check-prefix=CHECK-128B | ||||||||||||
Comment on lines
1
to
+2
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
This way, you don't need to add
https://llvm.org/docs/CommandGuide/FileCheck.html#options [nit] Side note. IMHO, Once "prefixes" are involved, "CHECK" in prefix name becomes noise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure I follow. I added the check for wider coverage. We can leave tests without check rules for a particular run. The only difference I see is that we could have a single label check that is checked in both runs but that doesn't seem to be what you meant There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was referring to the CHECK lines that are duplicated, eg for “ @transfer_read_0d”. But now I see that the duplication is only happening for negative tests, right? Sorry, was reading this in a rush and missed the finer details :( |
||||||||||||
|
||||||||||||
func.func @transfer_read_dims_match_contiguous( | ||||||||||||
%arg : memref<5x4x3x2xi8, strided<[24, 6, 2, 1], offset: ?>>) -> vector<5x4x3x2xi8> { | ||||||||||||
|
@@ -16,6 +17,9 @@ func.func @transfer_read_dims_match_contiguous( | |||||||||||
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> | ||||||||||||
// CHECK: return %[[VEC2D]] | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_dims_match_contiguous_empty_stride( | ||||||||||||
|
@@ -27,13 +31,16 @@ func.func @transfer_read_dims_match_contiguous_empty_stride( | |||||||||||
return %v : vector<5x4x3x2xi8> | ||||||||||||
} | ||||||||||||
|
||||||||||||
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride | ||||||||||||
// CHECK-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( | ||||||||||||
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 | ||||||||||||
// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3] | ||||||||||||
// CHECK: %[[READ1D:.+]] = vector.transfer_read %[[COLLAPSED]] | ||||||||||||
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[READ1D]] : vector<120xi8> to vector<5x4x3x2xi8> | ||||||||||||
// CHECK: return %[[VEC2D]] | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_match_contiguous_empty_stride( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
// The shape of the memref and the vector don't match, but the vector is a | ||||||||||||
|
@@ -57,6 +64,9 @@ func.func @transfer_read_dims_mismatch_contiguous( | |||||||||||
// CHECK: %[[VAL_5:.*]] = vector.shape_cast %[[VAL_4]] : vector<4xi8> to vector<1x1x2x2xi8> | ||||||||||||
// CHECK: return %[[VAL_5]] : vector<1x1x2x2xi8> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_contiguous( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_dims_mismatch_non_zero_indices( | ||||||||||||
|
@@ -66,7 +76,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( | |||||||||||
%m_out: memref<1x2x6xi32>) { | ||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||
%c0_i32 = arith.constant 0 : i32 | ||||||||||||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||||||||||||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||||||||||||
memref<1x43x4x6xi32>, vector<1x2x6xi32> | ||||||||||||
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : | ||||||||||||
vector<1x2x6xi32>, memref<1x2x6xi32> | ||||||||||||
Comment on lines
+79
to
82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Something for the future ... This example, on a machine with vectors which are 128 bits wide (e.g. Arm) would actually benefit from flattening. With 6 elements, we'd use 1.5 vector registers. And and with 12 elements, we'd use 3. That would be better utilization. Would that make sense as a TODO? (not asking for it in this patch) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector. |
||||||||||||
|
@@ -87,6 +97,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices( | |||||||||||
// CHECK: %[[COLLAPSED_OUT:.*]] = memref.collapse_shape %[[M_OUT]] {{\[}}[0, 1, 2]] : memref<1x2x6xi32> into memref<12xi32> | ||||||||||||
// CHECK: vector.transfer_write %[[READ]], %[[COLLAPSED_OUT]][%[[C_0_IDX]]] {in_bounds = [true]} : vector<12xi32>, memref<12xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
// The input memref has a dynamic trailing shape and hence is not flattened. | ||||||||||||
|
@@ -99,7 +112,7 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( | |||||||||||
%m_out: memref<1x2x6xi32>) { | ||||||||||||
%c0 = arith.constant 0 : index | ||||||||||||
%c0_i32 = arith.constant 0 : i32 | ||||||||||||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||||||||||||
%2 = vector.transfer_read %m_in[%c0, %idx_1, %idx_2, %c0], %c0_i32 {in_bounds = [true, true, true]} : | ||||||||||||
memref<1x?x4x6xi32>, vector<1x2x6xi32> | ||||||||||||
vector.transfer_write %2, %m_out[%c0, %c0, %c0] {in_bounds = [true, true, true]} : | ||||||||||||
vector<1x2x6xi32>, memref<1x2x6xi32> | ||||||||||||
|
@@ -115,6 +128,9 @@ func.func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( | |||||||||||
// CHECK: %[[SC:.*]] = vector.shape_cast %[[READ]] : vector<1x2x6xi32> to vector<12xi32> | ||||||||||||
// CHECK: vector.transfer_write %[[SC]], %[[COLLAPSED]]{{.*}} : vector<12xi32>, memref<12xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_zero_indices_dynamic_shapes( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_dims_mismatch_non_contiguous( | ||||||||||||
|
@@ -130,6 +146,9 @@ func.func @transfer_read_dims_mismatch_non_contiguous( | |||||||||||
// CHECK-NOT: memref.collapse_shape | ||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( | ||||||||||||
|
@@ -141,10 +160,13 @@ func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( | |||||||||||
return %v : vector<2x1x2x2xi8> | ||||||||||||
} | ||||||||||||
|
||||||||||||
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride | ||||||||||||
// CHECK-LABEL: func.func @transfer_read_dims_mismatch_non_contiguous_empty_stride( | ||||||||||||
// CHECK-NOT: memref.collapse_shape | ||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_dims_mismatch_non_contiguous_empty_stride( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_write_dims_match_contiguous( | ||||||||||||
|
@@ -155,13 +177,16 @@ func.func @transfer_write_dims_match_contiguous( | |||||||||||
return | ||||||||||||
} | ||||||||||||
|
||||||||||||
// CHECK-LABEL: func @transfer_write_dims_match_contiguous | ||||||||||||
// CHECK-LABEL: func @transfer_write_dims_match_contiguous( | ||||||||||||
// CHECK-SAME: %[[ARG:[0-9a-zA-Z]+]]: memref<5x4x3x2xi8 | ||||||||||||
// CHECK-SAME: %[[VEC:[0-9a-zA-Z]+]]: vector<5x4x3x2xi8> | ||||||||||||
// CHECK-DAG: %[[COLLAPSED:.+]] = memref.collapse_shape %[[ARG]] {{.}}[0, 1, 2, 3]{{.}} : memref<5x4x3x2xi8, {{.+}}> into memref<120xi8, {{.+}}> | ||||||||||||
// CHECK-DAG: %[[VEC1D:.+]] = vector.shape_cast %[[VEC]] : vector<5x4x3x2xi8> to vector<120xi8> | ||||||||||||
// CHECK: vector.transfer_write %[[VEC1D]], %[[COLLAPSED]] | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_write_dims_match_contiguous( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_write_dims_mismatch_contiguous( | ||||||||||||
|
@@ -182,6 +207,9 @@ func.func @transfer_write_dims_mismatch_contiguous( | |||||||||||
// CHECK: return | ||||||||||||
// CHECK: } | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_contiguous( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_write_dims_mismatch_non_contiguous( | ||||||||||||
|
@@ -196,6 +224,9 @@ func.func @transfer_write_dims_mismatch_non_contiguous( | |||||||||||
// CHECK-NOT: memref.collapse_shape | ||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_write_dims_mismatch_non_contiguous( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) { | ||||||||||||
|
@@ -207,6 +238,10 @@ func.func @transfer_write_0d(%arg : memref<i8>, %vec : vector<i8>) { | |||||||||||
// CHECK-NOT: memref.collapse_shape | ||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_write_0d( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
// CHECK-128B-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> { | ||||||||||||
|
@@ -219,6 +254,10 @@ func.func @transfer_read_0d(%arg : memref<i8>) -> vector<i8> { | |||||||||||
// CHECK-NOT: memref.collapse_shape | ||||||||||||
// CHECK-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_0d( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
// CHECK-128B-NOT: vector.shape_cast | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) -> vector<8x4xi8> { | ||||||||||||
|
@@ -241,6 +280,9 @@ func.func @transfer_read_flattenable_with_dynamic_dims_and_indices(%arg0 : memre | |||||||||||
// CHECK: %[[VEC2D:.+]] = vector.shape_cast %[[VEC1D]] : vector<32xi8> to vector<8x4xi8> | ||||||||||||
// CHECK: return %[[VEC2D]] : vector<8x4xi8> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_flattenable_with_dynamic_dims_and_indices( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vector<8x4xi8>, %dst : memref<?x?x8x4xi8, strided<[?, 32, 4, 1], offset: ?>>, %arg1 : index, %arg2 : index) { | ||||||||||||
|
@@ -260,6 +302,9 @@ func.func @transfer_write_flattenable_with_dynamic_dims_and_indices(%vec : vecto | |||||||||||
// CHECK-SAME: {in_bounds = [true]} | ||||||||||||
// CHECK-SAME: : vector<32xi8>, memref<?x?x32xi8, {{.+}}> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_write_flattenable_with_dynamic_dims_and_indices( | ||||||||||||
// CHECK-128B: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_flattenable_negative( | ||||||||||||
|
@@ -274,6 +319,9 @@ func.func @transfer_read_flattenable_negative( | |||||||||||
// CHECK-LABEL: func @transfer_read_flattenable_negative | ||||||||||||
// CHECK: vector.transfer_read {{.*}} vector<2x2x2x2xi8> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @transfer_read_flattenable_negative2( | ||||||||||||
|
@@ -288,6 +336,9 @@ func.func @transfer_read_flattenable_negative2( | |||||||||||
// CHECK-LABEL: func @transfer_read_flattenable_negative2 | ||||||||||||
// CHECK: vector.transfer_read {{.*}} vector<5x4x3x2xi8> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @transfer_read_flattenable_negative2( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { | ||||||||||||
|
@@ -302,6 +353,9 @@ func.func @fold_unit_dim_add_basic(%arg0 : vector<1x8xi32>) -> vector<1x8xi32> { | |||||||||||
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8xi32> | ||||||||||||
// CHECK: return %[[VAL_4]] : vector<1x8xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dim_add_basic( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> vector<1x8x1xi32> { | ||||||||||||
|
@@ -316,6 +370,9 @@ func.func @fold_unit_dim_add_leading_and_trailing(%arg0 : vector<1x8x1xi32>) -> | |||||||||||
// CHECK: %[[VAL_4:.*]] = vector.shape_cast %[[VAL_3]] : vector<8xi32> to vector<1x8x1xi32> | ||||||||||||
// CHECK: return %[[VAL_4]] : vector<1x8x1xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dim_add_leading_and_trailing( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, | ||||||||||||
|
@@ -334,6 +391,9 @@ func.func @fold_unit_dim_add(%arg0 : vector<8x1xi32>, | |||||||||||
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : vector<8xi32> | ||||||||||||
// CHECK: return %[[VAL_4]] : vector<8xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dim_add( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, | ||||||||||||
|
@@ -352,6 +412,9 @@ func.func @fold_unit_dim_mulf(%arg0 : vector<8x[2]x1xf32>, | |||||||||||
// CHECK: %[[VAL_4:.*]] = arith.mulf %[[VAL_2]], %[[VAL_3]] : vector<8x[2]xf32> | ||||||||||||
// CHECK: return %[[VAL_4]] : vector<8x[2]xf32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dim_mulf( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> { | ||||||||||||
|
@@ -367,6 +430,9 @@ func.func @fold_unit_dim_sitofp(%arg0 : vector<8x[2]x1xi8>) -> vector<8x[2]xf32> | |||||||||||
// CHECK: %[[VAL_2:.*]] = arith.sitofp %[[VAL_1]] : vector<8x[2]xi8> to vector<8x[2]xf32> | ||||||||||||
// CHECK: return %[[VAL_2]] : vector<8x[2]xf32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dim_sitofp( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
||||||||||||
// ----- | ||||||||||||
|
||||||||||||
// All shape casts are folded away | ||||||||||||
|
@@ -389,3 +455,7 @@ func.func @fold_unit_dims_entirely(%arg0 : vector<8xi32>, | |||||||||||
// CHECK: %[[VAL_3:.*]] = arith.muli %[[VAL_0]], %[[VAL_1]] : vector<8xi32> | ||||||||||||
// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_3]], %[[VAL_2]] : vector<8xi32> | ||||||||||||
// CHECK: return %[[VAL_4]] : vector<8xi32> | ||||||||||||
|
||||||||||||
// CHECK-128B-LABEL: func @fold_unit_dims_entirely( | ||||||||||||
// CHECK-128B-NOT: memref.collapse_shape | ||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There could be cases that only partial slice is contiguous. In this case, we could flatten trailing dims. I wonder if we will relax this a little more in the near future?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's the ultimate goal: partial and more targeted flattening but one step at a time. We first have to flatten the producer/consumers of these xfer ops to make sure we don't generate ops to reshape the vector.