Skip to content

Commit 934ffa3

Browse files
committed
Add regex interface
1 parent 6a6e24f commit 934ffa3

File tree

6 files changed

+196
-0
lines changed

6 files changed

+196
-0
lines changed
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
#include "regex.h"
6+
7+
// Third Party
8+
#include <re2/re2.h>
9+
10+
/**
11+
* @brief RE2-based implementation of IRegex.
12+
*/
13+
class Re2Regex : public IRegex {
14+
public:
15+
/**
16+
* @brief Construct a RE2 regex with the given pattern.
17+
*
18+
* @param pattern The regex pattern to compile.
19+
*/
20+
explicit Re2Regex(const std::string& pattern);
21+
22+
/**
23+
* @brief Return all non-overlapping matches found in the input string.
24+
*/
25+
virtual std::vector<Match> findAll(const std::string& text) const override;
26+
27+
protected:
28+
/**
29+
* @brief Check if RE2 compiled the pattern successfully.
30+
*/
31+
bool ok() const;
32+
33+
/**
34+
* @brief Expose internal RE2 pointer to the factory if needed.
35+
*/
36+
const re2::RE2* rawRegex() const;
37+
38+
private:
39+
std::unique_ptr<re2::RE2> regex_;
40+
41+
friend std::unique_ptr<IRegex> createRegex(const std::string& pattern);
42+
};

include/pytorch/tokenizers/regex.h

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <string>
5+
#include <vector>
6+
7+
struct Match {
8+
std::string text;
9+
size_t position;
10+
};
11+
12+
/**
13+
* @brief Abstract interface for regex wrappers.
14+
*/
15+
class IRegex {
16+
public:
17+
virtual ~IRegex() = default;
18+
19+
/**
20+
* @brief Find all non-overlapping matches in the input string.
21+
*
22+
* @param text The input string to search.
23+
* @return A vector of strings containing all matched substrings.
24+
*/
25+
virtual std::vector<Match> findAll(const std::string& text) const = 0;
26+
};
27+
28+
/**
29+
* @brief Creates a regex instance. Tries RE2 first, falls back to std::regex.
30+
*
31+
* @param pattern The regex pattern to compile.
32+
* @return A unique pointer to an IRegex-compatible object.
33+
*/
34+
std::unique_ptr<IRegex> createRegex(const std::string& pattern);
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
#pragma once
2+
3+
#include <memory>
4+
#include <regex>
5+
#include <string>
6+
#include "regex.h"
7+
8+
/**
9+
* @brief std::regex-based implementation of IRegex.
10+
*/
11+
class StdRegex : public IRegex {
12+
public:
13+
/**
14+
* @brief Construct a std::regex wrapper with the given pattern.
15+
*
16+
* @param pattern The regex pattern to compile.
17+
* @throws std::regex_error if the pattern is invalid.
18+
*/
19+
explicit StdRegex(const std::string& pattern);
20+
21+
/**
22+
* @brief Find all non-overlapping matches in the input string.
23+
*/
24+
virtual std::vector<Match> findAll(const std::string& text) const override;
25+
26+
private:
27+
std::regex regex_;
28+
};

src/re2_regex.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include "pytorch/tokenizers/re2_regex.h"
2+
#include <re2/re2.h>
3+
4+
Re2Regex::Re2Regex(const std::string& pattern) {
5+
regex_ = std::make_unique<re2::RE2>("(" + pattern + ")");
6+
// Warmup re2 as it is slow on the first run, void the return value as it's
7+
// not needed Refer to
8+
// https://github.com/google/re2/blob/6dcd83d60f7944926bfd308cc13979fc53dd69ca/re2/fuzzing/re2_fuzzer.cc#L136-L141
9+
(void)regex_->ReverseProgramSize();
10+
}
11+
12+
bool Re2Regex::ok() const {
13+
return regex_ && regex_->ok();
14+
}
15+
16+
const re2::RE2* Re2Regex::rawRegex() const {
17+
return regex_.get();
18+
}
19+
20+
std::vector<Match> Re2Regex::findAll(const std::string& text) const {
21+
std::vector<Match> result;
22+
re2::StringPiece input(text);
23+
re2::StringPiece piece;
24+
25+
const char* base = input.data();
26+
27+
while (RE2::FindAndConsume(&input, *regex_, &piece)) {
28+
size_t pos = piece.data() - base;
29+
result.push_back({ std::string(piece.data(), piece.size()), pos });
30+
}
31+
32+
return result;
33+
}

src/regex.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "pytorch/tokenizers/regex.h"
2+
#include "pytorch/tokenizers/re2_regex.h"
3+
#include "pytorch/tokenizers/std_regex.h"
4+
5+
#include <re2/re2.h>
6+
#include <iostream>
7+
#include <memory>
8+
9+
/**
10+
* @brief Factory function that creates a regex object using RE2 if possible.
11+
* Falls back to std::regex if RE2 rejects the pattern with
12+
* ErrorBadPerlOp.
13+
*/
14+
std::unique_ptr<IRegex> createRegex(const std::string& pattern) {
15+
auto re2 = std::make_unique<Re2Regex>(pattern);
16+
17+
if (re2->ok()) {
18+
return re2;
19+
}
20+
21+
const re2::RE2* raw = re2->rawRegex();
22+
if (raw && raw->error_code() == re2::RE2::ErrorBadPerlOp) {
23+
try {
24+
std::cout
25+
<< "RE2 is unable to support things such as negative lookaheads in "
26+
<< pattern << ", defaulting to std::regex.";
27+
return std::make_unique<StdRegex>(pattern);
28+
} catch (const std::regex_error& e) {
29+
std::cerr << "std::regex failed: " << e.what() << std::endl;
30+
return nullptr;
31+
}
32+
} else {
33+
std::cerr << "RE2 failed to compile pattern: " << pattern << "\n";
34+
std::cerr << "Error: " << (raw ? raw->error() : "unknown") << std::endl;
35+
return nullptr;
36+
}
37+
}

src/std_regex.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#include "pytorch/tokenizers/std_regex.h"
2+
#include <regex>
3+
4+
StdRegex::StdRegex(const std::string& pattern)
5+
: regex_("(" + pattern + ")") // Add parentheses like RE2 version
6+
{}
7+
8+
std::vector<Match> StdRegex::findAll(const std::string& text) const {
9+
std::vector<Match> result;
10+
std::sregex_iterator iter(text.begin(), text.end(), regex_);
11+
std::sregex_iterator end;
12+
13+
for (; iter != end; ++iter) {
14+
const auto& match = *iter;
15+
result.push_back({
16+
match[1].str(), // capture group 1
17+
static_cast<size_t>(match.position(1)) // position of group 1
18+
});
19+
}
20+
21+
return result;
22+
}

0 commit comments

Comments
 (0)