Skip to content

[mlir][Parser] Deduplicate floating-point parsing functionality #116172

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 7 additions & 28 deletions mlir/lib/AsmParser/AsmParserImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,34 +287,13 @@ class AsmParserImpl : public BaseT {
APFloat &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
SMLoc loc = curTok.getLoc();

// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val)
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
result = APFloat(isNegative ? -*val : *val);
bool losesInfo;
result.convert(semantics, APFloat::rmNearestTiesToEven, &losesInfo);
return success();
}

// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
std::optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
apResult, curTok, isNegative, semantics,
APFloat::semanticsSizeInBits(semantics))))
return failure();

result = *apResult;
parser.consumeToken(Token::integer);
return success();
}

return emitError(loc, "expected floating point literal");
std::optional<APFloat> apResult;
if (failed(parser.parseFloatFromLiteral(apResult, curTok, isNegative,
semantics)))
return failure();
parser.consumeToken();
result = *apResult;
return success();
}

/// Parse a floating point value from the stream.
Expand Down
70 changes: 13 additions & 57 deletions mlir/lib/AsmParser/AttributeParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -424,8 +424,7 @@ Attribute Parser::parseDecOrHexAttr(Type type, bool isNegative) {
if (auto floatType = dyn_cast<FloatType>(type)) {
std::optional<APFloat> result;
if (failed(parseFloatFromIntegerLiteral(result, tok, isNegative,
floatType.getFloatSemantics(),
floatType.getWidth())))
floatType.getFloatSemantics())))
return Attribute();
return FloatAttr::get(floatType, *result);
}
Expand Down Expand Up @@ -658,36 +657,11 @@ TensorLiteralParser::getFloatAttrElements(SMLoc loc, FloatType eltTy,
for (const auto &signAndToken : storage) {
bool isNegative = signAndToken.first;
const Token &token = signAndToken.second;

// Handle hexadecimal float literals.
if (token.is(Token::integer) && token.getSpelling().starts_with("0x")) {
std::optional<APFloat> result;
if (failed(p.parseFloatFromIntegerLiteral(result, token, isNegative,
eltTy.getFloatSemantics(),
eltTy.getWidth())))
return failure();

floatValues.push_back(*result);
continue;
}

// Check to see if any decimal integers or booleans were parsed.
if (!token.is(Token::floatliteral))
return p.emitError()
<< "expected floating-point elements, but parsed integer";

// Build the float values from tokens.
auto val = token.getFloatingPointValue();
if (!val)
return p.emitError("floating point value too large for attribute");

APFloat apVal(isNegative ? -*val : *val);
if (!eltTy.isF64()) {
bool unused;
apVal.convert(eltTy.getFloatSemantics(), APFloat::rmNearestTiesToEven,
&unused);
}
floatValues.push_back(apVal);
std::optional<APFloat> result;
if (failed(p.parseFloatFromLiteral(result, token, isNegative,
eltTy.getFloatSemantics())))
return failure();
floatValues.push_back(*result);
}
return success();
}
Expand Down Expand Up @@ -905,32 +879,14 @@ ParseResult DenseArrayElementParser::parseIntegerElement(Parser &p) {

ParseResult DenseArrayElementParser::parseFloatElement(Parser &p) {
bool isNegative = p.consumeIf(Token::minus);

Token token = p.getToken();
std::optional<APFloat> result;
auto floatType = cast<FloatType>(type);
if (p.consumeIf(Token::integer)) {
// Parse an integer literal as a float.
if (p.parseFloatFromIntegerLiteral(result, token, isNegative,
floatType.getFloatSemantics(),
floatType.getWidth()))
return failure();
} else if (p.consumeIf(Token::floatliteral)) {
// Parse a floating point literal.
std::optional<double> val = token.getFloatingPointValue();
if (!val)
return failure();
result = APFloat(isNegative ? -*val : *val);
if (!type.isF64()) {
bool unused;
result->convert(floatType.getFloatSemantics(),
APFloat::rmNearestTiesToEven, &unused);
}
} else {
return p.emitError("expected integer or floating point literal");
}

append(result->bitcastToAPInt());
std::optional<APFloat> fromIntLit;
if (failed(
p.parseFloatFromLiteral(fromIntLit, token, isNegative,
cast<FloatType>(type).getFloatSemantics())))
return failure();
p.consumeToken();
append(fromIntLit->bitcastToAPInt());
return success();
}

Expand Down
48 changes: 36 additions & 12 deletions mlir/lib/AsmParser/Parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,34 +347,58 @@ OptionalParseResult Parser::parseOptionalDecimalInteger(APInt &result) {
return success();
}

ParseResult Parser::parseFloatFromLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
// Check for a floating point value.
if (tok.is(Token::floatliteral)) {
auto val = tok.getFloatingPointValue();
if (!val)
return emitError(tok.getLoc()) << "floating point value too large";

result.emplace(isNegative ? -*val : *val);
bool unused;
result->convert(semantics, APFloat::rmNearestTiesToEven, &unused);
return success();
}

// Check for a hexadecimal float value.
if (tok.is(Token::integer))
return parseFloatFromIntegerLiteral(result, tok, isNegative, semantics);

return emitError(tok.getLoc()) << "expected floating point literal";
}

/// Parse a floating point value from an integer literal token.
ParseResult Parser::parseFloatFromIntegerLiteral(
std::optional<APFloat> &result, const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics, size_t typeSizeInBits) {
SMLoc loc = tok.getLoc();
ParseResult
Parser::parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics) {
StringRef spelling = tok.getSpelling();
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
if (!isHex) {
return emitError(loc, "unexpected decimal integer literal for a "
"floating point value")
return emitError(tok.getLoc(), "unexpected decimal integer literal for a "
"floating point value")
.attachNote()
<< "add a trailing dot to make the literal a float";
}
if (isNegative) {
return emitError(loc, "hexadecimal float literal should not have a "
"leading minus");
return emitError(tok.getLoc(),
"hexadecimal float literal should not have a "
"leading minus");
}

APInt intValue;
tok.getSpelling().getAsInteger(isHex ? 0 : 10, intValue);
if (intValue.getActiveBits() > typeSizeInBits)
return emitError(loc, "hexadecimal float constant out of range for type");
auto typeSizeInBits = APFloat::semanticsSizeInBits(semantics);
if (intValue.getActiveBits() > typeSizeInBits) {
return emitError(tok.getLoc(),
"hexadecimal float constant out of range for type");
}

APInt truncatedValue(typeSizeInBits, intValue.getNumWords(),
intValue.getRawData());

result.emplace(semantics, truncatedValue);

return success();
}

Expand Down
9 changes: 7 additions & 2 deletions mlir/lib/AsmParser/Parser.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

namespace mlir {
namespace detail {

//===----------------------------------------------------------------------===//
// Parser
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -151,11 +152,15 @@ class Parser {
/// Parse an optional integer value only in decimal format from the stream.
OptionalParseResult parseOptionalDecimalInteger(APInt &result);

/// Parse a floating point value from a literal.
ParseResult parseFloatFromLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics);

/// Parse a floating point value from an integer literal token.
ParseResult parseFloatFromIntegerLiteral(std::optional<APFloat> &result,
const Token &tok, bool isNegative,
const llvm::fltSemantics &semantics,
size_t typeSizeInBits);
const llvm::fltSemantics &semantics);

/// Returns true if the current token corresponds to a keyword.
bool isCurrentTokenAKeyword() const {
Expand Down
10 changes: 6 additions & 4 deletions mlir/test/IR/invalid-builtin-attributes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ func.func @elementsattr_floattype1() -> () {
// -----

func.func @elementsattr_floattype2() -> () {
// expected-error@+1 {{expected floating-point elements, but parsed integer}}
// expected-error@below {{unexpected decimal integer literal for a floating point value}}
// expected-note@below {{add a trailing dot to make the literal a float}}
"foo"(){bar = dense<[4]> : tensor<1xf32>} : () -> ()
}

Expand Down Expand Up @@ -138,21 +139,22 @@ func.func @float_in_int_tensor() {
// -----

func.func @float_in_bool_tensor() {
// expected-error @+1 {{expected integer elements, but parsed floating-point}}
// expected-error@below {{expected integer elements, but parsed floating-point}}
"foo"() {bar = dense<[true, 42.0]> : tensor<2xi1>} : () -> ()
}

// -----

func.func @decimal_int_in_float_tensor() {
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
// expected-error@below {{unexpected decimal integer literal for a floating point value}}
// expected-note@below {{add a trailing dot to make the literal a float}}
"foo"() {bar = dense<[42, 42.0]> : tensor<2xf32>} : () -> ()
}

// -----

func.func @bool_in_float_tensor() {
// expected-error @+1 {{expected floating-point elements, but parsed integer}}
// expected-error @+1 {{expected floating point literal}}
"foo"() {bar = dense<[42.0, true]> : tensor<2xf32>} : () -> ()
}

Expand Down
Loading