aboutsummaryrefslogtreecommitdiff
path: root/examples/train-text-from-scratch
diff options
context:
space:
mode:
Diffstat (limited to 'examples/train-text-from-scratch')
-rw-r--r--examples/train-text-from-scratch/train-text-from-scratch.cpp18
1 files changed, 10 insertions, 8 deletions
diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp
index 51271b4..7ec8595 100644
--- a/examples/train-text-from-scratch/train-text-from-scratch.cpp
+++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp
@@ -12,6 +12,9 @@
#include <algorithm>
#include <string>
+#if defined(_MSC_VER)
+#pragma warning(disable: 4244 4267) // possible loss of data
+#endif
struct random_normal_distribution {
std::mt19937 gen;
@@ -20,7 +23,6 @@ struct random_normal_distribution {
float max;
};
-
struct random_uniform_distribution {
std::mt19937 gen;
std::uniform_real_distribution<float> rd;
@@ -2366,7 +2368,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
file->write_u32(0);
file->write_u32(0);
file->write_u32(GGML_TYPE_F32);
- file->seek(-file->tell() & 31, SEEK_CUR);
+ file->seek(0-file->tell() & 31, SEEK_CUR);
return;
}
const char * name = ggml_get_name(tensor);
@@ -2381,7 +2383,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
file->write_u32(tensor->type);
file->write_raw(ne, sizeof(ne[0]) * nd);
file->write_raw(name, name_len);
- file->seek(-file->tell() & 31, SEEK_CUR);
+ file->seek(0-file->tell() & 31, SEEK_CUR);
file->write_raw(tensor->data, ggml_nbytes(tensor));
}
@@ -2402,7 +2404,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
std::string name = file->read_string(name_len);
GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0);
- file->seek(-file->tell() & 31, SEEK_CUR);
+ file->seek(0-file->tell() & 31, SEEK_CUR);
file->read_raw(tensor->data, ggml_nbytes(tensor));
}
@@ -2756,8 +2758,8 @@ struct train_params get_default_train_params() {
params.lbfgs_n_iter = 16;
params.adam_n_iter = 16;
- params.adam_alpha = 1e-3;
- params.adam_decay = 1e-3;
+ params.adam_alpha = 1e-3f;
+ params.adam_decay = 1e-3f;
params.mem_model_gb = 2;
params.mem_compute_gb = 24;
@@ -3331,8 +3333,8 @@ int main(int argc, char ** argv) {
int n_gen = params.n_predict;
int sample_ctx = n_tokens - n_tokens/8;
- sampler.params.temp = 0.2;
- sampler.params.repeat_penalty = 1.1;
+ sampler.params.temp = 0.2f;
+ sampler.params.repeat_penalty = 1.1f;
sampler.params.mirostat = 2;
init_sampler(&sampler, lctx);