Skip to content

Commit b959690

Browse files
ericcurtinNeoZhangJianyu
authored andcommitted
Implement s3:// protocol (ggml-org#11511)
For those that want to pull from s3 Signed-off-by: Eric Curtin <[email protected]>
1 parent 0c2584c commit b959690

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

examples/run/run.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
6565
return ret;
6666
}
6767

68+
static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
69+
std::ostringstream oss;
70+
oss << std::put_time(&tm, fmt);
71+
72+
return oss.str();
73+
}
74+
6875
class Opt {
6976
public:
7077
int init(int argc, const char ** argv) {
@@ -698,6 +705,39 @@ class LlamaData {
698705
return download(url, bn, true);
699706
}
700707

708+
int s3_dl(const std::string & model, const std::string & bn) {
709+
const size_t slash_pos = model.find('/');
710+
if (slash_pos == std::string::npos) {
711+
return 1;
712+
}
713+
714+
const std::string bucket = model.substr(0, slash_pos);
715+
const std::string key = model.substr(slash_pos + 1);
716+
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
717+
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
718+
if (!access_key || !secret_key) {
719+
printe("AWS credentials not found in environment\n");
720+
return 1;
721+
}
722+
723+
// Generate AWS Signature Version 4 headers
724+
// (Implementation requires HMAC-SHA256 and date handling)
725+
// Get current timestamp
726+
const time_t now = time(nullptr);
727+
const tm tm = *gmtime(&now);
728+
const std::string date = strftime_fmt("%Y%m%d", tm);
729+
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
730+
const std::vector<std::string> headers = {
731+
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
732+
"/us-east-1/s3/aws4_request",
733+
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
734+
};
735+
736+
const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;
737+
738+
return download(url, bn, true, headers);
739+
}
740+
701741
std::string basename(const std::string & path) {
702742
const size_t pos = path.find_last_of("/\\");
703743
if (pos == std::string::npos) {
@@ -738,6 +778,9 @@ class LlamaData {
738778
rm_until_substring(model_, "github:");
739779
rm_until_substring(model_, "://");
740780
ret = github_dl(model_, bn);
781+
} else if (string_starts_with(model_, "s3://")) {
782+
rm_until_substring(model_, "://");
783+
ret = s3_dl(model_, bn);
741784
} else { // ollama:// or nothing
742785
rm_until_substring(model_, "ollama.com/library/");
743786
rm_until_substring(model_, "://");

0 commit comments

Comments
 (0)