Skip to content

[mlir][spirv] Improve folding of MemRef to SPIRV Lowering #85433

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 28 additions & 24 deletions mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
assert(targetBits % sourceBits == 0);
Type type = srcIdx.getType();
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
auto srcBitsValue =
builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
}

/// Returns an adjusted spirv::AccessChainOp. Based on the
Expand All @@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
Value lastDim = op->getOperand(op.getNumOperands() - 1);
Type type = lastDim.getType();
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
auto indices = llvm::to_vector<4>(op.getIndices());
// There are two elements if this is a 1-D tensor.
assert(indices.size() == 2);
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
Type t = typeConverter.convertType(op.getComponentPtr().getType());
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
}
Expand All @@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
return srcBool;
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
zero);
}

/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
Expand All @@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
loc, builder.getIntegerType(targetBits), value);
}

value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
}
return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
offset);
return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
value, offset);
}

/// Returns true if the allocations of memref `type` generated from `allocOp`
Expand Down Expand Up @@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
return srcInt;

auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
return builder.create<spirv::IEqualOp>(loc, srcInt, one);
return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
}

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -597,25 +599,26 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
// ____XXXX________ -> ____________XXXX
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, spvLoadOp.getType(), spvLoadOp, offset);

// Apply the mask to extract corresponding bits.
Value mask = rewriter.create<spirv::ConstantOp>(
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
result =
rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);

// Apply sign extension on the loading value unconditionally. The signedness
// semantic is carried in the operator itself, we relies other pattern to
// handle the casting.
IntegerAttr shiftValueAttr =
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
Value shiftValue =
rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
shiftValue);
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
shiftValue);
rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
result, shiftValue);
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
loc, dstType, result, shiftValue);

rewriter.replaceOp(loadOp, result);

Expand Down Expand Up @@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,

// Create a mask to clear the destination. E.g., if it is the second i8 in
// i32, 0xFFFF00FF is created.
Value mask = rewriter.create<spirv::ConstantOp>(
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
Value clearBitsMask =
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
loc, dstType, mask, offset);
clearBitsMask =
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);

Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
Expand Down Expand Up @@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(

int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
}();

rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(
Expand Down
9 changes: 5 additions & 4 deletions mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
// broken down into progressive small steps so we can have intermediate steps
// using other dialects. At the moment SPIR-V is the final sink.

Value linearizedIndex = builder.create<spirv::ConstantOp>(
Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
loc, integerType, IntegerAttr::get(integerType, offset));
for (const auto &index : llvm::enumerate(indices)) {
Value strideVal = builder.create<spirv::ConstantOp>(
Value strideVal = builder.createOrFold<spirv::ConstantOp>(
loc, integerType,
IntegerAttr::get(integerType, strides[index.index()]));
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
Value update =
builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
linearizedIndex =
builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
}
return linearizedIndex;
}
Expand Down
8 changes: 2 additions & 6 deletions mlir/test/Conversion/GPUToSPIRV/load-store.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,9 @@ module attributes {
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
%13 = arith.addi %arg4, %3 : index
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
// CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
// CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
// CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
// CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>
Expand Down
Loading