@@ -926,7 +926,8 @@ CreateFileHandler(MemoryBuffer &FirstInput,
926
926
" '" + FilesType + " ': invalid file type specified" );
927
927
}
928
928
929
- OffloadBundlerConfig::OffloadBundlerConfig () {
929
+ OffloadBundlerConfig::OffloadBundlerConfig ()
930
+ : CompressedBundleVersion(CompressedOffloadBundle::DefaultVersion) {
930
931
if (llvm::compression::zstd::isAvailable ()) {
931
932
CompressionFormat = llvm::compression::Format::Zstd;
932
933
// Compression level 3 is usually sufficient for zstd since long distance
@@ -942,16 +943,13 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
942
943
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_IGNORE_ENV_VAR" );
943
944
if (IgnoreEnvVarOpt.has_value () && IgnoreEnvVarOpt.value () == " 1" )
944
945
return ;
945
-
946
946
auto VerboseEnvVarOpt = llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_VERBOSE" );
947
947
if (VerboseEnvVarOpt.has_value ())
948
948
Verbose = VerboseEnvVarOpt.value () == " 1" ;
949
-
950
949
auto CompressEnvVarOpt =
951
950
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESS" );
952
951
if (CompressEnvVarOpt.has_value ())
953
952
Compress = CompressEnvVarOpt.value () == " 1" ;
954
-
955
953
auto CompressionLevelEnvVarOpt =
956
954
llvm::sys::Process::GetEnv (" OFFLOAD_BUNDLER_COMPRESSION_LEVEL" );
957
955
if (CompressionLevelEnvVarOpt.has_value ()) {
@@ -964,6 +962,26 @@ OffloadBundlerConfig::OffloadBundlerConfig() {
964
962
<< " Warning: Invalid value for OFFLOAD_BUNDLER_COMPRESSION_LEVEL: "
965
963
<< CompressionLevelStr.str () << " . Ignoring it.\n " ;
966
964
}
965
+ auto CompressedBundleFormatVersionOpt =
966
+ llvm::sys::Process::GetEnv (" COMPRESSED_BUNDLE_FORMAT_VERSION" );
967
+ if (CompressedBundleFormatVersionOpt.has_value ()) {
968
+ llvm::StringRef VersionStr = CompressedBundleFormatVersionOpt.value ();
969
+ uint16_t Version;
970
+ if (!VersionStr.getAsInteger (10 , Version)) {
971
+ if (Version >= 2 && Version <= 3 )
972
+ CompressedBundleVersion = Version;
973
+ else
974
+ llvm::errs ()
975
+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
976
+ << VersionStr.str ()
977
+ << " . Valid values are 2 or 3. Using default version "
978
+ << CompressedBundleVersion << " .\n " ;
979
+ } else
980
+ llvm::errs ()
981
+ << " Warning: Invalid value for COMPRESSED_BUNDLE_FORMAT_VERSION: "
982
+ << VersionStr.str () << " . Using default version "
983
+ << CompressedBundleVersion << " .\n " ;
984
+ }
967
985
}
968
986
969
987
// Utility function to format numbers with commas
@@ -980,12 +998,11 @@ static std::string formatWithCommas(unsigned long long Value) {
980
998
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
981
999
CompressedOffloadBundle::compress (llvm::compression::Params P,
982
1000
const llvm::MemoryBuffer &Input,
983
- bool Verbose) {
1001
+ uint16_t Version, bool Verbose) {
984
1002
if (!llvm::compression::zstd::isAvailable () &&
985
1003
!llvm::compression::zlib::isAvailable ())
986
1004
return createStringError (llvm::inconvertibleErrorCode (),
987
1005
" Compression not supported" );
988
-
989
1006
llvm::Timer HashTimer (" Hash Calculation Timer" , " Hash calculation time" ,
990
1007
ClangOffloadBundlerTimerGroup);
991
1008
if (Verbose)
@@ -1002,7 +1019,6 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1002
1019
auto BufferUint8 = llvm::ArrayRef<uint8_t >(
1003
1020
reinterpret_cast <const uint8_t *>(Input.getBuffer ().data ()),
1004
1021
Input.getBuffer ().size ());
1005
-
1006
1022
llvm::Timer CompressTimer (" Compression Timer" , " Compression time" ,
1007
1023
ClangOffloadBundlerTimerGroup);
1008
1024
if (Verbose)
@@ -1012,22 +1028,54 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1012
1028
CompressTimer.stopTimer ();
1013
1029
1014
1030
uint16_t CompressionMethod = static_cast <uint16_t >(P.format );
1015
- uint32_t UncompressedSize = Input.getBuffer ().size ();
1016
- uint32_t TotalFileSize = MagicNumber.size () + sizeof (TotalFileSize) +
1017
- sizeof (Version) + sizeof (CompressionMethod) +
1018
- sizeof (UncompressedSize) + sizeof (TruncatedHash) +
1019
- CompressedBuffer.size ();
1031
+
1032
+ // Store sizes in 64-bit variables first
1033
+ uint64_t UncompressedSize64 = Input.getBuffer ().size ();
1034
+ uint64_t TotalFileSize64;
1035
+
1036
+ // Calculate total file size based on version
1037
+ if (Version == 2 ) {
1038
+ // For V2, ensure the sizes don't exceed 32-bit limit
1039
+ if (UncompressedSize64 > std::numeric_limits<uint32_t >::max ())
1040
+ return createStringError (llvm::inconvertibleErrorCode (),
1041
+ " Uncompressed size exceeds version 2 limit" );
1042
+ if ((MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1043
+ sizeof (CompressionMethod) + sizeof (uint32_t ) + sizeof (TruncatedHash) +
1044
+ CompressedBuffer.size ()) > std::numeric_limits<uint32_t >::max ())
1045
+ return createStringError (llvm::inconvertibleErrorCode (),
1046
+ " Total file size exceeds version 2 limit" );
1047
+
1048
+ TotalFileSize64 = MagicNumber.size () + sizeof (uint32_t ) + sizeof (Version) +
1049
+ sizeof (CompressionMethod) + sizeof (uint32_t ) +
1050
+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1051
+ } else { // Version 3
1052
+ TotalFileSize64 = MagicNumber.size () + sizeof (uint64_t ) + sizeof (Version) +
1053
+ sizeof (CompressionMethod) + sizeof (uint64_t ) +
1054
+ sizeof (TruncatedHash) + CompressedBuffer.size ();
1055
+ }
1020
1056
1021
1057
SmallVector<char , 0 > FinalBuffer;
1022
1058
llvm::raw_svector_ostream OS (FinalBuffer);
1023
1059
OS << MagicNumber;
1024
1060
OS.write (reinterpret_cast <const char *>(&Version), sizeof (Version));
1025
1061
OS.write (reinterpret_cast <const char *>(&CompressionMethod),
1026
1062
sizeof (CompressionMethod));
1027
- OS.write (reinterpret_cast <const char *>(&TotalFileSize),
1028
- sizeof (TotalFileSize));
1029
- OS.write (reinterpret_cast <const char *>(&UncompressedSize),
1030
- sizeof (UncompressedSize));
1063
+
1064
+ // Write size fields according to version
1065
+ if (Version == 2 ) {
1066
+ uint32_t TotalFileSize32 = static_cast <uint32_t >(TotalFileSize64);
1067
+ uint32_t UncompressedSize32 = static_cast <uint32_t >(UncompressedSize64);
1068
+ OS.write (reinterpret_cast <const char *>(&TotalFileSize32),
1069
+ sizeof (TotalFileSize32));
1070
+ OS.write (reinterpret_cast <const char *>(&UncompressedSize32),
1071
+ sizeof (UncompressedSize32));
1072
+ } else { // Version 3
1073
+ OS.write (reinterpret_cast <const char *>(&TotalFileSize64),
1074
+ sizeof (TotalFileSize64));
1075
+ OS.write (reinterpret_cast <const char *>(&UncompressedSize64),
1076
+ sizeof (UncompressedSize64));
1077
+ }
1078
+
1031
1079
OS.write (reinterpret_cast <const char *>(&TruncatedHash),
1032
1080
sizeof (TruncatedHash));
1033
1081
OS.write (reinterpret_cast <const char *>(CompressedBuffer.data ()),
@@ -1037,18 +1085,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1037
1085
auto MethodUsed =
1038
1086
P.format == llvm::compression::Format::Zstd ? " zstd" : " zlib" ;
1039
1087
double CompressionRate =
1040
- static_cast <double >(UncompressedSize ) / CompressedBuffer.size ();
1088
+ static_cast <double >(UncompressedSize64 ) / CompressedBuffer.size ();
1041
1089
double CompressionTimeSeconds = CompressTimer.getTotalTime ().getWallTime ();
1042
1090
double CompressionSpeedMBs =
1043
- (UncompressedSize / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1044
-
1091
+ (UncompressedSize64 / (1024.0 * 1024.0 )) / CompressionTimeSeconds;
1045
1092
llvm::errs () << " Compressed bundle format version: " << Version << " \n "
1046
1093
<< " Total file size (including headers): "
1047
- << formatWithCommas (TotalFileSize ) << " bytes\n "
1094
+ << formatWithCommas (TotalFileSize64 ) << " bytes\n "
1048
1095
<< " Compression method used: " << MethodUsed << " \n "
1049
1096
<< " Compression level: " << P.level << " \n "
1050
1097
<< " Binary size before compression: "
1051
- << formatWithCommas (UncompressedSize ) << " bytes\n "
1098
+ << formatWithCommas (UncompressedSize64 ) << " bytes\n "
1052
1099
<< " Binary size after compression: "
1053
1100
<< formatWithCommas (CompressedBuffer.size ()) << " bytes\n "
1054
1101
<< " Compression rate: "
@@ -1060,16 +1107,17 @@ CompressedOffloadBundle::compress(llvm::compression::Params P,
1060
1107
<< " Truncated MD5 hash: "
1061
1108
<< llvm::format_hex (TruncatedHash, 16 ) << " \n " ;
1062
1109
}
1110
+
1063
1111
return llvm::MemoryBuffer::getMemBufferCopy (
1064
1112
llvm::StringRef (FinalBuffer.data (), FinalBuffer.size ()));
1065
1113
}
1066
1114
1067
1115
llvm::Expected<std::unique_ptr<llvm::MemoryBuffer>>
1068
1116
CompressedOffloadBundle::decompress (const llvm::MemoryBuffer &Input,
1069
1117
bool Verbose) {
1070
-
1071
1118
StringRef Blob = Input.getBuffer ();
1072
1119
1120
+ // Check minimum header size (using V1 as it's the smallest)
1073
1121
if (Blob.size () < V1HeaderSize)
1074
1122
return llvm::MemoryBuffer::getMemBufferCopy (Blob);
1075
1123
@@ -1082,31 +1130,56 @@ CompressedOffloadBundle::decompress(const llvm::MemoryBuffer &Input,
1082
1130
1083
1131
size_t CurrentOffset = MagicSize;
1084
1132
1133
+ // Read version
1085
1134
uint16_t ThisVersion;
1086
1135
memcpy (&ThisVersion, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1087
1136
CurrentOffset += VersionFieldSize;
1088
1137
1138
+ // Verify header size based on version
1139
+ if (ThisVersion >= 2 && ThisVersion <= 3 ) {
1140
+ size_t RequiredSize = (ThisVersion == 2 ) ? V2HeaderSize : V3HeaderSize;
1141
+ if (Blob.size () < RequiredSize)
1142
+ return createStringError (inconvertibleErrorCode (),
1143
+ " Compressed bundle header size too small" );
1144
+ }
1145
+
1146
+ // Read compression method
1089
1147
uint16_t CompressionMethod;
1090
1148
memcpy (&CompressionMethod, Blob.data () + CurrentOffset, sizeof (uint16_t ));
1091
1149
CurrentOffset += MethodFieldSize;
1092
1150
1093
- uint32_t TotalFileSize;
1151
+ // Read total file size (version 2+)
1152
+ uint64_t TotalFileSize = 0 ;
1094
1153
if (ThisVersion >= 2 ) {
1095
- if (Blob.size () < V2HeaderSize)
1096
- return createStringError (inconvertibleErrorCode (),
1097
- " Compressed bundle header size too small" );
1098
- memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1099
- CurrentOffset += FileSizeFieldSize;
1154
+ if (ThisVersion == 2 ) {
1155
+ uint32_t TotalFileSize32;
1156
+ memcpy (&TotalFileSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1157
+ TotalFileSize = TotalFileSize32;
1158
+ CurrentOffset += FileSizeFieldSizeV2;
1159
+ } else { // Version 3
1160
+ memcpy (&TotalFileSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1161
+ CurrentOffset += FileSizeFieldSizeV3;
1162
+ }
1100
1163
}
1101
1164
1102
- uint32_t UncompressedSize;
1103
- memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1104
- CurrentOffset += UncompressedSizeFieldSize;
1165
+ // Read uncompressed size
1166
+ uint64_t UncompressedSize = 0 ;
1167
+ if (ThisVersion <= 2 ) {
1168
+ uint32_t UncompressedSize32;
1169
+ memcpy (&UncompressedSize32, Blob.data () + CurrentOffset, sizeof (uint32_t ));
1170
+ UncompressedSize = UncompressedSize32;
1171
+ CurrentOffset += UncompressedSizeFieldSizeV2;
1172
+ } else { // Version 3
1173
+ memcpy (&UncompressedSize, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1174
+ CurrentOffset += UncompressedSizeFieldSizeV3;
1175
+ }
1105
1176
1177
+ // Read hash
1106
1178
uint64_t StoredHash;
1107
1179
memcpy (&StoredHash, Blob.data () + CurrentOffset, sizeof (uint64_t ));
1108
1180
CurrentOffset += HashFieldSize;
1109
1181
1182
+ // Determine compression format
1110
1183
llvm::compression::Format CompressionFormat;
1111
1184
if (CompressionMethod ==
1112
1185
static_cast <uint16_t >(llvm::compression::Format::Zlib))
@@ -1372,7 +1445,8 @@ Error OffloadBundler::BundleFiles() {
1372
1445
auto CompressionResult = CompressedOffloadBundle::compress (
1373
1446
{BundlerConfig.CompressionFormat , BundlerConfig.CompressionLevel ,
1374
1447
/* zstdEnableLdm=*/ true },
1375
- *BufferMemory, BundlerConfig.Verbose );
1448
+ *BufferMemory, BundlerConfig.CompressedBundleVersion ,
1449
+ BundlerConfig.Verbose );
1376
1450
if (auto Error = CompressionResult.takeError ())
1377
1451
return Error;
1378
1452
0 commit comments