diff --git a/mlir/docs/EDSC.md b/mlir/docs/EDSC.md --- a/mlir/docs/EDSC.md +++ b/mlir/docs/EDSC.md @@ -21,7 +21,7 @@ operators are provided to allow concise and idiomatic expressions. ```c++ -ValueHandle zero = constant_index(0); +ValueHandle zero = std_constant_index(0); IndexHandle i, j, k; ``` @@ -49,8 +49,8 @@ j(indexType), lb(f->getArgument(0)), ub(f->getArgument(1)); - ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)), - f13(constant_float(llvm::APFloat(13.0f), f32Type)), + ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)), + f13(std_constant_float(llvm::APFloat(13.0f), f32Type)), i7(constant_int(7, 32)), i13(constant_int(13, 32)); AffineLoopNestBuilder(&i, lb, ub, 3)([&]{ diff --git a/mlir/include/mlir/Dialect/AffineOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/AffineOps/EDSC/Builders.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/EDSC/Builders.h @@ -0,0 +1,141 @@ +//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_AFFINEOPS_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_AFFINEOPS_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/AffineOps/AffineOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace edsc { + +/// Constructs a new AffineForOp and captures the associated induction +/// variable. A ValueHandle pointer is passed as the first argument and is the +/// *only* way to capture the loop induction variable. +LoopBuilder makeAffineLoopBuilder(ValueHandle *iv, + ArrayRef lbHandles, + ArrayRef ubHandles, + int64_t step); + +/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid +/// explicitly writing all the loops in a nest. This simple functionality is +/// also useful to write rank-agnostic custom ops. +/// +/// Usage: +/// +/// ```c++ +/// AffineLoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, +/// 1})( +/// [&](){ +/// ... +/// }); +/// ``` +/// +/// ```c++ +/// AffineLoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){ +/// AffineLoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){ +/// AffineLoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){ +/// ... +/// }), +/// }), +/// }); +/// ``` +class AffineLoopNestBuilder { +public: + /// This entry point accommodates the fact that AffineForOp implicitly uses + /// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max` + /// and and `min` constraints respectively. + AffineLoopNestBuilder(ValueHandle *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step); + AffineLoopNestBuilder(ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); + + void operator()(function_ref fun = nullptr); + +private: + SmallVector loops; +}; + +namespace op { + +ValueHandle operator+(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator-(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator*(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator/(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator%(ValueHandle lhs, ValueHandle rhs); +ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs); +ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs); + +ValueHandle operator!(ValueHandle value); +ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator||(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator^(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator==(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>(ValueHandle lhs, ValueHandle rhs); +ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); + +} // namespace op + +/// Operator overloadings. +template +ValueHandle TemplatedIndexedValue::operator+(ValueHandle e) { + using op::operator+; + return static_cast(*this) + e; +} +template +ValueHandle TemplatedIndexedValue::operator-(ValueHandle e) { + using op::operator-; + return static_cast(*this) - e; +} +template +ValueHandle TemplatedIndexedValue::operator*(ValueHandle e) { + using op::operator*; + return static_cast(*this) * e; +} +template +ValueHandle TemplatedIndexedValue::operator/(ValueHandle e) { + using op::operator/; + return static_cast(*this) / e; +} + +template +OperationHandle TemplatedIndexedValue::operator+=(ValueHandle e) { + using op::operator+; + return Store(*this + e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator-=(ValueHandle e) { + using op::operator-; + return Store(*this - e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator*=(ValueHandle e) { + using op::operator*; + return Store(*this * e, getBase(), {indices.begin(), indices.end()}); +} +template +OperationHandle TemplatedIndexedValue::operator/=(ValueHandle e) { + using op::operator/; + return Store(*this / e, getBase(), {indices.begin(), indices.end()}); +} + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_AFFINEOPS_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/AffineOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/AffineOps/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/AffineOps/EDSC/Intrinsics.h @@ -0,0 +1,32 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for AffineOps --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_AFFINEOPS_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_AFFINEOPS_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/AffineOps/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using affine_apply = ValueBuilder; +using affine_if = OperationBuilder; +using affine_load = ValueBuilder; +using affine_min = ValueBuilder; +using affine_max = ValueBuilder; +using affine_store = OperationBuilder; + +/// Provide an index notation around affine_load and affine_store. +using AffineIndexedValue = TemplatedIndexedValue; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -13,17 +13,24 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ #define MLIR_DIALECT_LINALG_EDSC_BUILDERS_H_ -#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +// TODO(ntv): Needed for SubViewOp::Range, clean this up. +#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" #include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" namespace mlir { +class AffineForOp; class BlockArgument; +class SubViewOp; + +namespace loop { +class ParallelOp; +} // namespace loop namespace edsc { +class AffineLoopNestBuilder; +class ParallelLoopNestBuilder; /// A LoopRangeBuilder is a generic NestedBuilder for loop.for operations. /// More specifically it is meant to be used as a temporary object for @@ -115,7 +122,6 @@ namespace ops { using edsc::StructuredIndexed; using edsc::ValueHandle; -using edsc::intrinsics::linalg_yield; //===----------------------------------------------------------------------===// // EDSC builders for linalg generic operations. diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Intrinsics.h @@ -8,13 +8,58 @@ #ifndef MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ #define MLIR_DIALECT_LINALG_EDSC_INTRINSICS_H_ -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/EDSC/Builders.h" +#include "mlir/Dialect/Linalg/EDSC/Builders.h" #include "mlir/EDSC/Intrinsics.h" +#include "mlir/Transforms/FoldUtils.h" namespace mlir { namespace edsc { + +template +ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { + return folder ? ValueHandle(folder->create(ScopedContext::getBuilder(), + ScopedContext::getLocation(), + args...)) + : ValueHandle(ScopedContext::getBuilder().create( + ScopedContext::getLocation(), args...)); +} + namespace intrinsics { +namespace folded { +/// Helper variadic abstraction to allow extending to any MLIR op without +/// boilerplate or Tablegen. +/// Arguably a builder is not a ValueHandle but in practice it is only used as +/// an alias to a notional ValueHandle. +/// Implementing it as a subclass allows it to compose all the way to Value. +/// Without subclassing, implicit conversion to Value would fail when composing +/// in patterns such as: `select(a, b, select(c, d, e))`. +template +struct ValueBuilder : public ValueHandle { + /// Folder-based + template + ValueBuilder(OperationFolder *folder, Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(args)...)) {} + ValueBuilder(OperationFolder *folder, ArrayRef vs) + : ValueBuilder(ValueBuilder::create(folder, detail::unpack(vs))) {} + template + ValueBuilder(OperationFolder *folder, ArrayRef vs, Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(vs), + detail::unpack(args)...)) {} + template + ValueBuilder(OperationFolder *folder, T t, ArrayRef vs, + Args... args) + : ValueHandle(ValueHandle::create(folder, detail::unpack(t), + detail::unpack(vs), + detail::unpack(args)...)) {} + template + ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef vs, + Args... args) + : ValueHandle(ValueHandle::create( + folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), + detail::unpack(args)...)) {} +}; + +} // namespace folded using linalg_copy = OperationBuilder; using linalg_fill = OperationBuilder; diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h @@ -0,0 +1,68 @@ +//===- Builders.h - MLIR Declarative Builder Classes ------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// Provides intuitive composable interfaces for building structured MLIR +// snippets in a declarative fashion. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_LOOPOPS_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_LOOPOPS_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/LoopOps/LoopOps.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace edsc { + +/// Constructs a new loop::ParallelOp and captures the associated induction +/// variables. An array of ValueHandle pointers is passed as the first +/// argument and is the *only* way to capture loop induction variables. +LoopBuilder makeParallelLoopBuilder(ArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps); +/// Constructs a new loop::ForOp and captures the associated induction +/// variable. A ValueHandle pointer is passed as the first argument and is the +/// *only* way to capture the loop induction variable. +LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, ValueHandle stepHandle); + +/// Helper class to sugar building loop.parallel loop nests from lower/upper +/// bounds and step sizes. +class ParallelLoopNestBuilder { +public: + ParallelLoopNestBuilder(ArrayRef ivs, + ArrayRef lbs, ArrayRef ubs, + ArrayRef steps); + + void operator()(function_ref fun = nullptr); + +private: + SmallVector loops; +}; + +/// Helper class to sugar building loop.for loop nests from ranges. +/// This is similar to edsc::AffineLoopNestBuilder except it operates on +/// loop.for. +class LoopNestBuilder { +public: + LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); + void operator()(std::function fun = nullptr); + +private: + SmallVector loops; +}; + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_LOOPOPS_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h @@ -0,0 +1,81 @@ +//===- Builders.h - MLIR EDSC Builders for StandardOps ----------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_STANDARDOPS_EDSC_BUILDERS_H_ +#define MLIR_DIALECT_STANDARDOPS_EDSC_BUILDERS_H_ + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Types.h" + +namespace mlir { +namespace edsc { + +/// Base class for MemRefBoundsCapture and VectorBoundsCapture. +class BoundsCapture { +public: + unsigned rank() const { return lbs.size(); } + ValueHandle lb(unsigned idx) { return lbs[idx]; } + ValueHandle ub(unsigned idx) { return ubs[idx]; } + int64_t step(unsigned idx) { return steps[idx]; } + std::tuple range(unsigned idx) { + return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); + } + void swapRanges(unsigned i, unsigned j) { + if (i == j) + return; + lbs[i].swap(lbs[j]); + ubs[i].swap(ubs[j]); + std::swap(steps[i], steps[j]); + } + + ArrayRef getLbs() { return lbs; } + ArrayRef getUbs() { return ubs; } + ArrayRef getSteps() { return steps; } + +protected: + SmallVector lbs; + SmallVector ubs; + SmallVector steps; +}; + +/// A MemRefBoundsCapture represents the information required to step through a +/// MemRef. It has placeholders for non-contiguous tensors that fit within the +/// Fortran subarray model. +/// At the moment it can only capture a MemRef with an identity layout map. +// TODO(ntv): Support MemRefs with layoutMaps. +class MemRefBoundsCapture : public BoundsCapture { +public: + explicit MemRefBoundsCapture(Value v); + MemRefBoundsCapture(const MemRefBoundsCapture &) = default; + MemRefBoundsCapture &operator=(const MemRefBoundsCapture &) = default; + + unsigned fastestVarying() const { return rank() - 1; } + +private: + ValueHandle base; +}; + +/// A VectorBoundsCapture represents the information required to step through a +/// Vector accessing each scalar element at a time. It is the counterpart of +/// a MemRefBoundsCapture but for vectors. This exists purely for boilerplate +/// avoidance. +class VectorBoundsCapture : public BoundsCapture { +public: + explicit VectorBoundsCapture(Value v); + VectorBoundsCapture(const VectorBoundsCapture &) = default; + VectorBoundsCapture &operator=(const VectorBoundsCapture &) = default; + +private: + ValueHandle base; +}; + +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_EDSC_BUILDERS_H_ diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -0,0 +1,103 @@ +//===- Intrinsics.h - MLIR EDSC Intrinsics for StandardOps ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_ +#define MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_ + +#include "mlir/Dialect/StandardOps/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" + +namespace mlir { +namespace edsc { +namespace intrinsics { + +using std_addf = ValueBuilder; +using std_alloc = ValueBuilder; +using std_call = OperationBuilder; +using std_constant_float = ValueBuilder; +using std_constant_index = ValueBuilder; +using std_constant_int = ValueBuilder; +using std_dealloc = OperationBuilder; +using std_dim = ValueBuilder; +using std_muli = ValueBuilder; +using std_mulf = ValueBuilder; +using std_memref_cast = ValueBuilder; +using std_ret = OperationBuilder; +using std_select = ValueBuilder; +using std_load = ValueBuilder; +using std_store = OperationBuilder; +using std_subi = ValueBuilder; +using std_tanh = ValueBuilder; +using std_view = ValueBuilder; +using std_zero_extendi = ValueBuilder; +using std_sign_extendi = ValueBuilder; + +/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. +/// +/// Prerequisites: +/// All Handles have already captured previously constructed IR objects. +OperationHandle br(BlockHandle bh, ArrayRef operands); + +/// Creates a new mlir::Block* and branches to it from the current block. +/// Argument types are specified by `operands`. +/// Captures the new block in `bh` and the actual `operands` in `captures`. To +/// insert the new mlir::Block*, a local ScopedContext is constructed and +/// released to the current block. The branch operation is then added to the +/// new block. +/// +/// Prerequisites: +/// `b` has not yet captured an mlir::Block*. +/// No `captures` have captured any mlir::Value. +/// All `operands` have already captured an mlir::Value +/// captures.size() == operands.size() +/// captures and operands are pairwise of the same type. +OperationHandle br(BlockHandle *bh, ArrayRef captures, + ArrayRef operands); + +/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with +/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and +/// `falseOperand` if `cond` evaluates to `false`). +/// +/// Prerequisites: +/// All Handles have captured previously constructed IR objects. +OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch, + ArrayRef trueOperands, + BlockHandle falseBranch, + ArrayRef falseOperands); + +/// Eagerly creates new mlir::Block* with argument types specified by +/// `trueOperands`/`falseOperands`. +/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in +/// `trueCaptures/falseCaptures`. +/// To insert the new mlir::Block*, a local ScopedContext is constructed and +/// released. The branch operation is then added in the original location and +/// targeting the eagerly constructed blocks. +/// +/// Prerequisites: +/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*. +/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value. +/// All `trueOperands`/`trueOperands` have already captured an mlir::Value +/// `trueCaptures`.size() == `trueOperands`.size() +/// `falseCaptures`.size() == `falseOperands`.size() +/// `trueCaptures` and `trueOperands` are pairwise of the same type +/// `falseCaptures` and `falseOperands` are pairwise of the same type. +OperationHandle cond_br(ValueHandle cond, BlockHandle *trueBranch, + ArrayRef trueCaptures, + ArrayRef trueOperands, + BlockHandle *falseBranch, + ArrayRef falseCaptures, + ArrayRef falseOperands); + +/// Provide an index notation around sdt_load and std_store. +using StdIndexedValue = + TemplatedIndexedValue; + +} // namespace intrinsics +} // namespace edsc +} // namespace mlir + +#endif // MLIR_DIALECT_STANDARDOPS_EDSC_INTRINSICS_H_ diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -14,23 +14,15 @@ #ifndef MLIR_EDSC_BUILDERS_H_ #define MLIR_EDSC_BUILDERS_H_ -#include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/LoopOps/LoopOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" -#include "mlir/Transforms/FoldUtils.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" namespace mlir { +class OperationFolder; namespace edsc { - -struct index_type { - explicit index_type(int64_t v) : v(v) {} - explicit operator int64_t() { return v; } - int64_t v; -}; - class BlockHandle; class CapturableHandle; class NestedBuilder; @@ -150,24 +142,6 @@ /// the name LoopBuilder (as opposed to say ForBuilder or AffineForBuilder). class LoopBuilder : public NestedBuilder { public: - /// Constructs a new AffineForOp and captures the associated induction - /// variable. A ValueHandle pointer is passed as the first argument and is the - /// *only* way to capture the loop induction variable. - static LoopBuilder makeAffine(ValueHandle *iv, - ArrayRef lbHandles, - ArrayRef ubHandles, int64_t step); - /// Constructs a new loop::ParallelOp and captures the associated induction - /// variables. An array of ValueHandle pointers is passed as the first - /// argument and is the *only* way to capture loop induction variables. - static LoopBuilder makeParallel(ArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, - ArrayRef steps); - /// Constructs a new loop::ForOp and captures the associated induction - /// variable. A ValueHandle pointer is passed as the first argument and is the - /// *only* way to capture the loop induction variable. - static LoopBuilder makeLoop(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, ValueHandle stepHandle); LoopBuilder(const LoopBuilder &) = delete; LoopBuilder(LoopBuilder &&) = default; @@ -181,72 +155,18 @@ private: LoopBuilder() = default; -}; - -/// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid -/// explicitly writing all the loops in a nest. This simple functionality is -/// also useful to write rank-agnostic custom ops. -/// -/// Usage: -/// -/// ```c++ -/// AffineLoopNestBuilder({&i, &j, &k}, {lb, lb, lb}, {ub, ub, ub}, {1, 1, -/// 1})( -/// [&](){ -/// ... -/// }); -/// ``` -/// -/// ```c++ -/// AffineLoopNestBuilder({&i}, {lb}, {ub}, {1})([&](){ -/// AffineLoopNestBuilder({&j}, {lb}, {ub}, {1})([&](){ -/// AffineLoopNestBuilder({&k}, {lb}, {ub}, {1})([&](){ -/// ... -/// }), -/// }), -/// }); -/// ``` -class AffineLoopNestBuilder { -public: - // This entry point accommodates the fact that AffineForOp implicitly uses - // multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max` - // and and `min` constraints respectively. - AffineLoopNestBuilder(ValueHandle *iv, ArrayRef lbs, - ArrayRef ubs, int64_t step); - AffineLoopNestBuilder(ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps); - - void operator()(function_ref fun = nullptr); - -private: - SmallVector loops; -}; -/// Helper class to sugar building loop.parallel loop nests from lower/upper -/// bounds and step sizes. -class ParallelLoopNestBuilder { -public: - ParallelLoopNestBuilder(ArrayRef ivs, - ArrayRef lbs, ArrayRef ubs, - ArrayRef steps); - - void operator()(function_ref fun = nullptr); - -private: - SmallVector loops; -}; - -/// Helper class to sugar building loop.for loop nests from ranges. -/// This is similar to edsc::AffineLoopNestBuilder except it operates on -/// loop.for. -class LoopNestBuilder { -public: - LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps); - void operator()(std::function fun = nullptr); - -private: - SmallVector loops; + friend LoopBuilder makeAffineLoopBuilder(ValueHandle *iv, + ArrayRef lbHandles, + ArrayRef ubHandles, + int64_t step); + friend LoopBuilder makeParallelLoopBuilder(ArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps); + friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, + ValueHandle ubHandle, + ValueHandle stepHandle); }; // This class exists solely to handle the C++ vexing parse case when @@ -337,13 +257,6 @@ /// been constructed in the past and that is captured "now" in the program. explicit ValueHandle(Value v) : t(v.getType()), v(v) {} - /// Builds a ConstantIndexOp of value `cst`. The constant is created at the - /// current insertion point. - /// This implicit constructor is provided to each build an eager Value for a - /// constant at the current insertion point in the IR. An implicit constructor - /// allows idiomatic expressions mixing ValueHandle and literals. - ValueHandle(index_type cst); - /// ValueHandle is a value type, use the default copy constructor. ValueHandle(const ValueHandle &other) = default; @@ -377,11 +290,6 @@ template static ValueHandle create(OperationFolder *folder, Args... args); - /// Special case to build composed AffineApply operations. - // TODO: createOrFold when available and move inside of the `create` method. - static ValueHandle createComposedAffineApply(AffineMap map, - ArrayRef operands); - /// Generic create for a named operation producing a single value. static ValueHandle create(StringRef name, ArrayRef operands, ArrayRef resultTypes, @@ -401,6 +309,12 @@ return v.getDefiningOp(); } + // Return a vector of fresh ValueHandles that have not captured. + static SmallVector makeIndexHandles(unsigned count) { + auto indexType = IndexType::get(ScopedContext::getContext()); + return SmallVector(count, ValueHandle(indexType)); + } + protected: ValueHandle() : t(), v(nullptr) {} @@ -555,48 +469,11 @@ Operation *op = ScopedContext::getBuilder() .create(ScopedContext::getLocation(), args...) .getOperation(); - if (op->getNumResults() == 1) { + if (op->getNumResults() == 1) return ValueHandle(op->getResult(0)); - } else if (op->getNumResults() == 0) { - if (auto f = dyn_cast(op)) { - return ValueHandle(f.getInductionVar()); - } - } llvm_unreachable("unsupported operation, use an OperationHandle instead"); } -template -ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { - return folder ? ValueHandle(folder->create(ScopedContext::getBuilder(), - ScopedContext::getLocation(), - args...)) - : ValueHandle(ScopedContext::getBuilder().create( - ScopedContext::getLocation(), args...)); -} - -namespace op { - -ValueHandle operator+(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator-(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator*(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator/(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator%(ValueHandle lhs, ValueHandle rhs); -ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs); -ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs); - -ValueHandle operator!(ValueHandle value); -ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator||(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator^(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator==(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator<(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator>(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); - -} // namespace op - /// Entry point to build multiple ValueHandle from a `Container` of Value or /// Type. template @@ -608,6 +485,105 @@ return res; } +/// A TemplatedIndexedValue brings an index notation over the template Load and +/// Store parameters. Assigning to an IndexedValue emits an actual `Store` +/// operation, while converting an IndexedValue to a ValueHandle emits an actual +/// `Load` operation. +template class TemplatedIndexedValue { +public: + explicit TemplatedIndexedValue(Type t) : base(t) {} + explicit TemplatedIndexedValue(Value v) + : TemplatedIndexedValue(ValueHandle(v)) {} + explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} + + TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; + + TemplatedIndexedValue operator()() { return *this; } + /// Returns a new `TemplatedIndexedValue`. + TemplatedIndexedValue operator()(ValueHandle index) { + TemplatedIndexedValue res(base); + res.indices.push_back(index); + return res; + } + template + TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { + return TemplatedIndexedValue(base, index).append(indices...); + } + TemplatedIndexedValue operator()(ArrayRef indices) { + return TemplatedIndexedValue(base, indices); + } + + /// Emits a `store`. + OperationHandle operator=(const TemplatedIndexedValue &rhs) { + ValueHandle rrhs(rhs); + return Store(rrhs, getBase(), {indices.begin(), indices.end()}); + } + OperationHandle operator=(ValueHandle rhs) { + return Store(rhs, getBase(), {indices.begin(), indices.end()}); + } + + /// Emits a `load` when converting to a ValueHandle. + operator ValueHandle() const { + return Load(getBase(), {indices.begin(), indices.end()}); + } + + /// Emits a `load` when converting to a Value. + Value operator*(void) const { + return Load(getBase(), {indices.begin(), indices.end()}).getValue(); + } + + ValueHandle getBase() const { return base; } + + /// Operator overloadings. + ValueHandle operator+(ValueHandle e); + ValueHandle operator-(ValueHandle e); + ValueHandle operator*(ValueHandle e); + ValueHandle operator/(ValueHandle e); + OperationHandle operator+=(ValueHandle e); + OperationHandle operator-=(ValueHandle e); + OperationHandle operator*=(ValueHandle e); + OperationHandle operator/=(ValueHandle e); + ValueHandle operator+(TemplatedIndexedValue e) { + return *this + static_cast(e); + } + ValueHandle operator-(TemplatedIndexedValue e) { + return *this - static_cast(e); + } + ValueHandle operator*(TemplatedIndexedValue e) { + return *this * static_cast(e); + } + ValueHandle operator/(TemplatedIndexedValue e) { + return *this / static_cast(e); + } + OperationHandle operator+=(TemplatedIndexedValue e) { + return this->operator+=(static_cast(e)); + } + OperationHandle operator-=(TemplatedIndexedValue e) { + return this->operator-=(static_cast(e)); + } + OperationHandle operator*=(TemplatedIndexedValue e) { + return this->operator*=(static_cast(e)); + } + OperationHandle operator/=(TemplatedIndexedValue e) { + return this->operator/=(static_cast(e)); + } + +private: + TemplatedIndexedValue(ValueHandle base, ArrayRef indices) + : base(base), indices(indices.begin(), indices.end()) {} + + TemplatedIndexedValue &append() { return *this; } + + template + TemplatedIndexedValue &append(T index, Args... indices) { + this->indices.push_back(static_cast(index)); + append(indices...); + return *this; + } + ValueHandle base; + SmallVector indices; +}; + } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/EDSC/Helpers.h b/mlir/include/mlir/EDSC/Helpers.h deleted file mode 100644 --- a/mlir/include/mlir/EDSC/Helpers.h +++ /dev/null @@ -1,258 +0,0 @@ -//===- Helpers.h - MLIR Declarative Helper Functionality --------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// Provides helper classes and syntactic sugar for declarative builders. -// -//===----------------------------------------------------------------------===// - -#ifndef MLIR_EDSC_HELPERS_H_ -#define MLIR_EDSC_HELPERS_H_ - -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" - -namespace mlir { -namespace edsc { - -// A TemplatedIndexedValue brings an index notation over the template Load and -// Store parameters. -template class TemplatedIndexedValue; - -// By default, edsc::IndexedValue provides an index notation around the affine -// load and stores. edsc::StdIndexedValue provides the standard load/store -// counterpart. -using IndexedValue = - TemplatedIndexedValue; -using StdIndexedValue = - TemplatedIndexedValue; - -// Base class for MemRefView and VectorView. -class View { -public: - unsigned rank() const { return lbs.size(); } - ValueHandle lb(unsigned idx) { return lbs[idx]; } - ValueHandle ub(unsigned idx) { return ubs[idx]; } - int64_t step(unsigned idx) { return steps[idx]; } - std::tuple range(unsigned idx) { - return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); - } - void swapRanges(unsigned i, unsigned j) { - if (i == j) - return; - lbs[i].swap(lbs[j]); - ubs[i].swap(ubs[j]); - std::swap(steps[i], steps[j]); - } - - ArrayRef getLbs() { return lbs; } - ArrayRef getUbs() { return ubs; } - ArrayRef getSteps() { return steps; } - -protected: - SmallVector lbs; - SmallVector ubs; - SmallVector steps; -}; - -/// A MemRefView represents the information required to step through a -/// MemRef. It has placeholders for non-contiguous tensors that fit within the -/// Fortran subarray model. -/// At the moment it can only capture a MemRef with an identity layout map. -// TODO(ntv): Support MemRefs with layoutMaps. -class MemRefView : public View { -public: - explicit MemRefView(Value v); - MemRefView(const MemRefView &) = default; - MemRefView &operator=(const MemRefView &) = default; - - unsigned fastestVarying() const { return rank() - 1; } - -private: - friend IndexedValue; - ValueHandle base; -}; - -/// A VectorView represents the information required to step through a -/// Vector accessing each scalar element at a time. It is the counterpart of -/// a MemRefView but for vectors. This exists purely for boilerplate avoidance. -class VectorView : public View { -public: - explicit VectorView(Value v); - VectorView(const VectorView &) = default; - VectorView &operator=(const VectorView &) = default; - -private: - friend IndexedValue; - ValueHandle base; -}; - -/// A TemplatedIndexedValue brings an index notation over the template Load and -/// Store parameters. This helper class is an abstraction purely for sugaring -/// purposes and allows writing compact expressions such as: -/// -/// ```mlir -/// // `IndexedValue` provided by default in the mlir::edsc namespace. -/// using IndexedValue = -/// TemplatedIndexedValue; -/// IndexedValue A(...), B(...), C(...); -/// For(ivs, zeros, shapeA, ones, { -/// C(ivs) = A(ivs) + B(ivs) -/// }); -/// ``` -/// -/// Assigning to an IndexedValue emits an actual `Store` operation, while -/// converting an IndexedValue to a ValueHandle emits an actual `Load` -/// operation. -template class TemplatedIndexedValue { -public: - explicit TemplatedIndexedValue(Type t) : base(t) {} - explicit TemplatedIndexedValue(Value v) - : TemplatedIndexedValue(ValueHandle(v)) {} - explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} - - TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; - - TemplatedIndexedValue operator()() { return *this; } - /// Returns a new `TemplatedIndexedValue`. - TemplatedIndexedValue operator()(ValueHandle index) { - TemplatedIndexedValue res(base); - res.indices.push_back(index); - return res; - } - template - TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { - return TemplatedIndexedValue(base, index).append(indices...); - } - TemplatedIndexedValue operator()(ArrayRef indices) { - return TemplatedIndexedValue(base, indices); - } - TemplatedIndexedValue operator()(ArrayRef indices) { - return TemplatedIndexedValue( - base, ArrayRef(indices.begin(), indices.end())); - } - - /// Emits a `store`. - // NOLINTNEXTLINE: unconventional-assign-operator - OperationHandle operator=(const TemplatedIndexedValue &rhs) { - ValueHandle rrhs(rhs); - return Store(rrhs, getBase(), {indices.begin(), indices.end()}); - } - // NOLINTNEXTLINE: unconventional-assign-operator - OperationHandle operator=(ValueHandle rhs) { - return Store(rhs, getBase(), {indices.begin(), indices.end()}); - } - - /// Emits a `load` when converting to a ValueHandle. - operator ValueHandle() const { - return Load(getBase(), {indices.begin(), indices.end()}); - } - - /// Emits a `load` when converting to a Value. - Value operator*(void) const { - return Load(getBase(), {indices.begin(), indices.end()}).getValue(); - } - - ValueHandle getBase() const { return base; } - - /// Operator overloadings. - ValueHandle operator+(ValueHandle e); - ValueHandle operator-(ValueHandle e); - ValueHandle operator*(ValueHandle e); - ValueHandle operator/(ValueHandle e); - OperationHandle operator+=(ValueHandle e); - OperationHandle operator-=(ValueHandle e); - OperationHandle operator*=(ValueHandle e); - OperationHandle operator/=(ValueHandle e); - ValueHandle operator+(TemplatedIndexedValue e) { - return *this + static_cast(e); - } - ValueHandle operator-(TemplatedIndexedValue e) { - return *this - static_cast(e); - } - ValueHandle operator*(TemplatedIndexedValue e) { - return *this * static_cast(e); - } - ValueHandle operator/(TemplatedIndexedValue e) { - return *this / static_cast(e); - } - OperationHandle operator+=(TemplatedIndexedValue e) { - return this->operator+=(static_cast(e)); - } - OperationHandle operator-=(TemplatedIndexedValue e) { - return this->operator-=(static_cast(e)); - } - OperationHandle operator*=(TemplatedIndexedValue e) { - return this->operator*=(static_cast(e)); - } - OperationHandle operator/=(TemplatedIndexedValue e) { - return this->operator/=(static_cast(e)); - } - -private: - TemplatedIndexedValue(ValueHandle base, ArrayRef indices) - : base(base), indices(indices.begin(), indices.end()) {} - - TemplatedIndexedValue &append() { return *this; } - - template - TemplatedIndexedValue &append(T index, Args... indices) { - this->indices.push_back(static_cast(index)); - append(indices...); - return *this; - } - ValueHandle base; - SmallVector indices; -}; - -/// Operator overloadings. -template -ValueHandle TemplatedIndexedValue::operator+(ValueHandle e) { - using op::operator+; - return static_cast(*this) + e; -} -template -ValueHandle TemplatedIndexedValue::operator-(ValueHandle e) { - using op::operator-; - return static_cast(*this) - e; -} -template -ValueHandle TemplatedIndexedValue::operator*(ValueHandle e) { - using op::operator*; - return static_cast(*this) * e; -} -template -ValueHandle TemplatedIndexedValue::operator/(ValueHandle e) { - using op::operator/; - return static_cast(*this) / e; -} - -template -OperationHandle TemplatedIndexedValue::operator+=(ValueHandle e) { - using op::operator+; - return Store(*this + e, getBase(), {indices.begin(), indices.end()}); -} -template -OperationHandle TemplatedIndexedValue::operator-=(ValueHandle e) { - using op::operator-; - return Store(*this - e, getBase(), {indices.begin(), indices.end()}); -} -template -OperationHandle TemplatedIndexedValue::operator*=(ValueHandle e) { - using op::operator*; - return Store(*this * e, getBase(), {indices.begin(), indices.end()}); -} -template -OperationHandle TemplatedIndexedValue::operator/=(ValueHandle e) { - using op::operator/; - return Store(*this / e, getBase(), {indices.begin(), indices.end()}); -} - -} // namespace edsc -} // namespace mlir - -#endif // MLIR_EDSC_HELPERS_H_ diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -7,7 +7,7 @@ //===----------------------------------------------------------------------===// // // Provides intuitive composable intrinsics for building snippets of MLIR -// declaratively +// declaratively. // //===----------------------------------------------------------------------===// @@ -15,6 +15,7 @@ #define MLIR_EDSC_INTRINSICS_H_ #include "mlir/EDSC/Builders.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" namespace mlir { @@ -24,62 +25,16 @@ namespace edsc { -/// An IndexHandle is a simple wrapper around a ValueHandle. -/// IndexHandles are ubiquitous enough to justify a new type to allow simple -/// declarations without boilerplate such as: -/// -/// ```c++ -/// IndexHandle i, j, k; -/// ``` -struct IndexHandle : public ValueHandle { - explicit IndexHandle() - : ValueHandle(ScopedContext::getBuilder().getIndexType()) {} - explicit IndexHandle(index_type v) : ValueHandle(v) {} - explicit IndexHandle(Value v) : ValueHandle(v) { - assert(v.getType() == ScopedContext::getBuilder().getIndexType() && - "Expected index type"); - } - explicit IndexHandle(ValueHandle v) : ValueHandle(v) { - assert(v.getType() == ScopedContext::getBuilder().getIndexType() && - "Expected index type"); - } - IndexHandle &operator=(const ValueHandle &v) { - assert(v.getType() == ScopedContext::getBuilder().getIndexType() && - "Expected index type"); - /// Creating a new IndexHandle(v) and then std::swap rightly complains the - /// binding has already occurred and that we should use another name. - this->t = v.getType(); - this->v = v.getValue(); - return *this; - } -}; - -inline SmallVector makeIndexHandles(unsigned rank) { - return SmallVector(rank); -} - -/// Entry point to build multiple ValueHandle* from a mutable list `ivs` of T. -template +/// Entry point to build multiple ValueHandle* from a mutable list `ivs`. inline SmallVector -makeHandlePointers(MutableArrayRef ivs) { +makeHandlePointers(MutableArrayRef ivs) { SmallVector pivs; pivs.reserve(ivs.size()); - for (auto &iv : ivs) { + for (auto &iv : ivs) pivs.push_back(&iv); - } return pivs; } -/// Returns a vector of the underlying Value from `ivs`. -inline SmallVector extractValues(ArrayRef ivs) { - SmallVector vals; - vals.reserve(ivs.size()); - for (auto &iv : ivs) { - vals.push_back(iv.getValue()); - } - return vals; -} - /// Provides a set of first class intrinsics. /// In the future, most of intrinsics related to Operation that don't contain /// other operations should be Tablegen'd. @@ -93,13 +48,6 @@ ValueHandleArray(ArrayRef vals) { values.append(vals.begin(), vals.end()); } - ValueHandleArray(ArrayRef vals) { - values.append(vals.begin(), vals.end()); - } - ValueHandleArray(ArrayRef vals) { - SmallVector tmp(vals.begin(), vals.end()); - values.append(tmp.begin(), tmp.end()); - } operator ArrayRef() { return values; } private: @@ -143,29 +91,6 @@ detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), detail::unpack(args)...)) {} - /// Folder-based - template - ValueBuilder(OperationFolder *folder, Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(args)...)) {} - ValueBuilder(OperationFolder *folder, ArrayRef vs) - : ValueBuilder(ValueBuilder::create(folder, detail::unpack(vs))) {} - template - ValueBuilder(OperationFolder *folder, ArrayRef vs, Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(vs), - detail::unpack(args)...)) {} - template - ValueBuilder(OperationFolder *folder, T t, ArrayRef vs, - Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(t), - detail::unpack(vs), - detail::unpack(args)...)) {} - template - ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef vs, - Args... args) - : ValueHandle(ValueHandle::create( - folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), - detail::unpack(args)...)) {} - ValueBuilder() : ValueHandle(ValueHandle::create()) {} }; @@ -191,88 +116,6 @@ OperationBuilder() : OperationHandle(OperationHandle::create()) {} }; -using addf = ValueBuilder; -using affine_apply = ValueBuilder; -using affine_if = OperationBuilder; -using affine_load = ValueBuilder; -using affine_min = ValueBuilder; -using affine_max = ValueBuilder; -using affine_store = OperationBuilder; -using alloc = ValueBuilder; -using call = OperationBuilder; -using constant_float = ValueBuilder; -using constant_index = ValueBuilder; -using constant_int = ValueBuilder; -using dealloc = OperationBuilder; -using dim = ValueBuilder; -using muli = ValueBuilder; -using mulf = ValueBuilder; -using memref_cast = ValueBuilder; -using ret = OperationBuilder; -using select = ValueBuilder; -using std_load = ValueBuilder; -using std_store = OperationBuilder; -using subi = ValueBuilder; -using tanh = ValueBuilder; -using view = ValueBuilder; -using zero_extendi = ValueBuilder; -using sign_extendi = ValueBuilder; - -/// Branches into the mlir::Block* captured by BlockHandle `b` with `operands`. -/// -/// Prerequisites: -/// All Handles have already captured previously constructed IR objects. -OperationHandle br(BlockHandle bh, ArrayRef operands); - -/// Creates a new mlir::Block* and branches to it from the current block. -/// Argument types are specified by `operands`. -/// Captures the new block in `bh` and the actual `operands` in `captures`. To -/// insert the new mlir::Block*, a local ScopedContext is constructed and -/// released to the current block. The branch operation is then added to the -/// new block. -/// -/// Prerequisites: -/// `b` has not yet captured an mlir::Block*. -/// No `captures` have captured any mlir::Value. -/// All `operands` have already captured an mlir::Value -/// captures.size() == operands.size() -/// captures and operands are pairwise of the same type. -OperationHandle br(BlockHandle *bh, ArrayRef captures, - ArrayRef operands); - -/// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with -/// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and -/// `falseOperand` if `cond` evaluates to `false`). -/// -/// Prerequisites: -/// All Handles have captured previously constructed IR objects. -OperationHandle cond_br(ValueHandle cond, BlockHandle trueBranch, - ArrayRef trueOperands, - BlockHandle falseBranch, - ArrayRef falseOperands); - -/// Eagerly creates new mlir::Block* with argument types specified by -/// `trueOperands`/`falseOperands`. -/// Captures the new blocks in `trueBranch`/`falseBranch` and the arguments in -/// `trueCaptures/falseCaptures`. -/// To insert the new mlir::Block*, a local ScopedContext is constructed and -/// released. The branch operation is then added in the original location and -/// targeting the eagerly constructed blocks. -/// -/// Prerequisites: -/// `trueBranch`/`falseBranch` has not yet captured an mlir::Block*. -/// No `trueCaptures`/`falseCaptures` have captured any mlir::Value. -/// All `trueOperands`/`trueOperands` have already captured an mlir::Value -/// `trueCaptures`.size() == `trueOperands`.size() -/// `falseCaptures`.size() == `falseOperands`.size() -/// `trueCaptures` and `trueOperands` are pairwise of the same type -/// `falseCaptures` and `falseOperands` are pairwise of the same type. -OperationHandle cond_br(ValueHandle cond, BlockHandle *trueBranch, - ArrayRef trueCaptures, - ArrayRef trueOperands, - BlockHandle *falseBranch, - ArrayRef falseCaptures, - ArrayRef falseOperands); } // namespace intrinsics } // namespace edsc } // namespace mlir diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -16,8 +16,7 @@ #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -47,24 +46,22 @@ using namespace mlir::LLVM; using namespace mlir::linalg; -using add = ValueBuilder; -using addi = ValueBuilder; -using bitcast = ValueBuilder; -using cmpi = ValueBuilder; -using constant = ValueBuilder; -using extractvalue = ValueBuilder; -using gep = ValueBuilder; -using insertvalue = ValueBuilder; -using llvm_call = OperationBuilder; +using llvm_add = ValueBuilder; +using llvm_bitcast = ValueBuilder; +using llvm_constant = ValueBuilder; +using llvm_extractvalue = ValueBuilder; +using llvm_gep = ValueBuilder; +using llvm_insertvalue = ValueBuilder; +using llvm_call = OperationBuilder; using llvm_icmp = ValueBuilder; using llvm_load = ValueBuilder; using llvm_store = OperationBuilder; using llvm_select = ValueBuilder; -using mul = ValueBuilder; -using ptrtoint = ValueBuilder; -using sub = ValueBuilder; -using llvm_undef = ValueBuilder; -using urem = ValueBuilder; +using llvm_mul = ValueBuilder; +using llvm_ptrtoint = ValueBuilder; +using llvm_sub = ValueBuilder; +using llvm_undef = ValueBuilder; +using llvm_urem = ValueBuilder; using llvm_alloca = ValueBuilder; using llvm_return = OperationBuilder; @@ -156,9 +153,9 @@ // Fill in an aggregate value of the descriptor. RangeOpOperandAdaptor adaptor(operands); Value desc = llvm_undef(rangeDescriptorTy); - desc = insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); - desc = insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); - desc = insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); + desc = llvm_insertvalue(desc, adaptor.min(), rewriter.getI64ArrayAttr(0)); + desc = llvm_insertvalue(desc, adaptor.max(), rewriter.getI64ArrayAttr(1)); + desc = llvm_insertvalue(desc, adaptor.step(), rewriter.getI64ArrayAttr(2)); rewriter.replaceOp(op, desc); return matchSuccess(); } @@ -249,8 +246,8 @@ Value indexing = adaptor.indexings()[i]; Value min = indexing; if (sliceOp.indexing(i).getType().isa()) - min = extractvalue(int64Ty, indexing, pos(0)); - baseOffset = add(baseOffset, mul(min, strides[i])); + min = llvm_extractvalue(int64Ty, indexing, pos(0)); + baseOffset = llvm_add(baseOffset, llvm_mul(min, strides[i])); } // Insert the base and aligned pointers. @@ -264,8 +261,8 @@ if (sliceOp.getShapedType().getRank() == 0) return rewriter.replaceOp(op, {desc}), matchSuccess(); - Value zero = - constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); + Value zero = llvm_constant( + int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; @@ -274,19 +271,19 @@ if (indexing.getType().isa()) { int rank = en.index(); Value rangeDescriptor = adaptor.indexings()[rank]; - Value min = extractvalue(int64Ty, rangeDescriptor, pos(0)); - Value max = extractvalue(int64Ty, rangeDescriptor, pos(1)); - Value step = extractvalue(int64Ty, rangeDescriptor, pos(2)); + Value min = llvm_extractvalue(int64Ty, rangeDescriptor, pos(0)); + Value max = llvm_extractvalue(int64Ty, rangeDescriptor, pos(1)); + Value step = llvm_extractvalue(int64Ty, rangeDescriptor, pos(2)); Value baseSize = baseDesc.size(rank); // Bound upper by base view upper bound. max = llvm_select(llvm_icmp(ICmpPredicate::slt, max, baseSize), max, baseSize); - Value size = sub(max, min); + Value size = llvm_sub(max, min); // Bound lower by zero. size = llvm_select(llvm_icmp(ICmpPredicate::slt, size, zero), zero, size); - Value stride = mul(strides[rank], step); + Value stride = llvm_mul(strides[rank], step); desc.setSize(numNewDims, size); desc.setStride(numNewDims, stride); ++numNewDims; @@ -450,8 +447,7 @@ /// Conversion pattern specialization for CopyOp. This kicks in when both input /// and output permutations are left unspecified or are the identity. -template <> -class LinalgOpConversion : public OpRewritePattern { +template <> class LinalgOpConversion : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -13,9 +13,10 @@ #include #include "mlir/Conversion/VectorToLoops/ConvertVectorToLoops.h" +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/VectorOps/VectorOps.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Attributes.h" @@ -27,17 +28,19 @@ #include "mlir/IR/Types.h" using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; using vector::TransferReadOp; using vector::TransferWriteOp; /// Analyzes the `transfer` to find an access dimension along the fastest remote /// MemRef dimension. If such a dimension with coalescing properties is found, -/// `pivs` and `vectorView` are swapped so that the invocation of +/// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of /// LoopNestBuilder captures it in the innermost loop. template static void coalesceCopy(TransferOpTy transfer, - SmallVectorImpl *pivs, - edsc::VectorView *vectorView) { + SmallVectorImpl *pivs, + VectorBoundsCapture *vectorBoundsCapture) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. auto remoteRank = transfer.getMemRefType().getRank(); @@ -61,25 +64,22 @@ } if (coalescedIdx >= 0) { std::swap(pivs->back(), (*pivs)[coalescedIdx]); - vectorView->swapRanges(pivs->size() - 1, coalescedIdx); + vectorBoundsCapture->swapRanges(pivs->size() - 1, coalescedIdx); } } /// Emits remote memory accesses that are clipped to the boundaries of the /// MemRef. template -static SmallVector clip(TransferOpTy transfer, - edsc::MemRefView &view, - ArrayRef ivs) { +static SmallVector clip(TransferOpTy transfer, + MemRefBoundsCapture &bounds, + ArrayRef ivs) { using namespace mlir::edsc; - using namespace edsc::op; - using edsc::intrinsics::select; - - IndexHandle zero(index_type(0)), one(index_type(1)); - SmallVector memRefAccess(transfer.indices()); - SmallVector clippedScalarAccessExprs( - memRefAccess.size(), edsc::IndexHandle()); + ValueHandle zero(std_constant_index(0)), one(std_constant_index(1)); + SmallVector memRefAccess(transfer.indices()); + auto clippedScalarAccessExprs = + ValueHandle::makeIndexHandles(memRefAccess.size()); // Indices accessing to remote memory are clipped and their expressions are // returned in clippedScalarAccessExprs. for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size(); @@ -103,19 +103,21 @@ // We cannot distinguish atm between unrolled dimensions that implement // the "always full" tile abstraction and need clipping from the other // ones. So we conservatively clip everything. - auto N = view.ub(memRefDim); + using namespace edsc::op; + auto N = bounds.ub(memRefDim); auto i = memRefAccess[memRefDim]; if (loopIndex < 0) { auto N_minus_1 = N - one; - auto select_1 = select(i < N, i, N_minus_1); - clippedScalarAccessExprs[memRefDim] = select(i < zero, zero, select_1); + auto select_1 = std_select(i < N, i, N_minus_1); + clippedScalarAccessExprs[memRefDim] = + std_select(i < zero, zero, select_1); } else { auto ii = ivs[loopIndex]; auto i_plus_ii = i + ii; auto N_minus_1 = N - one; - auto select_1 = select(i_plus_ii < N, i_plus_ii, N_minus_1); + auto select_1 = std_select(i_plus_ii < N, i_plus_ii, N_minus_1); clippedScalarAccessExprs[memRefDim] = - select(i_plus_ii < zero, zero, select_1); + std_select(i_plus_ii < zero, zero, select_1); } } @@ -165,9 +167,9 @@ /// /// ```mlir-dsc /// auto condMax = i + ii < N; -/// auto max = select(condMax, i + ii, N - one) +/// auto max = std_select(condMax, i + ii, N - one) /// auto cond = i + ii < zero; -/// select(cond, zero, max); +/// std_select(cond, zero, max); /// ``` /// /// In the future, clipping should not be the only way and instead we should @@ -246,41 +248,37 @@ template <> PatternMatchResult VectorTransferRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { - using namespace mlir::edsc; using namespace mlir::edsc::op; - using namespace mlir::edsc::intrinsics; - using IndexedValue = - TemplatedIndexedValue; TransferReadOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - IndexedValue remote(transfer.memref()); - MemRefView view(transfer.memref()); - VectorView vectorView(transfer.vector()); - SmallVector ivs = makeIndexHandles(vectorView.rank()); + StdIndexedValue remote(transfer.memref()); + MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); + VectorBoundsCapture vectorBoundsCapture(transfer.vector()); + auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank()); SmallVector pivs = - makeHandlePointers(MutableArrayRef(ivs)); - coalesceCopy(transfer, &pivs, &vectorView); + makeHandlePointers(MutableArrayRef(ivs)); + coalesceCopy(transfer, &pivs, &vectorBoundsCapture); - auto lbs = vectorView.getLbs(); - auto ubs = vectorView.getUbs(); + auto lbs = vectorBoundsCapture.getLbs(); + auto ubs = vectorBoundsCapture.getUbs(); SmallVector steps; - steps.reserve(vectorView.getSteps().size()); - for (auto step : vectorView.getSteps()) - steps.push_back(constant_index(step)); + steps.reserve(vectorBoundsCapture.getSteps().size()); + for (auto step : vectorBoundsCapture.getSteps()) + steps.push_back(std_constant_index(step)); // 2. Emit alloc-copy-load-dealloc. - ValueHandle tmp = alloc(tmpMemRefType(transfer)); - IndexedValue local(tmp); + ValueHandle tmp = std_alloc(tmpMemRefType(transfer)); + StdIndexedValue local(tmp); ValueHandle vec = vector_type_cast(tmp); LoopNestBuilder(pivs, lbs, ubs, steps)([&] { // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). - local(ivs) = remote(clip(transfer, view, ivs)); + local(ivs) = remote(clip(transfer, memRefBoundsCapture, ivs)); }); ValueHandle vectorValue = std_load(vec); - (dealloc(tmp)); // vexing parse + (std_dealloc(tmp)); // vexing parse // 3. Propagate. rewriter.replaceOp(op, vectorValue.getValue()); @@ -308,42 +306,38 @@ template <> PatternMatchResult VectorTransferRewriter::matchAndRewrite( Operation *op, PatternRewriter &rewriter) const { - using namespace mlir::edsc; - using namespace mlir::edsc::op; - using namespace mlir::edsc::intrinsics; - using IndexedValue = - TemplatedIndexedValue; + using namespace edsc::op; TransferWriteOp transfer = cast(op); // 1. Setup all the captures. ScopedContext scope(rewriter, transfer.getLoc()); - IndexedValue remote(transfer.memref()); - MemRefView view(transfer.memref()); + StdIndexedValue remote(transfer.memref()); + MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); ValueHandle vectorValue(transfer.vector()); - VectorView vectorView(transfer.vector()); - SmallVector ivs = makeIndexHandles(vectorView.rank()); + VectorBoundsCapture vectorBoundsCapture(transfer.vector()); + auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank()); SmallVector pivs = - makeHandlePointers(MutableArrayRef(ivs)); - coalesceCopy(transfer, &pivs, &vectorView); + makeHandlePointers(MutableArrayRef(ivs)); + coalesceCopy(transfer, &pivs, &vectorBoundsCapture); - auto lbs = vectorView.getLbs(); - auto ubs = vectorView.getUbs(); + auto lbs = vectorBoundsCapture.getLbs(); + auto ubs = vectorBoundsCapture.getUbs(); SmallVector steps; - steps.reserve(vectorView.getSteps().size()); - for (auto step : vectorView.getSteps()) - steps.push_back(constant_index(step)); + steps.reserve(vectorBoundsCapture.getSteps().size()); + for (auto step : vectorBoundsCapture.getSteps()) + steps.push_back(std_constant_index(step)); // 2. Emit alloc-store-copy-dealloc. - ValueHandle tmp = alloc(tmpMemRefType(transfer)); - IndexedValue local(tmp); + ValueHandle tmp = std_alloc(tmpMemRefType(transfer)); + StdIndexedValue local(tmp); ValueHandle vec = vector_type_cast(tmp); std_store(vectorValue, vec); LoopNestBuilder(pivs, lbs, ubs, steps)([&] { // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). - remote(clip(transfer, view, ivs)) = local(ivs); + remote(clip(transfer, memRefBoundsCapture, ivs)) = local(ivs); }); - (dealloc(tmp)); // vexing parse... + (std_dealloc(tmp)); // vexing parse... rewriter.eraseOp(op); return matchSuccess(); diff --git a/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/AffineOps/EDSC/Builders.cpp @@ -0,0 +1,286 @@ +//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AffineOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Builders.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" + +using namespace mlir; +using namespace mlir::edsc; + +static Optional emitStaticFor(ArrayRef lbs, + ArrayRef ubs, + int64_t step) { + if (lbs.size() != 1 || ubs.size() != 1) + return Optional(); + + auto *lbDef = lbs.front().getValue().getDefiningOp(); + auto *ubDef = ubs.front().getValue().getDefiningOp(); + if (!lbDef || !ubDef) + return Optional(); + + auto lbConst = dyn_cast(lbDef); + auto ubConst = dyn_cast(ubDef); + if (!lbConst || !ubConst) + return Optional(); + + return ValueHandle(ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), + lbConst.getValue(), + ubConst.getValue(), step) + .getInductionVar()); +} + +LoopBuilder mlir::edsc::makeAffineLoopBuilder(ValueHandle *iv, + ArrayRef lbHandles, + ArrayRef ubHandles, + int64_t step) { + mlir::edsc::LoopBuilder result; + if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { + *iv = staticFor.getValue(); + } else { + SmallVector lbs(lbHandles.begin(), lbHandles.end()); + SmallVector ubs(ubHandles.begin(), ubHandles.end()); + auto b = ScopedContext::getBuilder(); + *iv = ValueHandle( + b.create(ScopedContext::getLocation(), lbs, + b.getMultiDimIdentityMap(lbs.size()), ubs, + b.getMultiDimIdentityMap(ubs.size()), step) + .getInductionVar()); + } + auto *body = getForInductionVarOwner(iv->getValue()).getBody(); + result.enter(body, /*prev=*/1); + return result; +} + +mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( + ValueHandle *iv, ArrayRef lbs, ArrayRef ubs, + int64_t step) { + loops.emplace_back(makeAffineLoopBuilder(iv, lbs, ubs, step)); +} + +mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( + ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps) { + assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); + for (auto it : llvm::zip(ivs, lbs, ubs, steps)) + loops.emplace_back(makeAffineLoopBuilder(std::get<0>(it), std::get<1>(it), + std::get<2>(it), std::get<3>(it))); +} + +void mlir::edsc::AffineLoopNestBuilder::operator()( + function_ref fun) { + if (fun) + fun(); + // Iterate on the calling operator() on all the loops in the nest. + // The iteration order is from innermost to outermost because enter/exit needs + // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() + // occurs on calling operator()). The asymmetry is required for properly + // nesting imperfectly nested regions (see LoopBuilder::operator()). + for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) + (*lit)(); +} + +template +static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { + return ValueHandle::create(lhs.getValue(), rhs.getValue()); +} + +static std::pair +categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, + unsigned &numSymbols) { + AffineExpr d; + Value resultVal = nullptr; + if (auto constant = dyn_cast_or_null(val.getDefiningOp())) { + d = getAffineConstantExpr(constant.getValue(), context); + } else if (isValidSymbol(val) && !isValidDim(val)) { + d = getAffineSymbolExpr(numSymbols++, context); + resultVal = val; + } else { + d = getAffineDimExpr(numDims++, context); + resultVal = val; + } + return std::make_pair(d, resultVal); +} + +static ValueHandle createBinaryIndexHandle( + ValueHandle lhs, ValueHandle rhs, + function_ref affCombiner) { + MLIRContext *context = ScopedContext::getContext(); + unsigned numDims = 0, numSymbols = 0; + AffineExpr d0, d1; + Value v0, v1; + std::tie(d0, v0) = + categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); + std::tie(d1, v1) = + categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); + SmallVector operands; + if (v0) { + operands.push_back(v0); + } + if (v1) { + operands.push_back(v1); + } + auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); + // TODO: createOrFold when available. + Operation *op = + makeComposedAffineApply(ScopedContext::getBuilder(), + ScopedContext::getLocation(), map, operands) + .getOperation(); + assert(op->getNumResults() == 1 && "Expected single result AffineApply"); + return ValueHandle(op->getResult(0)); +} + +template +static ValueHandle createBinaryHandle( + ValueHandle lhs, ValueHandle rhs, + function_ref affCombiner) { + auto thisType = lhs.getValue().getType(); + auto thatType = rhs.getValue().getType(); + assert(thisType == thatType && "cannot mix types in operators"); + (void)thisType; + (void)thatType; + if (thisType.isIndex()) { + return createBinaryIndexHandle(lhs, rhs, affCombiner); + } else if (thisType.isa()) { + return createBinaryHandle(lhs, rhs); + } else if (thisType.isa()) { + return createBinaryHandle(lhs, rhs); + } else if (thisType.isa() || thisType.isa()) { + auto aggregateType = thisType.cast(); + if (aggregateType.getElementType().isa()) + return createBinaryHandle(lhs, rhs); + else if (aggregateType.getElementType().isa()) + return createBinaryHandle(lhs, rhs); + } + llvm_unreachable("failed to create a ValueHandle"); +} + +ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) { + return createBinaryHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); +} + +ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) { + return createBinaryHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); +} + +ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) { + return createBinaryHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); +} + +ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) { + return createBinaryHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { + llvm_unreachable("only exprs of non-index type support operator/"); + }); +} + +ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) { + return createBinaryHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); +} + +ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) { + return createBinaryIndexHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); +} + +ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) { + return createBinaryIndexHandle( + lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); +} + +ValueHandle mlir::edsc::op::operator!(ValueHandle value) { + assert(value.getType().isInteger(1) && "expected boolean expression"); + return ValueHandle::create(1, 1) - value; +} + +ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { + assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); + assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); + return lhs * rhs; +} + +ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { + return !(!lhs && !rhs); +} + +static ValueHandle createIComparisonExpr(CmpIPredicate predicate, + ValueHandle lhs, ValueHandle rhs) { + auto lhsType = lhs.getType(); + auto rhsType = rhs.getType(); + (void)lhsType; + (void)rhsType; + assert(lhsType == rhsType && "cannot mix types in operators"); + assert((lhsType.isa() || lhsType.isa()) && + "only integer comparisons are supported"); + + auto op = ScopedContext::getBuilder().create( + ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); + return ValueHandle(op.getResult()); +} + +static ValueHandle createFComparisonExpr(CmpFPredicate predicate, + ValueHandle lhs, ValueHandle rhs) { + auto lhsType = lhs.getType(); + auto rhsType = rhs.getType(); + (void)lhsType; + (void)rhsType; + assert(lhsType == rhsType && "cannot mix types in operators"); + assert(lhsType.isa() && "only float comparisons are supported"); + + auto op = ScopedContext::getBuilder().create( + ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); + return ValueHandle(op.getResult()); +} + +// All floating point comparison are ordered through EDSL +ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) + : + // TODO(ntv,zinenko): signed by default, how about unsigned? + createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); +} +ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { + auto type = lhs.getType(); + return type.isa() + ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) + : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); +} diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -13,14 +13,15 @@ #include "mlir/Dialect/GPU/MemoryPromotion.h" #include "mlir/Dialect/GPU/GPUDialect.h" -#include "mlir/Dialect/LoopOps/LoopOps.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Pass/Pass.h" #include "mlir/Support/Functional.h" #include "mlir/Transforms/LoopUtils.h" using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; using namespace mlir::gpu; /// Returns the textual name of a GPU dimension. @@ -41,17 +42,17 @@ /// single-iteration loops. Maps the innermost loops to thread dimensions, in /// reverse order to enable access coalescing in the innermost loop. static void insertCopyLoops(OpBuilder &builder, Location loc, - edsc::MemRefView &bounds, Value from, Value to) { + MemRefBoundsCapture &bounds, Value from, Value to) { // Create EDSC handles for bounds. unsigned rank = bounds.rank(); - SmallVector lbs, ubs, steps; + SmallVector lbs, ubs, steps; // Make sure we have enough loops to use all thread dimensions, these trivial // loops should be outermost and therefore inserted first. if (rank < GPUDialect::getNumWorkgroupDimensions()) { unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank; - edsc::ValueHandle zero = edsc::intrinsics::constant_index(0); - edsc::ValueHandle one = edsc::intrinsics::constant_index(1); + ValueHandle zero = std_constant_index(0); + ValueHandle one = std_constant_index(1); lbs.resize(extraLoops, zero); ubs.resize(extraLoops, one); steps.resize(extraLoops, one); @@ -63,9 +64,8 @@ // Emit constant operations for steps. steps.reserve(lbs.size()); - llvm::transform( - bounds.getSteps(), std::back_inserter(steps), - [](int64_t step) { return edsc::intrinsics::constant_index(step); }); + llvm::transform(bounds.getSteps(), std::back_inserter(steps), + [](int64_t step) { return std_constant_index(step); }); // Obtain thread identifiers and block sizes, necessary to map to them. auto indexType = builder.getIndexType(); @@ -79,12 +79,11 @@ } // Produce the loop nest with copies. - auto ivs = edsc::makeIndexHandles(lbs.size()); - auto ivPtrs = - edsc::makeHandlePointers(MutableArrayRef(ivs)); - edsc::LoopNestBuilder(ivPtrs, lbs, ubs, steps)([&]() { + SmallVector ivs(lbs.size(), ValueHandle(indexType)); + auto ivPtrs = makeHandlePointers(MutableArrayRef(ivs)); + LoopNestBuilder(ivPtrs, lbs, ubs, steps)([&]() { auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank); - edsc::StdIndexedValue fromHandle(from), toHandle(to); + StdIndexedValue fromHandle(from), toHandle(to); toHandle(activeIvs) = fromHandle(activeIvs); }); @@ -146,14 +145,14 @@ OpBuilder builder(region.getContext()); builder.setInsertionPointToStart(®ion.front()); - edsc::ScopedContext edscContext(builder, loc); - edsc::MemRefView fromView(from); - insertCopyLoops(builder, loc, fromView, from, to); + ScopedContext edscContext(builder, loc); + MemRefBoundsCapture fromBoundsCapture(from); + insertCopyLoops(builder, loc, fromBoundsCapture, from, to); builder.create(loc); builder.setInsertionPoint(®ion.front().back()); builder.create(loc); - insertCopyLoops(builder, loc, fromView, to, from); + insertCopyLoops(builder, loc, fromBoundsCapture, to, from); } /// Promotes a function argument to workgroup memory in the given function. The diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -6,20 +6,18 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/IR/Builders.h" +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" -#include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" -#include "mlir/IR/Builders.h" #include "mlir/Support/Functional.h" using namespace mlir; using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; -using namespace mlir::edsc::ops; using namespace mlir::linalg; using namespace mlir::loop; @@ -261,8 +259,8 @@ Operation *mlir::edsc::ops::linalg_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { - using edsc::intrinsics::tanh; - UnaryPointwiseOpBuilder unOp([](ValueHandle a) -> Value { return tanh(a); }); + UnaryPointwiseOpBuilder unOp( + [](ValueHandle a) -> Value { return std_tanh(a); }); return linalg_pointwise(unOp, I, O); } @@ -302,9 +300,8 @@ StructuredIndexed I2, StructuredIndexed O) { BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { - using edsc::intrinsics::select; using edsc::op::operator>; - return select(a > b, a, b).getValue(); + return std_select(a > b, a, b).getValue(); }); return linalg_pointwise(binOp, I1, I2, O); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -12,11 +12,12 @@ #include "mlir/Analysis/Dominance.h" #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/OpImplementation.h" @@ -36,6 +37,8 @@ using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; +using folded_std_constant_index = folded::ValueBuilder; + using llvm::dbgs; /// Implements a simple high-level fusion pass of linalg library operations. @@ -188,9 +191,11 @@ << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto viewDim = getViewDefiningLoopRange(producer, i); - loopRanges[i] = SubViewOp::Range{constant_index(folder, 0), - dim(viewDim.view, viewDim.dimension), - constant_index(folder, 1)}; + loopRanges[i] = SubViewOp::Range{ + folded_std_constant_index(folder, 0), + std_dim(viewDim.view, viewDim.dimension), + folded_std_constant_index(folder, 1) + }; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -6,16 +6,15 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Transforms/LinalgTransforms.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/LoopOps/LoopOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" @@ -31,9 +30,6 @@ using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -using IndexedStdValue = TemplatedIndexedValue; -using IndexedAffineValue = TemplatedIndexedValue; - using edsc::op::operator+; using edsc::op::operator==; @@ -77,7 +73,7 @@ SmallVector res; for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) { res.push_back( - linalg_range(constant_index(0), sizes[idx], constant_index(1))); + linalg_range(std_constant_index(0), sizes[idx], std_constant_index(1))); } return res; } @@ -98,8 +94,8 @@ permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); auto outputIvs = permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. @@ -119,7 +115,7 @@ auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); auto ivs = - SmallVector(allIvs.begin(), allIvs.begin() + nPar); + SmallVector(allIvs.begin(), allIvs.begin() + nPar); IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. @@ -135,7 +131,7 @@ assert(dotOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); - IndexHandle r_i(allIvs[0]); + ValueHandle r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), C(dotOp.getOutputBuffer(0)); // Emit scalar form. @@ -151,7 +147,7 @@ assert(matvecOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); - IndexHandle i(allIvs[0]), r_j(allIvs[1]); + ValueHandle i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), C(matvecOp.getOutputBuffer(0)); // Emit scalar form. @@ -167,7 +163,7 @@ assert(matmulOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); - IndexHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); + ValueHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), C(matmulOp.getOutputBuffer(0)); // Emit scalar form. @@ -258,7 +254,7 @@ auto funcOp = genericOp.getFunction(); if (funcOp) { // 2. Emit call. - Operation *callOp = call(funcOp, indexedValues); + Operation *callOp = std_call(funcOp, indexedValues); assert(callOp->getNumResults() == genericOp.getNumOutputs()); // 3. Emit std_store. @@ -359,7 +355,7 @@ if (auto funcOp = indexedGenericOp.getFunction()) { // 2. Emit call. - Operation *callOp = call(funcOp, indexedValues); + Operation *callOp = std_call(funcOp, indexedValues); assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs()); // 3. Emit std_store. @@ -442,15 +438,15 @@ return success(); } - SmallVector allIvs(nLoops); + SmallVector allIvs(nLoops, ValueHandle(b.getIndexType())); SmallVector allPIvs = - makeHandlePointers(MutableArrayRef(allIvs)); + makeHandlePointers(MutableArrayRef(allIvs)); auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, getViewSizes(b, linalgOp)); assert(loopRanges.size() == allIvs.size()); GenericLoopNestRangeBuilder(allPIvs, loopRanges)([&] { - auto allIvValues = extractValues(allIvs); + SmallVector allIvValues(allIvs.begin(), allIvs.end()); LinalgScopedEmitter::emitScalarImplementation( allIvValues, linalgOp); }); @@ -568,26 +564,26 @@ std::unique_ptr> mlir::createConvertLinalgToLoopsPass() { return std::make_unique< - LowerLinalgToLoopsPass>(); + LowerLinalgToLoopsPass>(); } std::unique_ptr> mlir::createConvertLinalgToParallelLoopsPass() { return std::make_unique< - LowerLinalgToLoopsPass>(); + LowerLinalgToLoopsPass>(); } std::unique_ptr> mlir::createConvertLinalgToAffineLoopsPass() { return std::make_unique< - LowerLinalgToLoopsPass>(); + LowerLinalgToLoopsPass>(); } /// Emits a loop nest of `loop.for` with the proper body for `op`. template LogicalResult mlir::linalg::linalgOpToLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit( + return LinalgOpToLoopsImpl::doit( op, rewriter); } @@ -595,7 +591,7 @@ template LogicalResult mlir::linalg::linalgOpToAffineLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit( + return LinalgOpToLoopsImpl::doit( op, rewriter); } @@ -603,7 +599,7 @@ template LogicalResult mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op) { - return LinalgOpToLoopsImpl::doit(op, rewriter); } @@ -630,18 +626,18 @@ mlir::linalg::linalgOpToParallelLoops(PatternRewriter &rewriter, Operation *op); -static PassRegistration> +static PassRegistration> structuredLoopsPass( "convert-linalg-to-loops", "Lower the operations from the linalg dialect into loops"); static PassRegistration< - LowerLinalgToLoopsPass> + LowerLinalgToLoopsPass> parallelLoopsPass( "convert-linalg-to-parallel-loops", "Lower the operations from the linalg dialect into parallel loops"); -static PassRegistration> +static PassRegistration> affineLoopsPass( "convert-linalg-to-affine-loops", "Lower the operations from the linalg dialect into affine loops"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -14,9 +14,8 @@ #include "mlir/Dialect/Linalg/Analysis/DependenceAnalysis.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/VectorOps/VectorOps.h" -#include "mlir/EDSC/Helpers.h" -#include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" @@ -191,8 +190,6 @@ SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, Operation *op) { - using edsc::intrinsics::std_load; - using edsc::intrinsics::std_store; using vector_contract = edsc::intrinsics::ValueBuilder; using vector_broadcast = edsc::intrinsics::ValueBuilder; using vector_type_cast = edsc::intrinsics::ValueBuilder; diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -10,13 +10,14 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/LoopOps/LoopOps.h" -#include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -35,6 +36,13 @@ using namespace mlir::linalg; using namespace mlir::loop; +using folded_affine_min = folded::ValueBuilder; +using folded_std_constant_index = folded::ValueBuilder; +using folded_std_constant_float = folded::ValueBuilder; +using folded_std_dim = folded::ValueBuilder; +using folded_std_muli = folded::ValueBuilder; +using folded_linalg_range = folded::ValueBuilder; + using llvm::SetVector; #define DEBUG_TYPE "linalg-promotion" @@ -50,10 +58,10 @@ auto width = llvm::divideCeil(elementType.getIntOrFloatBitWidth(), 8); if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size.getDefiningOp())) - return alloc( + return std_alloc( MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx))); - Value mul = muli(constant_index(width), size); - return alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); + Value mul = std_muli(std_constant_index(width), size); + return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul); } // Performs promotion of a `subView` into a local buffer of the size of the @@ -77,8 +85,8 @@ SubViewOp subView, bool dynamicBuffers, OperationFolder *folder) { - auto zero = constant_index(folder, 0); - auto one = constant_index(folder, 1); + auto zero = folded_std_constant_index(folder, 0); + auto one = folded_std_constant_index(folder, 1); auto viewType = subView.getType(); auto rank = viewType.getRank(); @@ -90,15 +98,15 @@ auto rank = en.index(); auto rangeValue = en.value(); Value d = rangeValue.size; - allocSize = muli(folder, allocSize, d).getValue(); + allocSize = folded_std_muli(folder, allocSize, d).getValue(); fullRanges.push_back(d); partialRanges.push_back( - linalg_range(folder, zero, dim(subView, rank), one)); + folded_linalg_range(folder, zero, std_dim(subView, rank), one)); } SmallVector dynSizes(fullRanges.size(), -1); auto buffer = allocBuffer(viewType.getElementType(), allocSize, dynamicBuffers); - auto fullLocalView = view( + auto fullLocalView = std_view( MemRefType::get(dynSizes, viewType.getElementType()), buffer, fullRanges); auto partialLocalView = linalg_slice(fullLocalView, partialRanges); return PromotionInfo{buffer, fullLocalView, partialLocalView}; @@ -135,7 +143,7 @@ // TODO(ntv): value to fill with should be related to the operation. // For now, just use APFloat(0.0f). auto t = subView.getType().getElementType().cast(); - Value fillVal = constant_float(folder, APFloat(0.0f), t); + Value fillVal = folded_std_constant_float(folder, APFloat(0.0f), t); // TODO(ntv): fill is only necessary if `promotionInfo` has a full local // view that is different from the partial local view and we are on the // boundary. @@ -198,7 +206,7 @@ // 4. Dealloc local buffers. for (const auto &pi : promotedBufferAndViews) - dealloc(pi.buffer); + std_dealloc(pi.buffer); return res; } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -10,13 +10,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/Dialect/Linalg/EDSC/Builders.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/LoopOps/LoopOps.h" -#include "mlir/EDSC/Helpers.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineExprVisitor.h" #include "mlir/IR/AffineMap.h" @@ -34,6 +34,10 @@ using namespace mlir::linalg; using namespace mlir::loop; +using folded_affine_min = folded::ValueBuilder; +using folded_std_constant_index = folded::ValueBuilder; +using folded_std_dim = folded::ValueBuilder; + #define DEBUG_TYPE "linalg-tiling" static llvm::cl::OptionCategory clOptionsCategory(DEBUG_TYPE " options"); @@ -83,8 +87,8 @@ // Create a new range with the applied tile sizes. SmallVector res; for (unsigned idx = 0, e = tileSizes.size(); idx < e; ++idx) { - res.push_back(SubViewOp::Range{constant_index(folder, 0), viewSizes[idx], - tileSizes[idx]}); + res.push_back(SubViewOp::Range{folded_std_constant_index(folder, 0), + viewSizes[idx], tileSizes[idx]}); } return std::make_tuple(res, loopIndexToRangeIndex); } @@ -239,16 +243,15 @@ [](Value v) { return !isZero(v); })) && "expected as many ivs as non-zero sizes"); - using edsc::intrinsics::select; - using edsc::op::operator+; - using edsc::op::operator<; + using namespace edsc::op; // Construct (potentially temporary) mins and maxes on which to apply maps // that define tile subviews. SmallVector lbs, subViewSizes; for (unsigned idx = 0, idxIvs = 0, e = tileSizes.size(); idx < e; ++idx) { bool isTiled = !isZero(tileSizes[idx]); - lbs.push_back(isTiled ? ivs[idxIvs++] : (Value)constant_index(folder, 0)); + lbs.push_back(isTiled ? ivs[idxIvs++] + : (Value)folded_std_constant_index(folder, 0)); subViewSizes.push_back(isTiled ? tileSizes[idx] : viewSizes[idx]); } @@ -276,9 +279,9 @@ strides.reserve(rank); for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { - offsets.push_back(constant_index(folder, 0)); - sizes.push_back(dim(view, r)); - strides.push_back(constant_index(folder, 1)); + offsets.push_back(folded_std_constant_index(folder, 0)); + sizes.push_back(std_dim(view, r)); + strides.push_back(folded_std_constant_index(folder, 1)); continue; } @@ -302,13 +305,13 @@ {getAffineDimExpr(/*position=*/0, b.getContext()), getAffineDimExpr(/*position=*/1, b.getContext()) - getAffineDimExpr(/*position=*/2, b.getContext())}); - auto d = dim(folder, view, r); - size = affine_min(folder, b.getIndexType(), minMap, - ValueRange{size, d, offset}); + auto d = folded_std_dim(folder, view, r); + size = folded_affine_min(folder, b.getIndexType(), minMap, + ValueRange{size, d, offset}); } sizes.push_back(size); - strides.push_back(constant_index(folder, 1)); + strides.push_back(folded_std_constant_index(folder, 1)); } res.push_back(b.create(loc, view, offsets, sizes, strides)); @@ -367,8 +370,8 @@ // 3. Create the tiled loops. LinalgOp res = op; - SmallVector ivs(loopRanges.size()); - auto pivs = makeHandlePointers(MutableArrayRef(ivs)); + auto ivs = ValueHandle::makeIndexHandles(loopRanges.size()); + auto pivs = makeHandlePointers(MutableArrayRef(ivs)); // Convert SubViewOp::Range to linalg_range. SmallVector linalgRanges; for (auto &range : loopRanges) { @@ -434,11 +437,11 @@ SmallVector tileSizeValues; tileSizeValues.reserve(tileSizes.size()); for (auto ts : tileSizes) - tileSizeValues.push_back(constant_index(folder, ts)); + tileSizeValues.push_back(folded_std_constant_index(folder, ts)); // Pad tile sizes with zero values to enforce our convention. if (tileSizeValues.size() < nLoops) { for (unsigned i = tileSizeValues.size(); i < nLoops; ++i) - tileSizeValues.push_back(constant_index(folder, 0)); + tileSizeValues.push_back(folded_std_constant_index(folder, 0)); } return tileLinalgOpImpl(b, op, tileSizeValues, permutation, folder); diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -11,11 +11,11 @@ //===----------------------------------------------------------------------===// #include "mlir/Dialect/Linalg/Utils/Utils.h" +#include "mlir/Dialect/AffineOps/AffineOps.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" #include "mlir/Dialect/LoopOps/LoopOps.h" #include "mlir/Dialect/StandardOps/Ops.h" -#include "mlir/EDSC/Helpers.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Matchers.h" @@ -25,8 +25,6 @@ #include "mlir/Transforms/FoldUtils.h" using namespace mlir; -using namespace mlir::edsc; -using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::loop; diff --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp @@ -0,0 +1,92 @@ +//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" + +using namespace mlir; +using namespace mlir::edsc; + +mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder( + ArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps) { + assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); + assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); + + loops.emplace_back(makeParallelLoopBuilder(ivs, lbs, ubs, steps)); +} + +void mlir::edsc::ParallelLoopNestBuilder::operator()( + function_ref fun) { + if (fun) + fun(); + // Iterate on the calling operator() on all the loops in the nest. + // The iteration order is from innermost to outermost because enter/exit needs + // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() + // occurs on calling operator()). The asymmetry is required for properly + // nesting imperfectly nested regions (see LoopBuilder::operator()). + for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) + (*lit)(); +} + +mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef steps) { + assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match"); + assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match"); + assert(ivs.size() == steps.size() && + "expected size of ivs and steps to match"); + loops.reserve(ivs.size()); + for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { + loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it), + std::get<2>(it), std::get<3>(it))); + } + assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); +} + +void mlir::edsc::LoopNestBuilder::LoopNestBuilder::operator()( + std::function fun) { + if (fun) + fun(); + for (auto &lit : reverse(loops)) + lit({}); +} + +LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps) { + LoopBuilder result; + auto opHandle = OperationHandle::create( + SmallVector(lbHandles.begin(), lbHandles.end()), + SmallVector(ubHandles.begin(), ubHandles.end()), + SmallVector(steps.begin(), steps.end())); + + loop::ParallelOp parallelOp = + cast(*opHandle.getOperation()); + for (size_t i = 0, e = ivs.size(); i < e; ++i) + *ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i)); + result.enter(parallelOp.getBody(), /*prev=*/1); + return result; +} + +mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder(ValueHandle *iv, + ValueHandle lbHandle, + ValueHandle ubHandle, + ValueHandle stepHandle) { + mlir::edsc::LoopBuilder result; + auto forOp = + OperationHandle::createOp(lbHandle, ubHandle, stepHandle); + *iv = ValueHandle(forOp.getInductionVar()); + auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); + result.enter(body, /*prev=*/1); + return result; +} diff --git a/mlir/lib/Dialect/StandardOps/CMakeLists.txt b/mlir/lib/Dialect/StandardOps/CMakeLists.txt --- a/mlir/lib/Dialect/StandardOps/CMakeLists.txt +++ b/mlir/lib/Dialect/StandardOps/CMakeLists.txt @@ -1,6 +1,7 @@ file(GLOB globbed *.c *.cpp) add_llvm_library(MLIRStandardOps ${globbed} + EDSC/Intrinsics.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/StandardOps diff --git a/mlir/lib/EDSC/Helpers.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp rename from mlir/lib/EDSC/Helpers.cpp rename to mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Helpers.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp @@ -1,4 +1,4 @@ -//===- Helpers.cpp - MLIR Declarative Helper Functionality ----------------===// +//===- Builders.cpp - MLIR Declarative Builder Classes --------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +6,13 @@ // //===----------------------------------------------------------------------===// -#include "mlir/EDSC/Helpers.h" -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" using namespace mlir; using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; static SmallVector getMemRefSizes(Value memRef) { MemRefType memRefType = memRef.getType().cast(); @@ -21,32 +22,28 @@ res.reserve(memRefType.getShape().size()); const auto &shape = memRefType.getShape(); for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) { - if (shape[idx] == -1) { - res.push_back(ValueHandle::create(memRef, idx)); - } else { - res.push_back(static_cast(shape[idx])); - } + if (shape[idx] == -1) + res.push_back(std_dim(memRef, idx)); + else + res.push_back(std_constant_index(shape[idx])); } return res; } -mlir::edsc::MemRefView::MemRefView(Value v) : base(v) { - assert(v.getType().isa() && "MemRefType expected"); - +mlir::edsc::MemRefBoundsCapture::MemRefBoundsCapture(Value v) : base(v) { auto memrefSizeValues = getMemRefSizes(v); - for (auto &size : memrefSizeValues) { - lbs.push_back(static_cast(0)); - ubs.push_back(size); + for (auto s : memrefSizeValues) { + lbs.push_back(std_constant_index(0)); + ubs.push_back(s); steps.push_back(1); } } -mlir::edsc::VectorView::VectorView(Value v) : base(v) { +mlir::edsc::VectorBoundsCapture::VectorBoundsCapture(Value v) : base(v) { auto vectorType = v.getType().cast(); - for (auto s : vectorType.getShape()) { - lbs.push_back(static_cast(0)); - ubs.push_back(static_cast(s)); + lbs.push_back(std_constant_index(0)); + ubs.push_back(std_constant_index(s)); steps.push_back(1); } } diff --git a/mlir/lib/EDSC/Intrinsics.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp rename from mlir/lib/EDSC/Intrinsics.cpp rename to mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp --- a/mlir/lib/EDSC/Intrinsics.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp @@ -6,8 +6,7 @@ // //===----------------------------------------------------------------------===// -#include "mlir/EDSC/Intrinsics.h" -#include "mlir/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" using namespace mlir; diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -7,8 +7,8 @@ //===----------------------------------------------------------------------===// #include "mlir/EDSC/Builders.h" -#include "mlir/Dialect/StandardOps/Ops.h" #include "mlir/IR/AffineExpr.h" +#include "mlir/IR/AffineMap.h" #include "llvm/ADT/Optional.h" @@ -65,13 +65,6 @@ return getBuilder().getContext(); } -mlir::edsc::ValueHandle::ValueHandle(index_type cst) { - auto &b = ScopedContext::getBuilder(); - auto loc = ScopedContext::getLocation(); - v = b.create(loc, cst.v).getResult(); - t = v.getType(); -} - ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { assert(t == other.t && "Wrong type capture"); assert(!v && "ValueHandle has already been captured, use a new name!"); @@ -79,28 +72,13 @@ return *this; } -ValueHandle -mlir::edsc::ValueHandle::createComposedAffineApply(AffineMap map, - ArrayRef operands) { - Operation *op = - makeComposedAffineApply(ScopedContext::getBuilder(), - ScopedContext::getLocation(), map, operands) - .getOperation(); - assert(op->getNumResults() == 1 && "Not a single result AffineApply"); - return ValueHandle(op->getResult(0)); -} - ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes) { Operation *op = OperationHandle::create(name, operands, resultTypes, attributes); - if (op->getNumResults() == 1) { + if (op->getNumResults() == 1) return ValueHandle(op->getResult(0)); - } - if (auto f = dyn_cast(op)) { - return ValueHandle(f.getInductionVar()); - } llvm_unreachable("unsupported operation, use an OperationHandle instead"); } @@ -149,75 +127,6 @@ return res; } -static Optional emitStaticFor(ArrayRef lbs, - ArrayRef ubs, - int64_t step) { - if (lbs.size() != 1 || ubs.size() != 1) - return Optional(); - - auto *lbDef = lbs.front().getValue().getDefiningOp(); - auto *ubDef = ubs.front().getValue().getDefiningOp(); - if (!lbDef || !ubDef) - return Optional(); - - auto lbConst = dyn_cast(lbDef); - auto ubConst = dyn_cast(ubDef); - if (!lbConst || !ubConst) - return Optional(); - - return ValueHandle::create(lbConst.getValue(), - ubConst.getValue(), step); -} - -mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeAffine( - ValueHandle *iv, ArrayRef lbHandles, - ArrayRef ubHandles, int64_t step) { - mlir::edsc::LoopBuilder result; - if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { - *iv = staticFor.getValue(); - } else { - SmallVector lbs(lbHandles.begin(), lbHandles.end()); - SmallVector ubs(ubHandles.begin(), ubHandles.end()); - *iv = ValueHandle::create( - lbs, ScopedContext::getBuilder().getMultiDimIdentityMap(lbs.size()), - ubs, ScopedContext::getBuilder().getMultiDimIdentityMap(ubs.size()), - step); - } - auto *body = getForInductionVarOwner(iv->getValue()).getBody(); - result.enter(body, /*prev=*/1); - return result; -} - -mlir::edsc::LoopBuilder mlir::edsc::LoopBuilder::makeParallel( - ArrayRef ivs, ArrayRef lbHandles, - ArrayRef ubHandles, ArrayRef steps) { - mlir::edsc::LoopBuilder result; - auto opHandle = OperationHandle::create( - SmallVector(lbHandles.begin(), lbHandles.end()), - SmallVector(ubHandles.begin(), ubHandles.end()), - SmallVector(steps.begin(), steps.end())); - - loop::ParallelOp parallelOp = - cast(*opHandle.getOperation()); - for (size_t i = 0, e = ivs.size(); i < e; ++i) - *ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i)); - result.enter(parallelOp.getBody(), /*prev=*/1); - return result; -} - -mlir::edsc::LoopBuilder -mlir::edsc::LoopBuilder::makeLoop(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, - ValueHandle stepHandle) { - mlir::edsc::LoopBuilder result; - auto forOp = - OperationHandle::createOp(lbHandle, ubHandle, stepHandle); - *iv = ValueHandle(forOp.getInductionVar()); - auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); - result.enter(body, /*prev=*/1); - return result; -} - void mlir::edsc::LoopBuilder::operator()(function_ref fun) { // Call to `exit` must be explicit and asymmetric (cannot happen in the // destructor) because of ordering wrt comma operator. @@ -242,83 +151,6 @@ exit(); } -mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( - ValueHandle *iv, ArrayRef lbs, ArrayRef ubs, - int64_t step) { - loops.emplace_back(LoopBuilder::makeAffine(iv, lbs, ubs, step)); -} - -mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( - ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps) { - assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); - assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); - assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); - for (auto it : llvm::zip(ivs, lbs, ubs, steps)) - loops.emplace_back(LoopBuilder::makeAffine( - std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it))); -} - -void mlir::edsc::AffineLoopNestBuilder::operator()( - function_ref fun) { - if (fun) - fun(); - // Iterate on the calling operator() on all the loops in the nest. - // The iteration order is from innermost to outermost because enter/exit needs - // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() - // occurs on calling operator()). The asymmetry is required for properly - // nesting imperfectly nested regions (see LoopBuilder::operator()). - for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) - (*lit)(); -} - -mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder( - ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps) { - assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); - assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); - assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); - - loops.emplace_back(LoopBuilder::makeParallel(ivs, lbs, ubs, steps)); -} - -void mlir::edsc::ParallelLoopNestBuilder::operator()( - function_ref fun) { - if (fun) - fun(); - // Iterate on the calling operator() on all the loops in the nest. - // The iteration order is from innermost to outermost because enter/exit needs - // to be asymmetric (i.e. enter() occurs on LoopBuilder construction, exit() - // occurs on calling operator()). The asymmetry is required for properly - // nesting imperfectly nested regions (see LoopBuilder::operator()). - for (auto lit = loops.rbegin(), eit = loops.rend(); lit != eit; ++lit) - (*lit)(); -} - -mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, - ArrayRef lbs, - ArrayRef ubs, - ArrayRef steps) { - assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match"); - assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match"); - assert(ivs.size() == steps.size() && - "expected size of ivs and steps to match"); - loops.reserve(ivs.size()); - for (auto it : llvm::zip(ivs, lbs, ubs, steps)) { - loops.emplace_back(LoopBuilder::makeLoop(std::get<0>(it), std::get<1>(it), - std::get<2>(it), std::get<3>(it))); - } - assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); -} - -void LoopNestBuilder::LoopNestBuilder::operator()( - std::function fun) { - if (fun) - fun(); - for (auto &lit : reverse(loops)) - lit({}); -} - mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle bh, Append) { assert(bh && "Expected already captured BlockHandle"); enter(bh.getBlock()); @@ -367,194 +199,3 @@ fun(); exit(); } - -template -static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { - return ValueHandle::create(lhs.getValue(), rhs.getValue()); -} - -static std::pair -categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, - unsigned &numSymbols) { - AffineExpr d; - Value resultVal = nullptr; - if (auto constant = dyn_cast_or_null(val.getDefiningOp())) { - d = getAffineConstantExpr(constant.getValue(), context); - } else if (isValidSymbol(val) && !isValidDim(val)) { - d = getAffineSymbolExpr(numSymbols++, context); - resultVal = val; - } else { - d = getAffineDimExpr(numDims++, context); - resultVal = val; - } - return std::make_pair(d, resultVal); -} - -static ValueHandle createBinaryIndexHandle( - ValueHandle lhs, ValueHandle rhs, - function_ref affCombiner) { - MLIRContext *context = ScopedContext::getContext(); - unsigned numDims = 0, numSymbols = 0; - AffineExpr d0, d1; - Value v0, v1; - std::tie(d0, v0) = - categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); - std::tie(d1, v1) = - categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); - SmallVector operands; - if (v0) { - operands.push_back(v0); - } - if (v1) { - operands.push_back(v1); - } - auto map = AffineMap::get(numDims, numSymbols, {affCombiner(d0, d1)}); - // TODO: createOrFold when available. - return ValueHandle::createComposedAffineApply(map, operands); -} - -template -static ValueHandle createBinaryHandle( - ValueHandle lhs, ValueHandle rhs, - function_ref affCombiner) { - auto thisType = lhs.getValue().getType(); - auto thatType = rhs.getValue().getType(); - assert(thisType == thatType && "cannot mix types in operators"); - (void)thisType; - (void)thatType; - if (thisType.isIndex()) { - return createBinaryIndexHandle(lhs, rhs, affCombiner); - } else if (thisType.isa()) { - return createBinaryHandle(lhs, rhs); - } else if (thisType.isa()) { - return createBinaryHandle(lhs, rhs); - } else if (thisType.isa() || thisType.isa()) { - auto aggregateType = thisType.cast(); - if (aggregateType.getElementType().isa()) - return createBinaryHandle(lhs, rhs); - else if (aggregateType.getElementType().isa()) - return createBinaryHandle(lhs, rhs); - } - llvm_unreachable("failed to create a ValueHandle"); -} - -ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) { - return createBinaryHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); -} - -ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) { - return createBinaryHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); -} - -ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) { - return createBinaryHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); -} - -ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) { - return createBinaryHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { - llvm_unreachable("only exprs of non-index type support operator/"); - }); -} - -ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) { - return createBinaryHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); -} - -ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) { - return createBinaryIndexHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); -} - -ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) { - return createBinaryIndexHandle( - lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); -} - -ValueHandle mlir::edsc::op::operator!(ValueHandle value) { - assert(value.getType().isInteger(1) && "expected boolean expression"); - return ValueHandle::create(1, 1) - value; -} - -ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { - assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); - assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); - return lhs * rhs; -} - -ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { - return !(!lhs && !rhs); -} - -static ValueHandle createIComparisonExpr(CmpIPredicate predicate, - ValueHandle lhs, ValueHandle rhs) { - auto lhsType = lhs.getType(); - auto rhsType = rhs.getType(); - (void)lhsType; - (void)rhsType; - assert(lhsType == rhsType && "cannot mix types in operators"); - assert((lhsType.isa() || lhsType.isa()) && - "only integer comparisons are supported"); - - auto op = ScopedContext::getBuilder().create( - ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); - return ValueHandle(op.getResult()); -} - -static ValueHandle createFComparisonExpr(CmpFPredicate predicate, - ValueHandle lhs, ValueHandle rhs) { - auto lhsType = lhs.getType(); - auto rhsType = rhs.getType(); - (void)lhsType; - (void)rhsType; - assert(lhsType == rhsType && "cannot mix types in operators"); - assert(lhsType.isa() && "only float comparisons are supported"); - - auto op = ScopedContext::getBuilder().create( - ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); - return ValueHandle(op.getResult()); -} - -// All floating point comparison are ordered through EDSL -ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); -} -ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); -} -ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) - : - // TODO(ntv,zinenko): signed by default, how about unsigned? - createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); -} -ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); -} -ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); -} -ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { - auto type = lhs.getType(); - return type.isa() - ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) - : createIComparisonExpr(CmpIPredicate::sge, lhs, rhs); -} diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -8,13 +8,11 @@ // RUN: mlir-edsc-builder-api-test | FileCheck %s -#include "mlir/Dialect/AffineOps/AffineOps.h" -#include "mlir/Dialect/Linalg/EDSC/Builders.h" +#include "mlir/Dialect/AffineOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Linalg/EDSC/Intrinsics.h" -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" -#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Dialect/LoopOps/EDSC/Builders.h" +#include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/EDSC/Builders.h" -#include "mlir/EDSC/Helpers.h" #include "mlir/EDSC/Intrinsics.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Builders.h" @@ -34,6 +32,8 @@ #include "llvm/Support/raw_ostream.h" using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; static MLIRContext &globalContext() { static thread_local MLIRContext context; @@ -50,9 +50,6 @@ } TEST_FUNC(builder_dynamic_for_func_args) { - using namespace edsc; - using namespace edsc::op; - using namespace edsc::intrinsics; auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto f = @@ -62,16 +59,18 @@ ScopedContext scope(builder, f.getLoc()); ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)), ub(f.getArgument(1)); - ValueHandle f7(constant_float(llvm::APFloat(7.0f), f32Type)); - ValueHandle f13(constant_float(llvm::APFloat(13.0f), f32Type)); - ValueHandle i7(constant_int(7, 32)); - ValueHandle i13(constant_int(13, 32)); + ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)); + ValueHandle f13(std_constant_float(llvm::APFloat(13.0f), f32Type)); + ValueHandle i7(std_constant_int(7, 32)); + ValueHandle i13(std_constant_int(13, 32)); AffineLoopNestBuilder(&i, lb, ub, 3)([&] { - lb *index_type(3) + ub; - lb + index_type(3); + using namespace edsc::op; + lb *std_constant_index(3) + ub; + lb + std_constant_index(3); AffineLoopNestBuilder(&j, lb, ub, 2)([&] { - ceilDiv(index_type(31) * floorDiv(i + j * index_type(3), index_type(32)), - index_type(32)); + ceilDiv(std_constant_index(31) * floorDiv(i + j * std_constant_index(3), + std_constant_index(32)), + std_constant_index(32)); ((f7 + f13) / f7) % f13 - f7 *f13; ((i7 + i13) / i7) % i13 - i7 *i13; }); @@ -103,9 +102,6 @@ } TEST_FUNC(builder_dynamic_for) { - using namespace edsc; - using namespace edsc::op; - using namespace edsc::intrinsics; auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("builder_dynamic_for", {}, {indexType, indexType, indexType, indexType}); @@ -114,6 +110,7 @@ ScopedContext scope(builder, f.getLoc()); ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)), d(f.getArgument(3)); + using namespace edsc::op; AffineLoopNestBuilder(&i, a - b, c + d, 2)(); // clang-format off @@ -127,9 +124,6 @@ } TEST_FUNC(builder_loop_for) { - using namespace edsc; - using namespace edsc::op; - using namespace edsc::intrinsics; auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("builder_loop_for", {}, {indexType, indexType, indexType, indexType}); @@ -138,6 +132,7 @@ ScopedContext scope(builder, f.getLoc()); ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)), d(f.getArgument(3)); + using namespace edsc::op; LoopNestBuilder(&i, a - b, c + d, a)(); // clang-format off @@ -151,9 +146,6 @@ } TEST_FUNC(builder_max_min_for) { - using namespace edsc; - using namespace edsc::op; - using namespace edsc::intrinsics; auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("builder_max_min_for", {}, {indexType, indexType, indexType, indexType}); @@ -163,7 +155,7 @@ ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)), ub1(f.getArgument(2)), ub2(f.getArgument(3)); AffineLoopNestBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)(); - ret(); + std_ret(); // clang-format off // CHECK-LABEL: func @builder_max_min_for(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { @@ -175,8 +167,6 @@ } TEST_FUNC(builder_blocks) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto f = makeFunction("builder_blocks"); @@ -220,8 +210,6 @@ } TEST_FUNC(builder_blocks_eager) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto f = makeFunction("builder_blocks_eager"); @@ -264,8 +252,6 @@ } TEST_FUNC(builder_cond_branch) { - using namespace edsc; - using namespace edsc::intrinsics; auto f = makeFunction("builder_cond_branch", {}, {IntegerType::get(1, &globalContext())}); @@ -278,8 +264,8 @@ ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); BlockHandle b1, b2, functionBlock(&f.front()); - BlockBuilder(&b1, {&arg1})([&] { ret(); }); - BlockBuilder(&b2, {&arg2, &arg3})([&] { ret(); }); + BlockBuilder(&b1, {&arg1})([&] { std_ret(); }); + BlockBuilder(&b2, {&arg2, &arg3})([&] { std_ret(); }); // Get back to entry block and add a conditional branch BlockBuilder(functionBlock, Append())([&] { cond_br(funcArg, b1, {c32}, b2, {c64, c42}); @@ -301,8 +287,6 @@ } TEST_FUNC(builder_cond_branch_eager) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto f = makeFunction("builder_cond_branch_eager", {}, {IntegerType::get(1, &globalContext())}); @@ -319,10 +303,10 @@ BlockHandle b1, b2; cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42}); BlockBuilder(b1, Append())([]{ - ret(); + std_ret(); }); BlockBuilder(b2, Append())([]{ - ret(); + std_ret(); }); // CHECK-LABEL: @builder_cond_branch_eager @@ -340,9 +324,8 @@ } TEST_FUNC(builder_helpers) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; + auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, @@ -356,10 +339,12 @@ // clang-format off ValueHandle f7( ValueHandle::create(llvm::APFloat(7.0f), f32Type)); - MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), + MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); - IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - IndexHandle i, j, k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; + AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType), + lb0(indexType), lb1(indexType), lb2(indexType), + ub0(indexType), ub1(indexType), ub2(indexType); int64_t step0, step1, step2; std::tie(lb0, ub0, step0) = vA.range(0); std::tie(lb1, ub1, step1) = vA.range(1); @@ -398,8 +383,6 @@ } TEST_FUNC(custom_ops) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("custom_ops", {}, {indexType, indexType}); @@ -413,8 +396,9 @@ // clang-format off ValueHandle vh(indexType), vh20(indexType), vh21(indexType); OperationHandle ih0, ih2; - IndexHandle m, n, M(f.getArgument(0)), N(f.getArgument(1)); - IndexHandle ten(index_type(10)), twenty(index_type(20)); + ValueHandle m(indexType), n(indexType); + ValueHandle M(f.getArgument(0)), N(f.getArgument(1)); + ValueHandle ten(std_constant_index(10)), twenty(std_constant_index(20)); AffineLoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{ vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}); ih0 = MY_CUSTOM_OP_0({m, m + n}, {}); @@ -438,8 +422,6 @@ } TEST_FUNC(insertion_in_block) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto indexType = IndexType::get(&globalContext()); auto f = makeFunction("insertion_in_block", {}, {indexType, indexType}); @@ -463,23 +445,22 @@ f.erase(); } -TEST_FUNC(zero_and_sign_extendi_op_i1_to_i8) { - using namespace edsc; - using namespace edsc::intrinsics; +TEST_FUNC(zero_and_std_sign_extendi_op_i1_to_i8) { using namespace edsc::op; auto i1Type = IntegerType::get(1, &globalContext()); auto i8Type = IntegerType::get(8, &globalContext()); auto memrefType = MemRefType::get({}, i1Type, {}, 0); - auto f = makeFunction("zero_and_sign_extendi_op", {}, {memrefType, memrefType}); + auto f = makeFunction("zero_and_std_sign_extendi_op", {}, + {memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - IndexedValue A(f.getArgument(0)); - IndexedValue B(f.getArgument(1)); + AffineIndexedValue A(f.getArgument(0)); + AffineIndexedValue B(f.getArgument(1)); // clang-format off - edsc::intrinsics::zero_extendi(*A, i8Type); - edsc::intrinsics::sign_extendi(*B, i8Type); - // CHECK-LABEL: @zero_and_sign_extendi_op + edsc::intrinsics::std_zero_extendi(*A, i8Type); + edsc::intrinsics::std_sign_extendi(*B, i8Type); + // CHECK-LABEL: @zero_and_std_sign_extendi_op // CHECK: %[[SRC1:.*]] = affine.load // CHECK: zexti %[[SRC1]] : i1 to i8 // CHECK: %[[SRC2:.*]] = affine.load @@ -490,9 +471,8 @@ } TEST_FUNC(select_op_i32) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; + auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); @@ -501,16 +481,17 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle zero = constant_index(0), one = constant_index(1); - MemRefView vA(f.getArgument(0)); - IndexedValue A(f.getArgument(0)); - IndexHandle i, j; + ValueHandle zero = std_constant_index(0), one = std_constant_index(1); + MemRefBoundsCapture vA(f.getArgument(0)); + AffineIndexedValue A(f.getArgument(0)); + ValueHandle i(indexType), j(indexType); AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ - // This test exercises IndexedValue::operator Value. + // This test exercises AffineIndexedValue::operator Value. // Without it, one must force conversion to ValueHandle as such: - // edsc::intrinsics::select( + // std_select( // i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j))) - edsc::intrinsics::select(i == zero, *A(zero, zero), *A(i, j)); + using edsc::op::operator==; + std_select(i == zero, *A(zero, zero), *A(i, j)); }); // CHECK-LABEL: @select_op @@ -526,9 +507,7 @@ } TEST_FUNC(select_op_f32) { - using namespace edsc; - using namespace edsc::intrinsics; - using namespace edsc::op; + auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); @@ -537,18 +516,18 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle zero = constant_index(0), one = constant_index(1); - MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)); - IndexedValue A(f.getArgument(0)), B(f.getArgument(1)); - IndexHandle i, j; + ValueHandle zero = std_constant_index(0), one = std_constant_index(1); + MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)); + AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)); + ValueHandle i(indexType), j(indexType); AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ - - edsc::intrinsics::select(B(i, j) == B(i+one, j), *A(zero, zero), *A(i, j)); - edsc::intrinsics::select(B(i, j) != B(i+one, j), *A(zero, zero), *A(i, j)); - edsc::intrinsics::select(B(i, j) >= B(i+one, j), *A(zero, zero), *A(i, j)); - edsc::intrinsics::select(B(i, j) <= B(i+one, j), *A(zero, zero), *A(i, j)); - edsc::intrinsics::select(B(i, j) < B(i+one, j), *A(zero, zero), *A(i, j)); - edsc::intrinsics::select(B(i, j) > B(i+one, j), *A(zero, zero), *A(i, j)); + using namespace edsc::op; + std_select(B(i, j) == B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(B(i, j) != B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(B(i, j) >= B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(B(i, j) <= B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(B(i, j) < B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(B(i, j) > B(i + one, j), *A(zero, zero), *A(i, j)); }); // CHECK-LABEL: @select_op @@ -604,9 +583,7 @@ // Inject an EDSC-constructed computation to exercise imperfectly nested 2-d // tiling. TEST_FUNC(tile_2d) { - using namespace edsc; - using namespace edsc::intrinsics; - using namespace edsc::op; + auto indexType = IndexType::get(&globalContext()); auto memrefType = MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize}, @@ -615,12 +592,16 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = constant_index(0); - MemRefView vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); - IndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - IndexHandle i, j, k1, k2, M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); + ValueHandle zero = std_constant_index(0); + MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)), + vC(f.getArgument(2)); + AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), + C(f.getArgument(2)); + ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType); + ValueHandle M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); // clang-format off + using namespace edsc::op; AffineLoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{ AffineLoopNestBuilder(&k1, zero, O, 1)([&]{ C(i, j, k1) = A(i, j, k1) + B(i, j, k1); @@ -675,8 +656,6 @@ // Exercise StdIndexedValue for loads and stores. TEST_FUNC(indirect_access) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto memrefType = MemRefType::get({ShapedType::kDynamicSize}, FloatType::getF32(&globalContext()), {}, 0); @@ -685,11 +664,11 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = constant_index(0); - MemRefView vC(f.getArgument(2)); - IndexedValue B(f.getArgument(1)), D(f.getArgument(3)); + ValueHandle zero = std_constant_index(0); + MemRefBoundsCapture vC(f.getArgument(2)); + AffineIndexedValue B(f.getArgument(1)), D(f.getArgument(3)); StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2)); - IndexHandle i, N(vC.ub(0)); + ValueHandle i(builder.getIndexType()), N(vC.ub(0)); // clang-format off AffineLoopNestBuilder(&i, zero, N, 1)([&]{ @@ -711,8 +690,6 @@ // Exercise affine loads and stores build with empty maps. TEST_FUNC(empty_map_load_store) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto memrefType = MemRefType::get({}, FloatType::getF32(&globalContext()), {}, 0); @@ -721,10 +698,10 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = constant_index(0); - ValueHandle one = constant_index(1); - IndexedValue input(f.getArgument(0)), res(f.getArgument(1)); - IndexHandle iv; + ValueHandle zero = std_constant_index(0); + ValueHandle one = std_constant_index(1); + AffineIndexedValue input(f.getArgument(0)), res(f.getArgument(1)); + ValueHandle iv(builder.getIndexType()); // clang-format off AffineLoopNestBuilder(&iv, zero, one, 1)([&]{ @@ -749,8 +726,6 @@ // CHECK-NEXT: } else { // clang-format on TEST_FUNC(affine_if_op) { - using namespace edsc; - using namespace edsc::intrinsics; using namespace edsc::op; auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( @@ -760,7 +735,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = constant_index(0), ten = constant_index(10); + ValueHandle zero = std_constant_index(0), ten = std_constant_index(10); SmallVector isEq = {false, false, false, false}; SmallVector affineExprs = { @@ -927,9 +902,6 @@ // CHECK: linalg.reshape {{.*}} [affine_map<(d0, d1, d2) -> (d0, d1)>, affine_map<(d0, d1, d2) -> (d2)>] : memref<32x16xf32> into memref<4x8x16xf32> // clang-format on TEST_FUNC(linalg_metadata_ops) { - using namespace edsc; - using namespace edsc::intrinsics; - auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({4, 8, 16}, f32Type, {}, 0); auto f = makeFunction("linalg_metadata_ops", {}, {memrefType});