SG++-Doxygen-Documentation
test_sgdeLaplace.py

This example can be found under datadriven/examples/test_sgdeLaplace.py.

1 # -------------------------------------------------------------------------------
2 # DataDist tests
3 # -------------------------------------------------------------------------------
4 import matplotlib.pyplot as plt
5 import numpy as np
6 import json
7 from scipy.integrate.quadpack import dblquad
8 from pysgpp import createOperationDensityMarginalize, \
9  createOperationLTwoDotExplicit, createOperationQuadrature, \
10  createOperationMakePositive, DataVector, Grid, \
11  BandwidthOptimizationType_SILVERMANSRULE, \
12  KernelType_GAUSSIAN
13 import pysgpp.extensions.datadriven.uq.dists as dists
14 from pysgpp.extensions.datadriven.uq.dists import J, Normal, Uniform, SGDEdist, Lognormal
15 from pysgpp.extensions.datadriven.uq.plot import plotDensity2d
16 from pysgpp.extensions.datadriven.uq.plot.plot1d import plotDensity1d, plotSG1d
17 from pysgpp.extensions.datadriven.uq.plot.plot2d import plotSGDE2d, plotSG2d
18 from pysgpp.extensions.datadriven.uq.plot.plot3d import plotDensity3d, plotSG3d
19 from pysgpp.extensions.datadriven.uq.operations.sparse_grid import hierarchize
20 from pysgpp.extensions.datadriven.uq.transformation.JointTransformation import JointTransformation
21 from pysgpp.extensions.datadriven.uq.transformation.LinearTransformation import LinearTransformation
22 from pysgpp.extensions.datadriven.uq.dists.MultivariateNormal import MultivariateNormal
23 from pysgpp.extensions.datadriven.uq.dists.KDEDist import KDEDist
24 from pysgpp.extensions.datadriven.uq.quadrature.sparse_grid import doQuadrature
25 from pysgpp.extensions.datadriven.uq.dists.Dist import Dist
26 
27 
28 
29 def test_sgdeLaplace():
30  l2_samples = 10000
31  # sample_range = np.arange(10, 500, 50)
32  sample_range = [10, 20, 50, 100, 200, 500]
33  points = {}
34  grids = ["linear",
35  "modlinear", # keine OperationQuadrature
36  "poly",
37  "modpoly",
38  "polyBoundary",
39  "polyClenshawCurtis",
40  "modPolyClenshawCurtis",
41  "polyClenshawCurtisBoundary",
42  "bsplineClenshawCurtis",
43  "modBsplineClenshawCurtis" # keine OperationMultipleEval
44  ]
45 
46  U = dists.J([dists.Lognormal.by_alpha(0.5, 0.1, 0.001),
47  dists.Lognormal.by_alpha(0.5, 0.1, 0.001)])
48  l2_errors = {}
49  for grid in grids:
50  l2_errors[grid] = []
51  points[grid] = []
52 
53  l2_errors["kde"] = []
54  samples = 1000
55  for samples in sample_range:
56  # for lvl in range(5, 6):
57  trainSamples = U.rvs(samples)
58  # testSamples = U.rvs(l2_samples)
59  for grid_name in grids:
60  # build parameter set
61  print("--------------------Samples: {} Grid: {}--------------------".format(samples, grid_name))
62  dist_sgde = SGDEdist.byLearnerSGDEConfig(trainSamples,
63  bounds=U.getBounds(),
64  unitIntegrand=True,
65  config={"grid_level": 1,
66  "grid_type": grid_name,
67  "grid_maxDegree": 6,
68  "refinement_numSteps": 0,
69  "refinement_numPoints": 10,
70  "solver_threshold": 1e-10,
71  "solver_verbose": False,
72  "regularization_type": "Laplace",
73  "crossValidation_lambda": 1e-6,
74  "crossValidation_enable": True,
75  "crossValidation_kfold": 4,
76  "crossValidation_lambdaSteps": 10,
77  "crossValidation_silent": False})
78  points[grid_name].append(dist_sgde.grid.getSize())
79  # l2_errors[grid_name].append(dist_sgde.l2error(U, testSamplesUnit=testSamples))
80  l2_errors[grid_name].append(dist_sgde.l2error(U, n=l2_samples))
81  # plt.figure()
82  # plotDensity2d(U, levels=(10, 20, 40, 50, 60))
83  # plt.figure()
84  # plotDensity2d(dist_sgde, levels=(10, 20, 40, 50, 60))
85  # plt.show()
86 
87  dist_kde = dists.KDEDist(trainSamples,
88  kernelType=KernelType_GAUSSIAN,
89  bandwidthOptimizationType=BandwidthOptimizationType_SILVERMANSRULE)
90  l2_errors["kde"].append(dist_kde.l2error(U, testSamplesUnit=testSamples))
91 
92  for grid_name in grids:
93  plt.plot(sample_range, l2_errors[grid_name], label=grid_name)
94  # plt.plot(points[grid], l2_errors[grid_name],".-", label=grid_name)
95 
96  plt.plot(sample_range, l2_errors["kde"], label="KDE")
97 
98  # plt.plot([x for x in range(1,300, 100)], [l2_errors["kde"][0] for i in range(1,4)], label="KDE")
99 
100  plt.xlabel("# Gitterpunkte")
101  plt.ylabel("L2-Fehler")
102  plt.yscale("log")
103  plt.legend()
104  plt.show()
105 
106 test_sgdeLaplace()