clj-djl.training

accuracy

(accuracy)

add-accumulator

(add-accumulator acc key)

Adds an accumulator to the accuracy for the results of the evaluation with the given key.

add-evaluator

(add-evaluator config evaluator)

add-training-listeners

(add-training-listeners config listener)

as-consumer

macro

(as-consumer f)

backward

(backward gc target)

binary-accuracy

(binary-accuracy)(binary-accuracy threshold)(binary-accuracy acc-name threshold)(binary-accuracy acc-name threshold axis)

close

(close batch)

config

(config {:keys [loss devices data-manager initializer parameter optimizer evaluator listeners]})

default-training-config

fit

(fit trainer nepochs train-iter)(fit trainer nepochs train-iter test-iter)

forward

(forward trainer input)

get-accumulator

(get-accumulator acc key)

Returns the accumulated evaluator value.

get-devices

(get-devices config)

get-evaluators

(get-evaluators trainer)

get-gradient

(get-gradient ndarray)

Returns the gradient NDArray attached to this NDArray.

get-loss

(get-loss trainer)

get-manager

(get-manager trainer)

get-metrics

(get-metrics trainer)

get-model

(get-model trainer)

get-result

(get-result trainer)

get-training-result

gradient-collector

(gradient-collector)(gradient-collector trainer)

initialize

(initialize trainer shapes)(initialize trainer shape & shapes)

iter-seq

(iter-seq iterable)(iter-seq iterable iter)

iterate-dataset

(iterate-dataset trainer ds)

metrics

(metrics)

new-accuracy

new-binary-accuracy

new-default-training-config

(new-default-training-config loss)

new-default-training-listeners

new-gradient-collector

new-progress-bar

new-topk-accuracy

new-trainer

new-training-config

notify-listeners

(notify-listeners trainer callback)

opt-initializer

(opt-initializer config initializer parameter)

opt-optimizer

(opt-optimizer config optimizer)

parameter-store

(parameter-store manager copy)

progress-bar

(progress-bar)

set-metrics

(set-metrics trainer metrics)

set-requires-gradient

(set-requires-gradient ndarray requires-grad)

softmax-cross-entropy-loss

(softmax-cross-entropy-loss)

step

(step trainer)

stop-gradient

(stop-gradient ndarray)

topk-accuracy

(topk-accuracy topk)(topk-accuracy index topk)(topk-accuracy name index topk)

train-batch

(train-batch trainer batch)

trainer

(trainer model config)(trainer {:keys [model loss devices data-manager initializer parameter optimizer listeners]})

training-config

training-listeners

(training-listeners)

update-accumulator

(update-accumulator acc key label-list pred-list)

Updates the accuracy with the given key based on a NDList of labels and predictions.

validate-batch

(validate-batch trainer batch)