Skip to content

Commit 38f8a3c

Browse files
authored
[mlir][spirv] Improve folding of MemRef to SPIRV Lowering (#85433)
Investigate the lowering of MemRef Load/Store ops and implement additional folding of created ops Aims to improve readability of generated lowered SPIR-V code. Part of work #70704
1 parent 6295e67 commit 38f8a3c

File tree

8 files changed

+93
-210
lines changed

8 files changed

+93
-210
lines changed

mlir/lib/Conversion/MemRefToSPIRV/MemRefToSPIRV.cpp

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,11 +50,12 @@ static Value getOffsetForBitwidth(Location loc, Value srcIdx, int sourceBits,
5050
assert(targetBits % sourceBits == 0);
5151
Type type = srcIdx.getType();
5252
IntegerAttr idxAttr = builder.getIntegerAttr(type, targetBits / sourceBits);
53-
auto idx = builder.create<spirv::ConstantOp>(loc, type, idxAttr);
53+
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, idxAttr);
5454
IntegerAttr srcBitsAttr = builder.getIntegerAttr(type, sourceBits);
55-
auto srcBitsValue = builder.create<spirv::ConstantOp>(loc, type, srcBitsAttr);
56-
auto m = builder.create<spirv::UModOp>(loc, srcIdx, idx);
57-
return builder.create<spirv::IMulOp>(loc, type, m, srcBitsValue);
55+
auto srcBitsValue =
56+
builder.createOrFold<spirv::ConstantOp>(loc, type, srcBitsAttr);
57+
auto m = builder.createOrFold<spirv::UModOp>(loc, srcIdx, idx);
58+
return builder.createOrFold<spirv::IMulOp>(loc, type, m, srcBitsValue);
5859
}
5960

6061
/// Returns an adjusted spirv::AccessChainOp. Based on the
@@ -74,11 +75,11 @@ adjustAccessChainForBitwidth(const SPIRVTypeConverter &typeConverter,
7475
Value lastDim = op->getOperand(op.getNumOperands() - 1);
7576
Type type = lastDim.getType();
7677
IntegerAttr attr = builder.getIntegerAttr(type, targetBits / sourceBits);
77-
auto idx = builder.create<spirv::ConstantOp>(loc, type, attr);
78+
auto idx = builder.createOrFold<spirv::ConstantOp>(loc, type, attr);
7879
auto indices = llvm::to_vector<4>(op.getIndices());
7980
// There are two elements if this is a 1-D tensor.
8081
assert(indices.size() == 2);
81-
indices.back() = builder.create<spirv::SDivOp>(loc, lastDim, idx);
82+
indices.back() = builder.createOrFold<spirv::SDivOp>(loc, lastDim, idx);
8283
Type t = typeConverter.convertType(op.getComponentPtr().getType());
8384
return builder.create<spirv::AccessChainOp>(loc, t, op.getBasePtr(), indices);
8485
}
@@ -91,7 +92,8 @@ static Value castBoolToIntN(Location loc, Value srcBool, Type dstType,
9192
return srcBool;
9293
Value zero = spirv::ConstantOp::getZero(dstType, loc, builder);
9394
Value one = spirv::ConstantOp::getOne(dstType, loc, builder);
94-
return builder.create<spirv::SelectOp>(loc, dstType, srcBool, one, zero);
95+
return builder.createOrFold<spirv::SelectOp>(loc, dstType, srcBool, one,
96+
zero);
9597
}
9698

9799
/// Returns the `targetBits`-bit value shifted by the given `offset`, and cast
@@ -111,10 +113,10 @@ static Value shiftValue(Location loc, Value value, Value offset, Value mask,
111113
loc, builder.getIntegerType(targetBits), value);
112114
}
113115

114-
value = builder.create<spirv::BitwiseAndOp>(loc, value, mask);
116+
value = builder.createOrFold<spirv::BitwiseAndOp>(loc, value, mask);
115117
}
116-
return builder.create<spirv::ShiftLeftLogicalOp>(loc, value.getType(), value,
117-
offset);
118+
return builder.createOrFold<spirv::ShiftLeftLogicalOp>(loc, value.getType(),
119+
value, offset);
118120
}
119121

120122
/// Returns true if the allocations of memref `type` generated from `allocOp`
@@ -165,7 +167,7 @@ static Value castIntNToBool(Location loc, Value srcInt, OpBuilder &builder) {
165167
return srcInt;
166168

167169
auto one = spirv::ConstantOp::getOne(srcInt.getType(), loc, builder);
168-
return builder.create<spirv::IEqualOp>(loc, srcInt, one);
170+
return builder.createOrFold<spirv::IEqualOp>(loc, srcInt, one);
169171
}
170172

171173
//===----------------------------------------------------------------------===//
@@ -597,25 +599,26 @@ IntLoadOpPattern::matchAndRewrite(memref::LoadOp loadOp, OpAdaptor adaptor,
597599
// ____XXXX________ -> ____________XXXX
598600
Value lastDim = accessChainOp->getOperand(accessChainOp.getNumOperands() - 1);
599601
Value offset = getOffsetForBitwidth(loc, lastDim, srcBits, dstBits, rewriter);
600-
Value result = rewriter.create<spirv::ShiftRightArithmeticOp>(
602+
Value result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
601603
loc, spvLoadOp.getType(), spvLoadOp, offset);
602604

603605
// Apply the mask to extract corresponding bits.
604-
Value mask = rewriter.create<spirv::ConstantOp>(
606+
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
605607
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
606-
result = rewriter.create<spirv::BitwiseAndOp>(loc, dstType, result, mask);
608+
result =
609+
rewriter.createOrFold<spirv::BitwiseAndOp>(loc, dstType, result, mask);
607610

608611
// Apply sign extension on the loading value unconditionally. The signedness
609612
// semantic is carried in the operator itself, we relies other pattern to
610613
// handle the casting.
611614
IntegerAttr shiftValueAttr =
612615
rewriter.getIntegerAttr(dstType, dstBits - srcBits);
613616
Value shiftValue =
614-
rewriter.create<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
615-
result = rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, result,
616-
shiftValue);
617-
result = rewriter.create<spirv::ShiftRightArithmeticOp>(loc, dstType, result,
618-
shiftValue);
617+
rewriter.createOrFold<spirv::ConstantOp>(loc, dstType, shiftValueAttr);
618+
result = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(loc, dstType,
619+
result, shiftValue);
620+
result = rewriter.createOrFold<spirv::ShiftRightArithmeticOp>(
621+
loc, dstType, result, shiftValue);
619622

620623
rewriter.replaceOp(loadOp, result);
621624

@@ -744,11 +747,12 @@ IntStoreOpPattern::matchAndRewrite(memref::StoreOp storeOp, OpAdaptor adaptor,
744747

745748
// Create a mask to clear the destination. E.g., if it is the second i8 in
746749
// i32, 0xFFFF00FF is created.
747-
Value mask = rewriter.create<spirv::ConstantOp>(
750+
Value mask = rewriter.createOrFold<spirv::ConstantOp>(
748751
loc, dstType, rewriter.getIntegerAttr(dstType, (1 << srcBits) - 1));
749-
Value clearBitsMask =
750-
rewriter.create<spirv::ShiftLeftLogicalOp>(loc, dstType, mask, offset);
751-
clearBitsMask = rewriter.create<spirv::NotOp>(loc, dstType, clearBitsMask);
752+
Value clearBitsMask = rewriter.createOrFold<spirv::ShiftLeftLogicalOp>(
753+
loc, dstType, mask, offset);
754+
clearBitsMask =
755+
rewriter.createOrFold<spirv::NotOp>(loc, dstType, clearBitsMask);
752756

753757
Value storeVal = shiftValue(loc, adaptor.getValue(), offset, mask, rewriter);
754758
Value adjustedPtr = adjustAccessChainForBitwidth(typeConverter, accessChainOp,
@@ -910,7 +914,7 @@ LogicalResult ReinterpretCastPattern::matchAndRewrite(
910914

911915
int64_t attrVal = cast<IntegerAttr>(offset.get<Attribute>()).getInt();
912916
Attribute attr = rewriter.getIntegerAttr(intType, attrVal);
913-
return rewriter.create<spirv::ConstantOp>(loc, intType, attr);
917+
return rewriter.createOrFold<spirv::ConstantOp>(loc, intType, attr);
914918
}();
915919

916920
rewriter.replaceOpWithNewOp<spirv::InBoundsPtrAccessChainOp>(

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -991,15 +991,16 @@ Value mlir::spirv::linearizeIndex(ValueRange indices, ArrayRef<int64_t> strides,
991991
// broken down into progressive small steps so we can have intermediate steps
992992
// using other dialects. At the moment SPIR-V is the final sink.
993993

994-
Value linearizedIndex = builder.create<spirv::ConstantOp>(
994+
Value linearizedIndex = builder.createOrFold<spirv::ConstantOp>(
995995
loc, integerType, IntegerAttr::get(integerType, offset));
996996
for (const auto &index : llvm::enumerate(indices)) {
997-
Value strideVal = builder.create<spirv::ConstantOp>(
997+
Value strideVal = builder.createOrFold<spirv::ConstantOp>(
998998
loc, integerType,
999999
IntegerAttr::get(integerType, strides[index.index()]));
1000-
Value update = builder.create<spirv::IMulOp>(loc, strideVal, index.value());
1000+
Value update =
1001+
builder.createOrFold<spirv::IMulOp>(loc, index.value(), strideVal);
10011002
linearizedIndex =
1002-
builder.create<spirv::IAddOp>(loc, linearizedIndex, update);
1003+
builder.createOrFold<spirv::IAddOp>(loc, update, linearizedIndex);
10031004
}
10041005
return linearizedIndex;
10051006
}

mlir/test/Conversion/GPUToSPIRV/load-store.mlir

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,9 @@ module attributes {
6060
// CHECK: %[[INDEX2:.*]] = spirv.IAdd %[[ARG4]], %[[LOCALINVOCATIONIDX]]
6161
%13 = arith.addi %arg4, %3 : index
6262
// CHECK: %[[ZERO:.*]] = spirv.Constant 0 : i32
63-
// CHECK: %[[OFFSET1_0:.*]] = spirv.Constant 0 : i32
6463
// CHECK: %[[STRIDE1_1:.*]] = spirv.Constant 4 : i32
65-
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[STRIDE1_1]], %[[INDEX1]] : i32
66-
// CHECK: %[[OFFSET1_1:.*]] = spirv.IAdd %[[OFFSET1_0]], %[[UPDATE1_1]] : i32
67-
// CHECK: %[[STRIDE1_2:.*]] = spirv.Constant 1 : i32
68-
// CHECK: %[[UPDATE1_2:.*]] = spirv.IMul %[[STRIDE1_2]], %[[INDEX2]] : i32
69-
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[OFFSET1_1]], %[[UPDATE1_2]] : i32
64+
// CHECK: %[[UPDATE1_1:.*]] = spirv.IMul %[[INDEX1]], %[[STRIDE1_1]] : i32
65+
// CHECK: %[[OFFSET1_2:.*]] = spirv.IAdd %[[INDEX2]], %[[UPDATE1_1]] : i32
7066
// CHECK: %[[PTR1:.*]] = spirv.AccessChain %[[ARG0]]{{\[}}%[[ZERO]], %[[OFFSET1_2]]{{\]}}
7167
// CHECK-NEXT: %[[VAL1:.*]] = spirv.Load "StorageBuffer" %[[PTR1]]
7268
%14 = memref.load %arg0[%12, %13] : memref<12x4xf32, #spirv.storage_class<StorageBuffer>>

0 commit comments

Comments
 (0)