Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 48 additions & 24 deletions source/source_estate/module_pot/pot_ml_exx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ ML_EXX::ML_EXX()

ML_EXX::~ML_EXX(){}

void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in)
void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running)
{
torch::set_default_dtype(caffe2::TypeMeta::fromScalarType(torch::kDouble));
auto output = torch::get_default_dtype();
std::cout << "Default type: " << output << std::endl;
ofs_running << " Default type: " << output << std::endl;
Comment thread
mohanchen marked this conversation as resolved.

this->set_device(inp.of_ml_device);
this->set_device(inp.of_ml_device, ofs_running);

this->nx = rho_basis_in->nrxx;
this->nx_tot = rho_basis_in->nrxx;
Expand All @@ -48,15 +48,23 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
inp.of_ml_tanhp_nl,
inp.of_ml_tanhq_nl);

std::cout << "ninput = " << this->ninput << std::endl;
ofs_running << "ninput = " << this->ninput << std::endl;

if (PARAM.inp.ml_exx)
{
int nnode = 100;
int nlayer = 3;
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device);
torch::load(this->nn, "net.pt", this->device_type);
std::cout << "load net done" << std::endl;
this->nn = std::make_shared<NN_OFImpl>(this->nx, 0, this->ninput, nnode, nlayer, this->device, ofs_running);
try
{
torch::load(this->nn, "net.pt", this->device_type);
}
catch (const std::exception& e)
{
ModuleBase::WARNING_QUIT("ML_EXX::set_para",
"Failed to load neural network model from net.pt: " + std::string(e.what()));
}
ofs_running << "load net done (ML EXX neural network functional model loaded successfully)" << std::endl;
if (PARAM.inp.of_ml_feg != 0)
{
torch::Tensor feg_inpt = torch::zeros(this->ninput, this->device_type);
Expand All @@ -74,7 +82,7 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
this->feg_net_F = this->nn->forward(feg_inpt).to(this->device_CPU).contiguous().data_ptr<double>()[0];
}

std::cout << "feg_net_F = " << this->feg_net_F << std::endl;
ofs_running << "feg_net_F = " << this->feg_net_F << std::endl;
}
}

Expand All @@ -88,8 +96,24 @@ void ML_EXX::set_para(const Input_para& inp, const UnitCell* ucell_in, const Mod
this->chi_pnl = inp.of_ml_chi_pnl;
this->chi_qnl = inp.of_ml_chi_qnl;

this->cal_tool->set_para(this->nx, inp.nelec, inp.of_tf_weight, inp.of_vw_weight, this->chi_p, this->chi_q,
this->chi_xi, this->chi_pnl, this->chi_qnl, this->nkernel, inp.of_ml_kernel, inp.of_ml_kernel_scaling, inp.of_ml_yukawa_alpha, inp.of_ml_kernel_file, this->dV * rho_basis_in->nxyz, rho_basis_in);
this->cal_tool->set_para(
this->nx,
inp.nelec,
inp.of_tf_weight,
inp.of_vw_weight,
this->chi_p,
this->chi_q,
this->chi_xi,
this->chi_pnl,
this->chi_qnl,
this->nkernel,
inp.of_ml_kernel,
inp.of_ml_kernel_scaling,
inp.of_ml_yukawa_alpha,
inp.of_ml_kernel_file,
this->dV * rho_basis_in->nxyz,
rho_basis_in,
ofs_running);
}
}

Expand All @@ -112,9 +136,9 @@ void ML_EXX::ml_potential(const double * const * prho, const ModulePW::PW_Basis
rho_data[ir] = std::abs(prho[0][ir]);
}

this->updateInput(prho_mod, pw_rho);
this->update_input(prho_mod, pw_rho);

this->NN_forward(prho_mod, pw_rho, true);
this->nn_forward(prho_mod, pw_rho, true);

torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous();
this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr<double>();
Expand Down Expand Up @@ -148,13 +172,13 @@ void ML_EXX::ml_potential(const double * const * prho, const ModulePW::PW_Basis
* @param pw_rho PW_Basis
* @param veff effective potential
*/
void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff)
void ML_EXX::gen_training_data(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff)
{
if (PARAM.inp.of_kinetic == "ml")
{
this->updateInput(prho, pw_rho);
this->update_input(prho, pw_rho);

this->NN_forward(prho, pw_rho, true);
this->nn_forward(prho, pw_rho, true);

torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous();
this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr<double>();
Expand All @@ -166,8 +190,8 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba

this->get_potential_(prho, pw_rho, potential);

this->dumpTensor("enhancement.npy", enhancement);
this->dumpMatrix("potential.npy", potential);
this->dump_tensor("enhancement.npy", enhancement);
this->dump_matrix("potential.npy", potential);
}
}

Expand All @@ -177,28 +201,28 @@ void ML_EXX::generateTrainData(const double * const *prho, const ModulePW::PW_Ba
* @param prho charge density
* @param pw_rho PW_Basis
*/
void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho)
void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running)
{
// for test =====================
std::vector<long unsigned int> cshape = {(long unsigned) this->nx};
bool fortran_order = false;

std::vector<double> temp_prho(this->nx);
this->loadVector("path_to_rho_file", temp_prho);
this->load_vector("path_to_rho_file", temp_prho);

double ** prho = new double *[1];
prho[0] = new double[this->nx];
for (int ir = 0; ir < this->nx; ++ir) prho[0][ir] = temp_prho[ir];
for (int ir = 0; ir < this->nx; ++ir)
{
if (prho[0][ir] == 0.){
std::cout << "WARNING: rho = 0" << std::endl;
ofs_running << "WARNING: rho = 0" << std::endl;
}
};
// ==============================
this->updateInput(prho, pw_rho);
this->update_input(prho, pw_rho);

this->NN_forward(prho, pw_rho, true);
this->nn_forward(prho, pw_rho, true);

torch::Tensor enhancement_cpu_tensor = this->nn->F.to(this->device_CPU).contiguous();
this->enhancement_cpu_ptr = enhancement_cpu_tensor.data_ptr<double>();
Expand All @@ -210,8 +234,8 @@ void ML_EXX::localTest(const double * const *pprho, const ModulePW::PW_Basis *pw

this->get_potential_(prho, pw_rho, potential);

this->dumpTensor("enhancement-abacus.npy", enhancement);
this->dumpMatrix("potential-abacus.npy", potential);
this->dump_tensor("enhancement-abacus.npy", enhancement);
this->dump_matrix("potential-abacus.npy", potential);
exit(0);
}

Expand Down
11 changes: 5 additions & 6 deletions source/source_estate/module_pot/pot_ml_exx.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@ class ML_EXX : public ML_Base
ML_EXX();
virtual ~ML_EXX();

void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in);
void set_para(const Input_para& inp, const UnitCell* ucell_in, const ModulePW::PW_Basis* rho_basis_in, std::ostream& ofs_running);

void ml_potential(const double * const * prho, const ModulePW::PW_Basis *pw_rho, ModuleBase::matrix &rpotential);

// output all parameters
void generateTrainData(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff);
void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho);
void gen_training_data(const double * const *prho, const ModulePW::PW_Basis *pw_rho, const double *veff);
void localTest(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::ostream& ofs_running);

void init_data(
const int &nkernel,
Expand Down Expand Up @@ -56,13 +55,13 @@ class PotML_EXX : public PotBase
this->dynamic_mode = true;
this->fixed_mode = false;

this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in);
this->ml_exx.set_para(PARAM.inp, ucell_in, rho_basis_in, GlobalV::ofs_running);
}
~PotML_EXX() {};

void cal_v_eff(const Charge*const chg, const UnitCell*const ucell, ModuleBase::matrix& v_eff) override
{
if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_);
if (PARAM.inp.of_ml_local_test) this->ml_exx.localTest(chg->rho, this->rho_basis_, GlobalV::ofs_running);
this->ml_exx.ml_potential(chg->rho, this->rho_basis_, v_eff);
}

Expand Down
3 changes: 2 additions & 1 deletion source/source_io/module_ctrl/ctrl_output_pw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,8 @@ void ModuleIO::ctrl_runner_pw(UnitCell& ucell,
inp.of_ml_yukawa_alpha,
inp.of_ml_kernel_file,
ucell.omega,
pw_rho);
pw_rho,
GlobalV::ofs_running);

write_mlkedf_desc.generateTrainData_KS(PARAM.globalv.global_mlkedf_descriptor_dir,
stp.template get_psi_t<T, Device>(),
Expand Down
4 changes: 2 additions & 2 deletions source/source_io/module_ml/cal_mlkedf_descriptors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ void Cal_MLKEDF_Descriptors::set_para(
const std::vector<double> &yukawa_alpha,
const std::vector<std::string> &kernel_file,
const double &omega,
const ModulePW::PW_Basis *pw_rho
const ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running
)
{
this->nx = nx;
Expand All @@ -34,7 +35,6 @@ void Cal_MLKEDF_Descriptors::set_para(
this->kernel_scaling = kernel_scaling;
this->yukawa_alpha = yukawa_alpha;
this->kernel_file = kernel_file;
std::cout << "nkernel = " << nkernel << std::endl;

if (PARAM.inp.of_wt_rho0 != 0)
{
Expand Down
3 changes: 2 additions & 1 deletion source/source_io/module_ml/cal_mlkedf_descriptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ class Cal_MLKEDF_Descriptors
const std::vector<double> &yukawa_alpha,
const std::vector<std::string> &kernel_file,
const double &omega,
const ModulePW::PW_Basis *pw_rho);
const ModulePW::PW_Basis *pw_rho,
std::ostream& ofs_running);
// get input parameters
void getGamma(const double * const *prho, std::vector<double> &rgamma);
void getP(const double * const *prho, const ModulePW::PW_Basis *pw_rho, std::vector<std::vector<double>> &pnablaRho, std::vector<double> &rp);
Expand Down
1 change: 1 addition & 0 deletions source/source_pw/module_ofdft/kedf_extwt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,7 @@ void KEDF_ExtWT::tau_extwt(const double* const* prho, ModulePW::PW_Basis* pw_rho
*/
void KEDF_ExtWT::extwt_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential)
{
ModuleBase::TITLE("KEDF_ExtWT", "extwt_potential");
ModuleBase::timer::start("KEDF_ExtWT", "extwt_potential");

// 1. WT potential
Expand Down
5 changes: 3 additions & 2 deletions source/source_pw/module_ofdft/kedf_lkt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,8 @@ void KEDF_LKT::tau_lkt(const double* const* prho, ModulePW::PW_Basis* pw_rho, do
*/
void KEDF_LKT::lkt_potential(const double* const* prho, ModulePW::PW_Basis* pw_rho, ModuleBase::matrix& rpotential)
{
ModuleBase::timer::start("KEDF_LKT", "LKT_potential");
ModuleBase::TITLE("KEDF_LKT", "lkt_potential");
ModuleBase::timer::start("KEDF_LKT", "lkt_potential");
this->lkt_energy = 0.;
double* as = new double[pw_rho->nrxx]; // a*s
double** nabla_rho = new double*[3];
Expand Down Expand Up @@ -193,7 +194,7 @@ void KEDF_LKT::lkt_potential(const double* const* prho, ModulePW::PW_Basis* pw_r
delete[] nabla_rho;
delete[] nabla_term;

ModuleBase::timer::end("KEDF_LKT", "LKT_potential");
ModuleBase::timer::end("KEDF_LKT", "lkt_potential");
}

/**
Expand Down
Loading
Loading