Program Listing for File transformations.hpp

Return to documentation for file (src/integrals/transformations.hpp)

//
// Created by Zack Williams on 23/02/2024.
//

#ifndef TRANSFORMATIONS_HPP
#define TRANSFORMATIONS_HPP

#include <cassert>

#include "utils/linalg.hpp"
#include "utils/parallel.hpp"
#include "utils/utils.hpp"

namespace uw12::integrals::transformations {
inline linalg::Mat mo_transform_one_index_full(
    const linalg::Mat &J3, const linalg::Mat &C
) {
  using namespace linalg;

  const auto n_ao = n_rows(C);
  const auto n_orb = n_cols(C);
  const auto n_df = n_cols(J3);
  if (n_rows(J3) != n_ao * (n_ao + 1) / 2) {
    throw std::runtime_error("J3 and C of incompatible sizes");
  }

  Mat result(n_ao * n_orb, n_df);

  const auto parallel_fn = [&result, &J3, &C](const size_t A) {
    assign_cols(result, vectorise(utils::square(col(J3, A)) * C), A);
  };

  parallel::parallel_for(0, n_df, parallel_fn);

  return result;
}

inline linalg::Mat transform_first_index(
    const linalg::Mat &J3, const linalg::Mat &C
) {
  using namespace linalg;

  const auto n1 = n_rows(C);
  const auto n3 = n_cols(J3);
  const auto n4 = n_cols(C);
  if (n_rows(J3) % n1 != 0) {
    throw std::runtime_error("number of rows of J3 not a multiple of n_ao");
  }
  const size_t n2 = n_rows(J3) / n1;

  const Mat C_t = transpose(C);

  Mat result(n4 * n2, n3);

  const auto parallel_fn = [&result, &J3, &C_t, n1, n2](const size_t A) {
    // reshaping the column of A3 without memory copy
    const auto A12 = reshape_col(J3, A, n1, n2);

    assign_cols(result, vectorise(C_t * A12), A);
  };

  parallel::parallel_for(0, n3, parallel_fn);

  return result;
}

inline linalg::Mat transform_second_index(
    const linalg::Mat &J3, const linalg::Mat &C
) {
  using namespace linalg;

  const auto n2 = n_rows(C);
  const auto n3 = n_cols(J3);
  const auto n4 = n_cols(C);
  if (n_rows(J3) % n2 != 0) {
    throw std::runtime_error(
        "number of rows J3 is not a multiple of number of orbitals being "
        "transformed"
    );
  }
  assert(n_rows(J3) % n2 == 0);
  const size_t n1 = n_rows(J3) / n2;

  Mat result(n1 * n4, n3);

  const auto parallel_fn = [&result, &J3, &C, n2, n1](const size_t A) {
    const auto col_mat = reshape_col(J3, A, n1, n2);

    assign_cols(result, vectorise(col_mat * C), A);
  };

  parallel::parallel_for(0, n3, parallel_fn);

  return result;
}

inline linalg::Mat mo_transform_two_index_full(
    const linalg::Mat &J3, const linalg::Mat &C_left, const linalg::Mat &C_right
) {
  if (linalg::n_rows(C_left) != linalg::n_rows(C_right)) {
    throw std::runtime_error(
        "Coefficient matrices with different numbers of ao functions"
    );
  }

  const auto tmp = mo_transform_one_index_full(J3, C_right);

  return transform_first_index(tmp, C_left);
}
}  // namespace uw12::integrals::transformations

#endif  // TRANSFORMATIONS_HPP