-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implements Chebyshev point and polynomial generation.
Adds additional functionality (and minor bug fixes) to Polynomial data structure to support generation of Chebyshev polynomials. Part of #266 PiperOrigin-RevId: 721820339
- Loading branch information
1 parent
d1361e4
commit 1098b98
Showing
7 changed files
with
348 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
"""Approximation utilities""" | ||
|
||
package( | ||
default_applicable_licenses = ["@heir//:license"], | ||
default_visibility = ["//visibility:public"], | ||
) | ||
|
||
cc_library( | ||
name = "Chebyshev", | ||
srcs = ["Chebyshev.cpp"], | ||
hdrs = ["Chebyshev.h"], | ||
deps = [ | ||
"@heir//lib/Utils/Polynomial", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Support", | ||
], | ||
) | ||
|
||
cc_test( | ||
name = "ChebyshevTest", | ||
srcs = ["ChebyshevTest.cpp"], | ||
deps = [ | ||
":Chebyshev", | ||
"@googletest//:gtest_main", | ||
"@heir//lib/Utils/Polynomial", | ||
"@llvm-project//llvm:Support", | ||
"@llvm-project//mlir:Support", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
|
||
#include <cmath> | ||
#include <cstdint> | ||
|
||
#include "lib/Utils/Polynomial/Polynomial.h" | ||
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace approximation { | ||
|
||
using ::llvm::APFloat; | ||
using ::llvm::SmallVector; | ||
using ::mlir::heir::polynomial::FloatPolynomial; | ||
|
||
// When we move to C++20, we can use std::numbers::pi | ||
inline constexpr double kPi = 3.14159265358979323846; | ||
|
||
void getChebyshevPoints(int64_t numPoints, SmallVector<APFloat> &results) { | ||
if (numPoints == 0) { | ||
return; | ||
} | ||
if (numPoints == 1) { | ||
results.push_back(APFloat(0.)); | ||
return; | ||
} | ||
|
||
// The values are most simply described as | ||
// | ||
// cos(pi * j / (n-1)) for 0 <= j <= n-1. | ||
// | ||
// But to enforce symmetry around the origin---broken by slight numerical | ||
// inaccuracies---and the left-to-right ordering, we apply the identity | ||
// | ||
// cos(x + pi) = -cos(x) = sin(x - pi/2) | ||
// | ||
// to arrive at | ||
// | ||
// sin(pi*j/(n-1) - pi/2) = sin(pi * (2j - (n-1)) / (2(n-1))) | ||
// | ||
// An this is equivalent to the formula below, where the range of j is shifted | ||
// and rescaled from {0, ..., n-1} to {-n+1, -n+3, ..., n-3, n-1}. | ||
int64_t m = numPoints - 1; | ||
for (int64_t j = -m; j < m + 1; j += 2) { | ||
results.push_back(APFloat(std::sin(kPi * j / (2 * m)))); | ||
} | ||
} | ||
|
||
void getChebyshevPolynomials(int64_t numPolynomials, | ||
SmallVector<FloatPolynomial> &results) { | ||
if (numPolynomials < 1) return; | ||
|
||
if (numPolynomials >= 1) { | ||
// 1 | ||
results.push_back(FloatPolynomial::fromCoefficients({1.})); | ||
} | ||
if (numPolynomials >= 2) { | ||
// 2x | ||
results.push_back(FloatPolynomial::fromCoefficients({0., 2.})); | ||
} | ||
|
||
if (numPolynomials <= 2) return; | ||
|
||
for (int64_t i = 2; i < numPolynomials; ++i) { | ||
auto &last = results.back(); | ||
auto &secondLast = results[results.size() - 2]; | ||
results.push_back(last.monomialMul(1).scale(APFloat(2.)).sub(secondLast)); | ||
} | ||
} | ||
|
||
} // namespace approximation | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
#ifndef LIB_UTILS_APPROXIMATION_CHEBYSHEV_H_ | ||
#define LIB_UTILS_APPROXIMATION_CHEBYSHEV_H_ | ||
|
||
#include <cstdint> | ||
|
||
#include "lib/Utils/Polynomial/Polynomial.h" | ||
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace approximation { | ||
|
||
/// Generate Chebyshev points of the second kind, storing them in the results | ||
/// outparameter. The output points are ordered left to right on the interval | ||
/// [-1, 1]. | ||
/// | ||
/// This is a port of the chebfun routine at | ||
/// https://github.com/chebfun/chebfun/blob/db207bc9f48278ca4def15bf90591bfa44d0801d/%40chebtech2/chebpts.m#L34 | ||
void getChebyshevPoints(int64_t numPoints, | ||
SmallVector<::llvm::APFloat> &results); | ||
|
||
/// Generate the first `numPolynomials` Chebyshev polynomials of the second | ||
/// kind, storing them in the results outparameter. | ||
/// | ||
/// The first few polynomials are 1, 2x, 4x^2 - 1, 8x^3 - 4x, ... | ||
void getChebyshevPolynomials( | ||
int64_t numPolynomials, | ||
SmallVector<::mlir::heir::polynomial::FloatPolynomial> &results); | ||
|
||
} // namespace approximation | ||
} // namespace heir | ||
} // namespace mlir | ||
|
||
#endif // LIB_UTILS_APPROXIMATION_CHEBYSHEV_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#include <cstdint> | ||
|
||
#include "gmock/gmock.h" // from @googletest | ||
#include "gtest/gtest.h" // from @googletest | ||
#include "lib/Utils/Approximation/Chebyshev.h" | ||
#include "lib/Utils/Polynomial/Polynomial.h" | ||
#include "llvm/include/llvm/ADT/APFloat.h" // from @llvm-project | ||
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project | ||
|
||
namespace mlir { | ||
namespace heir { | ||
namespace approximation { | ||
namespace { | ||
|
||
using ::llvm::APFloat; | ||
using ::mlir::heir::polynomial::FloatPolynomial; | ||
using ::testing::ElementsAre; | ||
|
||
TEST(ChebyshevTest, TestGetChebyshevPointsSingle) { | ||
SmallVector<APFloat> chebPts; | ||
int64_t n = 1; | ||
getChebyshevPoints(n, chebPts); | ||
EXPECT_THAT(chebPts, ElementsAre(APFloat(0.))); | ||
} | ||
|
||
TEST(ChebyshevTest, TestGetChebyshevPoints5) { | ||
SmallVector<APFloat> chebPts; | ||
int64_t n = 5; | ||
getChebyshevPoints(n, chebPts); | ||
EXPECT_THAT(chebPts, ElementsAre(APFloat(-1.0), APFloat(-0.7071067811865475), | ||
APFloat(0.0), APFloat(0.7071067811865475), | ||
APFloat(1.0))); | ||
} | ||
|
||
TEST(ChebyshevTest, TestGetChebyshevPoints9) { | ||
SmallVector<APFloat> chebPts; | ||
int64_t n = 9; | ||
getChebyshevPoints(n, chebPts); | ||
EXPECT_THAT(chebPts, ElementsAre(APFloat(-1.0), APFloat(-0.9238795325112867), | ||
APFloat(-0.7071067811865475), | ||
APFloat(-0.3826834323650898), APFloat(0.0), | ||
APFloat(0.3826834323650898), | ||
APFloat(0.7071067811865475), | ||
APFloat(0.9238795325112867), APFloat(1.0))); | ||
} | ||
|
||
TEST(ChebyshevTest, TestGetChebyshevPolynomials) { | ||
SmallVector<FloatPolynomial> chebPolys; | ||
int64_t n = 9; | ||
getChebyshevPolynomials(n, chebPolys); | ||
|
||
for (const auto& p : chebPolys) p.dump(); | ||
|
||
EXPECT_THAT( | ||
chebPolys, | ||
ElementsAre( | ||
FloatPolynomial::fromCoefficients({1.}), | ||
FloatPolynomial::fromCoefficients({0., 2.}), | ||
FloatPolynomial::fromCoefficients({-1., 0., 4.}), | ||
FloatPolynomial::fromCoefficients({0., -4., 0., 8.}), | ||
FloatPolynomial::fromCoefficients({1., 0., -12., 0., 16.}), | ||
FloatPolynomial::fromCoefficients({0., 6., 0., -32., 0., 32.}), | ||
FloatPolynomial::fromCoefficients({-1., 0., 24., 0., -80., 0., 64.}), | ||
FloatPolynomial::fromCoefficients( | ||
{0., -8., 0., 80., 0., -192., 0., 128.}), | ||
FloatPolynomial::fromCoefficients( | ||
{1., 0., -40., 0., 240., 0., -448., 0., 256.}))); | ||
} | ||
|
||
} // namespace | ||
} // namespace approximation | ||
} // namespace heir | ||
} // namespace mlir |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.