Skip to content

Implement s3:// protocol #11511

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 1, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions examples/run/run.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ static int printe(const char * fmt, ...) {
return ret;
}

static std::string strftime_fmt(const char * fmt, const std::tm & tm) {
std::ostringstream oss;
oss << std::put_time(&tm, fmt);

return oss.str();
}

class Opt {
public:
int init(int argc, const char ** argv) {
Expand Down Expand Up @@ -698,6 +705,39 @@ class LlamaData {
return download(url, bn, true);
}

int s3_dl(const std::string & model, const std::string & bn) {
const size_t slash_pos = model.find('/');
if (slash_pos == std::string::npos) {
return 1;
}

const std::string bucket = model.substr(0, slash_pos);
const std::string key = model.substr(slash_pos + 1);
const char * access_key = std::getenv("AWS_ACCESS_KEY_ID");
const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
if (!access_key || !secret_key) {
printe("AWS credentials not found in environment\n");
return 1;
}

// Generate AWS Signature Version 4 headers
// (Implementation requires HMAC-SHA256 and date handling)
// Get current timestamp
const time_t now = time(nullptr);
const tm tm = *gmtime(&now);
const std::string date = strftime_fmt("%Y%m%d", tm);
const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm);
const std::vector<std::string> headers = {
"Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date +
"/us-east-1/s3/aws4_request",
"x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime
};

const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key;

return download(url, bn, true, headers);
}

std::string basename(const std::string & path) {
const size_t pos = path.find_last_of("/\\");
if (pos == std::string::npos) {
Expand Down Expand Up @@ -738,6 +778,9 @@ class LlamaData {
rm_until_substring(model_, "github:");
rm_until_substring(model_, "://");
ret = github_dl(model_, bn);
} else if (string_starts_with(model_, "s3://")) {
rm_until_substring(model_, "://");
ret = s3_dl(model_, bn);
} else { // ollama:// or nothing
rm_until_substring(model_, "ollama.com/library/");
rm_until_substring(model_, "://");
Expand Down
Loading