ggml backend v2

This commit is contained in:
abb128 2023-11-25 09:39:04 +02:00
parent f31db527d6
commit ca9c9d5a9a
14 changed files with 6852 additions and 5935 deletions

View File

@ -550,35 +550,35 @@ static void randomize_lora(struct my_llama_lora * lora, int seed, float mean, fl
struct random_normal_distribution * rnd = init_random_normal_distribution(seed, mean, std, min, max);
randomize_tensor_normal(lora->tok_embeddings_a, rnd);
randomize_tensor_normal(lora->tok_embeddings_b, rnd);
ggml_set_zero(lora->tok_embeddings_b);
randomize_tensor_normal(lora->norm_a, rnd);
randomize_tensor_normal(lora->norm_b, rnd);
ggml_set_zero(lora->norm_b);
randomize_tensor_normal(lora->output_a, rnd);
randomize_tensor_normal(lora->output_b, rnd);
ggml_set_zero(lora->output_b);
for (uint32_t i = 0; i < n_layer; ++i) {
auto & layer = lora->layers[i];
randomize_tensor_normal(layer.attention_norm_a, rnd);
randomize_tensor_normal(layer.attention_norm_b, rnd);
ggml_set_zero(layer.attention_norm_b);
randomize_tensor_normal(layer.wq_a, rnd);
randomize_tensor_normal(layer.wq_b, rnd);
ggml_set_zero(layer.wq_b);
randomize_tensor_normal(layer.wk_a, rnd);
randomize_tensor_normal(layer.wk_b, rnd);
ggml_set_zero(layer.wk_b);
randomize_tensor_normal(layer.wv_a, rnd);
randomize_tensor_normal(layer.wv_b, rnd);
ggml_set_zero(layer.wv_b);
randomize_tensor_normal(layer.wo_a, rnd);
randomize_tensor_normal(layer.wo_b, rnd);
ggml_set_zero(layer.wo_b);
randomize_tensor_normal(layer.ffn_norm_a, rnd);
randomize_tensor_normal(layer.ffn_norm_b, rnd);
ggml_set_zero(layer.ffn_norm_b);
randomize_tensor_normal(layer.w1_a, rnd);
randomize_tensor_normal(layer.w1_b, rnd);
ggml_set_zero(layer.w1_b);
randomize_tensor_normal(layer.w2_a, rnd);
randomize_tensor_normal(layer.w2_b, rnd);
ggml_set_zero(layer.w2_b);
randomize_tensor_normal(layer.w3_a, rnd);
randomize_tensor_normal(layer.w3_b, rnd);
ggml_set_zero(layer.w3_b);
}
free_random_normal_distribution(rnd);
@ -644,8 +644,9 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
const int rope_mode = 0;
return ggml_rope_custom(ctx,
t, KQ_pos, n_rot, rope_mode, n_ctx,
rope_freq_base, rope_freq_scale);
t, KQ_pos, n_rot, rope_mode, n_ctx, 0,
rope_freq_base, rope_freq_scale, 0.0f, 1.0f, 0.0f, 0.0f
);
};
set_name(tokens_input, "tokens_input");
@ -773,7 +774,7 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
if (enable_checkpointing) {
ggml_build_backward_gradient_checkpointing(ctx, gf, gb, gb_tmp, checkpoints.data(), (int) checkpoints.size());
} else {
*gb = *gf;
ggml_graph_cpy(gf, gb);
ggml_build_backward_expand(ctx, gf, gb, true);
}
@ -1308,6 +1309,7 @@ int finetune_train(struct train_params params) {
opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
opt->params.print_forward_graph = false;
opt->params.print_backward_graph = false;
opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
opt->params.n_threads = params.common.n_threads;
opt->params.past = params.common.opt_past;
opt->params.delta = params.common.opt_delta;
@ -1434,11 +1436,9 @@ int finetune_train(struct train_params params) {
ggml_allocr_free(alloc);
// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.common.use_checkpointing ? 3 : 2
)
const size_t estimated_compute_size_wo_data = (
2*LLAMA_TRAIN_MAX_NODES*ggml_tensor_overhead() +
(params.common.use_checkpointing ? 3 : 2)*(GGML_OBJECT_SIZE+ggml_graph_overhead_custom(LLAMA_TRAIN_MAX_NODES, true))
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
@ -1461,11 +1461,11 @@ int finetune_train(struct train_params params) {
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,
@ -1494,11 +1494,11 @@ int finetune_train(struct train_params params) {
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gf->order = best_order;
gb = ggml_new_graph(ctx_compute);
gb = ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
? ggml_new_graph_custom(ctx_compute, LLAMA_TRAIN_MAX_NODES, true)
: NULL;
loss = llama_build_lora_finetune_graphs(
&model, &lora, alloc, ctx_compute,

View File

@ -1,51 +1,21 @@
#include "ggml-alloc.h"
#include "ggml-backend.h"
#include "ggml-backend-impl.h"
#include "ggml.h"
#include "ggml-impl.h"
#include <assert.h>
#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
#define MAX_FREE_BLOCKS 256
//#define GGML_ALLOCATOR_DEBUG
//#define AT_PRINTF printf
#define AT_PRINTF(...) ((void)0)
struct hash_node {
struct ggml_tensor * t;
int n_children;
int n_views;
};
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static struct hash_node * hash_get(struct hash_node hash_table[], struct ggml_tensor * t) {
size_t h = hash(t);
// linear probing
size_t i = h;
while (hash_table[i].t != NULL) {
if (hash_table[i].t == t) {
return &hash_table[i];
}
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
GGML_ASSERT(false);
}
}
hash_table[i].t = t;
return &hash_table[i];
}
//#define AT_PRINTF(...) fprintf(stderr, __VA_ARGS__)
#define AT_PRINTF(...)
// TODO: GGML_PAD ?
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
@ -59,20 +29,18 @@ struct free_block {
size_t size;
};
#define MAX_FREE_BLOCKS 256
struct ggml_allocr {
struct ggml_tallocr {
struct ggml_backend_buffer * buffer;
bool buffer_owned;
void * data;
void * base;
size_t alignment;
int n_free_blocks;
struct free_block free_blocks[MAX_FREE_BLOCKS];
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size;
bool measure;
int parse_seq[GGML_MAX_CONCUR];
int parse_seq_len;
#ifdef GGML_ALLOCATOR_DEBUG
struct ggml_tensor * allocated_tensors[1024];
@ -80,7 +48,7 @@ struct ggml_allocr {
};
#ifdef GGML_ALLOCATOR_DEBUG
static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
static void add_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == NULL) {
alloc->allocated_tensors[i] = tensor;
@ -89,7 +57,7 @@ static void add_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor
}
GGML_ASSERT(!"out of allocated_tensors");
}
static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
static void remove_allocated_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
for (int i = 0; i < 1024; i++) {
if (alloc->allocated_tensors[i] == tensor ||
(alloc->allocated_tensors[i] != NULL && alloc->allocated_tensors[i]->data == tensor->data)) {
@ -103,7 +71,7 @@ static void remove_allocated_tensor(struct ggml_allocr * alloc, struct ggml_tens
#endif
// check if a tensor is allocated by this buffer
static bool ggml_allocr_is_own(struct ggml_allocr * alloc, const struct ggml_tensor * tensor) {
static bool ggml_tallocr_is_own(ggml_tallocr_t alloc, const struct ggml_tensor * tensor) {
return tensor->buffer == alloc->buffer;
}
@ -111,7 +79,7 @@ static bool ggml_is_view(struct ggml_tensor * t) {
return t->view_src != NULL;
}
void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
GGML_ASSERT(!ggml_is_view(tensor)); // views generally get data pointer from one of their sources
GGML_ASSERT(tensor->data == NULL); // avoid allocating tensor which already has memory allocated
@ -162,9 +130,10 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
}
tensor->data = addr;
AT_PRINTF("%s: allocated data at %p\n", __func__, tensor->data);
tensor->buffer = alloc->buffer;
ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
if (!alloc->measure) {
ggml_backend_buffer_init_tensor(alloc->buffer, tensor);
}
#ifdef GGML_ALLOCATOR_DEBUG
add_allocated_tensor(alloc, tensor);
@ -180,16 +149,16 @@ void ggml_allocr_alloc(struct ggml_allocr * alloc, struct ggml_tensor * tensor)
}
#endif
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->data + size);
alloc->max_size = MAX(alloc->max_size, (char*)addr - (char*)alloc->base + size);
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tensor * tensor) {
if (ggml_allocr_is_own(alloc, tensor) == false) {
static void ggml_tallocr_free_tensor(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
if (ggml_tallocr_is_own(alloc, tensor) == false) {
// the tensor was not allocated in this buffer
// this can happen because the graph allocator will try to free weights and other tensors from different buffers
// the easiest way to deal with this is just to ignore it
AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
// AT_PRINTF("ignoring %s (their buffer: %p, our buffer: %p)\n", tensor->name, (void *)tensor->buffer, (void *)alloc->buffer);
return;
}
@ -199,7 +168,9 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
size = aligned_offset(NULL, size, alloc->alignment);
AT_PRINTF("%s: freeing %s at %p (%zu bytes) - n_free_blocks = %d\n", __func__, tensor->name, ptr, size, alloc->n_free_blocks);
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
if (!alloc->measure) {
ggml_backend_buffer_free_tensor(alloc->buffer, tensor);
}
#ifdef GGML_ALLOCATOR_DEBUG
remove_allocated_tensor(alloc, tensor);
@ -253,91 +224,180 @@ static void ggml_allocr_free_tensor(struct ggml_allocr * alloc, struct ggml_tens
alloc->n_free_blocks++;
}
void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n) {
for (int i = 0; i < n; i++) {
alloc->parse_seq[i] = list[i];
}
alloc->parse_seq_len = n;
}
void ggml_allocr_reset(struct ggml_allocr * alloc) {
void ggml_tallocr_reset(ggml_tallocr_t alloc) {
alloc->n_free_blocks = 1;
size_t align_offset = aligned_offset(alloc->data, 0, alloc->alignment);
alloc->free_blocks[0].addr = (char *)alloc->data + align_offset;
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
size_t align_offset = aligned_offset(alloc->base, 0, alloc->alignment);
alloc->free_blocks[0].addr = (char *)alloc->base + align_offset;
if (alloc->measure) {
alloc->free_blocks[0].size = SIZE_MAX/2; // restrict maximum size of a measure allocator to half size_t max to avoid overflows
} else {
alloc->free_blocks[0].size = ggml_backend_buffer_get_size(alloc->buffer) - align_offset;
}
}
struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment) {
ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment) {
struct ggml_backend_buffer * buffer = ggml_backend_cpu_buffer_from_ptr(NULL, data, size);
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
*alloc = (struct ggml_allocr){
/*.buffer = */ buffer,
/*.buffer_owned = */ true,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
*alloc = (struct ggml_tallocr) {
/*.buffer = */ buffer,
/*.buffer_owned = */ true,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ alignment,
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
/*.allocated_tensors = */ {0},
#endif
};
ggml_allocr_reset(alloc);
ggml_tallocr_reset(alloc);
return alloc;
}
struct ggml_allocr * ggml_allocr_new_measure(size_t alignment) {
struct ggml_allocr * alloc = ggml_allocr_new((void *)0x1000, (size_t)-0x1001, alignment);
ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment) {
ggml_tallocr_t alloc = ggml_tallocr_new((void *)0x1000, SIZE_MAX/2, alignment);
alloc->measure = true;
return alloc;
}
struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
struct ggml_allocr * alloc = (struct ggml_allocr *)malloc(sizeof(struct ggml_allocr));
ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend) {
// create a backend buffer to get the correct tensor allocation sizes
ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, 1);
*alloc = (struct ggml_allocr){
/*.buffer = */ buffer,
/*.buffer_owned = */ false,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ ggml_backend_buffer_get_alignment(buffer),
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.hash_table = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
/*.parse_seq = */ {0},
/*.parse_seq_len = */ 0,
// TODO: move alloc initialization to a common ggml_tallocr_new_impl function
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
alloc->buffer_owned = true;
alloc->measure = true;
ggml_tallocr_reset(alloc);
return alloc;
}
ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size) {
ggml_backend_buffer_t buffer = ggml_backend_alloc_buffer(backend, size);
ggml_tallocr_t alloc = ggml_tallocr_new_from_buffer(buffer);
alloc->buffer_owned = true;
return alloc;
}
ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
ggml_tallocr_t alloc = (ggml_tallocr_t)malloc(sizeof(struct ggml_tallocr));
*alloc = (struct ggml_tallocr) {
/*.buffer = */ buffer,
/*.buffer_owned = */ false,
/*.base = */ ggml_backend_buffer_get_base(buffer),
/*.alignment = */ ggml_backend_buffer_get_alignment(buffer),
/*.n_free_blocks = */ 0,
/*.free_blocks = */ {{0}},
/*.max_size = */ 0,
/*.measure = */ false,
#ifdef GGML_ALLOCATOR_DEBUG
/*.allocated_tensors = */ {0},
/*.allocated_tensors = */ {0},
#endif
};
ggml_allocr_reset(alloc);
ggml_tallocr_reset(alloc);
return alloc;
}
void ggml_allocr_free(struct ggml_allocr * alloc) {
struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t alloc) {
return alloc->buffer;
}
void ggml_tallocr_free(ggml_tallocr_t alloc) {
if (alloc == NULL) {
return;
}
if (alloc->buffer_owned) {
ggml_backend_buffer_free(alloc->buffer);
}
free(alloc);
}
bool ggml_allocr_is_measure(struct ggml_allocr * alloc) {
bool ggml_tallocr_is_measure(ggml_tallocr_t alloc) {
return alloc->measure;
}
//////////// compute graph allocator
size_t ggml_tallocr_max_size(ggml_tallocr_t alloc) {
return alloc->max_size;
}
// graph allocator
struct hash_node {
int n_children;
int n_views;
};
struct ggml_gallocr {
ggml_tallocr_t talloc;
struct ggml_hash_set hash_set;
struct hash_node * hash_values;
size_t hash_values_size;
ggml_tallocr_t * hash_allocs;
int * parse_seq;
int parse_seq_len;
};
ggml_gallocr_t ggml_gallocr_new(void) {
ggml_gallocr_t galloc = (ggml_gallocr_t)malloc(sizeof(struct ggml_gallocr));
*galloc = (struct ggml_gallocr) {
/*.talloc = */ NULL,
/*.hash_set = */ {0},
/*.hash_values = */ NULL,
/*.hash_values_size = */ 0,
/*.hash_allocs = */ NULL,
/*.parse_seq = */ NULL,
/*.parse_seq_len = */ 0,
};
return galloc;
}
void ggml_gallocr_free(ggml_gallocr_t galloc) {
if (galloc == NULL) {
return;
}
if (galloc->hash_set.keys != NULL) {
free(galloc->hash_set.keys);
}
if (galloc->hash_values != NULL) {
free(galloc->hash_values);
}
if (galloc->hash_allocs != NULL) {
free(galloc->hash_allocs);
}
if (galloc->parse_seq != NULL) {
free(galloc->parse_seq);
}
free(galloc);
}
void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n) {
free(galloc->parse_seq);
galloc->parse_seq = malloc(sizeof(int) * n);
for (int i = 0; i < n; i++) {
galloc->parse_seq[i] = list[i];
}
galloc->parse_seq_len = n;
}
static struct hash_node * hash_get(ggml_gallocr_t galloc, struct ggml_tensor * t) {
size_t i = ggml_hash_find_or_insert(galloc->hash_set, t);
return &galloc->hash_values[i];
}
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
@ -378,23 +438,40 @@ static bool ggml_op_can_inplace(enum ggml_op op) {
}
}
static void init_view(struct ggml_allocr * alloc, struct ggml_tensor * view) {
assert(view->view_src != NULL && view->view_src->data != NULL);
view->backend = view->view_src->backend;
static ggml_tallocr_t node_tallocr(ggml_gallocr_t galloc, struct ggml_tensor * node) {
if (galloc->talloc != NULL) {
return galloc->talloc;
}
return galloc->hash_allocs[ggml_hash_find_or_insert(galloc->hash_set, node)];
}
static void init_view(ggml_gallocr_t galloc, struct ggml_tensor * view, bool update_backend) {
ggml_tallocr_t alloc = node_tallocr(galloc, view);
//printf("init_view: %s from src %s\n", view->name, view->view_src->name);
GGML_ASSERT(view->view_src != NULL && view->view_src->data != NULL);
if (update_backend) {
view->backend = view->view_src->backend;
}
view->buffer = view->view_src->buffer;
view->data = (char *)view->view_src->data + view->view_offs;
// FIXME: the view should be initialized by the owning buffer, but currently this breaks the CUDA backend
// due to the ggml_tensor_extra_gpu ring buffer overwriting the KV cache extras
assert(ggml_allocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
ggml_backend_buffer_init_tensor(alloc->buffer, view);
assert(ggml_tallocr_is_measure(alloc) || !view->buffer || view->buffer->backend == alloc->buffer->backend);
if (!alloc->measure) {
ggml_backend_buffer_init_tensor(alloc->buffer, view);
}
}
static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node) {
struct hash_node * ht = alloc->hash_table;
static void allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
ggml_tallocr_t alloc = node_tallocr(galloc, node);
if (node->data == NULL) {
if (ggml_is_view(node)) {
init_view(alloc, node);
init_view(galloc, node, true);
} else {
// see if we can reuse a parent's buffer (inplace)
if (ggml_op_can_inplace(node->op)) {
@ -405,16 +482,16 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
}
// if the node's data is external, then we cannot re-use it
if (ggml_allocr_is_own(alloc, parent) == false) {
if (ggml_tallocr_is_own(alloc, parent) == false) {
AT_PRINTF("not reusing parent %s for %s as %p is external\n", parent->name, node->name, parent->data);
continue;
}
struct hash_node * p_hn = hash_get(ht, parent);
struct hash_node * p_hn = hash_get(galloc, parent);
if (parent->data != NULL && p_hn->n_children == 1 && p_hn->n_views == 0 && ggml_are_same_layout(node, parent)) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
struct hash_node * view_src_hn = hash_get(galloc, view_src);
if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) {
// TODO: the offset of the view parent must be kept to ensure that the op doesn't overwrite
// the parent's data that it will need later (same layout requirement). the problem is that then
@ -424,171 +501,267 @@ static void allocate_node(struct ggml_allocr * alloc, struct ggml_tensor * node)
AT_PRINTF("reusing view parent %s (%s) for %s\n", parent->name, view_src->name, node->name);
node->view_src = view_src;
view_src_hn->n_views += 1;
init_view(alloc, node);
init_view(galloc, node, false);
return;
}
}
else {
} else {
AT_PRINTF("reusing parent %s for %s\n", parent->name, node->name);
node->view_src = parent;
p_hn->n_views += 1;
init_view(alloc, node);
init_view(galloc, node, false);
return;
}
}
}
}
ggml_allocr_alloc(alloc, node);
ggml_tallocr_alloc(alloc, node);
}
}
}
size_t ggml_allocr_alloc_graph_n(
struct ggml_allocr * alloc,
struct ggml_cgraph ** graphs, int n_graphs,
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs) {
static void free_node(ggml_gallocr_t galloc, struct ggml_tensor * node) {
ggml_tallocr_t alloc = node_tallocr(galloc, node);
// reset hash table
struct hash_node * ht = alloc->hash_table;
memset(ht, 0, sizeof(struct hash_node) * GGML_GRAPH_HASHTABLE_SIZE);
ggml_tallocr_free_tensor(alloc, node);
}
static void ggml_tallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgraph * gf) {
const int * parse_seq = galloc->parse_seq;
int parse_seq_len = galloc->parse_seq_len;
// count number of children and views
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = node->view_src;
hash_get(ht, view_src)->n_views += 1;
if (node->buffer == NULL && node->data != NULL) {
// view of a pre-allocated tensor, didn't call init_view() yet
init_view(alloc, node);
}
if (ggml_is_view(node)) {
struct ggml_tensor * view_src = node->view_src;
hash_get(galloc, view_src)->n_views += 1;
if (node->buffer == NULL && node->data != NULL) {
// view of a pre-allocated tensor, didn't call init_view() yet
init_view(galloc, node, true);
}
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
hash_get(ht, parent)->n_children += 1;
if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
init_view(alloc, parent);
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
hash_get(galloc, parent)->n_children += 1;
if (ggml_is_view(parent) && parent->buffer == NULL && parent->data != NULL) {
init_view(galloc, parent, true);
}
}
}
// allocate tensors
for (int g = 0; g < n_graphs; g++) {
struct ggml_cgraph * gf = graphs[g];
AT_PRINTF("####### graph %d/%d\n", g, n_graphs);
// graph inputs are allocated first to ensure that they are not overwritten by each other
if (inputs != NULL && inputs[g] != NULL) {
for (int i = 0; inputs[g][i] != NULL; i++) {
struct ggml_tensor * input = inputs[g][i];
AT_PRINTF("input: %s\n", input->name);
allocate_node(alloc, input);
// if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
int last_barrier_pos = 0;
int n_nodes = parse_seq_len ? parse_seq_len : gf->n_nodes;
for (int ind = 0; ind < n_nodes; ind++) {
// allocate a node if there is no parse_seq or this is not a barrier
if (parse_seq_len == 0 || parse_seq[ind] != -1) {
int i = parse_seq_len ? parse_seq[ind] : ind;
struct ggml_tensor * node = gf->nodes[i];
// allocate parents (leafs)
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
allocate_node(galloc, parent);
}
// allocate node
allocate_node(galloc, node);
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
AT_PRINTF("%s", parent->name);
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
AT_PRINTF(", ");
}
}
AT_PRINTF("\n");
}
// if we have parse_seq then we allocate nodes following the list, and we only free nodes at barriers
int last_barrier_pos = 0;
int n_nodes = alloc->parse_seq_len ? alloc->parse_seq_len : gf->n_nodes;
for (int ind = 0; ind < n_nodes; ind++) {
// allocate a node if there is no parse_seq or this is not a barrier
if ((alloc->parse_seq_len==0) || alloc->parse_seq[ind] != -1) {
int i = alloc->parse_seq_len ? alloc->parse_seq[ind] : ind;
struct ggml_tensor * node = gf->nodes[i];
// update parents
// update immediately if there is no parse_seq
// update only at barriers if there is parse_seq
if ((parse_seq_len == 0) || parse_seq[ind] == -1) {
int update_start = parse_seq_len ? last_barrier_pos : ind;
int update_end = parse_seq_len ? ind : ind + 1;
for (int i = update_start; i < update_end; i++) {
int node_i = parse_seq_len ? parse_seq[i] : i;
struct ggml_tensor * node = gf->nodes[node_i];
// allocate parents (leafs)
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
allocate_node(alloc, parent);
}
struct hash_node * p_hn = hash_get(galloc, parent);
p_hn->n_children -= 1;
// allocate node
allocate_node(alloc, node);
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
AT_PRINTF("exec: %s (%s) <= ", ggml_op_name(node->op), node->name);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
AT_PRINTF("%s", parent->name);
if (j < GGML_MAX_SRC - 1 && node->src[j + 1] != NULL) {
AT_PRINTF(", ");
}
}
AT_PRINTF("\n");
}
// update parents
// update immediately if there is no parse_seq
// update only at barriers if there is parse_seq
if ((alloc->parse_seq_len == 0) || alloc->parse_seq[ind] == -1) {
int update_start = alloc->parse_seq_len ? last_barrier_pos : ind;
int update_end = alloc->parse_seq_len ? ind : ind + 1;
for (int i = update_start; i < update_end; i++) {
int node_i = alloc->parse_seq_len ? alloc->parse_seq[i] : i;
struct ggml_tensor * node = gf->nodes[node_i];
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(galloc, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0) {
free_node(galloc, view_src);
}
}
struct hash_node * p_hn = hash_get(ht, parent);
p_hn->n_children -= 1;
//AT_PRINTF("parent %s: %d children, %d views\n", parent->name, parent->n_children, parent->n_views);
if (p_hn->n_children == 0 && p_hn->n_views == 0) {
if (ggml_is_view(parent)) {
struct ggml_tensor * view_src = parent->view_src;
struct hash_node * view_src_hn = hash_get(ht, view_src);
view_src_hn->n_views -= 1;
AT_PRINTF("view_src %s: %d children, %d views\n", view_src->name, view_src_hn->n_children, view_src_hn->n_views);
if (view_src_hn->n_views == 0 && view_src_hn->n_children == 0 && view_src->data != node->data) {
ggml_allocr_free_tensor(alloc, view_src);
}
}
else {
if (parent->data != node->data) {
ggml_allocr_free_tensor(alloc, parent);
}
}
else {
free_node(galloc, parent);
}
}
}
AT_PRINTF("\n");
if (alloc->parse_seq_len) {
last_barrier_pos = ind + 1;
}
}
}
// free graph outputs here that wouldn't be freed otherwise because they have no children
if (outputs != NULL && outputs[g] != NULL) {
for (int i = 0; outputs[g][i] != NULL; i++) {
struct ggml_tensor * output = outputs[g][i];
AT_PRINTF("output: %s\n", output->name);
ggml_allocr_free_tensor(alloc, output);
AT_PRINTF("\n");
if (parse_seq_len) {
last_barrier_pos = ind + 1;
}
}
}
return alloc->max_size;
}
size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph) {
return ggml_allocr_alloc_graph_n(alloc, &graph, 1, NULL, NULL);
size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph) {
size_t hash_size = graph->visited_hash_table.size;
// check if the hash table is initialized and large enough
if (galloc->hash_set.size < hash_size) {
if (galloc->hash_set.keys != NULL) {
free(galloc->hash_set.keys);
}
if (galloc->hash_values != NULL) {
free(galloc->hash_values);
}
galloc->hash_set.keys = malloc(sizeof(struct ggml_tensor *) * hash_size);
galloc->hash_set.size = hash_size;
galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
}
// reset hash table
memset(galloc->hash_set.keys, 0, sizeof(struct ggml_tensor *) * hash_size);
memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
galloc->talloc = talloc;
ggml_tallocr_alloc_graph_impl(galloc, graph);
galloc->talloc = NULL;
size_t max_size = ggml_tallocr_max_size(talloc);
return max_size;
}
size_t ggml_allocr_max_size(struct ggml_allocr * alloc) {
return alloc->max_size;
void ggml_gallocr_alloc_graph_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, struct ggml_hash_set hash_set, ggml_tallocr_t * hash_node_talloc) {
const size_t hash_size = hash_set.size;
GGML_ASSERT(hash_size >= (size_t)(graph->n_nodes + graph->n_leafs));
galloc->talloc = NULL;
// alloc hash_values if needed
if (galloc->hash_values == NULL || galloc->hash_values_size < hash_size) {
free(galloc->hash_values);
galloc->hash_values = malloc(sizeof(struct hash_node) * hash_size);
galloc->hash_values_size = hash_size;
}
// free hash_set.keys if needed
if (galloc->hash_set.keys != NULL) {
free(galloc->hash_set.keys);
}
galloc->hash_set = hash_set;
// reset hash values
memset(galloc->hash_values, 0, sizeof(struct hash_node) * hash_size);
galloc->hash_allocs = hash_node_talloc;
ggml_tallocr_alloc_graph_impl(galloc, graph);
// remove unowned resources
galloc->hash_set.keys = NULL;
galloc->hash_allocs = NULL;
}
// legacy API wrapper
struct ggml_allocr {
ggml_tallocr_t talloc;
ggml_gallocr_t galloc;
};
static ggml_allocr_t ggml_allocr_new_impl(ggml_tallocr_t talloc) {
ggml_allocr_t alloc = (ggml_allocr_t)malloc(sizeof(struct ggml_allocr));
*alloc = (struct ggml_allocr) {
/*.talloc = */ talloc,
/*.galloc = */ ggml_gallocr_new(),
};
return alloc;
}
ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment) {
return ggml_allocr_new_impl(ggml_tallocr_new(data, size, alignment));
}
ggml_allocr_t ggml_allocr_new_measure(size_t alignment) {
return ggml_allocr_new_impl(ggml_tallocr_new_measure(alignment));
}
ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer) {
return ggml_allocr_new_impl(ggml_tallocr_new_from_buffer(buffer));
}
ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size) {
return ggml_allocr_new_impl(ggml_tallocr_new_from_backend(backend, size));
}
ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend) {
return ggml_allocr_new_impl(ggml_tallocr_new_measure_from_backend(backend));
}
struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc) {
return ggml_tallocr_get_buffer(alloc->talloc);
}
void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n) {
ggml_gallocr_set_parse_seq(alloc->galloc, list, n);
}
void ggml_allocr_free(ggml_allocr_t alloc) {
ggml_gallocr_free(alloc->galloc);
ggml_tallocr_free(alloc->talloc);
free(alloc);
}
bool ggml_allocr_is_measure(ggml_allocr_t alloc) {
return ggml_tallocr_is_measure(alloc->talloc);
}
void ggml_allocr_reset(ggml_allocr_t alloc) {
ggml_tallocr_reset(alloc->talloc);
}
void ggml_allocr_alloc(ggml_allocr_t alloc, struct ggml_tensor * tensor) {
ggml_tallocr_alloc(alloc->talloc, tensor);
}
size_t ggml_allocr_max_size(ggml_allocr_t alloc) {
return ggml_tallocr_max_size(alloc->talloc);
}
size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph) {
return ggml_gallocr_alloc_graph(alloc->galloc, alloc->talloc, graph);
}

View File

@ -6,27 +6,79 @@
extern "C" {
#endif
struct ggml_backend;
struct ggml_backend_buffer;
GGML_API struct ggml_allocr * ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_measure(size_t alignment);
GGML_API struct ggml_allocr * ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
//
// Legacy API
//
typedef struct ggml_allocr * ggml_allocr_t;
// initialize allocator for use with CPU backend only
GGML_API ggml_allocr_t ggml_allocr_new(void * data, size_t size, size_t alignment);
GGML_API ggml_allocr_t ggml_allocr_new_measure(size_t alignment);
// initialize allocator for use with ggml-backend
GGML_API ggml_allocr_t ggml_allocr_new_from_buffer(struct ggml_backend_buffer * buffer);
GGML_API ggml_allocr_t ggml_allocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
GGML_API ggml_allocr_t ggml_allocr_new_measure_from_backend(struct ggml_backend * backend);
GGML_API struct ggml_backend_buffer * ggml_allocr_get_buffer(ggml_allocr_t alloc);
// tell the allocator to parse nodes following the order described in the list
// you should call this if your graph are optimized to execute out-of-order
GGML_API void ggml_allocr_set_parse_seq(struct ggml_allocr * alloc, const int * list, int n);
GGML_API void ggml_allocr_set_parse_seq(ggml_allocr_t alloc, const int * list, int n);
GGML_API void ggml_allocr_free (struct ggml_allocr * alloc);
GGML_API bool ggml_allocr_is_measure (struct ggml_allocr * alloc);
GGML_API void ggml_allocr_reset (struct ggml_allocr * alloc);
GGML_API void ggml_allocr_alloc (struct ggml_allocr * alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocr_alloc_graph(struct ggml_allocr * alloc, struct ggml_cgraph * graph);
GGML_API size_t ggml_allocr_max_size (struct ggml_allocr * alloc);
GGML_API void ggml_allocr_free (ggml_allocr_t alloc);
GGML_API bool ggml_allocr_is_measure (ggml_allocr_t alloc);
GGML_API void ggml_allocr_reset (ggml_allocr_t alloc);
GGML_API void ggml_allocr_alloc (ggml_allocr_t alloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_allocr_max_size (ggml_allocr_t alloc);
GGML_API size_t ggml_allocr_alloc_graph_n(
struct ggml_allocr * alloc,
struct ggml_cgraph ** graphs, int n_graphs,
struct ggml_tensor *** inputs, struct ggml_tensor *** outputs);
GGML_API size_t ggml_allocr_alloc_graph(ggml_allocr_t alloc, struct ggml_cgraph * graph);
//
// ggml-backend v2 API
//
// Seperate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API
// Tensor allocator
typedef struct ggml_tallocr * ggml_tallocr_t;
GGML_API ggml_tallocr_t ggml_tallocr_new(void * data, size_t size, size_t alignment);
GGML_API ggml_tallocr_t ggml_tallocr_new_measure(size_t alignment);
GGML_API ggml_tallocr_t ggml_tallocr_new_from_buffer(struct ggml_backend_buffer * buffer);
GGML_API ggml_tallocr_t ggml_tallocr_new_from_backend(struct ggml_backend * backend, size_t size); // allocates an owned buffer
GGML_API ggml_tallocr_t ggml_tallocr_new_measure_from_backend(struct ggml_backend * backend);
GGML_API struct ggml_backend_buffer * ggml_tallocr_get_buffer(ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_free (ggml_tallocr_t talloc);
GGML_API bool ggml_tallocr_is_measure (ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_reset (ggml_tallocr_t talloc);
GGML_API void ggml_tallocr_alloc (ggml_tallocr_t talloc, struct ggml_tensor * tensor);
GGML_API size_t ggml_tallocr_max_size (ggml_tallocr_t talloc);
// Graph allocator
typedef struct ggml_gallocr * ggml_gallocr_t;
GGML_API ggml_gallocr_t ggml_gallocr_new(void);
GGML_API void ggml_gallocr_free(ggml_gallocr_t galloc);
GGML_API void ggml_gallocr_set_parse_seq(ggml_gallocr_t galloc, const int * list, int n);
GGML_API size_t ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, ggml_tallocr_t talloc, struct ggml_cgraph * graph);
// Allocate tensors from the allocators given by the hash table
GGML_API void ggml_gallocr_alloc_graph_n(
ggml_gallocr_t galloc,
struct ggml_cgraph * graph,
struct ggml_hash_set hash_set,
ggml_tallocr_t * hash_node_talloc);
#ifdef __cplusplus
}

View File

@ -0,0 +1,87 @@
#pragma once
// ggml-backend internal header
#include "ggml-backend.h"
#ifdef __cplusplus
extern "C" {
#endif
//
// Backend buffer
//
typedef void * ggml_backend_buffer_context_t;
struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
};
struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface;
ggml_backend_t backend;
ggml_backend_buffer_context_t context;
size_t size;
};
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend,
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size);
//
// Backend
//
typedef void * ggml_backend_context_t;
struct ggml_backend_i {
const char * (*get_name)(ggml_backend_t backend);
void (*free)(ggml_backend_t backend);
// buffer allocation
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
// get buffer alignment
size_t (*get_alignment)(ggml_backend_t backend);
// tensor data access
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*synchronize) (ggml_backend_t backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
// compute graph without a plan
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
// check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
};
struct ggml_backend {
struct ggml_backend_i iface;
ggml_backend_context_t context;
};
#ifdef __cplusplus
}
#endif

View File

@ -1,7 +1,9 @@
#include "ggml-backend.h"
#include "ggml-backend-impl.h"
#include "ggml-alloc.h"
#include "ggml-impl.h"
#include <assert.h>
#include <limits.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
@ -16,23 +18,27 @@
ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend,
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size) {
ggml_backend_buffer_context_t context,
size_t size) {
ggml_backend_buffer_t buffer = malloc(sizeof(struct ggml_backend_buffer));
GGML_ASSERT(iface.get_base != NULL);
(*buffer) = (struct ggml_backend_buffer) {
/* .interface = */ iface,
/* .backend = */ backend,
/* .context = */ context,
/* .size = */ size,
/* .interface = */ iface,
/* .backend = */ backend,
/* .context = */ context,
/* .size = */ size,
};
return buffer;
}
void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
if (buffer == NULL) {
return;
}
if (buffer->iface.free_buffer != NULL) {
buffer->iface.free_buffer(buffer);
}
@ -43,15 +49,20 @@ size_t ggml_backend_buffer_get_alignment(ggml_backend_buffer_t buffer) {
return ggml_backend_get_alignment(buffer->backend);
}
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
return buffer->iface.get_base(buffer);
}
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
return buffer->size;
}
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
void * base = buffer->iface.get_base(buffer);
GGML_ASSERT(base != NULL && "backend buffer base cannot be NULL");
return base;
}
size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
// get_alloc_size is optional, defaults to ggml_nbytes
if (buffer->iface.get_alloc_size) {
return buffer->iface.get_alloc_size(buffer, tensor);
}
@ -59,12 +70,14 @@ size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct g
}
void ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
// init_tensor is optional
if (buffer->iface.init_tensor) {
buffer->iface.init_tensor(buffer, tensor);
}
}
void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
// free_tensor is optional
if (buffer->iface.free_tensor) {
buffer->iface.free_tensor(buffer, tensor);
}
@ -73,14 +86,21 @@ void ggml_backend_buffer_free_tensor(ggml_backend_buffer_t buffer, struct ggml_t
// backend
ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor) {
return tensor->buffer->backend;
return tensor->buffer ? tensor->buffer->backend : NULL;
}
const char * ggml_backend_name(ggml_backend_t backend) {
if (backend == NULL) {
return "NULL";
}
return backend->iface.get_name(backend);
}
void ggml_backend_free(ggml_backend_t backend) {
if (backend == NULL) {
return;
}
backend->iface.free(backend);
}
@ -101,13 +121,23 @@ void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * dat
}
void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
ggml_get_backend(tensor)->iface.set_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
ggml_backend_t backend = ggml_get_backend(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(backend != NULL && "tensor backend not set");
backend->iface.set_tensor_async(backend, tensor, data, offset, size);
backend->iface.synchronize(backend);
}
void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
ggml_get_backend(tensor)->iface.get_tensor_async(ggml_get_backend(tensor), tensor, data, offset, size);
ggml_get_backend(tensor)->iface.synchronize(ggml_get_backend(tensor));
ggml_backend_t backend = ggml_get_backend(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(backend != NULL && "tensor backend not set");
backend->iface.get_tensor_async(backend, tensor, data, offset, size);
backend->iface.synchronize(backend);
}
void ggml_backend_synchronize(ggml_backend_t backend) {
@ -156,7 +186,7 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
//printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
// printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
// fprintf(stderr, "cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
if (src == dst) {
return;
@ -170,9 +200,9 @@ void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst
ggml_get_backend(src)->iface.cpy_tensor_to(ggml_get_backend(src)->context, src, dst);
} else {
// shouldn't be hit when copying from/to CPU
#ifndef NDEBUG
#ifndef NDEBUG
fprintf(stderr, "ggml_backend_tensor_copy: neither cpy_tensor_from nor cpy_tensor_to are implemented for backends %s and %s, falling back to get/set\n", ggml_backend_name(src->buffer->backend), ggml_backend_name(dst->buffer->backend));
#endif
#endif
size_t nbytes = ggml_nbytes(src);
void * data = malloc(nbytes);
ggml_backend_tensor_get(src, data, 0, nbytes);
@ -192,7 +222,7 @@ struct ggml_backend_cpu_context {
static const char * ggml_backend_cpu_name(ggml_backend_t backend) {
return "CPU";
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_free(ggml_backend_t backend) {
@ -208,24 +238,24 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
free(buffer->context);
UNUSED(buffer);
UNUSED(buffer);
}
static struct ggml_backend_buffer_i cpu_backend_buffer_i = {
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .init_tensor = */ NULL, // no initialization required
/* .free_tensor = */ NULL, // no cleanup required
/* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer,
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .init_tensor = */ NULL, // no initialization required
/* .free_tensor = */ NULL, // no cleanup required
};
// for buffers from ptr, free is not called
static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = {
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .init_tensor = */ NULL,
/* .free_tensor = */ NULL,
/* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed
/* .get_base = */ ggml_backend_cpu_buffer_get_base,
/* .get_alloc_size = */ NULL, // defaults to ggml_nbytes
/* .init_tensor = */ NULL,
/* .free_tensor = */ NULL,
};
static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
@ -234,12 +264,14 @@ static ggml_backend_buffer_t ggml_backend_cpu_alloc_buffer(ggml_backend_t backen
size += TENSOR_ALIGNMENT; // malloc may return an address that is not aligned
void * data = malloc(size); // TODO: maybe use GGML_ALIGNED_MALLOC?
GGML_ASSERT(data != NULL && "failed to allocate buffer");
return ggml_backend_buffer_init(backend, cpu_backend_buffer_i, data, size);
}
static size_t ggml_backend_cpu_get_alignment(ggml_backend_t backend) {
return TENSOR_ALIGNMENT;
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
@ -248,7 +280,7 @@ static void ggml_backend_cpu_set_tensor_async(ggml_backend_t backend, struct ggm
memcpy((char *)tensor->data + offset, data, size);
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
@ -257,24 +289,23 @@ static void ggml_backend_cpu_get_tensor_async(ggml_backend_t backend, const stru
memcpy(data, (const char *)tensor->data + offset, size);
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_synchronize(ggml_backend_t backend) {
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_cpy_tensor_from(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_cpy_tensor_to(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
// for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends
ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
ggml_backend_tensor_set(dst, src->data, 0, ggml_nbytes(src));
UNUSED(backend);
UNUSED(backend);
}
struct ggml_backend_plan_cpu {
@ -303,7 +334,7 @@ static void ggml_backend_cpu_graph_plan_free(ggml_backend_t backend, ggml_backen
free(cpu_plan->cplan.work_data);
free(cpu_plan);
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
@ -311,7 +342,7 @@ static void ggml_backend_cpu_graph_plan_compute(ggml_backend_t backend, ggml_bac
ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
UNUSED(backend);
UNUSED(backend);
}
static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
@ -332,25 +363,25 @@ static void ggml_backend_cpu_graph_compute(ggml_backend_t backend, struct ggml_c
static bool ggml_backend_cpu_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
return true;
UNUSED(backend);
UNUSED(op);
UNUSED(backend);
UNUSED(op);
}
static struct ggml_backend_i cpu_backend_i = {
/* .get_name = */ ggml_backend_cpu_name,
/* .free = */ ggml_backend_cpu_free,
/* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_get_alignment,
/* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
/* .synchronize = */ ggml_backend_cpu_synchronize,
/* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
/* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .supports_op = */ ggml_backend_cpu_supports_op,
/* .get_name = */ ggml_backend_cpu_name,
/* .free = */ ggml_backend_cpu_free,
/* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
/* .get_alignment = */ ggml_backend_cpu_get_alignment,
/* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
/* .synchronize = */ ggml_backend_cpu_synchronize,
/* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
/* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .supports_op = */ ggml_backend_cpu_supports_op,
};
ggml_backend_t ggml_backend_cpu_init(void) {
@ -363,8 +394,8 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ggml_backend_t cpu_backend = malloc(sizeof(struct ggml_backend));
*cpu_backend = (struct ggml_backend) {
/* .interface = */ cpu_backend_i,
/* .context = */ ctx
/* .interface = */ cpu_backend_i,
/* .context = */ ctx
};
return cpu_backend;
}
@ -383,3 +414,537 @@ void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads) {
ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size) {
return ggml_backend_buffer_init(backend_cpu, cpu_backend_buffer_i_from_ptr, ptr, size);
}
// scheduler
#define GGML_MAX_BACKENDS 4
#define GGML_MAX_SPLITS 256
#define GGML_MAX_SPLIT_INPUTS 16
struct ggml_backend_sched_split {
ggml_tallocr_t tallocr;
int i_start;
int i_end;
struct ggml_tensor * inputs[GGML_MAX_SPLIT_INPUTS];
int n_inputs;
struct ggml_cgraph * graph;
};
struct ggml_backend_sched {
int n_backends;
ggml_backend_t backends[GGML_MAX_BACKENDS];
ggml_tallocr_t tallocs[GGML_MAX_BACKENDS];
ggml_gallocr_t galloc;
struct ggml_hash_set hash_set;
ggml_tallocr_t * node_talloc; // [hash_set.size]
struct ggml_tensor * (* node_copies)[GGML_MAX_BACKENDS]; // [hash_set.size][GGML_MAX_BACKENDS]
struct ggml_cgraph * graph;
struct ggml_backend_sched_split splits[GGML_MAX_SPLITS];
int n_splits;
struct ggml_context * ctx;
// align context_buffer to GGML_MEM_ALIGN
#ifdef _MSC_VER
__declspec(align(GGML_MEM_ALIGN))
#else
__attribute__((aligned(GGML_MEM_ALIGN)))
#endif
char context_buffer[GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS*sizeof(struct ggml_tensor) + GGML_MAX_SPLITS*sizeof(struct ggml_cgraph)];
};
#define hash_id(node) ggml_hash_find_or_insert(sched->hash_set, node)
#define node_allocr(node) sched->node_talloc[hash_id(node)]
static bool ggml_is_view_op(enum ggml_op op) {
return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE;
}
// returns the priority of the backend, lower is better
static int sched_backend_prio(ggml_backend_sched_t sched, ggml_backend_t backend) {
for (int i = 0; i < sched->n_backends; i++) {
if (sched->backends[i] == backend) {
return i;
}
}
return INT_MAX;
}
static int sched_allocr_prio(ggml_backend_sched_t sched, ggml_tallocr_t allocr) {
for (int i = 0; i < sched->n_backends; i++) {
if (sched->tallocs[i] == allocr) {
return i;
}
}
return INT_MAX;
}
// returns the backend that should be used for the node based on the current locations
char causes[GGML_DEFAULT_GRAPH_SIZE*4 + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS][128]; // debug, remove
static ggml_backend_t sched_backend_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * node) {
// if the dst tensor is already allocated in a buffer, we must assume that it is critical to keep it there
// ie. kv cache updates
// note that this doesn't allow fallback to CPU. need to add output tensors to the splits to copy the data back to the original backend.
// dst
ggml_backend_t cur_backend = ggml_get_backend(node);
if (cur_backend != NULL) {
sprintf(causes[hash_id(node)], "1.dst");
return cur_backend;
}
// view_src
if (node->view_src != NULL && ggml_get_backend(node->view_src) != NULL) {
sprintf(causes[hash_id(node)], "1.vsrc");
return ggml_get_backend(node->view_src);
}
// src
int cur_prio = INT_MAX;
size_t cur_size = 0;
for (int i = 0; i < GGML_MAX_SRC; i++) {
const struct ggml_tensor * src = node->src[i];
if (src == NULL) {
break;
}
ggml_backend_t src_backend = ggml_get_backend(src);
if (src_backend != NULL) {
int src_prio = sched_backend_prio(sched, src_backend);
size_t src_size = ggml_nbytes(src);
if (src_prio < cur_prio && src_size >= cur_size) {
cur_prio = src_prio;
cur_size = src_size;
cur_backend = src_backend;
sprintf(causes[hash_id(node)], "1.src%d", i);
}
}
}
return cur_backend;
}
static char * fmt_size(size_t size) {
static char buffer[128];
if (size >= 1024*1024) {
sprintf(buffer, "%zuM", size/1024/1024);
} else {
sprintf(buffer, "%zuK", size/1024);
}
return buffer;
}
static void sched_print_assignments(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
int cur_split = 0;
for (int i = 0; i < graph->n_nodes; i++) {
if (cur_split < sched->n_splits && i == sched->splits[cur_split].i_start) {
ggml_backend_t split_backend = ggml_tallocr_get_buffer(sched->splits[cur_split].tallocr)->backend;
fprintf(stderr, "\n## SPLIT #%d: %s # %d inputs: ", cur_split, ggml_backend_name(split_backend), sched->splits[cur_split].n_inputs);
for (int j = 0; j < sched->splits[cur_split].n_inputs; j++) {
fprintf(stderr, "[%s (%5.5s)] ", sched->splits[cur_split].inputs[j]->name, fmt_size(ggml_nbytes(sched->splits[cur_split].inputs[j])));
}
fprintf(stderr, "\n");
cur_split++;
}
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
ggml_tallocr_t node_allocr = node_allocr(node);
ggml_backend_t node_backend = node_allocr ? ggml_tallocr_get_buffer(node_allocr)->backend : NULL;
fprintf(stderr, "node #%3d (%10.10s): %20.20s (%4.4s) [%4.4s %8.8s]:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), node_allocr ? ggml_backend_name(node_backend) : "NULL", causes[hash_id(node)]);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
ggml_backend_t src_backend = src_allocr ? ggml_tallocr_get_buffer(src_allocr)->backend : NULL;
fprintf(stderr, " %20.20s (%4.4s) [%4.4s %8.8s]", src->name, fmt_size(ggml_nbytes(src)), src_backend ? ggml_backend_name(src_backend) : "NULL", causes[hash_id(src)]);
}
fprintf(stderr, "\n");
}
}
// creates a copy of the tensor with the same memory layout
static struct ggml_tensor * ggml_dup_tensor_layout(struct ggml_context * ctx, const struct ggml_tensor * tensor) {
struct ggml_tensor * dup = ggml_dup_tensor(ctx, tensor);
for (int i = 0; i < GGML_MAX_DIMS; i++) {
dup->nb[i] = tensor->nb[i];
}
return dup;
}
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
// TODO: merge passes
static void sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset state
size_t hash_size = sched->hash_set.size;
memset(sched->hash_set.keys, 0, sizeof(sched->hash_set.keys[0]) * hash_size);
memset(sched->node_talloc, 0, sizeof(sched->node_talloc[0]) * hash_size);
memset(sched->node_copies, 0, sizeof(sched->node_copies[0]) * hash_size);
sched->n_splits = 0;
struct ggml_init_params params = {
/*.mem_size = */ sizeof(sched->context_buffer),
/*.mem_buffer = */ sched->context_buffer,
/*.no_alloc = */ true
};
if (sched->ctx != NULL) {
ggml_free(sched->ctx);
}
sched->ctx = ggml_init(params);
// pass 1: assign backends to ops with allocated inputs
for (int i = 0; i < graph->n_leafs; i++) {
struct ggml_tensor * leaf = graph->leafs[i];
if (node_allocr(leaf) != NULL) {
// do not overwrite user assignments
continue;
}
ggml_backend_t leaf_backend = ggml_get_backend(leaf);
if (leaf_backend == NULL && leaf->view_src != NULL) {
leaf_backend = ggml_get_backend(leaf->view_src);
}
if (leaf_backend != NULL) {
node_allocr(leaf) = ggml_backend_sched_get_tallocr(sched, leaf_backend);
}
}
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (node_allocr(node) != NULL) {
// do not overwrite user assignments
continue;
}
ggml_backend_t node_backend = sched_backend_from_cur(sched, node);
if (node_backend != NULL) {
node_allocr(node) = ggml_backend_sched_get_tallocr(sched, node_backend);
}
}
//printf("PASS 1 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
// pass 2: assign backends to ops from current assignments
// TODO:
// - reuse sched_backend_from_cur
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr == NULL) {
int cur_prio = INT_MAX;
size_t cur_size = 0;
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr != NULL) {
int src_prio = sched_allocr_prio(sched, src_allocr);
size_t src_size = ggml_nbytes(src);
if (src_prio < cur_prio && src_size >= cur_size) {
cur_prio = src_prio;
cur_size = src_size;
node_allocr = src_allocr;
sprintf(causes[hash_id(node)], "2.src%d", j);
}
}
}
if (node_allocr != NULL) {
node_allocr(node) = node_allocr;
}
}
}
//printf("PASS 2 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
// pass 3: assign backends to remaining src from dst (should only be leafs)
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
ggml_tallocr_t node_allocr = node_allocr(node);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr == NULL) {
node_allocr(src) = node_allocr;
}
}
}
//printf("PASS 3 ASSIGNMENTS\n"); sched_print_assignments(sched, graph);
// pass 4: split graph, find tensors that need to be copied
// TODO:
// - when switching from a less preferred backend to a more preferred backend, check if it is possible to move the switch to an earlier point for the same cost
// find first backend
int cur_split = 0;
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (node->view_src == NULL) {
sched->splits[0].tallocr = node_allocr(node);
break;
}
}
sched->splits[0].i_start = 0;
sched->splits[0].n_inputs = 0;
memset(sched->splits[0].inputs, 0, sizeof(sched->splits[0].inputs)); //HACK
ggml_tallocr_t cur_allocr = sched->splits[0].tallocr;
size_t cur_backend_id = sched_allocr_prio(sched, cur_allocr);
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
if (ggml_is_view_op(node->op)) {
continue;
}
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr != cur_allocr) {
sched->splits[cur_split].i_end = i;
cur_split++;
GGML_ASSERT(cur_split < GGML_MAX_SPLITS);
sched->splits[cur_split].tallocr = node_allocr;
sched->splits[cur_split].i_start = i;
sched->splits[cur_split].n_inputs = 0;
memset(sched->splits[cur_split].inputs, 0, sizeof(sched->splits[cur_split].inputs)); //HACK
cur_allocr = node_allocr;
cur_backend_id = sched_allocr_prio(sched, cur_allocr);
}
// find inputs that are not on the same backend
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr != node_allocr) {
int n_inputs = sched->splits[cur_split].n_inputs++;
GGML_ASSERT(n_inputs < GGML_MAX_SPLIT_INPUTS);
sched->splits[cur_split].inputs[n_inputs] = (struct ggml_tensor *)src;
// create copies
size_t id = hash_id(src);
if (sched->node_copies[id][cur_backend_id] == NULL) {
struct ggml_tensor * tensor_copy = ggml_dup_tensor_layout(sched->ctx, src);
sched->node_copies[id][cur_backend_id] = tensor_copy;
node_allocr(tensor_copy) = cur_allocr;
ggml_backend_t backend = ggml_tallocr_get_buffer(cur_allocr)->backend;
ggml_format_name(tensor_copy, "%s#%s", ggml_backend_name(backend), src->name);
}
node->src[j] = sched->node_copies[id][cur_backend_id];
}
}
}
sched->splits[cur_split].i_end = graph->n_nodes;
sched->n_splits = cur_split + 1;
//fprintf(stderr, "PASS 4 ASSIGNMENTS\n"); sched_print_assignments(sched, graph); fflush(stdout);
#if 1
// sanity check: all sources should have the same backend as the node
for (int i = 0; i < graph->n_nodes; i++) {
struct ggml_tensor * node = graph->nodes[i];
ggml_tallocr_t node_allocr = node_allocr(node);
if (node_allocr == NULL) {
fprintf(stderr, "!!!!!!! %s has no backend\n", node->name);
}
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
break;
}
ggml_tallocr_t src_allocr = node_allocr(src);
if (src_allocr != node_allocr /* && src_backend != NULL */) { // ignore nulls for now
fprintf(stderr, "!!!! %s has backend %s, src %d (%s) has backend %s\n",
node->name, node_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(node_allocr)->backend) : "NULL",
j, src->name, src_allocr ? ggml_backend_name(ggml_tallocr_get_buffer(src_allocr)->backend) : "NULL");
}
}
}
#endif
// create copies of the graph for each split
// FIXME: avoid this copy, pass split inputs to ggml_gallocr_alloc_graph_n in some other way
struct ggml_cgraph * graph_copy = ggml_new_graph_custom(sched->ctx, graph->n_nodes + sched->n_splits*GGML_MAX_SPLIT_INPUTS, false);
for (int i = 0; i < sched->n_splits; i++) {
struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(sched->ctx, graph, split->i_start, split->i_end);
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input = split->inputs[j];
struct ggml_tensor * input_cpy = sched->node_copies[hash_id(input)][sched_allocr_prio(sched, split->tallocr)];
input_cpy->src[0] = input;
graph_copy->nodes[graph_copy->n_nodes++] = input_cpy;
}
for (int j = split->i_start; j < split->i_end; j++) {
graph_copy->nodes[graph_copy->n_nodes++] = graph->nodes[j];
}
}
sched->graph = graph_copy;
}
static void sched_alloc_splits(ggml_backend_sched_t sched) {
ggml_gallocr_alloc_graph_n(
sched->galloc,
sched->graph,
sched->hash_set,
sched->node_talloc);
}
static void sched_compute_splits(ggml_backend_sched_t sched) {
uint64_t copy_us[GGML_MAX_BACKENDS] = {0};
uint64_t compute_us[GGML_MAX_BACKENDS] = {0};
struct ggml_backend_sched_split * splits = sched->splits;
for (int i = 0; i < sched->n_splits; i++) {
struct ggml_backend_sched_split * split = &splits[i];
ggml_backend_t split_backend = ggml_tallocr_get_buffer(split->tallocr)->backend;
int split_backend_id = sched_backend_prio(sched, split_backend);
// copy the input tensors to the split backend
uint64_t copy_start_us = ggml_time_us();
for (int j = 0; j < split->n_inputs; j++) {
struct ggml_tensor * input_cpy = sched->node_copies[hash_id(split->inputs[j])][sched_backend_prio(sched, split_backend)];
if (split->inputs[j]->buffer == NULL) {
if (split->inputs[j]->view_src == NULL) {
fprintf(stderr, "input %s has no buffer and no view_src\n", split->inputs[j]->name);
exit(1);
}
struct ggml_tensor * view = split->inputs[j];
view->backend = view->view_src->backend;
view->buffer = view->view_src->buffer;
view->data = (char *)view->view_src->data + view->view_offs;
ggml_backend_buffer_init_tensor(ggml_backend_sched_get_buffer(sched, view->buffer->backend), view);
}
if (input_cpy->buffer == NULL) {
fprintf(stderr, "input_cpy %s has no buffer\n", input_cpy->name);
exit(1);
}
GGML_ASSERT(split->inputs[j]->buffer->backend != input_cpy->buffer->backend);
GGML_ASSERT(input_cpy->buffer->backend == split_backend);
ggml_backend_tensor_copy(split->inputs[j], input_cpy);
}
// ggml_backend_synchronize(split_backend);
int64_t copy_end_us = ggml_time_us();
copy_us[split_backend_id] += copy_end_us - copy_start_us;
#if 0
char split_filename[GGML_MAX_NAME];
snprintf(split_filename, GGML_MAX_NAME, "split_%i_%s.dot", i, ggml_backend_name(split_backend));
ggml_graph_dump_dot(split->graph, NULL, split_filename);
#endif
uint64_t compute_start_us = ggml_time_us();
ggml_backend_graph_compute(split_backend, split->graph);
// ggml_backend_synchronize(split_backend);
uint64_t compute_end_us = ggml_time_us();
compute_us[split_backend_id] += compute_end_us - compute_start_us;
}
#if 0
// per-backend timings
fprintf(stderr, "sched_compute_splits times (%d splits):\n", sched->n_splits);
for (int i = 0; i < sched->n_backends; i++) {
if (copy_us[i] > 0 || compute_us[i] > 0) {
fprintf(stderr, "\t%5.5s: %lu us copy, %lu us compute\n", ggml_backend_name(sched->backends[i]), copy_us[i], compute_us[i]);
}
}
#endif
}
static void sched_reset(ggml_backend_sched_t sched) {
for (int i = 0; i < sched->n_backends; i++) {
ggml_tallocr_reset(sched->tallocs[i]);
}
}
ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends) {
GGML_ASSERT(n_backends <= GGML_MAX_BACKENDS);
struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched));
memset(sched, 0, sizeof(struct ggml_backend_sched));
fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024);
sched->n_backends = n_backends;
for (int i = 0; i < n_backends; i++) {
sched->backends[i] = backends[i];
}
sched->galloc = ggml_gallocr_new();
// init measure allocs for each backend
for (int i = 0; i < n_backends; i++) {
sched->tallocs[i] = ggml_tallocr_new_measure_from_backend(backends[i]);
}
return sched;
}
void ggml_backend_sched_free(ggml_backend_sched_t sched) {
if (sched == NULL) {
return;
}
for (int i = 0; i < sched->n_backends; i++) {
ggml_tallocr_free(sched->tallocs[i]);
}
ggml_gallocr_free(sched->galloc);
free(sched->hash_set.keys);
free(sched->node_talloc);
free(sched->node_copies);
free(sched);
}
void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
// initialize hash tables
size_t hash_size = measure_graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS;
sched->hash_set.size = hash_size;
sched->hash_set.keys = malloc(sizeof(sched->hash_set.keys[0]) * hash_size);
sched->node_talloc = malloc(sizeof(sched->node_talloc[0]) * hash_size);
sched->node_copies = malloc(sizeof(sched->node_copies[0]) * hash_size);
sched_split_graph(sched, measure_graph);
sched_alloc_splits(sched);
// allocate buffers and reset allocators
for (int i = 0; i < sched->n_backends; i++) {
size_t size = ggml_tallocr_max_size(sched->tallocs[i]);
ggml_tallocr_free(sched->tallocs[i]);
sched->tallocs[i] = ggml_tallocr_new_from_backend(sched->backends[i], size);
}
sched_reset(sched);
}
void ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT(sched->hash_set.size >= graph->visited_hash_table.size + GGML_MAX_SPLITS*GGML_MAX_SPLIT_INPUTS);
sched_split_graph(sched, graph);
sched_alloc_splits(sched);
sched_compute_splits(sched);
sched_reset(sched);
}
ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend) {
int backend_index = sched_backend_prio(sched, backend);
return sched->tallocs[backend_index];
}
ggml_backend_buffer_t ggml_backend_sched_get_buffer(ggml_backend_sched_t sched, ggml_backend_t backend) {
int backend_index = sched_backend_prio(sched, backend);
return ggml_tallocr_get_buffer(sched->tallocs[backend_index]);
}
void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
int backend_index = sched_backend_prio(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
node_allocr(node) = sched->tallocs[backend_index];
}

View File

@ -1,142 +1,135 @@
#pragma once
#include "ggml.h"
#include "ggml-alloc.h"
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_backend;
struct ggml_backend_buffer;
// type-erased backend-specific types / wrappers
typedef void * ggml_backend_context_t;
typedef void * ggml_backend_graph_plan_t;
typedef void * ggml_backend_buffer_context_t;
//
// Backend buffer
//
// avoid accessing internals of these types
typedef struct ggml_backend * ggml_backend_t;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
struct ggml_backend_buffer;
typedef struct ggml_backend_buffer * ggml_backend_buffer_t;
//
// backend buffer
//
// backend buffer functions
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
struct ggml_backend_buffer_i {
void (*free_buffer) (ggml_backend_buffer_t buffer);
void * (*get_base) (ggml_backend_buffer_t buffer); // get base pointer
size_t (*get_alloc_size)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // post-allocation callback
void (*free_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); // pre-free callback
};
//
// Backend
//
// TODO: hide behind API
struct ggml_backend_buffer {
struct ggml_backend_buffer_i iface;
struct ggml_backend;
typedef struct ggml_backend * ggml_backend_t;
typedef void * ggml_backend_graph_plan_t;
ggml_backend_t backend;
ggml_backend_buffer_context_t context;
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
size_t size;
};
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
// backend buffer functions
GGML_API ggml_backend_buffer_t ggml_backend_buffer_init(
struct ggml_backend * backend,
struct ggml_backend_buffer_i iface,
ggml_backend_buffer_context_t context,
size_t size);
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
GGML_API void ggml_backend_buffer_free (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alignment (ggml_backend_buffer_t buffer);
GGML_API void * ggml_backend_buffer_get_base (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_size (ggml_backend_buffer_t buffer);
GGML_API size_t ggml_backend_buffer_get_alloc_size(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_init_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API void ggml_backend_buffer_free_tensor (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor);
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
//
// backend
//
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
struct ggml_backend_i {
const char * (*get_name)(ggml_backend_t backend);
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*free)(ggml_backend_t backend);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
// buffer allocation
ggml_backend_buffer_t (*alloc_buffer)(ggml_backend_t backend, size_t size);
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
// get buffer alignment
size_t (*get_alignment)(ggml_backend_t backend);
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
// tensor data access
// these functions can be asynchronous, helper functions are provided for synchronous access that automatically call synchronize
void (*set_tensor_async)(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*synchronize) (ggml_backend_t backend);
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (ggml_backend_t backend, struct ggml_tensor * src, struct ggml_tensor * dst);
//
// CPU backend
//
// compute graph with a plan
ggml_backend_graph_plan_t (*graph_plan_create) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
void (*graph_plan_free) (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
void (*graph_plan_compute)(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
// compute graph without a plan
void (*graph_compute)(ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
// check if the backend supports an operation
bool (*supports_op)(ggml_backend_t backend, const struct ggml_tensor * op);
};
// Create a backend buffer from an existing pointer
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
// TODO: hide behind API
struct ggml_backend {
struct ggml_backend_i iface;
ggml_backend_context_t context;
};
//
// Backend scheduler
//
// backend helper functions
GGML_API ggml_backend_t ggml_get_backend(const struct ggml_tensor * tensor);
// The backend scheduler allows for multiple backends to be used together
// Handles compute buffer allocation, assignment of tensors to backends, and copying of tensors between backends
// The backends are selected based on:
// - the backend that supports the operation
// - the location of the pre-allocated tensors (e.g. the weights)
/*
Example usage:
GGML_API const char * ggml_backend_name(ggml_backend_t backend);
GGML_API void ggml_backend_free(ggml_backend_t backend);
sched = ggml_backend_sched_new({backend_gpu, backend_gpu2, backend_cpu}, num_backends);
// sched is initialized with measure allocators and cannot be used until allocated with a measure graph
GGML_API ggml_backend_buffer_t ggml_backend_alloc_buffer(ggml_backend_t backend, size_t size);
// initialize buffers from a measure graph
measure_graph = build_graph(sched); // use the allocr to allocate inputs as needed
GGML_API size_t ggml_backend_get_alignment(ggml_backend_t backend);
// in build_graph:
build_graph(...) {
// allocating tensors in a specific backend (optional, recommended: pre-allocate inputs in a different buffer)
alloc_cpu = ggml_backend_sched_get_allocr(sched, backend_cpu);
ggml_allocr_alloc(alloc_cpu, tensor);
GGML_API void ggml_backend_tensor_set_async( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// manually assigning nodes to a backend (optional, shouldn't be needed in most cases)
struct ggml_tensor * node = ggml_mul_mat(ctx, ...);
ggml_backend_sched_set_node_backend(sched, node, backend_gpu);
}
GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
// allocate backend buffers from measure graph
ggml_backend_sched_init_measure(sched, measure_graph);
GGML_API void ggml_backend_synchronize(ggml_backend_t backend);
// the scheduler is now ready to compute graphs
GGML_API ggml_backend_graph_plan_t ggml_backend_graph_plan_create (ggml_backend_t backend, struct ggml_cgraph * cgraph);
// compute
graph = build_graph(sched);
ggml_backend_sched_graph_compute(sched, graph);
*/
GGML_API void ggml_backend_graph_plan_free (ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan);
GGML_API void ggml_backend_graph_compute (ggml_backend_t backend, struct ggml_cgraph * cgraph);
GGML_API bool ggml_backend_supports_op (ggml_backend_t backend, const struct ggml_tensor * op);
struct ggml_backend_sched;
typedef struct ggml_backend_sched * ggml_backend_sched_t;
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
// Initialize a backend scheduler
GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_backends);
//
// CPU backend
//
GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched);
GGML_API ggml_backend_t ggml_backend_cpu_init(void);
// Initialize backend buffers from a measure graph
GGML_API void ggml_backend_sched_init_measure(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph);
GGML_API bool ggml_backend_is_cpu(ggml_backend_t backend);
GGML_API ggml_tallocr_t ggml_backend_sched_get_tallocr(ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_sched_get_buffer (ggml_backend_sched_t sched, ggml_backend_t backend);
GGML_API void ggml_backend_cpu_set_n_threads(ggml_backend_t backend_cpu, int n_threads);
GGML_API void ggml_backend_sched_set_node_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_buffer_t ggml_backend_cpu_buffer_from_ptr(ggml_backend_t backend_cpu, void * ptr, size_t size);
// Allocate a graph on the backend scheduler
GGML_API void ggml_backend_sched_graph_compute(
ggml_backend_sched_t sched,
struct ggml_cgraph * graph);
#ifdef __cplusplus
}

View File

@ -39,12 +39,6 @@ extern "C" {
#endif
#endif
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
// 16-bit float
// on Arm, we use __fp16
// on x86, we use uint16_t
@ -173,7 +167,7 @@ static inline float ggml_compute_fp16_to_fp32(ggml_fp16_t h) {
const uint32_t denormalized_cutoff = UINT32_C(1) << 27;
const uint32_t result = sign |
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
(two_w < denormalized_cutoff ? fp32_to_bits(denormalized_value) : fp32_to_bits(normalized_value));
return fp32_from_bits(result);
}
@ -230,7 +224,19 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#endif
// TODO: backend v2 PR
#define GGML_HASHTABLE_FULL ((size_t)-1)
#define GGML_HASHTABLE_ALREADY_EXISTS ((size_t)-2)
bool ggml_hash_contains (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HASHTABLE_FULL if table is full, otherwise the current index of the key or where it should be inserted
size_t ggml_hash_find (const struct ggml_hash_set hash_set, struct ggml_tensor * key);
// returns GGML_HAHSHTABLE_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full
size_t ggml_hash_insert ( struct ggml_hash_set hash_set, struct ggml_tensor * key);
// return index, asserts if table is full
size_t ggml_hash_find_or_insert( struct ggml_hash_set hash_set, struct ggml_tensor * key);
#ifdef __cplusplus
}

View File

@ -14,32 +14,12 @@
//
#include <arm_neon.h>
#if !defined(__aarch64__)
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
#endif
#else
#ifdef __wasm_simd128__
#include <wasm_simd128.h>
#else
#ifdef __POWER9_VECTOR__
#if defined(__POWER9_VECTOR__) || defined(__powerpc64__)
#include <altivec.h>
#undef bool
#define bool _Bool
@ -47,13 +27,15 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h>
#else
#if !defined(__riscv) && !defined(__s390__)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if !defined(__riscv)
#include <immintrin.h>
#endif
#endif
#endif
#endif
#endif
#endif
#ifdef __riscv_v_intrinsic
#include <riscv_vector.h>
@ -61,6 +43,7 @@ inline static int32_t vaddvq_s32(int32x4_t v) {
#undef MIN
#undef MAX
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define MAX(a, b) ((a) > (b) ? (a) : (b))
@ -283,14 +266,34 @@ static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128
#endif // defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__)
#if defined(__ARM_NEON)
#if !defined(__aarch64__)
/*
// 64-bit compatibility
// vaddvq_s16
// vpaddq_s16
// vaddvq_s32
// vaddvq_f32
// vmaxvq_f32
// vcvtnq_s32_f32
inline static int32_t vaddvq_s16(int16x8_t v) {
return
(int32_t)vgetq_lane_s16(v, 0) + (int32_t)vgetq_lane_s16(v, 1) +
(int32_t)vgetq_lane_s16(v, 2) + (int32_t)vgetq_lane_s16(v, 3) +
(int32_t)vgetq_lane_s16(v, 4) + (int32_t)vgetq_lane_s16(v, 5) +
(int32_t)vgetq_lane_s16(v, 6) + (int32_t)vgetq_lane_s16(v, 7);
}
inline static int16x8_t vpaddq_s16(int16x8_t a, int16x8_t b) {
int16x4_t a0 = vpadd_s16(vget_low_s16(a), vget_high_s16(a));
int16x4_t b0 = vpadd_s16(vget_low_s16(b), vget_high_s16(b));
return vcombine_s16(a0, b0);
}
inline static int32_t vaddvq_s32(int32x4_t v) {
return vgetq_lane_s32(v, 0) + vgetq_lane_s32(v, 1) + vgetq_lane_s32(v, 2) + vgetq_lane_s32(v, 3);
}
*/
inline static float vaddvq_f32(float32x4_t v) {
return vgetq_lane_f32(v, 0) + vgetq_lane_f32(v, 1) + vgetq_lane_f32(v, 2) + vgetq_lane_f32(v, 3);
@ -313,6 +316,96 @@ inline static int32x4_t vcvtnq_s32_f32(float32x4_t v) {
return res;
}
// vld1q_s16_x2
// vld1q_u8_x2
// vld1q_u8_x4
// vld1q_s8_x2
// vld1q_s8_x4
// TODO: double-check these work correctly
typedef struct ggml_int16x8x2_t {
int16x8_t val[2];
} ggml_int16x8x2_t;
inline static ggml_int16x8x2_t ggml_vld1q_s16_x2(const int16_t * ptr) {
ggml_int16x8x2_t res;
res.val[0] = vld1q_s16(ptr + 0);
res.val[1] = vld1q_s16(ptr + 8);
return res;
}
typedef struct ggml_uint8x16x2_t {
uint8x16_t val[2];
} ggml_uint8x16x2_t;
inline static ggml_uint8x16x2_t ggml_vld1q_u8_x2(const uint8_t * ptr) {
ggml_uint8x16x2_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
return res;
}
typedef struct ggml_uint8x16x4_t {
uint8x16_t val[4];
} ggml_uint8x16x4_t;
inline static ggml_uint8x16x4_t ggml_vld1q_u8_x4(const uint8_t * ptr) {
ggml_uint8x16x4_t res;
res.val[0] = vld1q_u8(ptr + 0);
res.val[1] = vld1q_u8(ptr + 16);
res.val[2] = vld1q_u8(ptr + 32);
res.val[3] = vld1q_u8(ptr + 48);
return res;
}
typedef struct ggml_int8x16x2_t {
int8x16_t val[2];
} ggml_int8x16x2_t;
inline static ggml_int8x16x2_t ggml_vld1q_s8_x2(const int8_t * ptr) {
ggml_int8x16x2_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
return res;
}
typedef struct ggml_int8x16x4_t {
int8x16_t val[4];
} ggml_int8x16x4_t;
inline static ggml_int8x16x4_t ggml_vld1q_s8_x4(const int8_t * ptr) {
ggml_int8x16x4_t res;
res.val[0] = vld1q_s8(ptr + 0);
res.val[1] = vld1q_s8(ptr + 16);
res.val[2] = vld1q_s8(ptr + 32);
res.val[3] = vld1q_s8(ptr + 48);
return res;
}
#else
#define ggml_int16x8x2_t int16x8x2_t
#define ggml_uint8x16x2_t uint8x16x2_t
#define ggml_uint8x16x4_t uint8x16x4_t
#define ggml_int8x16x2_t int8x16x2_t
#define ggml_int8x16x4_t int8x16x4_t
#define ggml_vld1q_s16_x2 vld1q_s16_x2
#define ggml_vld1q_u8_x2 vld1q_u8_x2
#define ggml_vld1q_u8_x4 vld1q_u8_x4
#define ggml_vld1q_s8_x2 vld1q_s8_x2
#define ggml_vld1q_s8_x4 vld1q_s8_x4
#endif
#endif
@ -1226,7 +1319,7 @@ static float make_q3_quants(int n, int nmax, const float * restrict x, int8_t *
}
static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t * restrict L, float * restrict the_min,
int ntry, float alpha) {
int ntry, float alpha) {
float min = x[0];
float max = x[0];
for (int i = 1; i < n; ++i) {
@ -1269,13 +1362,18 @@ static float make_qkx1_quants(int n, int nmax, const float * restrict x, uint8_t
}
static float make_qkx2_quants(int n, int nmax, const float * restrict x, const float * restrict weights,
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
float rmin, float rdelta, int nstep, bool use_mad) {
uint8_t * restrict L, float * restrict the_min, uint8_t * restrict Laux,
float rmin, float rdelta, int nstep, bool use_mad) {
float min = x[0];
float max = x[0];
float sum_w = weights[0];
float sum_x = sum_w * x[0];
#ifdef HAVE_BUGGY_APPLE_LINKER
// use 'volatile' to prevent unroll and work around a bug in Apple ld64 1015.7
for (volatile int i = 1; i < n; ++i) {
#else
for (int i = 1; i < n; ++i) {
#endif
if (x[i] < min) min = x[i];
if (x[i] > max) max = x[i];
float w = weights[i];
@ -3559,7 +3657,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0);
#endif
int8x16x2_t q2bytes;
ggml_int8x16x2_t q2bytes;
uint8_t aux[16];
float sum = 0;
@ -3578,8 +3676,8 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const ggml_int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
@ -3607,7 +3705,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#endif
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
@ -3615,9 +3713,9 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
const ggml_uint8x16x2_t q2bits = ggml_vld1q_u8_x2(q2); q2 += 32;
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
ggml_int8x16x2_t q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
@ -3951,7 +4049,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t vzero = vdupq_n_s32(0);
#endif
int8x16x4_t q2bytes;
ggml_int8x16x4_t q2bytes;
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
@ -3976,7 +4074,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t q2bits = vld1q_u8(q2);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
@ -4240,7 +4338,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
const int8_t m32 = 32;
int8x16x4_t q3bytes;
ggml_int8x16x4_t q3bytes;
float sum = 0;
@ -4252,9 +4350,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].hmask;
const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
uint8x16x4_t q3h;
ggml_uint8x16x4_t q3h;
int32_t isum = 0;
@ -4270,9 +4368,9 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
const ggml_uint8x16x2_t q3bits = ggml_vld1q_u8_x2(q3); q3 += 32;
const ggml_int8x16x4_t q8bytes_1 = ggml_vld1q_s8_x4(q8); q8 += 64;
const ggml_int8x16x4_t q8bytes_2 = ggml_vld1q_s8_x4(q8); q8 += 64;
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
@ -4774,7 +4872,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t m3b = vdupq_n_u8(0x3);
const uint8x16_t mh = vdupq_n_u8(4);
int8x16x4_t q3bytes;
ggml_int8x16x4_t q3bytes;
uint16_t aux16[2];
int8_t * scales = (int8_t *)aux16;
@ -4783,11 +4881,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
for (int i = 0; i < nb; ++i) {
uint8x16x4_t q3h;
ggml_uint8x16x4_t q3h;
const uint8x8_t hbits = vld1_u8(x[i].hmask);
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(y[i].qs);
const uint16_t a = *(const uint16_t *)x[i].scales;
aux16[0] = a & 0x0f0f;
@ -5136,8 +5234,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
int8x16x2_t q4bytes;
int8x16x2_t q8bytes;
ggml_int8x16x2_t q4bytes;
ggml_int8x16x2_t q8bytes;
float sumf = 0;
@ -5172,17 +5270,17 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4); q4 += 32;
#ifdef __ARM_FEATURE_DOTPROD
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
@ -5190,7 +5288,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5199,7 +5297,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = ggml_vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5514,8 +5612,8 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
float sumf = 0;
int8x16x2_t q4bytes;
int8x16x4_t q8bytes;
ggml_int8x16x2_t q4bytes;
ggml_int8x16x4_t q8bytes;
float sum_mins = 0.f;
@ -5536,10 +5634,10 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const float d = y[i].d * (float)x[i].d[0];
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
const ggml_uint8x16x2_t q4bits = ggml_vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD
q8bytes = vld1q_s8_x4(q8);
q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
@ -5553,7 +5651,7 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int32_t sumi2 = vaddvq_s32(p2) * scales[1];
#else
q8bytes = vld1q_s8_x4(q8);
q8bytes = ggml_vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
@ -5787,7 +5885,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
int8x16x4_t q5bytes;
ggml_int8x16x4_t q5bytes;
float sumf = 0;
@ -5817,16 +5915,16 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8_t * restrict qh = x[i].qh;
const int8_t * restrict q8 = y[i].qs;
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh);
uint8x16x4_t q5h;
ggml_uint8x16x4_t q5h;
int32_t sumi = 0;
for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); q5 += 32;
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5); q5 += 32;
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q5h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q5h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@ -6220,8 +6318,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const int32x4_t mzero = vdupq_n_s32(0);
#endif
int8x16x4_t q5bytes;
uint8x16x4_t q5h;
ggml_int8x16x4_t q5bytes;
ggml_uint8x16x4_t q5h;
float sumf = 0;
@ -6236,8 +6334,8 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x8_t qhbits = vld1_u8(qh);
const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
const ggml_uint8x16x2_t q5bits = ggml_vld1q_u8_x2(q5);
const ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4));
@ -6513,8 +6611,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3);
int8x16x4_t q6bytes;
uint8x16x4_t q6h;
ggml_int8x16x4_t q6bytes;
ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
@ -6526,9 +6624,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const int8_t * restrict scale = x[i].scales;
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
const ggml_int16x8x2_t q8sums = ggml_vld1q_s16_x2(y[i].bsums);
const int8x16_t scales = vld1q_s8(scale);
const int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const ggml_int16x8x2_t q6scales = {vmovl_s8(vget_low_s8(scales)), vmovl_s8(vget_high_s8(scales))};
const int32x4_t prod = vaddq_s32(vaddq_s32(vmull_s16(vget_low_s16 (q8sums.val[0]), vget_low_s16 (q6scales.val[0])),
vmull_s16(vget_high_s16(q8sums.val[0]), vget_high_s16(q6scales.val[0]))),
@ -6540,9 +6638,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
for (int j = 0; j < QK_K/128; ++j) {
uint8x16x2_t qhbits = vld1q_u8_x2(qh); qh += 32;
uint8x16x4_t q6bits = vld1q_u8_x4(q6); q6 += 64;
int8x16x4_t q8bytes = vld1q_s8_x4(q8); q8 += 64;
ggml_uint8x16x2_t qhbits = ggml_vld1q_u8_x2(qh); qh += 32;
ggml_uint8x16x4_t q6bits = ggml_vld1q_u8_x4(q6); q6 += 64;
ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits.val[0]), 4);
q6h.val[1] = vshlq_n_u8(vandq_u8(mone, qhbits.val[1]), 4);
@ -6585,7 +6683,7 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
scale += 2;
#endif
q8bytes = vld1q_s8_x4(q8); q8 += 64;
q8bytes = ggml_vld1q_s8_x4(q8); q8 += 64;
shifted = vshrq_n_u8(qhbits.val[0], 4);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, shifted), 4);
@ -6989,8 +7087,8 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16_t mone = vdupq_n_u8(3);
int8x16x4_t q6bytes;
uint8x16x4_t q6h;
ggml_int8x16x4_t q6bytes;
ggml_uint8x16x4_t q6h;
for (int i = 0; i < nb; ++i) {
@ -7004,9 +7102,9 @@ void ggml_vec_dot_q6_K_q8_K(const int n, float * restrict s, const void * restri
int32_t isum = 0;
uint8x16_t qhbits = vld1q_u8(qh);
uint8x16x2_t q6bits = vld1q_u8_x2(q6);
int8x16x4_t q8bytes = vld1q_s8_x4(q8);
uint8x16_t qhbits = vld1q_u8(qh);
ggml_uint8x16x2_t q6bits = ggml_vld1q_u8_x2(q6);
ggml_int8x16x4_t q8bytes = ggml_vld1q_s8_x4(q8);
q6h.val[0] = vshlq_n_u8(vandq_u8(mone, qhbits), 4);
uint8x16_t shifted = vshrq_n_u8(qhbits, 2);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -33,6 +33,7 @@ struct train_state * init_train_state() {
state->opt = new struct ggml_opt_context;
state->opt->ctx = NULL;
state->opt->params = ggml_opt_default_params(GGML_OPT_ADAM);
state->opt->params.graph_size = LLAMA_TRAIN_MAX_NODES;
state->opt->loss_after = 0.0f;
return state;
@ -1136,6 +1137,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " -ngl N, --n-gpu-layers N Number of model layers to offload to GPU (default %d)", params->n_gpu_layers);
fprintf(stderr, "\n");
}
@ -1355,6 +1357,17 @@ bool consume_common_train_arg(
return true;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "-ngl" || arg == "--n-gpu-layers") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
#ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
params->n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
#endif
} else if (arg == "-h" || arg == "--help") {
params->print_usage = true;
return true;

View File

@ -9,6 +9,8 @@
#include "ggml.h"
#include "llama.h"
#define LLAMA_TRAIN_MAX_NODES 16384
typedef std::string mt19937_state;
struct train_state {