Skip to content

Commit c0d477b

Browse files
committed
update
1 parent d54e8b8 commit c0d477b

File tree

8 files changed

+281
-48
lines changed

8 files changed

+281
-48
lines changed

build.zig

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
const std = @import("std");
22

33
// Zig Version: 0.11.0-dev.3798+a5e15eced
4+
// Zig Build Command: zig build
5+
// Zig Run Command: zig build -h
6+
// zig build run_test0_zig
7+
// zig build run_test1_zig
48
pub fn build(b: *std.Build) void {
59
const target = b.standardTargetOptions(.{});
610
const optimize = b.standardOptimizeOption(.{});
711

8-
const zig_tests = .{
12+
// tests_zig
13+
const tests_zig = .{
914
"test0",
15+
"test1",
1016
};
11-
inline for (zig_tests) |name| {
17+
inline for (tests_zig) |name| {
1218
const exe = b.addExecutable(.{
1319
.name = name,
1420
.root_source_file = .{ .path = std.fmt.comptimePrint("tests/{s}.zig", .{name}) },
@@ -25,7 +31,7 @@ pub fn build(b: *std.Build) void {
2531
const run_cmd = b.addRunArtifact(exe);
2632
run_cmd.step.dependOn(b.getInstallStep());
2733
if (b.args) |args| run_cmd.addArgs(args);
28-
const run_step = b.step("test_" ++ name, "Run tests");
34+
const run_step = b.step("run_" ++ name ++ "_zig", "Run tests_zig");
2935
run_step.dependOn(&run_cmd.step);
3036
}
3137
}

test1-1-backward.dot

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
digraph G {
2+
newrank = true;
3+
rankdir = LR;
4+
"00000225cc0f7070" [ style = filled; fillcolor = yellow; shape = record; label="node_0 (f32)|0 [1, 1, 1] | <x>none | <g>x+y"; ]
5+
"00000225cc0f73a0" [ style = filled; fillcolor = green; shape = record; label="node_1 (f32)|1 [1, 1, 1] | <x>x*y | <g>x+y"; ]
6+
"00000225cc0f75c0" [ style = filled; fillcolor = green; shape = record; label="node_2 (f32)|2 [1, 1, 1] | <x>x*y | <g>none"; ]
7+
"00000225cc0f77e0" [ style = filled; fillcolor = white; shape = record; label="node_3 (f32)|3 [1, 1, 1] | <x>x*y"; ]
8+
"00000225cc0f7a00" [ style = filled; fillcolor = lightblue; shape = record; label="node_5 (f32)|5 [1, 1, 1] | <x>x*y | <g>none"; ]
9+
"00000225cc0f7c20" [ style = filled; fillcolor = lightblue; shape = record; label="node_6 (f32)|6 [1, 1, 1] | <x>x+y | <g>none"; ]
10+
"00000225cc0f7e40" [ style = filled; fillcolor = lightblue; shape = record; label="node_7 (f32)|7 [1, 1, 1] | <x>x*y | <g>none"; ]
11+
"00000225cc0f7290" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_0 (f32)|CONST 0 [1, 1] | (3.0e+00)"; ]
12+
"00000225cc0f7180" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_1 (f32)|CONST 1 [1, 1] | (0.0e+00)"; ]
13+
"00000225cc0f74b0" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_2 (f32)|CONST 2 [1, 1] | (0.0e+00)"; ]
14+
"00000225cc0f76d0" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_3 (f32)|CONST 3 [1, 1] | (1.0e+00)"; ]
15+
"00000225cc0f7070":x -> "00000225cc0f73a0":x [ arrowhead = vee; style = solid; label = "x"; ]
16+
"00000225cc0f7070":x -> "00000225cc0f73a0":x [ arrowhead = vee; style = solid; label = "y"; ]
17+
"00000225cc0f73a0":x -> "00000225cc0f75c0":x [ arrowhead = vee; style = solid; label = "x"; ]
18+
"00000225cc0f7290":x -> "00000225cc0f75c0":x [ arrowhead = vee; style = solid; label = "y"; ]
19+
"00000225cc0f7290":x -> "00000225cc0f77e0":x [ arrowhead = vee; style = solid; label = "x"; ]
20+
"00000225cc0f75c0":g -> "00000225cc0f77e0":x [ arrowhead = vee; style = solid; label = "y"; ]
21+
"00000225cc0f74b0":x -> "00000225cc0f73a0":g [ arrowhead = empty; style = dashed; label = "x"; ]
22+
"00000225cc0f77e0":x -> "00000225cc0f73a0":g [ arrowhead = empty; style = dashed; label = "y"; ]
23+
"00000225cc0f7070":x -> "00000225cc0f7a00":x [ arrowhead = vee; style = solid; label = "x"; ]
24+
"00000225cc0f73a0":g -> "00000225cc0f7a00":x [ arrowhead = vee; style = solid; label = "y"; ]
25+
"00000225cc0f7180":x -> "00000225cc0f7c20":x [ arrowhead = vee; style = solid; label = "x"; ]
26+
"00000225cc0f7a00":x -> "00000225cc0f7c20":x [ arrowhead = vee; style = solid; label = "y"; ]
27+
"00000225cc0f7070":x -> "00000225cc0f7e40":x [ arrowhead = vee; style = solid; label = "x"; ]
28+
"00000225cc0f73a0":g -> "00000225cc0f7e40":x [ arrowhead = vee; style = solid; label = "y"; ]
29+
"00000225cc0f7c20":x -> "00000225cc0f7070":g [ arrowhead = empty; style = dashed; label = "x"; ]
30+
"00000225cc0f7e40":x -> "00000225cc0f7070":g [ arrowhead = empty; style = dashed; label = "y"; ]
31+
}

test1-1-forward.dot

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
digraph G {
2+
newrank = true;
3+
rankdir = LR;
4+
"00000225cc0f7070" [ style = filled; fillcolor = yellow; shape = record; label="node_0 (f32)|0 [1, 1, 1] | <x>none | <g>x+y"; ]
5+
"00000225cc0f73a0" [ style = filled; fillcolor = green; shape = record; label="node_1 (f32)|1 [1, 1, 1] | <x>x*y | <g>x+y"; ]
6+
"00000225cc0f75c0" [ style = filled; fillcolor = green; shape = record; label="node_2 (f32)|2 [1, 1, 1] | <x>x*y | <g>none"; ]
7+
"00000225cc0f7290" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_0 (f32)|CONST 0 [1, 1] | (3.0e+00)"; ]
8+
"00000225cc0f7070":x -> "00000225cc0f73a0":x [ arrowhead = vee; style = solid; label = "x"; ]
9+
"00000225cc0f7070":x -> "00000225cc0f73a0":x [ arrowhead = vee; style = solid; label = "y"; ]
10+
"00000225cc0f73a0":x -> "00000225cc0f75c0":x [ arrowhead = vee; style = solid; label = "x"; ]
11+
"00000225cc0f7290":x -> "00000225cc0f75c0":x [ arrowhead = vee; style = solid; label = "y"; ]
12+
}

test1-2-backward.dot

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
digraph G {
2+
newrank = true;
3+
rankdir = LR;
4+
"00000225cc0f8280" [ style = filled; fillcolor = yellow; shape = record; label="node_0 (f32)|0 [1, 1, 1] | <x>none | <g>x+y"; ]
5+
"00000225cc0f87d0" [ style = filled; fillcolor = green; shape = record; label="node_1 (f32)|1 [1, 1, 1] | <x>x*y | <g>x+y"; ]
6+
"00000225cc0f8390" [ style = filled; fillcolor = yellow; shape = record; label="node_2 (f32)|2 [1, 1, 1] | <x>none | <g>x+y"; ]
7+
"00000225cc0f89f0" [ style = filled; fillcolor = green; shape = record; label="node_3 (f32)|3 [1, 1, 1] | <x>x*y | <g>x+y"; ]
8+
"00000225cc0f8c10" [ style = filled; fillcolor = green; shape = record; label="node_4 (f32)|4 [1, 1, 1] | <x>x+y | <g>none"; ]
9+
"00000225cc0f8f40" [ style = filled; fillcolor = white; shape = record; label="node_5 (f32)|5 [1, 1, 1] | <x>x+y"; ]
10+
"00000225cc0f9490" [ style = filled; fillcolor = lightblue; shape = record; label="node_6 (f32)|6 [1, 1, 1] | <x>x*y | <g>x+y"; ]
11+
"00000225cc0f96b0" [ style = filled; fillcolor = lightblue; shape = record; label="node_7 (f32)|7 [1, 1, 1] | <x>x+y | <g>none"; ]
12+
"00000225cc0f9050" [ style = filled; fillcolor = lightblue; shape = record; label="node_8 (f32)|8 [1, 1, 1] | <x>x*y | <g>x+y"; ]
13+
"00000225cc0f9270" [ style = filled; fillcolor = lightblue; shape = record; label="node_9 (f32)|9 [1, 1, 1] | <x>x+y | <g>x+y"; ]
14+
"00000225cc0f8e30" [ style = filled; fillcolor = white; shape = record; label="node_10 (f32)|10 [1, 1, 1] | <x>x+y"; ]
15+
"00000225cc0f98d0" [ style = filled; fillcolor = lightblue; shape = record; label="node_11 (f32)|11 [1, 1, 1] | <x>x*y | <g>x+y"; ]
16+
"00000225cc0f9af0" [ style = filled; fillcolor = lightblue; shape = record; label="node_12 (f32)|12 [1, 1, 1] | <x>x+y | <g>x+y"; ]
17+
"00000225cc0f9d10" [ style = filled; fillcolor = lightblue; shape = record; label="node_13 (f32)|13 [1, 1, 1] | <x>x*y | <g>x+y"; ]
18+
"00000225cc0f9f30" [ style = filled; fillcolor = lightblue; shape = record; label="node_14 (f32)|14 [1, 1, 1] | <x>x+y | <g>none"; ]
19+
"00000225cc0f86c0" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_0 (f32)|CONST 0 [1, 1] | (0.0e+00)"; ]
20+
"00000225cc0f8b00" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_1 (f32)|CONST 1 [1, 1] | (0.0e+00)"; ]
21+
"00000225cc0f8d20" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_2 (f32)|CONST 2 [1, 1] | (1.0e+00)"; ]
22+
"00000225cc0f85b0" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_3 (f32)|CONST 3 [1, 1] | (0.0e+00)"; ]
23+
"00000225cc0f88e0" [ style = filled; fillcolor = pink; shape = record; label="<x>leaf_4 (f32)|CONST 4 [1, 1] | (0.0e+00)"; ]
24+
"00000225cc0f8280":x -> "00000225cc0f87d0":x [ arrowhead = vee; style = solid; label = "x"; ]
25+
"00000225cc0f8280":x -> "00000225cc0f87d0":x [ arrowhead = vee; style = solid; label = "y"; ]
26+
"00000225cc0f8280":x -> "00000225cc0f89f0":x [ arrowhead = vee; style = solid; label = "x"; ]
27+
"00000225cc0f8390":x -> "00000225cc0f89f0":x [ arrowhead = vee; style = solid; label = "y"; ]
28+
"00000225cc0f87d0":x -> "00000225cc0f8c10":x [ arrowhead = vee; style = solid; label = "x"; ]
29+
"00000225cc0f89f0":x -> "00000225cc0f8c10":x [ arrowhead = vee; style = solid; label = "y"; ]
30+
"00000225cc0f8b00":x -> "00000225cc0f8f40":x [ arrowhead = vee; style = solid; label = "x"; ]
31+
"00000225cc0f8d20":x -> "00000225cc0f8f40":x [ arrowhead = vee; style = solid; label = "y"; ]
32+
"00000225cc0f8280":x -> "00000225cc0f9490":x [ arrowhead = vee; style = solid; label = "x"; ]
33+
"00000225cc0f8f40":x -> "00000225cc0f9490":x [ arrowhead = vee; style = solid; label = "y"; ]
34+
"00000225cc0f86c0":x -> "00000225cc0f96b0":x [ arrowhead = vee; style = solid; label = "x"; ]
35+
"00000225cc0f9490":x -> "00000225cc0f96b0":x [ arrowhead = vee; style = solid; label = "y"; ]
36+
"00000225cc0f8390":x -> "00000225cc0f9050":x [ arrowhead = vee; style = solid; label = "x"; ]
37+
"00000225cc0f8f40":x -> "00000225cc0f9050":x [ arrowhead = vee; style = solid; label = "y"; ]
38+
"00000225cc0f85b0":x -> "00000225cc0f9270":x [ arrowhead = vee; style = solid; label = "x"; ]
39+
"00000225cc0f9050":x -> "00000225cc0f9270":x [ arrowhead = vee; style = solid; label = "y"; ]
40+
"00000225cc0f88e0":x -> "00000225cc0f8e30":x [ arrowhead = vee; style = solid; label = "x"; ]
41+
"00000225cc0f8d20":x -> "00000225cc0f8e30":x [ arrowhead = vee; style = solid; label = "y"; ]
42+
"00000225cc0f8280":x -> "00000225cc0f98d0":x [ arrowhead = vee; style = solid; label = "x"; ]
43+
"00000225cc0f8e30":x -> "00000225cc0f98d0":x [ arrowhead = vee; style = solid; label = "y"; ]
44+
"00000225cc0f9270":x -> "00000225cc0f9af0":x [ arrowhead = vee; style = solid; label = "x"; ]
45+
"00000225cc0f98d0":x -> "00000225cc0f9af0":x [ arrowhead = vee; style = solid; label = "y"; ]
46+
"00000225cc0f8280":x -> "00000225cc0f9d10":x [ arrowhead = vee; style = solid; label = "x"; ]
47+
"00000225cc0f8e30":x -> "00000225cc0f9d10":x [ arrowhead = vee; style = solid; label = "y"; ]
48+
"00000225cc0f9af0":x -> "00000225cc0f9f30":x [ arrowhead = vee; style = solid; label = "x"; ]
49+
"00000225cc0f9d10":x -> "00000225cc0f9f30":x [ arrowhead = vee; style = solid; label = "y"; ]
50+
}

test1-2-forward.dot

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
digraph G {
2+
newrank = true;
3+
rankdir = LR;
4+
"00000225cc0f8280" [ style = filled; fillcolor = yellow; shape = record; label="node_0 (f32)|0 [1, 1, 1] | <x>none | <g>x+y"; ]
5+
"00000225cc0f87d0" [ style = filled; fillcolor = green; shape = record; label="node_1 (f32)|1 [1, 1, 1] | <x>x*y | <g>x+y"; ]
6+
"00000225cc0f8390" [ style = filled; fillcolor = yellow; shape = record; label="node_2 (f32)|2 [1, 1, 1] | <x>none | <g>x+y"; ]
7+
"00000225cc0f89f0" [ style = filled; fillcolor = green; shape = record; label="node_3 (f32)|3 [1, 1, 1] | <x>x*y | <g>x+y"; ]
8+
"00000225cc0f8c10" [ style = filled; fillcolor = green; shape = record; label="node_4 (f32)|4 [1, 1, 1] | <x>x+y | <g>none"; ]
9+
"00000225cc0f8280":x -> "00000225cc0f87d0":x [ arrowhead = vee; style = solid; label = "x"; ]
10+
"00000225cc0f8280":x -> "00000225cc0f87d0":x [ arrowhead = vee; style = solid; label = "y"; ]
11+
"00000225cc0f8280":x -> "00000225cc0f89f0":x [ arrowhead = vee; style = solid; label = "x"; ]
12+
"00000225cc0f8390":x -> "00000225cc0f89f0":x [ arrowhead = vee; style = solid; label = "y"; ]
13+
"00000225cc0f87d0":x -> "00000225cc0f8c10":x [ arrowhead = vee; style = solid; label = "x"; ]
14+
"00000225cc0f89f0":x -> "00000225cc0f8c10":x [ arrowhead = vee; style = solid; label = "y"; ]
15+
}

tests/test0.zig

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,43 @@
1-
const std = @import("std");
2-
const c = @cImport({
3-
@cInclude("ggml/ggml.h");
4-
@cInclude("stdio.h");
5-
@cInclude("stdlib.h");
6-
});
7-
8-
pub fn main() !void {
9-
const params: c.ggml_init_params = .{
10-
.mem_size = 128*1024*1024,
11-
.mem_buffer = null,
12-
.no_alloc = false,
13-
};
14-
15-
const ctx0 = c.ggml_init(params);
16-
17-
const t1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 10);
18-
const t2 = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_I16, 10, 20);
19-
const t3 = c.ggml_new_tensor_3d(ctx0, c.GGML_TYPE_I32, 10, 20, 30);
20-
21-
try std.testing.expect(t1.*.n_dims == 1);
22-
try std.testing.expect(t1.*.ne[0] == 10);
23-
try std.testing.expect(t1.*.nb[1] == 10*@sizeOf(f32));
24-
25-
try std.testing.expect(t2.*.n_dims == 2);
26-
try std.testing.expect(t2.*.ne[0] == 10);
27-
try std.testing.expect(t2.*.ne[1] == 20);
28-
try std.testing.expect(t2.*.nb[1] == 10*@sizeOf(i16));
29-
try std.testing.expect(t2.*.nb[2] == 10*20*@sizeOf(i16));
30-
31-
try std.testing.expect(t3.*.n_dims == 3);
32-
try std.testing.expect(t3.*.ne[0] == 10);
33-
try std.testing.expect(t3.*.ne[1] == 20);
34-
try std.testing.expect(t3.*.ne[2] == 30);
35-
try std.testing.expect(t3.*.nb[1] == 10*@sizeOf(i32));
36-
try std.testing.expect(t3.*.nb[2] == 10*20*@sizeOf(i32));
37-
try std.testing.expect(t3.*.nb[3] == 10*20*30*@sizeOf(i32));
38-
39-
c.ggml_print_objects(ctx0);
40-
41-
c.ggml_free(ctx0);
42-
43-
_ = try std.io.getStdIn().reader().readByte();
44-
}
1+
const std = @import("std");
2+
const c = @cImport({
3+
@cInclude("ggml/ggml.h");
4+
@cInclude("stdio.h");
5+
@cInclude("stdlib.h");
6+
});
7+
8+
pub fn main() !void {
9+
const params: c.ggml_init_params = .{
10+
.mem_size = 128*1024*1024,
11+
.mem_buffer = null,
12+
.no_alloc = false,
13+
};
14+
15+
const ctx0 = c.ggml_init(params);
16+
defer c.ggml_free(ctx0);
17+
18+
const t1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 10);
19+
const t2 = c.ggml_new_tensor_2d(ctx0, c.GGML_TYPE_I16, 10, 20);
20+
const t3 = c.ggml_new_tensor_3d(ctx0, c.GGML_TYPE_I32, 10, 20, 30);
21+
22+
try std.testing.expect(t1.*.n_dims == 1);
23+
try std.testing.expect(t1.*.ne[0] == 10);
24+
try std.testing.expect(t1.*.nb[1] == 10*@sizeOf(f32));
25+
26+
try std.testing.expect(t2.*.n_dims == 2);
27+
try std.testing.expect(t2.*.ne[0] == 10);
28+
try std.testing.expect(t2.*.ne[1] == 20);
29+
try std.testing.expect(t2.*.nb[1] == 10*@sizeOf(i16));
30+
try std.testing.expect(t2.*.nb[2] == 10*20*@sizeOf(i16));
31+
32+
try std.testing.expect(t3.*.n_dims == 3);
33+
try std.testing.expect(t3.*.ne[0] == 10);
34+
try std.testing.expect(t3.*.ne[1] == 20);
35+
try std.testing.expect(t3.*.ne[2] == 30);
36+
try std.testing.expect(t3.*.nb[1] == 10*@sizeOf(i32));
37+
try std.testing.expect(t3.*.nb[2] == 10*20*@sizeOf(i32));
38+
try std.testing.expect(t3.*.nb[3] == 10*20*30*@sizeOf(i32));
39+
40+
c.ggml_print_objects(ctx0);
41+
42+
_ = try std.io.getStdIn().reader().readByte();
43+
}

tests/test1.zig

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
const std = @import("std");
2+
const c = @cImport({
3+
@cInclude("ggml/ggml.h");
4+
@cInclude("stdio.h");
5+
@cInclude("stdlib.h");
6+
});
7+
8+
pub fn main() !void {
9+
const params = .{
10+
.mem_size = 128*1024*1024,
11+
.mem_buffer = null,
12+
.no_alloc = false,
13+
};
14+
15+
const ctx0 = c.ggml_init(params);
16+
defer c.ggml_free(ctx0);
17+
18+
{
19+
const x: [*c]c.struct_ggml_tensor = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
20+
21+
c.ggml_set_param(ctx0, x);
22+
23+
const a = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
24+
const b = c.ggml_mul(ctx0, x, x);
25+
const f = c.ggml_mul(ctx0, b, a);
26+
27+
// a*x^2
28+
// 2*a*x
29+
30+
c.ggml_print_objects(ctx0);
31+
32+
const gf = c.ggml_build_forward(f);
33+
const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false);
34+
35+
36+
_ = c.ggml_set_f32(x, 2.0);
37+
_ = c.ggml_set_f32(a, 3.0);
38+
39+
c.ggml_graph_reset(@constCast(&gf));
40+
_ = c.ggml_set_f32(f.*.grad, 1.0);
41+
42+
c.ggml_graph_compute(ctx0, @constCast(&gb));
43+
44+
std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
45+
std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
46+
47+
try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 12.0);
48+
try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 12.0);
49+
50+
_ = c.ggml_set_f32(x, 3.0);
51+
52+
c.ggml_graph_reset(@constCast(&gf));
53+
_ = c.ggml_set_f32(f.*.grad, 1.0);
54+
55+
c.ggml_graph_compute(ctx0, @constCast(&gb));
56+
57+
std.debug.print("f = {d:.6}\n", .{c.ggml_get_f32_1d(f, 0)});
58+
std.debug.print("df/dx = {d:.6}\n", .{c.ggml_get_f32_1d(x.*.grad, 0)});
59+
60+
try std.testing.expect(c.ggml_get_f32_1d(f, 0) == 27.0);
61+
try std.testing.expect(c.ggml_get_f32_1d(x.*.grad, 0) == 18.0);
62+
63+
c.ggml_graph_dump_dot(&gf, null, "test1-1-forward.dot");
64+
c.ggml_graph_dump_dot(&gb, &gf, "test1-1-backward.dot");
65+
}
66+
67+
/////////////////////////////////////////////////////////////
68+
69+
{
70+
const x1 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
71+
const x2 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
72+
const x3 = c.ggml_new_tensor_1d(ctx0, c.GGML_TYPE_F32, 1);
73+
74+
_ = c.ggml_set_f32(x1, 3.0);
75+
_ = c.ggml_set_f32(x2, 1.0);
76+
_ = c.ggml_set_f32(x3, 0.0);
77+
78+
c.ggml_set_param(ctx0, x1);
79+
c.ggml_set_param(ctx0, x2);
80+
81+
const y = c.ggml_add(ctx0, c.ggml_mul(ctx0, x1, x1), c.ggml_mul(ctx0, x1, x2));
82+
83+
const gf = c.ggml_build_forward(y);
84+
const gb = c.ggml_build_backward(ctx0, @constCast(&gf), false);
85+
86+
c.ggml_graph_reset(@constCast(&gf));
87+
_ = c.ggml_set_f32(y.*.grad, 1.0);
88+
89+
c.ggml_graph_compute(ctx0, @constCast(&gb));
90+
91+
std.debug.print("y = {d:.6}\n", .{c.ggml_get_f32_1d(y, 0)});
92+
std.debug.print("df/dx1 = {d:.6}\n", .{c.ggml_get_f32_1d(x1.*.grad, 0)});
93+
std.debug.print("df/dx2 = {d:.6}\n", .{c.ggml_get_f32_1d(x2.*.grad, 0)});
94+
95+
try std.testing.expect(c.ggml_get_f32_1d(y, 0) == 12.0);
96+
try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 7.0);
97+
try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 3.0);
98+
99+
const g1 = x1.*.grad;
100+
const g2 = x2.*.grad;
101+
102+
const gbb = c.ggml_build_backward(ctx0, @constCast(&gb), true);
103+
104+
c.ggml_graph_reset(@constCast(&gb));
105+
_ = c.ggml_set_f32(g1.*.grad, 1.0);
106+
_ = c.ggml_set_f32(g2.*.grad, 1.0);
107+
108+
c.ggml_graph_compute(ctx0, @constCast(&gbb));
109+
110+
std.debug.print("H * [1, 1] = [ {d:.6} {d:.6} ]\n", .{c.ggml_get_f32_1d(x1.*.grad, 0), c.ggml_get_f32_1d(x2.*.grad, 0)});
111+
112+
try std.testing.expect(c.ggml_get_f32_1d(x1.*.grad, 0) == 3.0);
113+
try std.testing.expect(c.ggml_get_f32_1d(x2.*.grad, 0) == 1.0);
114+
115+
c.ggml_graph_dump_dot(&gf, null, "test1-2-forward.dot");
116+
c.ggml_graph_dump_dot(&gb, &gf, "test1-2-backward.dot");
117+
}
118+
119+
_ = try std.io.getStdIn().reader().readByte();
120+
}

thirdparty/ggml

Submodule ggml updated from bc696b3 to 00c40aa

0 commit comments

Comments
 (0)