Skip to content

Commit 1e3c630

Browse files
authored
[MLIR] Extend floating point parsing support (#90442)
Parsing support for floating point types was missing a few features: 1. Parsing floating point attributes from integer literals was supported only for types with bitwidth smaller or equal to 64. 2. Downstream users could not use `AsmParser::parseFloat` to parse float types which are printed as integer literals. This commit addresses both these points. It extends `Parser::parseFloatFromIntegerLiteral` to support arbitrary bitwidth, and exposes a new API to parse arbitrary floating point given an fltSemantics as input. The usage of this new API is introduced in the Test Dialect.
1 parent 294eecd commit 1e3c630

File tree

7 files changed

+165
-16
lines changed

7 files changed

+165
-16
lines changed

mlir/include/mlir/IR/OpImplementation.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,10 @@ class AsmParser {
700700
/// Parse a floating point value from the stream.
701701
virtual ParseResult parseFloat(double &result) = 0;
702702

703+
/// Parse a floating point value into APFloat from the stream.
704+
virtual ParseResult parseFloat(const llvm::fltSemantics &semantics,
705+
APFloat &result) = 0;
706+
703707
/// Parse an integer value from the stream.
704708
template <typename IntT>
705709
ParseResult parseInteger(IntT &result) {

mlir/lib/AsmParser/AsmParserImpl.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -269,8 +269,12 @@ class AsmParserImpl : public BaseT {
269269
return success();
270270
}
271271

272-
/// Parse a floating point value from the stream.
273-
ParseResult parseFloat(double &result) override {
272+
/// Parse a floating point value with given semantics from the stream. Since
273+
/// this implementation parses the string as double precision and only
274+
/// afterwards converts the value to the requested semantic, precision may be
275+
/// lost.
276+
ParseResult parseFloat(const llvm::fltSemantics &semantics,
277+
APFloat &result) override {
274278
bool isNegative = parser.consumeIf(Token::minus);
275279
Token curTok = parser.getToken();
276280
SMLoc loc = curTok.getLoc();
@@ -281,26 +285,38 @@ class AsmParserImpl : public BaseT {
281285
if (!val)
282286
return emitError(loc, "floating point value too large");
283287
parser.consumeToken(Token::floatliteral);
284-
result = isNegative ? -*val : *val;
288+
result = APFloat(isNegative ? -*val : *val);
289+
bool losesInfo;
290+
result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
285291
return success();
286292
}
287293

288294
// Check for a hexadecimal float value.
289295
if (curTok.is(Token::integer)) {
290296
std::optional<APFloat> apResult;
291297
if (failed(parser.parseFloatFromIntegerLiteral(
292-
apResult, curTok, isNegative, APFloat::IEEEdouble(),
293-
/*typeSizeInBits=*/64)))
298+
apResult, curTok, isNegative, semantics,
299+
APFloat::semanticsSizeInBits(semantics))))
294300
return failure();
295301

302+
result = *apResult;
296303
parser.consumeToken(Token::integer);
297-
result = apResult->convertToDouble();
298304
return success();
299305
}
300306

301307
return emitError(loc, "expected floating point literal");
302308
}
303309

310+
/// Parse a floating point value from the stream.
311+
ParseResult parseFloat(double &result) override {
312+
llvm::APFloat apResult(0.0);
313+
if (parseFloat(APFloat::IEEEdouble(), apResult))
314+
return failure();
315+
316+
result = apResult.convertToDouble();
317+
return success();
318+
}
319+
304320
/// Parse an optional integer value from the stream.
305321
OptionalParseResult parseOptionalInteger(APInt &result) override {
306322
return parser.parseOptionalInteger(result);

mlir/lib/AsmParser/Parser.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -326,19 +326,15 @@ ParseResult Parser::parseFloatFromIntegerLiteral(
326326
"leading minus");
327327
}
328328

329-
std::optional<uint64_t> value = tok.getUInt64IntegerValue();
330-
if (!value)
329+
APInt intValue;
330+
tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
331+
if (intValue.getActiveBits() > typeSizeInBits)
331332
return emitError(loc, "hexadecimal float constant out of range for type");
332333

333-
if (&semantics == &APFloat::IEEEdouble()) {
334-
result = APFloat(semantics, APInt(typeSizeInBits, *value));
335-
return success();
336-
}
334+
APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
335+
intValue.getRawData());
337336

338-
APInt apInt(typeSizeInBits, *value);
339-
if (apInt != *value)
340-
return emitError(loc, "hexadecimal float constant out of range for type");
341-
result = APFloat(semantics, apInt);
337+
result.emplace(semantics, truncatedValue);
342338

343339
return success();
344340
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
// RUN: mlir-opt %s -split-input-file -verify-diagnostics| FileCheck %s
2+
3+
// CHECK-LABEL: @test_enum_attr_roundtrip
4+
func.func @test_enum_attr_roundtrip() -> () {
5+
// CHECK: attr = #test.custom_float<"float" : 2.000000e+00>
6+
"test.op"() {attr = #test.custom_float<"float" : 2.>} : () -> ()
7+
// CHECK: attr = #test.custom_float<"double" : 2.000000e+00>
8+
"test.op"() {attr = #test.custom_float<"double" : 2.>} : () -> ()
9+
// CHECK: attr = #test.custom_float<"fp80" : 2.000000e+00>
10+
"test.op"() {attr = #test.custom_float<"fp80" : 2.>} : () -> ()
11+
// CHECK: attr = #test.custom_float<"float" : 0x7FC00000>
12+
"test.op"() {attr = #test.custom_float<"float" : 0x7FC00000>} : () -> ()
13+
// CHECK: attr = #test.custom_float<"double" : 0x7FF0000001000000>
14+
"test.op"() {attr = #test.custom_float<"double" : 0x7FF0000001000000>} : () -> ()
15+
// CHECK: attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>
16+
"test.op"() {attr = #test.custom_float<"fp80" : 0x7FFFC000000000100000>} : () -> ()
17+
return
18+
}
19+
20+
// -----
21+
22+
// Verify literal must be hex or float
23+
24+
// expected-error @below {{unexpected decimal integer literal for a floating point value}}
25+
// expected-note @below {{add a trailing dot to make the literal a float}}
26+
"test.op"() {attr = #test.custom_float<"float" : 42>} : () -> ()
27+
28+
// -----
29+
30+
// Integer value must be in the width of the floating point type
31+
32+
// expected-error @below {{hexadecimal float constant out of range for type}}
33+
"test.op"() {attr = #test.custom_float<"float" : 0x7FC000000>} : () -> ()
34+
35+
36+
// -----
37+
38+
// Integer value must be in the width of the floating point type
39+
40+
// expected-error @below {{hexadecimal float constant out of range for type}}
41+
"test.op"() {attr = #test.custom_float<"double" : 0x7FC000007FC0000000>} : () -> ()
42+
43+
44+
// -----
45+
46+
// Integer value must be in the width of the floating point type
47+
48+
// expected-error @below {{hexadecimal float constant out of range for type}}
49+
"test.op"() {attr = #test.custom_float<"fp80" : 0x7FC0000007FC0000007FC000000>} : () -> ()
50+
51+
// -----
52+
53+
// Value must be a floating point literal or integer literal
54+
55+
// expected-error @below {{expected floating point literal}}
56+
"test.op"() {attr = #test.custom_float<"float" : "blabla">} : () -> ()
57+

mlir/test/IR/parser.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,30 @@ func.func @bfloat16_special_values() {
11051105
return
11061106
}
11071107

1108+
// CHECK-LABEL: @f80_special_values
1109+
func.func @f80_special_values() {
1110+
// F80 signaling NaNs.
1111+
// CHECK: arith.constant 0x7FFFE000000000000001 : f80
1112+
%0 = arith.constant 0x7FFFE000000000000001 : f80
1113+
// CHECK: arith.constant 0x7FFFB000000000000011 : f80
1114+
%1 = arith.constant 0x7FFFB000000000000011 : f80
1115+
1116+
// F80 quiet NaNs.
1117+
// CHECK: arith.constant 0x7FFFC000000000100000 : f80
1118+
%2 = arith.constant 0x7FFFC000000000100000 : f80
1119+
// CHECK: arith.constant 0x7FFFE000000001000000 : f80
1120+
%3 = arith.constant 0x7FFFE000000001000000 : f80
1121+
1122+
// F80 positive infinity.
1123+
// CHECK: arith.constant 0x7FFF8000000000000000 : f80
1124+
%4 = arith.constant 0x7FFF8000000000000000 : f80
1125+
// F80 negative infinity.
1126+
// CHECK: arith.constant 0xFFFF8000000000000000 : f80
1127+
%5 = arith.constant 0xFFFF8000000000000000 : f80
1128+
1129+
return
1130+
}
1131+
11081132
// We want to print floats in exponential notation with 6 significant digits,
11091133
// but it may lead to precision loss when parsing back, in which case we print
11101134
// the decimal form instead.

mlir/test/lib/Dialect/Test/TestAttrDefs.td

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,4 +340,15 @@ def TestConditionalAliasAttr : Test_Attr<"TestConditionalAlias"> {
340340
}];
341341
}
342342

343+
// Test AsmParser::parseFloat(const fltSemnatics&, APFloat&) API through the
344+
// custom parser and printer.
345+
def TestCustomFloatAttr : Test_Attr<"TestCustomFloat"> {
346+
let mnemonic = "custom_float";
347+
let parameters = (ins "mlir::StringAttr":$type_str, APFloatParameter<"">:$value);
348+
349+
let assemblyFormat = [{
350+
`<` custom<CustomFloatAttr>($type_str, $value) `>`
351+
}];
352+
}
353+
343354
#endif // TEST_ATTRDEFS

mlir/test/lib/Dialect/Test/TestAttributes.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include "mlir/IR/ExtensibleDialect.h"
1919
#include "mlir/IR/Types.h"
2020
#include "mlir/Support/LogicalResult.h"
21+
#include "llvm/ADT/APFloat.h"
2122
#include "llvm/ADT/Hashing.h"
2223
#include "llvm/ADT/StringExtras.h"
2324
#include "llvm/ADT/TypeSwitch.h"
@@ -240,6 +241,46 @@ static void printConditionalAlias(AsmPrinter &p, StringAttr value) {
240241
p.printKeywordOrString(value);
241242
}
242243

244+
//===----------------------------------------------------------------------===//
245+
// Custom Float Attribute
246+
//===----------------------------------------------------------------------===//
247+
248+
static void printCustomFloatAttr(AsmPrinter &p, StringAttr typeStrAttr,
249+
APFloat value) {
250+
p << typeStrAttr << " : " << value;
251+
}
252+
253+
static ParseResult parseCustomFloatAttr(AsmParser &p, StringAttr &typeStrAttr,
254+
FailureOr<APFloat> &value) {
255+
256+
std::string str;
257+
if (p.parseString(&str))
258+
return failure();
259+
260+
typeStrAttr = StringAttr::get(p.getContext(), str);
261+
262+
if (p.parseColon())
263+
return failure();
264+
265+
const llvm::fltSemantics *semantics;
266+
if (str == "float")
267+
semantics = &llvm::APFloat::IEEEsingle();
268+
else if (str == "double")
269+
semantics = &llvm::APFloat::IEEEdouble();
270+
else if (str == "fp80")
271+
semantics = &llvm::APFloat::x87DoubleExtended();
272+
else
273+
return p.emitError(p.getCurrentLocation(), "unknown float type, expected "
274+
"'float', 'double' or 'fp80'");
275+
276+
APFloat parsedValue(0.0);
277+
if (p.parseFloat(*semantics, parsedValue))
278+
return failure();
279+
280+
value.emplace(parsedValue);
281+
return success();
282+
}
283+
243284
//===----------------------------------------------------------------------===//
244285
// Tablegen Generated Definitions
245286
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)