Skip to content

Commit 3392ab5

Browse files
committed
fix: some JET errors
1 parent bc67537 commit 3392ab5

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

src/EvaluateDerivative.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,12 +222,16 @@ function eval_grad_tree_array(
222222
eval_grad_tree_array(tree, n_gradients, nothing, cX, operators, Val(true))
223223
elseif constant_mode
224224
index_tree = index_constants(tree)
225-
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(false))
225+
eval_grad_tree_array(
226+
tree, n_gradients, index_tree, cX, operators, Val(false)
227+
)
226228
elseif both_mode
227229
# features come first because we can use size(cX, 1) to skip them
228230
index_tree = index_constants(tree)
229-
eval_grad_tree_array(tree, n_gradients, index_tree, cX, operators, Val(:both))
230-
end
231+
eval_grad_tree_array(
232+
tree, n_gradients, index_tree, cX, operators, Val(:both)
233+
)
234+
end::ResultOk2
231235

232236
return (result.x, result.dx, result.ok)
233237
end
@@ -351,7 +355,7 @@ function grad_deg0_eval(
351355
end
352356

353357
derivative_part = zero_mat
354-
derivative_part[index, :] .= one(T)
358+
fill!(@view(derivative_part[index, :]), one(T))
355359
return ResultOk2(const_part, derivative_part, true)
356360
end
357361

src/OperatorEnumConstruction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,8 +228,10 @@ function _extend_operators(operators, skip_user_operators, kws, __module__::Modu
228228
)
229229
end
230230

231-
empty_old_operators_idx = findfirst(x -> first(x.args) == :empty_old_operators, kws)
232-
internal_idx = findfirst(x -> first(x.args) == :internal, kws)
231+
empty_old_operators_idx = findfirst(
232+
x -> hasproperty(x, :args) && first(x.args) == :empty_old_operators, kws
233+
)
234+
internal_idx = findfirst(x -> hasproperty(x, :args) && first(x.args) == :internal, kws)
233235

234236
empty_old_operators = if empty_old_operators_idx !== nothing
235237
@assert kws[empty_old_operators_idx].head == :(=)

0 commit comments

Comments
 (0)