Skip to content

Add regex interface with re2 and std::regex implementations #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions include/pytorch/tokenizers/re2_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <string>

#include <re2/re2.h>

#include <pytorch/tokenizers/regex.h>

namespace tokenizers {

/**
* @brief RE2-based implementation of IRegex.
*/
class Re2Regex : public IRegex {
public:
/**
* @brief Construct a RE2 regex with the given pattern.
*
* @param pattern The regex pattern to compile.
*/
explicit Re2Regex(const std::string& pattern);

/**
* @brief Return all non-overlapping matches found in the input string.
*/
virtual std::vector<Match> find_all(const std::string& text) const override;

private:
std::unique_ptr<re2::RE2> regex_;

friend Result<std::unique_ptr<IRegex>> create_regex(
const std::string& pattern);
};

} // namespace tokenizers
48 changes: 48 additions & 0 deletions include/pytorch/tokenizers/regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <string>
#include <vector>

#include <pytorch/tokenizers/result.h>

namespace tokenizers {

struct Match {
size_t start; // starting index of the match
size_t end; // ending index of the match (exclusive)
};

/**
* @brief Abstract interface for regex wrappers.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May like something like this:

#pragma once

#include <string>
#include <vector>

class Regex {
public:
  virtual ~Regex() = default;

  // The only method subclasses have to implement.
  virtual std::pair<size_t, size_t> match(const std::string& text, size_t start) const = 0;

  // Convenience overload to match from the beginning.
  std::pair<size_t, size_t> match(const std::string& text) const {
    return match(text, 0);
  }

  // General implementation to match all.
  std::vector<std::pair<size_t, size_t>> match_all(const std::string& text, size_t start = 0) const {
    std::vector<std::pair<size_t, size_t>> matches;
    for (size_t length = 0;; start += length) {
      std::tie(start, length) = match(text, start);
      if (length == 0) {
        break;
      }
      matches.emplace_back(start, length);
    }
    return matches;
  }
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like we should just leave this API as is. We can get into a more granular API design later if necessary but the main point of all of this was to simply just provide a pcre2 fallback if re2 didn't work. I don't really expect people to be adding different regex implementations to be honest so don't want to overengineer too much. Another reason is I'd like to not mess with the current re2 code which uses FindAndConsume, which is stateful and would not fit into the proposed match API

*/
class IRegex {
public:
virtual ~IRegex() = default;

/**
* @brief Find all non-overlapping matches in the input string.
*
* @param text The input string to search.
* @return A vector of strings containing all matched substrings.
*/
virtual std::vector<Match> find_all(const std::string& text) const = 0;
};

/**
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex.
*
* @param pattern The regex pattern to compile.
* @return A unique pointer to an IRegex-compatible object.
*/
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern);

} // namespace tokenizers
40 changes: 40 additions & 0 deletions include/pytorch/tokenizers/std_regex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <memory>
#include <regex>
#include <string>
#include "regex.h"

namespace tokenizers {

/**
* @brief std::regex-based implementation of IRegex.
*/
class StdRegex : public IRegex {
public:
/**
* @brief Construct a std::regex wrapper with the given pattern.
*
* @param pattern The regex pattern to compile.
* @throws std::regex_error if the pattern is invalid.
*/
explicit StdRegex(const std::string& pattern);

/**
* @brief Find all non-overlapping matches in the input string.
*/
virtual std::vector<Match> find_all(const std::string& text) const override;

private:
std::regex regex_;
};

} // namespace tokenizers
36 changes: 36 additions & 0 deletions src/re2_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/re2_regex.h>

namespace tokenizers {

Re2Regex::Re2Regex(const std::string& pattern) {
regex_ = std::make_unique<re2::RE2>(pattern);
// Warmup re2 as it is slow on the first run, void the return value as it's
// not needed Refer to
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
(void)regex_->ReverseProgramSize();
}

std::vector<Match> Re2Regex::find_all(const std::string& text) const {
std::vector<Match> result;
re2::StringPiece input(text);
re2::StringPiece piece;

const char* base = input.data();

while (RE2::FindAndConsume(&input, *regex_, &piece)) {
size_t start = piece.data() - base;
result.push_back({start, start + piece.size()});
}

return result;
}

} // namespace tokenizers
50 changes: 50 additions & 0 deletions src/regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/re2_regex.h>
#include <pytorch/tokenizers/regex.h>
#include <pytorch/tokenizers/std_regex.h>

#include <re2/re2.h>
#include <iostream>
#include <memory>

namespace tokenizers {

/**
* @brief Factory function that creates a regex object using RE2 if possible.
* Falls back to std::regex if RE2 rejects the pattern with
* ErrorBadPerlOp.
*/
Result<std::unique_ptr<IRegex>> create_regex(const std::string& pattern) {
// Try RE2 first
auto re2 = std::make_unique<Re2Regex>("(" + pattern + ")");

if (re2->regex_->ok()) {
return static_cast<std::unique_ptr<IRegex>>(std::move(re2));
}

if (re2->regex_->error_code() == re2::RE2::ErrorBadPerlOp) {
try {
std::cout
<< "RE2 is unable to support things such as negative lookaheads in "
<< pattern << ", defaulting to std::regex.";
auto std_regex = std::make_unique<StdRegex>("(" + pattern + ")");
return static_cast<std::unique_ptr<IRegex>>(std::move(std_regex));
} catch (const std::regex_error& e) {
std::cerr << "std::regex failed: " << e.what() << std::endl;
return tokenizers::Error::LoadFailure;
}
} else {
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
std::cerr << "Error: " << (re2->regex_->error()) << std::endl;
return tokenizers::Error::LoadFailure;
}
}

} // namespace tokenizers
30 changes: 30 additions & 0 deletions src/std_regex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <pytorch/tokenizers/std_regex.h>
#include <regex>

namespace tokenizers {

StdRegex::StdRegex(const std::string& pattern) : regex_(pattern) {}

std::vector<Match> StdRegex::find_all(const std::string& text) const {
std::vector<Match> result;
std::sregex_iterator iter(text.begin(), text.end(), regex_);
std::sregex_iterator end;

for (; iter != end; ++iter) {
const auto& match = *iter;
size_t start = match.position(1);
result.push_back({start, start + match[1].length()});
}

return result;
}

} // namespace tokenizers