Skip to content

Commit cc839da

Browse files
committed
fixup
1 parent 4c87214 commit cc839da

File tree

6 files changed

+88
-44
lines changed

6 files changed

+88
-44
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,15 @@ template <typename T> T MyValueOrThrow(absl::StatusOr<T> v) {
119119
}
120120
}
121121

122+
extern "C" void ReactantHandleCuResult(uint32_t curesult) {
123+
if (curesult != 0) {
124+
std::string err = "Bad Cuda Result = " + std::to_string(curesult);
125+
if (ReactantThrowError) {
126+
ReactantThrowError(err.c_str());
127+
}
128+
}
129+
}
130+
122131
// MLIR C-API extras
123132
#pragma region MLIR Extra
124133
extern "C" MlirAttribute mlirComplexAttrDoubleGet(MlirContext ctx,

deps/ReactantExtra/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,7 @@ cc_library(
436436
"-Wl,-exported_symbol,_ConvertLLVMToMLIR",
437437
"-Wl,-exported_symbol,_RegisterEnzymeXLAGPUHandler",
438438
"-Wl,-exported_symbol,_ReactantThrowError",
439+
"-Wl,-exported_symbol,_ReactantHandleCuResult",
439440
]}),
440441
deps = [
441442
"@enzyme//:EnzymeMLIR",

deps/ReactantExtra/WORKSPACE

Lines changed: 37 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ http_archive(
99
urls = ["https://github.com/wsmoses/nsync/archive/{commit}.tar.gz".format(commit = NSYNC_COMMIT)],
1010
)
1111

12-
ENZYMEXLA_COMMIT = "64a1c283072d4ce4eb319c69b32a6f3c68f30cbe"
12+
ENZYMEXLA_COMMIT = "362f33f518900ebf66cee7f0135a436907f8f692"
1313
ENZYMEXLA_SHA256 = ""
1414

1515
http_archive(
@@ -95,39 +95,39 @@ LLVM_TARGETS = select({
9595
}) + ["AArch64", "X86", "ARM"]
9696

9797
# Uncomment these lines to use a custom LLVM commit
98-
LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
99-
LLVM_SHA256 = ""
100-
http_archive(
101-
name = "llvm-raw",
102-
build_file_content = "# empty",
103-
sha256 = LLVM_SHA256,
104-
strip_prefix = "llvm-project-" + LLVM_COMMIT,
105-
urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106-
)
107-
108-
109-
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110-
maybe(
111-
http_archive,
112-
name = "llvm_zlib",
113-
build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114-
sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115-
strip_prefix = "zlib-ng-2.0.7",
116-
urls = [
117-
"https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118-
],
119-
)
120-
121-
maybe(
122-
http_archive,
123-
name = "llvm_zstd",
124-
build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125-
sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126-
strip_prefix = "zstd-1.5.2",
127-
urls = [
128-
"https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129-
],
130-
)
98+
# LLVM_COMMIT = "b39c5cb6977f35ad727d86b2dd6232099734ffd3"
99+
# LLVM_SHA256 = ""
100+
# http_archive(
101+
# name = "llvm-raw",
102+
# build_file_content = "# empty",
103+
# sha256 = LLVM_SHA256,
104+
# strip_prefix = "llvm-project-" + LLVM_COMMIT,
105+
# urls = ["https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT)],
106+
# )
107+
#
108+
#
109+
# load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
110+
# maybe(
111+
# http_archive,
112+
# name = "llvm_zlib",
113+
# build_file = "@llvm-raw//utils/bazel/third_party_build:zlib-ng.BUILD",
114+
# sha256 = "e36bb346c00472a1f9ff2a0a4643e590a254be6379da7cddd9daeb9a7f296731",
115+
# strip_prefix = "zlib-ng-2.0.7",
116+
# urls = [
117+
# "https://github.com/zlib-ng/zlib-ng/archive/refs/tags/2.0.7.zip",
118+
# ],
119+
# )
120+
#
121+
# maybe(
122+
# http_archive,
123+
# name = "llvm_zstd",
124+
# build_file = "@llvm-raw//utils/bazel/third_party_build:zstd.BUILD",
125+
# sha256 = "7c42d56fac126929a6a85dbc73ff1db2411d04f104fae9bdea51305663a83fd0",
126+
# strip_prefix = "zstd-1.5.2",
127+
# urls = [
128+
# "https://github.com/facebook/zstd/releases/download/v1.5.2/zstd-1.5.2.tar.gz"
129+
# ],
130+
# )
131131

132132
http_archive(
133133
name = "jax",
@@ -138,9 +138,9 @@ http_archive(
138138
patches = ["@enzyme_ad//:patches/jax.patch"],
139139
)
140140

141-
load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142-
# XLA_COMMIT = "88d46fe4b15fff95eae16c64f612e18b71ff49c5"
143-
# XLA_SHA256 = ""
141+
# load("@jax//third_party/xla:workspace.bzl", "XLA_COMMIT", "XLA_SHA256")
142+
XLA_COMMIT = "1bb4fc18e73faa1c001d96bfe3a22f733987b018"
143+
XLA_SHA256 = ""
144144

145145
http_archive(
146146
name = "xla",

ext/ReactantCUDAExt.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ function get_field_offset(T::Type, path)
423423
offset = 0
424424
current_type = T
425425

426+
426427
for field in path
427428
# Get the field index
428429
field_idx = if field isa Integer
@@ -438,11 +439,17 @@ function get_field_offset(T::Type, path)
438439
end
439440

440441
# Add the offset of this field
441-
offset += fieldoffset(current_type, field_idx)
442+
toffset = fieldoffset(current_type, field_idx)
443+
tcurrent_type = fieldtype(current_type, field_idx)
444+
offset += toffset
445+
@show current_type, field_idx, toffset, offset, tcurrent_type
442446

443447
# Update current_type to the field's type for next iteration
444-
current_type = fieldtype(current_type, field_idx)
448+
current_type = tcurrent_type
449+
445450
end
451+
452+
@show T, path, offset
446453

447454
return offset
448455
end
@@ -550,6 +557,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
550557
1,
551558
)
552559
push!(allocs, (alloc, argty))
560+
@show string(alloc), string(argty), typeof(a)
553561

554562
sz = sizeof(a)
555563
array_ty = MLIR.IR.Type(MLIR.API.mlirLLVMArrayTypeGet(MLIR.IR.Type(Int8), sz))
@@ -656,7 +664,7 @@ Reactant.@reactant_overlay @noinline function (func::LLVMFunc{F,tt})(
656664
fn=MLIR.IR.FlatSymbolRefAttribute(sym_name),
657665
output_operand_aliases=MLIR.IR.Attribute(output_operand_aliases),
658666
)
659-
@show string(call), typeof(func.f), collect(map(typeof, args))
667+
# @show string(call), typeof(func.f), collect(map(typeof, args))
660668

661669
argidx = 1
662670
for arg in values(seen)
@@ -786,6 +794,11 @@ function __init__()
786794
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
787795
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
788796
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
797+
ptr4 = Reactant.XLA.Libdl.dlsym(handle, "cuStreamSynchronize"; throw_error=false)
798+
if ptr4 === nothing
799+
ptr4 = C_NULL
800+
end
801+
Reactant.Compiler.cuSync[] = Base.reinterpret(UInt, ptr4)
789802
end
790803
return nothing
791804
end

src/Compiler.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,8 @@ end
357357
const cuLaunch = Ref{UInt}(0)
358358
const cuFunc = Ref{UInt}(0)
359359
const cuModule = Ref{UInt}(0)
360+
const cuSync = Ref{UInt}(0)
361+
const DEBUG_KERNEL = Ref{Bool}(false)
360362

361363
function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
362364
# Explicitly don't use block! to avoid creating a closure, which creates
@@ -385,7 +387,14 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::
385387
if isdefined(Reactant_jll, :ptxas_path)
386388
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
387389
end
388-
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
390+
if DEBUG_KERNEL[]
391+
curesulthandler = XLA.Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "ReactantHandleCuResult")
392+
@assert curesulthandler !== nothing
393+
curesulthandler = Base.reinterpret(UInt, curesulthandler)
394+
kern = "lower-kernel{debug=true cuResultHandlerPtr=$curesulthandler run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) cuStreamSynchronizePtr=$(cuSync[])},symbol-dce"
395+
else
396+
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])},symbol-dce"
397+
end
389398

390399
opt_passes = optimization_passes(; no_nan, sroa=true)
391400
opt_passes2 = optimization_passes(; no_nan, sroa=false)

src/Tracing.jl

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,18 @@ function Base.showerror(io::IO, err::NoFieldMatchError)
318318
)
319319
end
320320

321+
function make_tracer(
322+
seen,
323+
@nospecialize(prev::Union{Base.ExceptionStack, Core.MethodInstance}),
324+
@nospecialize(path),
325+
mode;
326+
toscalar=false,
327+
tobatch=nothing,
328+
track_numbers=(),
329+
kwargs...,
330+
)
331+
return prev
332+
end
321333
append_path(path, i) = (path..., i)
322334

323335
function make_tracer(
@@ -590,7 +602,7 @@ function make_tracer(
590602
if mode == ArrayToConcrete
591603
return ConcreteRNumber(prev)
592604
else
593-
if mode == TracedTrack
605+
if mode == TracedTrack || mode == NoStopTracedTrack
594606
res = TracedRNumber{RT}(
595607
(path,), TracedUtils.broadcast_to_size(prev, ()).mlir_data
596608
)
@@ -638,7 +650,7 @@ end
638650
function make_tracer(
639651
seen, @nospecialize(prev::RT), @nospecialize(path), mode; track_numbers=(), kwargs...
640652
) where {RT<:Array}
641-
if haskey(seen, prev)
653+
if mode != NoStopTracedTrack && haskey(seen, prev)
642654
return seen[prev]
643655
end
644656
if mode == ArrayToConcrete && eltype(RT) <: ReactantPrimitive
@@ -699,7 +711,7 @@ function make_tracer(
699711
end
700712

701713
function make_tracer(seen, prev::Core.Box, @nospecialize(path), mode; kwargs...)
702-
if haskey(seen, prev)
714+
if mode != NoStopTracedTrack && haskey(seen, prev)
703715
return seen[prev]
704716
end
705717
prev2 = prev.contents

0 commit comments

Comments
 (0)