Skip to content

Commit ae6d11f

Browse files
committed
address comments
1 parent 9d78e81 commit ae6d11f

File tree

4 files changed

+7
-13
lines changed

4 files changed

+7
-13
lines changed

mlir/include/mlir/Dialect/Vector/Utils/VectorUtils.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -211,11 +211,6 @@ Value createReadOrMaskedRead(OpBuilder &builder, Location loc, Value source,
211211
/// static sizes in `shape`.
212212
LogicalResult isValidMaskedInputVector(ArrayRef<int64_t> shape,
213213
ArrayRef<int64_t> inputVectorSizes);
214-
215-
/// Returns true if the leading dim(s) of `type` are fixed and the trailing dim
216-
/// is scalable.
217-
bool isTrailingDimScalable(VectorType type);
218-
219214
} // namespace vector
220215

221216
/// Constructs a permutation map of invariant memref indices to vector

mlir/lib/Dialect/Vector/Transforms/LowerVectorGather.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,8 @@ struct FlattenGather : OpRewritePattern<vector::GatherOp> {
6868

6969
// Unrolling doesn't take vscale into account. Pattern is disabled for
7070
// vectors with leading scalable dim(s).
71-
if (resultTy.isScalable() && !isTrailingDimScalable(resultTy))
72-
return rewriter.notifyMatchFailure(
73-
op, "vector type must be fixed-width or scalable in trailing dim");
71+
if (resultTy.getScalableDims().front())
72+
return rewriter.notifyMatchFailure(op, "cannot unroll scalable dim");
7473

7574
Location loc = op.getLoc();
7675
Value indexVec = op.getIndexVec();

mlir/lib/Dialect/Vector/Transforms/LowerVectorShapeCast.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,11 @@ class ScalableShapeCastOpRewritePattern
342342
rewriter.replaceOp(op, result);
343343
return success();
344344
}
345+
346+
static bool isTrailingDimScalable(VectorType type) {
347+
return type.getRank() >= 1 && type.getScalableDims().back() &&
348+
!llvm::is_contained(type.getScalableDims().drop_back(), true);
349+
}
345350
};
346351

347352
} // namespace

mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -396,8 +396,3 @@ vector::isValidMaskedInputVector(ArrayRef<int64_t> shape,
396396
}
397397
return success();
398398
}
399-
400-
bool vector::isTrailingDimScalable(VectorType type) {
401-
return type.getRank() >= 1 && type.getScalableDims().back() &&
402-
!llvm::is_contained(type.getScalableDims().drop_back(), true);
403-
}

0 commit comments

Comments
 (0)