diff --git a/mlir/docs/Dialects/Linalg.md b/mlir/docs/Dialects/Linalg.md --- a/mlir/docs/Dialects/Linalg.md +++ b/mlir/docs/Dialects/Linalg.md @@ -531,9 +531,9 @@ void batchmatmul::regionBuilder(ArrayRef args) { using namespace edsc; using namespace intrinsics; - ValueHandle _0(args[0]), _1(args[1]), _2(args[2]); - ValueHandle _4 = std_mulf(_0, _1); - ValueHandle _5 = std_addf(_2, _4); + Value _0(args[0]), _1(args[1]), _2(args[2]); + Value _4 = std_mulf(_0, _1); + Value _5 = std_addf(_2, _4); (linalg_yield(ValueRange{ _5 })); } ``` diff --git a/mlir/docs/EDSC.md b/mlir/docs/EDSC.md --- a/mlir/docs/EDSC.md +++ b/mlir/docs/EDSC.md @@ -13,30 +13,17 @@ supporting a simple declarative API with globally accessible builders. These declarative builders are available within the lifetime of a `ScopedContext`. -## ValueHandle and IndexHandle - -`mlir::edsc::ValueHandle` and `mlir::edsc::IndexHandle` provide typed -abstractions around an `mlir::Value`. These abstractions are "delayed", in the -sense that they allow separating declaration from definition. They may capture -IR snippets, as they are built, for programmatic manipulation. Intuitive -operators are provided to allow concise and idiomatic expressions. - -```c++ -ValueHandle zero = std_constant_index(0); -IndexHandle i, j, k; -``` - ## Intrinsics -`mlir::edsc::ValueBuilder` is a generic wrapper for the `mlir::Builder::create` -method that operates on `ValueHandle` objects and return a single ValueHandle. -For instructions that return no values or that return multiple values, the -`mlir::edsc::InstructionBuilder` can be used. Named intrinsics are provided as +`mlir::ValueBuilder` is a generic wrapper for the `mlir::OpBuilder::create` +method that operates on `Value` objects and return a single Value. For +instructions that return no values or that return multiple values, the +`mlir::edsc::OperationBuilder` can be used. Named intrinsics are provided as syntactic sugar to further reduce boilerplate. ```c++ using load = ValueBuilder; -using store = InstructionBuilder; +using store = OperationBuilder; ``` ## LoopBuilder and AffineLoopNestBuilder @@ -46,14 +33,11 @@ ```c++ ScopedContext scope(f.get()); - ValueHandle i(indexType), - j(indexType), - lb(f->getArgument(0)), - ub(f->getArgument(1)); - ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)), - f13(std_constant_float(llvm::APFloat(13.0f), f32Type)), - i7(constant_int(7, 32)), - i13(constant_int(13, 32)); + Value i, j, lb(f->getArgument(0)), ub(f->getArgument(1)); + Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type)), + f13(std_constant_float(llvm::APFloat(13.0f), f32Type)), + i7(constant_int(7, 32)), + i13(constant_int(13, 32)); AffineLoopNestBuilder(&i, lb, ub, 3)([&]{ lb * index_type(3) + ub; lb + index_type(3); @@ -84,11 +68,10 @@ Arguments<(ins Tensor:$A, Tensor:$B)>, Results<(outs Tensor: $C)> { code referenceImplementation = [{ - auto ivs = makeIndexHandles(view_A.rank()); - auto pivs = makePIndexHandles(ivs); + SmallVector ivs(view_A.rank()); IndexedValue A(arg_A), B(arg_B), C(arg_C); - AffineLoopNestBuilder(pivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())( - [&]{ + AffineLoopNestBuilder( + ivs, view_A.getLbs(), view_A.getUbs(), view_A.getSteps())([&]{ C(ivs) = A(ivs) + B(ivs) }); }]; @@ -124,10 +107,4 @@ `LoopNestBuilder`. See the `builder-api-test.cpp` test for more usage examples. Since the implementation of declarative builders is in C++, it is also available -to program the IR with an embedded-DSL flavor directly integrated in MLIR. We -make use of these properties in the tutorial. - -Spoiler: MLIR also provides Python bindings for these builders, and a -full-fledged Python machine learning DSL with automatic differentiation -targeting MLIR was built as an early research collaboration. - +to program the IR with an embedded-DSL flavor directly integrated in MLIR. diff --git a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h @@ -23,12 +23,10 @@ namespace edsc { /// Constructs a new AffineForOp and captures the associated induction -/// variable. A ValueHandle pointer is passed as the first argument and is the +/// variable. A Value pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. -LoopBuilder makeAffineLoopBuilder(ValueHandle *iv, - ArrayRef lbHandles, - ArrayRef ubHandles, - int64_t step); +LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step); /// Explicit nested LoopBuilder. Offers a compressed multi-loop builder to avoid /// explicitly writing all the loops in a nest. This simple functionality is @@ -58,10 +56,10 @@ /// This entry point accommodates the fact that AffineForOp implicitly uses /// multiple `lbs` and `ubs` with one single `iv` and `step` to encode `max` /// and and `min` constraints respectively. - AffineLoopNestBuilder(ValueHandle *iv, ArrayRef lbs, - ArrayRef ubs, int64_t step); - AffineLoopNestBuilder(ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps); + AffineLoopNestBuilder(Value *iv, ArrayRef lbs, ArrayRef ubs, + int64_t step); + AffineLoopNestBuilder(MutableArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); void operator()(function_ref fun = nullptr); @@ -71,133 +69,134 @@ namespace op { -ValueHandle operator+(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator-(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator*(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator/(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator%(ValueHandle lhs, ValueHandle rhs); -ValueHandle floorDiv(ValueHandle lhs, ValueHandle rhs); -ValueHandle ceilDiv(ValueHandle lhs, ValueHandle rhs); - -ValueHandle operator!(ValueHandle value); -ValueHandle operator&&(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator||(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator^(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator==(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator!=(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator<(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator<=(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator>(ValueHandle lhs, ValueHandle rhs); -ValueHandle operator>=(ValueHandle lhs, ValueHandle rhs); +Value operator+(Value lhs, Value rhs); +Value operator-(Value lhs, Value rhs); +Value operator*(Value lhs, Value rhs); +Value operator/(Value lhs, Value rhs); +Value operator%(Value lhs, Value rhs); +Value floorDiv(Value lhs, Value rhs); +Value ceilDiv(Value lhs, Value rhs); + +/// Logical operator overloadings. +Value negate(Value value); +Value operator&&(Value lhs, Value rhs); +Value operator||(Value lhs, Value rhs); +Value operator^(Value lhs, Value rhs); + +/// Comparison operator overloadings. +Value eq(Value lhs, Value rhs); +Value ne(Value lhs, Value rhs); +Value operator<(Value lhs, Value rhs); +Value operator<=(Value lhs, Value rhs); +Value operator>(Value lhs, Value rhs); +Value operator>=(Value lhs, Value rhs); } // namespace op /// Arithmetic operator overloadings. template -ValueHandle TemplatedIndexedValue::operator+(ValueHandle e) { +Value TemplatedIndexedValue::operator+(Value e) { using op::operator+; - return static_cast(*this) + e; + return static_cast(*this) + e; } template -ValueHandle TemplatedIndexedValue::operator-(ValueHandle e) { +Value TemplatedIndexedValue::operator-(Value e) { using op::operator-; - return static_cast(*this) - e; + return static_cast(*this) - e; } template -ValueHandle TemplatedIndexedValue::operator*(ValueHandle e) { +Value TemplatedIndexedValue::operator*(Value e) { using op::operator*; - return static_cast(*this) * e; + return static_cast(*this) * e; } template -ValueHandle TemplatedIndexedValue::operator/(ValueHandle e) { +Value TemplatedIndexedValue::operator/(Value e) { using op::operator/; - return static_cast(*this) / e; + return static_cast(*this) / e; } template -ValueHandle TemplatedIndexedValue::operator%(ValueHandle e) { +Value TemplatedIndexedValue::operator%(Value e) { using op::operator%; - return static_cast(*this) % e; + return static_cast(*this) % e; } template -ValueHandle TemplatedIndexedValue::operator^(ValueHandle e) { +Value TemplatedIndexedValue::operator^(Value e) { using op::operator^; - return static_cast(*this) ^ e; + return static_cast(*this) ^ e; } /// Assignment-arithmetic operator overloadings. template -OperationHandle TemplatedIndexedValue::operator+=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator+=(Value e) { using op::operator+; return Store(*this + e, getBase(), {indices.begin(), indices.end()}); } template -OperationHandle TemplatedIndexedValue::operator-=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator-=(Value e) { using op::operator-; return Store(*this - e, getBase(), {indices.begin(), indices.end()}); } template -OperationHandle TemplatedIndexedValue::operator*=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator*=(Value e) { using op::operator*; return Store(*this * e, getBase(), {indices.begin(), indices.end()}); } template -OperationHandle TemplatedIndexedValue::operator/=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator/=(Value e) { using op::operator/; return Store(*this / e, getBase(), {indices.begin(), indices.end()}); } template -OperationHandle TemplatedIndexedValue::operator%=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator%=(Value e) { using op::operator%; return Store(*this % e, getBase(), {indices.begin(), indices.end()}); } template -OperationHandle TemplatedIndexedValue::operator^=(ValueHandle e) { +OperationHandle TemplatedIndexedValue::operator^=(Value e) { using op::operator^; return Store(*this ^ e, getBase(), {indices.begin(), indices.end()}); } /// Logical operator overloadings. template -ValueHandle TemplatedIndexedValue::operator&&(ValueHandle e) { +Value TemplatedIndexedValue::operator&&(Value e) { using op::operator&&; - return static_cast(*this) && e; + return static_cast(*this) && e; } template -ValueHandle TemplatedIndexedValue::operator||(ValueHandle e) { +Value TemplatedIndexedValue::operator||(Value e) { using op::operator||; - return static_cast(*this) || e; + return static_cast(*this) || e; } /// Comparison operator overloadings. template -ValueHandle TemplatedIndexedValue::operator==(ValueHandle e) { - using op::operator==; - return static_cast(*this) == e; +Value TemplatedIndexedValue::eq(Value e) { + return eq(value, e); } template -ValueHandle TemplatedIndexedValue::operator!=(ValueHandle e) { - using op::operator!=; - return static_cast(*this) != e; +Value TemplatedIndexedValue::ne(Value e) { + return ne(value, e); } template -ValueHandle TemplatedIndexedValue::operator<(ValueHandle e) { +Value TemplatedIndexedValue::operator<(Value e) { using op::operator<; - return static_cast(*this) < e; + return static_cast(*this) < e; } template -ValueHandle TemplatedIndexedValue::operator<=(ValueHandle e) { +Value TemplatedIndexedValue::operator<=(Value e) { using op::operator<=; - return static_cast(*this) <= e; + return static_cast(*this) <= e; } template -ValueHandle TemplatedIndexedValue::operator>(ValueHandle e) { +Value TemplatedIndexedValue::operator>(Value e) { using op::operator>; - return static_cast(*this) > e; + return static_cast(*this) > e; } template -ValueHandle TemplatedIndexedValue::operator>=(ValueHandle e) { +Value TemplatedIndexedValue::operator>=(Value e) { using op::operator>=; - return static_cast(*this) >= e; + return static_cast(*this) >= e; } } // namespace edsc diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/Builders.h @@ -42,11 +42,10 @@ class LoopRangeBuilder : public NestedBuilder { public: /// Constructs a new loop.for and captures the associated induction - /// variable. A ValueHandle pointer is passed as the first argument and is the + /// variable. A Value pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. - LoopRangeBuilder(ValueHandle *iv, ValueHandle range); - LoopRangeBuilder(ValueHandle *iv, Value range); - LoopRangeBuilder(ValueHandle *iv, SubViewOp::Range range); + LoopRangeBuilder(Value *iv, Value range); + LoopRangeBuilder(Value *iv, SubViewOp::Range range); LoopRangeBuilder(const LoopRangeBuilder &) = delete; LoopRangeBuilder(LoopRangeBuilder &&) = default; @@ -57,7 +56,7 @@ /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is /// scoped within a LoopRangeBuilder. - ValueHandle operator()(std::function fun = nullptr); + Value operator()(std::function fun = nullptr); }; /// Helper class to sugar building loop.for loop nests from ranges. @@ -65,13 +64,10 @@ /// directly. In the current implementation it produces loop.for operations. class LoopNestRangeBuilder { public: - LoopNestRangeBuilder(ArrayRef ivs, - ArrayRef ranges); - LoopNestRangeBuilder(ArrayRef ivs, - ArrayRef ranges); - LoopNestRangeBuilder(ArrayRef ivs, + LoopNestRangeBuilder(MutableArrayRef ivs, ArrayRef ranges); + LoopNestRangeBuilder(MutableArrayRef ivs, ArrayRef ranges); - edsc::ValueHandle operator()(std::function fun = nullptr); + Value operator()(std::function fun = nullptr); private: SmallVector loops; @@ -81,7 +77,7 @@ /// ranges. template class GenericLoopNestRangeBuilder { public: - GenericLoopNestRangeBuilder(ArrayRef ivs, + GenericLoopNestRangeBuilder(MutableArrayRef ivs, ArrayRef ranges); void operator()(std::function fun = nullptr) { (*builder)(fun); } @@ -124,7 +120,6 @@ namespace ops { using edsc::StructuredIndexed; -using edsc::ValueHandle; //===----------------------------------------------------------------------===// // EDSC builders for linalg generic operations. @@ -160,7 +155,7 @@ /// with in-place semantics and parallelism. /// Unary pointwise operation (with broadcast) entry point. -using UnaryPointwiseOpBuilder = function_ref; +using UnaryPointwiseOpBuilder = function_ref; Operation *linalg_generic_pointwise(UnaryPointwiseOpBuilder unaryOp, StructuredIndexed I, StructuredIndexed O); @@ -171,7 +166,7 @@ StructuredIndexed O); /// Binary pointwise operation (with broadcast) entry point. -using BinaryPointwiseOpBuilder = function_ref; +using BinaryPointwiseOpBuilder = function_ref; Operation *linalg_generic_pointwise(BinaryPointwiseOpBuilder binaryOp, StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O); @@ -202,7 +197,7 @@ /// | C(m, n) += A(m, k) * B(k, n) /// ``` Operation * -linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, +linalg_generic_matmul(Value vA, Value vB, Value vC, MatmulRegionBuilder regionBuilder = macRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current @@ -214,7 +209,7 @@ /// ``` /// and returns the tensor `C`. Operation * -linalg_generic_matmul(ValueHandle vA, ValueHandle vB, RankedTensorType tC, +linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC, MatmulRegionBuilder regionBuilder = mulRegionBuilder); /// Build a linalg.generic, under the current ScopedContext, at the current @@ -226,8 +221,7 @@ /// ``` /// and returns the tensor `D`. Operation * -linalg_generic_matmul(ValueHandle vA, ValueHandle vB, ValueHandle vC, - RankedTensorType tD, +linalg_generic_matmul(Value vA, Value vB, Value vC, RankedTensorType tD, MatmulRegionBuilder regionBuilder = macRegionBuilder); template @@ -260,8 +254,8 @@ /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// // TODO(ntv) Extend convolution rank with some template magic. -Operation *linalg_generic_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, ArrayRef strides = {}, +Operation *linalg_generic_conv_nhwc(Value vI, Value vW, Value vO, + ArrayRef strides = {}, ArrayRef dilations = {}); template @@ -295,8 +289,7 @@ /// For now `...` must be empty (i.e. only 2-D convolutions are supported). /// // TODO(ntv) Extend convolution rank with some template magic. -Operation *linalg_generic_dilated_conv_nhwc(ValueHandle vI, ValueHandle vW, - ValueHandle vO, +Operation *linalg_generic_dilated_conv_nhwc(Value vI, Value vW, Value vO, int depth_multiplier = 1, ArrayRef strides = {}, ArrayRef dilations = {}); diff --git a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h --- a/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/EDSC/FoldedIntrinsics.h @@ -15,16 +15,52 @@ namespace mlir { namespace edsc { +namespace intrinsics { -template -ValueHandle ValueHandle::create(OperationFolder *folder, Args... args) { - return folder ? ValueHandle(folder->create(ScopedContext::getBuilder(), - ScopedContext::getLocation(), - args...)) - : ValueHandle(ScopedContext::getBuilder().create( - ScopedContext::getLocation(), args...)); -} +template +struct FoldedValueBuilder { + // Builder-based + template + FoldedValueBuilder(OperationFolder *folder, Args... args) { + value = folder ? folder->create(ScopedContext::getBuilder(), + ScopedContext::getLocation(), args...) + : ScopedContext::getBuilder().create( + ScopedContext::getLocation(), args...); + } + operator Value() { return value; } + Value value; +}; + +using folded_std_constant_index = FoldedValueBuilder; +using folded_std_constant_float = FoldedValueBuilder; +using folded_std_constant_int = FoldedValueBuilder; +using folded_std_constant = FoldedValueBuilder; +using folded_std_dim = FoldedValueBuilder; +using folded_std_muli = FoldedValueBuilder; +using folded_std_addi = FoldedValueBuilder; +using folded_std_addf = FoldedValueBuilder; +using folded_std_alloc = FoldedValueBuilder; +using folded_std_constant = FoldedValueBuilder; +using folded_std_constant_float = FoldedValueBuilder; +using folded_std_constant_index = FoldedValueBuilder; +using folded_std_constant_int = FoldedValueBuilder; +using folded_std_dim = FoldedValueBuilder; +using folded_std_extract_element = FoldedValueBuilder; +using folded_std_index_cast = FoldedValueBuilder; +using folded_std_muli = FoldedValueBuilder; +using folded_std_mulf = FoldedValueBuilder; +using folded_std_memref_cast = FoldedValueBuilder; +using folded_std_select = FoldedValueBuilder; +using folded_std_load = FoldedValueBuilder; +using folded_std_subi = FoldedValueBuilder; +using folded_std_sub_view = FoldedValueBuilder; +using folded_std_tanh = FoldedValueBuilder; +using folded_std_tensor_load = FoldedValueBuilder; +using folded_std_view = FoldedValueBuilder; +using folded_std_zero_extendi = FoldedValueBuilder; +using folded_std_sign_extendi = FoldedValueBuilder; +} // namespace intrinsics } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/LoopOps/EDSC/Builders.h @@ -23,27 +23,30 @@ namespace edsc { /// Constructs a new loop::ParallelOp and captures the associated induction -/// variables. An array of ValueHandle pointers is passed as the first +/// variables. An array of Value pointers is passed as the first /// argument and is the *only* way to capture loop induction variables. -LoopBuilder makeParallelLoopBuilder(ArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, - ArrayRef steps); +LoopBuilder makeParallelLoopBuilder(MutableArrayRef ivs, + ArrayRef lbs, ArrayRef ubs, + ArrayRef steps); /// Constructs a new loop::ForOp and captures the associated induction -/// variable. A ValueHandle pointer is passed as the first argument and is the +/// variable. A Value pointer is passed as the first argument and is the /// *only* way to capture the loop induction variable. -LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, ValueHandle stepHandle, - ArrayRef iter_args_handles = {}, - ValueRange iter_args_init_values = {}); +LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step, + MutableArrayRef iterArgsHandles, + ValueRange iterArgsInitValues); +LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step, + MutableArrayRef iterArgsHandles, + ValueRange iterArgsInitValues); +inline LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step) { + return makeLoopBuilder(iv, lb, ub, step, MutableArrayRef{}, {}); +} /// Helper class to sugar building loop.parallel loop nests from lower/upper /// bounds and step sizes. class ParallelLoopNestBuilder { public: - ParallelLoopNestBuilder(ArrayRef ivs, - ArrayRef lbs, ArrayRef ubs, - ArrayRef steps); + ParallelLoopNestBuilder(MutableArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); void operator()(function_ref fun = nullptr); @@ -56,12 +59,12 @@ /// loop.for. class LoopNestBuilder { public: - LoopNestBuilder(ValueHandle *iv, ValueHandle lb, ValueHandle ub, - ValueHandle step, - ArrayRef iter_args_handles = {}, - ValueRange iter_args_init_values = {}); - LoopNestBuilder(ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps); + LoopNestBuilder(Value *iv, Value lb, Value ub, Value step); + LoopNestBuilder(Value *iv, Value lb, Value ub, Value step, + MutableArrayRef iterArgsHandles, + ValueRange iterArgsInitValues); + LoopNestBuilder(MutableArrayRef ivs, ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); Operation::result_range operator()(std::function fun = nullptr); private: diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Builders.h @@ -20,27 +20,27 @@ class BoundsCapture { public: unsigned rank() const { return lbs.size(); } - ValueHandle lb(unsigned idx) { return lbs[idx]; } - ValueHandle ub(unsigned idx) { return ubs[idx]; } + Value lb(unsigned idx) { return lbs[idx]; } + Value ub(unsigned idx) { return ubs[idx]; } int64_t step(unsigned idx) { return steps[idx]; } - std::tuple range(unsigned idx) { + std::tuple range(unsigned idx) { return std::make_tuple(lbs[idx], ubs[idx], steps[idx]); } void swapRanges(unsigned i, unsigned j) { if (i == j) return; - lbs[i].swap(lbs[j]); - ubs[i].swap(ubs[j]); + std::swap(lbs[i], lbs[j]); + std::swap(ubs[i], ubs[j]); std::swap(steps[i], steps[j]); } - ArrayRef getLbs() { return lbs; } - ArrayRef getUbs() { return ubs; } + ArrayRef getLbs() { return lbs; } + ArrayRef getUbs() { return ubs; } ArrayRef getSteps() { return steps; } protected: - SmallVector lbs; - SmallVector ubs; + SmallVector lbs; + SmallVector ubs; SmallVector steps; }; @@ -58,7 +58,7 @@ unsigned fastestVarying() const { return rank() - 1; } private: - ValueHandle base; + Value base; }; /// A VectorBoundsCapture represents the information required to step through a @@ -72,7 +72,7 @@ VectorBoundsCapture &operator=(const VectorBoundsCapture &) = default; private: - ValueHandle base; + Value base; }; } // namespace edsc diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -14,40 +14,6 @@ namespace mlir { namespace edsc { namespace intrinsics { -namespace folded { -/// Helper variadic abstraction to allow extending to any MLIR op without -/// boilerplate or Tablegen. -/// Arguably a builder is not a ValueHandle but in practice it is only used as -/// an alias to a notional ValueHandle. -/// Implementing it as a subclass allows it to compose all the way to Value. -/// Without subclassing, implicit conversion to Value would fail when composing -/// in patterns such as: `select(a, b, select(c, d, e))`. -template -struct ValueBuilder : public ValueHandle { - /// Folder-based - template - ValueBuilder(OperationFolder *folder, Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(args)...)) {} - ValueBuilder(OperationFolder *folder, ArrayRef vs) - : ValueBuilder(ValueBuilder::create(folder, detail::unpack(vs))) {} - template - ValueBuilder(OperationFolder *folder, ArrayRef vs, Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(vs), - detail::unpack(args)...)) {} - template - ValueBuilder(OperationFolder *folder, T t, ArrayRef vs, - Args... args) - : ValueHandle(ValueHandle::create(folder, detail::unpack(t), - detail::unpack(vs), - detail::unpack(args)...)) {} - template - ValueBuilder(OperationFolder *folder, T1 t1, T2 t2, ArrayRef vs, - Args... args) - : ValueHandle(ValueHandle::create( - folder, detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), - detail::unpack(args)...)) {} -}; -} // namespace folded using std_addf = ValueBuilder; using std_alloc = ValueBuilder; @@ -80,7 +46,7 @@ /// /// Prerequisites: /// All Handles have already captured previously constructed IR objects. -OperationHandle std_br(BlockHandle bh, ArrayRef operands); +OperationHandle std_br(BlockHandle bh, ArrayRef operands); /// Creates a new mlir::Block* and branches to it from the current block. /// Argument types are specified by `operands`. @@ -95,8 +61,9 @@ /// All `operands` have already captured an mlir::Value /// captures.size() == operands.size() /// captures and operands are pairwise of the same type. -OperationHandle std_br(BlockHandle *bh, ArrayRef captures, - ArrayRef operands); +OperationHandle std_br(BlockHandle *bh, ArrayRef types, + MutableArrayRef captures, + ArrayRef operands); /// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with /// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and @@ -104,10 +71,10 @@ /// /// Prerequisites: /// All Handles have captured previously constructed IR objects. -OperationHandle std_cond_br(ValueHandle cond, BlockHandle trueBranch, - ArrayRef trueOperands, +OperationHandle std_cond_br(Value cond, BlockHandle trueBranch, + ArrayRef trueOperands, BlockHandle falseBranch, - ArrayRef falseOperands); + ArrayRef falseOperands); /// Eagerly creates new mlir::Block* with argument types specified by /// `trueOperands`/`falseOperands`. @@ -125,45 +92,17 @@ /// `falseCaptures`.size() == `falseOperands`.size() /// `trueCaptures` and `trueOperands` are pairwise of the same type /// `falseCaptures` and `falseOperands` are pairwise of the same type. -OperationHandle std_cond_br(ValueHandle cond, BlockHandle *trueBranch, - ArrayRef trueCaptures, - ArrayRef trueOperands, - BlockHandle *falseBranch, - ArrayRef falseCaptures, - ArrayRef falseOperands); +OperationHandle std_cond_br(Value cond, BlockHandle *trueBranch, + ArrayRef trueTypes, + MutableArrayRef trueCaptures, + ArrayRef trueOperands, + BlockHandle *falseBranch, ArrayRef falseTypes, + MutableArrayRef falseCaptures, + ArrayRef falseOperands); /// Provide an index notation around sdt_load and std_store. using StdIndexedValue = TemplatedIndexedValue; - -using folded_std_constant_index = folded::ValueBuilder; -using folded_std_constant_float = folded::ValueBuilder; -using folded_std_constant_int = folded::ValueBuilder; -using folded_std_constant = folded::ValueBuilder; -using folded_std_dim = folded::ValueBuilder; -using folded_std_muli = folded::ValueBuilder; -using folded_std_addi = folded::ValueBuilder; -using folded_std_addf = folded::ValueBuilder; -using folded_std_alloc = folded::ValueBuilder; -using folded_std_constant = folded::ValueBuilder; -using folded_std_constant_float = folded::ValueBuilder; -using folded_std_constant_index = folded::ValueBuilder; -using folded_std_constant_int = folded::ValueBuilder; -using folded_std_dim = folded::ValueBuilder; -using folded_std_extract_element = folded::ValueBuilder; -using folded_std_index_cast = folded::ValueBuilder; -using folded_std_muli = folded::ValueBuilder; -using folded_std_mulf = folded::ValueBuilder; -using folded_std_memref_cast = folded::ValueBuilder; -using folded_std_select = folded::ValueBuilder; -using folded_std_load = folded::ValueBuilder; -using folded_std_subi = folded::ValueBuilder; -using folded_std_sub_view = folded::ValueBuilder; -using folded_std_tanh = folded::ValueBuilder; -using folded_std_tensor_load = folded::ValueBuilder; -using folded_std_view = folded::ValueBuilder; -using folded_std_zero_extendi = folded::ValueBuilder; -using folded_std_sign_extendi = folded::ValueBuilder; } // namespace intrinsics } // namespace edsc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Vector/EDSC/Intrinsics.h @@ -18,6 +18,7 @@ using vector_contract = ValueBuilder; using vector_matmul = ValueBuilder; using vector_print = OperationBuilder; +using vector_type_cast = ValueBuilder; } // namespace intrinsics } // namespace edsc diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -24,9 +24,7 @@ namespace edsc { class BlockHandle; -class CapturableHandle; class NestedBuilder; -class ValueHandle; /// Helper class to transparently handle builder insertion points by RAII. /// As its name indicates, a ScopedContext is means to be used locally in a @@ -70,10 +68,23 @@ /// Defensively keeps track of the current NestedBuilder to ensure proper /// scoping usage. NestedBuilder *nestedBuilder; +}; + +template +struct ValueBuilder { + // Builder-based + template + ValueBuilder(Args... args) { + Operation *op = ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), args...) + .getOperation(); + if (op->getNumResults() != 1) + llvm_unreachable("unsupported operation, use OperationBuilder instead"); + value = op->getResult(0); + } - // TODO: Implement scoping of ValueHandles. To do this we need a proper data - // structure to hold ValueHandle objects. We can emulate one but there should - // already be something available in LLVM for this purpose. + operator Value() { return value; } + Value value; }; /// A NestedBuilder is a scoping abstraction to create an idiomatic syntax @@ -82,8 +93,7 @@ /// exists between object construction and method invocation on said object (in /// our case, the call to `operator()`). /// This ordering allows implementing an abstraction that decouples definition -/// from declaration (in a PL sense) on placeholders of type ValueHandle and -/// BlockHandle. +/// from declaration (in a PL sense) on placeholders. class NestedBuilder { protected: NestedBuilder() = default; @@ -158,19 +168,17 @@ private: LoopBuilder() = default; - friend LoopBuilder makeAffineLoopBuilder(ValueHandle *iv, - ArrayRef lbHandles, - ArrayRef ubHandles, + friend LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef lbHandles, + ArrayRef ubHandles, int64_t step); - friend LoopBuilder makeParallelLoopBuilder(ArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, - ArrayRef steps); - friend LoopBuilder makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, - ValueHandle stepHandle, - ArrayRef iter_args_handles, - ValueRange iter_args_init_values); + friend LoopBuilder makeParallelLoopBuilder(MutableArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps); + friend LoopBuilder makeLoopBuilder(Value *iv, Value lbHandle, Value ubHandle, + Value stepHandle, + MutableArrayRef iterArgsHandles, + ValueRange iterArgsInitValues); Operation *op; }; @@ -194,9 +202,11 @@ /// Enters the new mlir::Block* and sets the insertion point to its end. /// /// Prerequisites: - /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are + /// The Value `args` are typed delayed Values; i.e. they are /// not yet bound to mlir::Value. - BlockBuilder(BlockHandle *bh, ArrayRef args); + BlockBuilder(BlockHandle *bh) : BlockBuilder(bh, {}, {}) {} + BlockBuilder(BlockHandle *bh, ArrayRef types, + MutableArrayRef args); /// Constructs a new mlir::Block with argument types derived from `args` and /// appends it as the last block in the region. @@ -204,9 +214,10 @@ /// Enters the new mlir::Block* and sets the insertion point to its end. /// /// Prerequisites: - /// The ValueHandle `args` are typed delayed ValueHandles; i.e. they are + /// The Value `args` are typed delayed Values; i.e. they are /// not yet bound to mlir::Value. - BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef args); + BlockBuilder(BlockHandle *bh, Region ®ion, ArrayRef types, + MutableArrayRef args); /// The only purpose of this operator is to serve as a sequence point so that /// the evaluation of `fun` (which build IR snippets in a scoped fashion) is @@ -218,120 +229,18 @@ BlockBuilder &operator=(BlockBuilder &other) = delete; }; -/// Base class for ValueHandle, OperationHandle and BlockHandle. +/// Base class for Value, OperationHandle and BlockHandle. /// Not meant to be used outside of these classes. class CapturableHandle { protected: CapturableHandle() = default; }; -/// ValueHandle implements a (potentially "delayed") typed Value abstraction. -/// ValueHandle should be captured by pointer but otherwise passed by Value -/// everywhere. -/// A ValueHandle can have 3 states: -/// 1. null state (empty type and empty value), in which case it does not hold -/// a value and must never hold a Value (now or in the future). This is -/// used for MLIR operations with zero returns as well as the result of -/// calling a NestedBuilder::operator(). In both cases the objective is to -/// have an object that can be inserted in an ArrayRef to -/// implement nesting; -/// 2. delayed state (empty value), in which case it represents an eagerly -/// typed "delayed" value that can be hold a Value in the future; -/// 3. constructed state,in which case it holds a Value. -/// -/// A ValueHandle is meant to capture a single Value and should be used for -/// operations that have a single result. For convenience of use, we also -/// include AffineForOp in this category although it does not return a value. -/// In the case of AffineForOp, the captured Value is the loop induction -/// variable. -class ValueHandle : public CapturableHandle { -public: - /// A ValueHandle in a null state can never be captured; - static ValueHandle null() { return ValueHandle(); } - - /// A ValueHandle that is constructed from a Type represents a typed "delayed" - /// Value. A delayed Value can only capture Values of the specified type. - /// Such a delayed value represents the declaration (in the PL sense) of a - /// placeholder for an mlir::Value that will be constructed and captured at - /// some later point in the program. - explicit ValueHandle(Type t) : t(t), v(nullptr) {} - - /// A ValueHandle that is constructed from an mlir::Value is an "eager" - /// Value. An eager Value represents both the declaration and the definition - /// (in the PL sense) of a placeholder for an mlir::Value that has already - /// been constructed in the past and that is captured "now" in the program. - explicit ValueHandle(Value v) : t(v.getType()), v(v) {} - - /// ValueHandle is a value type, use the default copy constructor. - ValueHandle(const ValueHandle &other) = default; - - /// ValueHandle is a value type, the assignment operator typechecks before - /// assigning. - ValueHandle &operator=(const ValueHandle &other); - - /// Provide a swap operator. - void swap(ValueHandle &other) { - if (this == &other) - return; - std::swap(t, other.t); - std::swap(v, other.v); - } - - /// Implicit conversion useful for automatic conversion to Container. - operator Value() const { return getValue(); } - operator Type() const { return getType(); } - operator bool() const { return hasValue(); } - - /// Generic mlir::Op create. This is the key to being extensible to the whole - /// of MLIR without duplicating the type system or the op definitions. - template - static ValueHandle create(Args... args); - - /// Generic mlir::Op create. This is the key to being extensible to the whole - /// of MLIR without duplicating the type system or the op definitions. - /// When non-null, the optional pointer `folder` is used to call into the - /// `createAndFold` builder method. If `folder` is null, the regular `create` - /// method is called. - template - static ValueHandle create(OperationFolder *folder, Args... args); - - /// Generic create for a named operation producing a single value. - static ValueHandle create(StringRef name, ArrayRef operands, - ArrayRef resultTypes, - ArrayRef attributes = {}); - - bool hasValue() const { return v != nullptr; } - Value getValue() const { - assert(hasValue() && "Unexpected null value;"); - return v; - } - bool hasType() const { return t != Type(); } - Type getType() const { return t; } - - Operation *getOperation() const { - if (!v) - return nullptr; - return v.getDefiningOp(); - } - - // Return a vector of fresh ValueHandles that have not captured. - static SmallVector makeIndexHandles(unsigned count) { - auto indexType = IndexType::get(ScopedContext::getContext()); - return SmallVector(count, ValueHandle(indexType)); - } - -protected: - ValueHandle() : t(), v(nullptr) {} - - Type t; - Value v; -}; - -/// An OperationHandle can be used in lieu of ValueHandle to capture the +/// An OperationHandle can be used in lieu of Value to capture the /// operation in cases when one does not care about, or cannot extract, a /// unique Value from the operation. /// This can be used for capturing zero result operations as well as -/// multi-result operations that are not supported by ValueHandle. +/// multi-result operations that are not supported by Value. /// We do not distinguish further between zero and multi-result operations at /// this time. struct OperationHandle : public CapturableHandle { @@ -349,7 +258,7 @@ static Op createOp(Args... args); /// Generic create for a named operation. - static OperationHandle create(StringRef name, ArrayRef operands, + static OperationHandle create(StringRef name, ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes = {}); @@ -360,23 +269,6 @@ Operation *op; }; -/// Simple wrapper to build a generic operation without successor blocks. -template -struct CustomOperation { - CustomOperation(StringRef name) : name(name) { - static_assert(std::is_same() || - std::is_same(), - "Only CustomOperation or " - "CustomOperation can be constructed."); - } - HandleType operator()(ArrayRef operands = {}, - ArrayRef resultTypes = {}, - ArrayRef attributes = {}) { - return HandleType::create(name, operands, resultTypes, attributes); - } - std::string name; -}; - /// A BlockHandle represents a (potentially "delayed") Block abstraction. /// This extra abstraction is necessary because an mlir::Block is not an /// mlir::Value. @@ -427,32 +319,45 @@ /// C(buffer_value_or_tensor_type); /// makeGenericLinalgOp({A({m, n}), B({k, n})}, {C({m, n})}, ... ); /// ``` -struct StructuredIndexed : public ValueHandle { - StructuredIndexed(Type type) : ValueHandle(type) {} - StructuredIndexed(Value value) : ValueHandle(value) {} - StructuredIndexed(ValueHandle valueHandle) : ValueHandle(valueHandle) {} +struct StructuredIndexed { + StructuredIndexed(Value v) : value(v) {} + StructuredIndexed(Type t) : type(t) {} StructuredIndexed operator()(ArrayRef indexings) { - return this->hasValue() ? StructuredIndexed(this->getValue(), indexings) - : StructuredIndexed(this->getType(), indexings); + return value ? StructuredIndexed(value, indexings) + : StructuredIndexed(type, indexings); } - StructuredIndexed(Type t, ArrayRef indexings) - : ValueHandle(t), exprs(indexings.begin(), indexings.end()) { - assert(t.isa() && "RankedTensor expected"); - } StructuredIndexed(Value v, ArrayRef indexings) - : ValueHandle(v), exprs(indexings.begin(), indexings.end()) { + : value(v), exprs(indexings.begin(), indexings.end()) { assert((v.getType().isa() || v.getType().isa() || v.getType().isa()) && "MemRef, RankedTensor or Vector expected"); } - StructuredIndexed(ValueHandle vh, ArrayRef indexings) - : ValueHandle(vh), exprs(indexings.begin(), indexings.end()) {} + StructuredIndexed(Type t, ArrayRef indexings) + : type(t), exprs(indexings.begin(), indexings.end()) { + assert((t.isa() || t.isa() || + t.isa()) && + "MemRef, RankedTensor or Vector expected"); + } - ArrayRef getExprs() { return exprs; } + bool hasValue() const { return value; } + Value getValue() const { + assert(value && "StructuredIndexed Value not set."); + return value; + } + Type getType() const { + assert((value || type) && "StructuredIndexed Value and Type not set."); + return value ? value.getType() : type; + } + ArrayRef getExprs() const { return exprs; } + operator Value() const { return getValue(); } + operator Type() const { return getType(); } private: + // Only one of Value or type may be set. + Type type; + Value value; SmallVector exprs; }; @@ -472,179 +377,139 @@ .getOperation()); } -template -ValueHandle ValueHandle::create(Args... args) { - Operation *op = ScopedContext::getBuilder() - .create(ScopedContext::getLocation(), args...) - .getOperation(); - if (op->getNumResults() == 1) - return ValueHandle(op->getResult(0)); - llvm_unreachable("unsupported operation, use an OperationHandle instead"); -} - -/// Entry point to build multiple ValueHandle from a `Container` of Value or -/// Type. -template -inline SmallVector makeValueHandles(Container values) { - SmallVector res; - res.reserve(values.size()); - for (auto v : values) - res.push_back(ValueHandle(v)); - return res; -} - /// A TemplatedIndexedValue brings an index notation over the template Load and /// Store parameters. Assigning to an IndexedValue emits an actual `Store` -/// operation, while converting an IndexedValue to a ValueHandle emits an actual +/// operation, while converting an IndexedValue to a Value emits an actual /// `Load` operation. template class TemplatedIndexedValue { public: - explicit TemplatedIndexedValue(Type t) : base(t) {} - explicit TemplatedIndexedValue(Value v) - : TemplatedIndexedValue(ValueHandle(v)) {} - explicit TemplatedIndexedValue(ValueHandle v) : base(v) {} + explicit TemplatedIndexedValue(Value v) : value(v) {} TemplatedIndexedValue(const TemplatedIndexedValue &rhs) = default; TemplatedIndexedValue operator()() { return *this; } /// Returns a new `TemplatedIndexedValue`. - TemplatedIndexedValue operator()(ValueHandle index) { - TemplatedIndexedValue res(base); + TemplatedIndexedValue operator()(Value index) { + TemplatedIndexedValue res(value); res.indices.push_back(index); return res; } template - TemplatedIndexedValue operator()(ValueHandle index, Args... indices) { - return TemplatedIndexedValue(base, index).append(indices...); + TemplatedIndexedValue operator()(Value index, Args... indices) { + return TemplatedIndexedValue(value, index).append(indices...); } - TemplatedIndexedValue operator()(ArrayRef indices) { - return TemplatedIndexedValue(base, indices); + TemplatedIndexedValue operator()(ArrayRef indices) { + return TemplatedIndexedValue(value, indices); } /// Emits a `store`. OperationHandle operator=(const TemplatedIndexedValue &rhs) { - ValueHandle rrhs(rhs); - return Store(rrhs, getBase(), {indices.begin(), indices.end()}); - } - OperationHandle operator=(ValueHandle rhs) { - return Store(rhs, getBase(), {indices.begin(), indices.end()}); - } - - /// Emits a `load` when converting to a ValueHandle. - operator ValueHandle() const { - return Load(getBase(), {indices.begin(), indices.end()}); + return Store(rhs, value, indices); } + OperationHandle operator=(Value rhs) { return Store(rhs, value, indices); } /// Emits a `load` when converting to a Value. - Value operator*(void)const { - return Load(getBase(), {indices.begin(), indices.end()}).getValue(); - } + operator Value() const { return Load(value, indices); } - ValueHandle getBase() const { return base; } + Value getBase() const { return value; } /// Arithmetic operator overloadings. - ValueHandle operator+(ValueHandle e); - ValueHandle operator-(ValueHandle e); - ValueHandle operator*(ValueHandle e); - ValueHandle operator/(ValueHandle e); - ValueHandle operator%(ValueHandle e); - ValueHandle operator^(ValueHandle e); - ValueHandle operator+(TemplatedIndexedValue e) { - return *this + static_cast(e); + Value operator+(Value e); + Value operator-(Value e); + Value operator*(Value e); + Value operator/(Value e); + Value operator%(Value e); + Value operator^(Value e); + Value operator+(TemplatedIndexedValue e) { + return *this + static_cast(e); } - ValueHandle operator-(TemplatedIndexedValue e) { - return *this - static_cast(e); + Value operator-(TemplatedIndexedValue e) { + return *this - static_cast(e); } - ValueHandle operator*(TemplatedIndexedValue e) { - return *this * static_cast(e); + Value operator*(TemplatedIndexedValue e) { + return *this * static_cast(e); } - ValueHandle operator/(TemplatedIndexedValue e) { - return *this / static_cast(e); + Value operator/(TemplatedIndexedValue e) { + return *this / static_cast(e); } - ValueHandle operator%(TemplatedIndexedValue e) { - return *this % static_cast(e); + Value operator%(TemplatedIndexedValue e) { + return *this % static_cast(e); } - ValueHandle operator^(TemplatedIndexedValue e) { - return *this ^ static_cast(e); + Value operator^(TemplatedIndexedValue e) { + return *this ^ static_cast(e); } /// Assignment-arithmetic operator overloadings. - OperationHandle operator+=(ValueHandle e); - OperationHandle operator-=(ValueHandle e); - OperationHandle operator*=(ValueHandle e); - OperationHandle operator/=(ValueHandle e); - OperationHandle operator%=(ValueHandle e); - OperationHandle operator^=(ValueHandle e); + OperationHandle operator+=(Value e); + OperationHandle operator-=(Value e); + OperationHandle operator*=(Value e); + OperationHandle operator/=(Value e); + OperationHandle operator%=(Value e); + OperationHandle operator^=(Value e); OperationHandle operator+=(TemplatedIndexedValue e) { - return this->operator+=(static_cast(e)); + return this->operator+=(static_cast(e)); } OperationHandle operator-=(TemplatedIndexedValue e) { - return this->operator-=(static_cast(e)); + return this->operator-=(static_cast(e)); } OperationHandle operator*=(TemplatedIndexedValue e) { - return this->operator*=(static_cast(e)); + return this->operator*=(static_cast(e)); } OperationHandle operator/=(TemplatedIndexedValue e) { - return this->operator/=(static_cast(e)); + return this->operator/=(static_cast(e)); } OperationHandle operator%=(TemplatedIndexedValue e) { - return this->operator%=(static_cast(e)); + return this->operator%=(static_cast(e)); } OperationHandle operator^=(TemplatedIndexedValue e) { - return this->operator^=(static_cast(e)); + return this->operator^=(static_cast(e)); } /// Logical operator overloadings. - ValueHandle operator&&(ValueHandle e); - ValueHandle operator||(ValueHandle e); - ValueHandle operator&&(TemplatedIndexedValue e) { - return *this && static_cast(e); + Value operator&&(Value e); + Value operator||(Value e); + Value operator&&(TemplatedIndexedValue e) { + return *this && static_cast(e); } - ValueHandle operator||(TemplatedIndexedValue e) { - return *this || static_cast(e); + Value operator||(TemplatedIndexedValue e) { + return *this || static_cast(e); } /// Comparison operator overloadings. - ValueHandle operator==(ValueHandle e); - ValueHandle operator!=(ValueHandle e); - ValueHandle operator<(ValueHandle e); - ValueHandle operator<=(ValueHandle e); - ValueHandle operator>(ValueHandle e); - ValueHandle operator>=(ValueHandle e); - ValueHandle operator==(TemplatedIndexedValue e) { - return *this == static_cast(e); - } - ValueHandle operator!=(TemplatedIndexedValue e) { - return *this != static_cast(e); - } - ValueHandle operator<(TemplatedIndexedValue e) { - return *this < static_cast(e); + Value eq(Value e); + Value ne(Value e); + Value operator<(Value e); + Value operator<=(Value e); + Value operator>(Value e); + Value operator>=(Value e); + Value operator<(TemplatedIndexedValue e) { + return *this < static_cast(e); } - ValueHandle operator<=(TemplatedIndexedValue e) { - return *this <= static_cast(e); + Value operator<=(TemplatedIndexedValue e) { + return *this <= static_cast(e); } - ValueHandle operator>(TemplatedIndexedValue e) { - return *this > static_cast(e); + Value operator>(TemplatedIndexedValue e) { + return *this > static_cast(e); } - ValueHandle operator>=(TemplatedIndexedValue e) { - return *this >= static_cast(e); + Value operator>=(TemplatedIndexedValue e) { + return *this >= static_cast(e); } private: - TemplatedIndexedValue(ValueHandle base, ArrayRef indices) - : base(base), indices(indices.begin(), indices.end()) {} + TemplatedIndexedValue(Value value, ArrayRef indices) + : value(value), indices(indices.begin(), indices.end()) {} TemplatedIndexedValue &append() { return *this; } template TemplatedIndexedValue &append(T index, Args... indices) { - this->indices.push_back(static_cast(index)); + this->indices.push_back(static_cast(index)); append(indices...); return *this; } - ValueHandle base; - SmallVector indices; + Value value; + SmallVector indices; }; } // namespace edsc diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -25,99 +25,27 @@ namespace edsc { -/// Entry point to build multiple ValueHandle* from a mutable list `ivs`. -inline SmallVector -makeHandlePointers(MutableArrayRef ivs) { - SmallVector pivs; - pivs.reserve(ivs.size()); - for (auto &iv : ivs) - pivs.push_back(&iv); - return pivs; -} - /// Provides a set of first class intrinsics. /// In the future, most of intrinsics related to Operation that don't contain /// other operations should be Tablegen'd. namespace intrinsics { -namespace detail { -/// Helper structure to be used with ValueBuilder / OperationBuilder. -/// It serves the purpose of removing boilerplate specialization for the sole -/// purpose of implicitly converting ArrayRef -> ArrayRef. -class ValueHandleArray { -public: - ValueHandleArray(ArrayRef vals) { - values.append(vals.begin(), vals.end()); - } - operator ArrayRef() { return values; } - -private: - ValueHandleArray() = default; - SmallVector values; -}; - -template -inline T unpack(T value) { - return value; -} - -inline detail::ValueHandleArray unpack(ArrayRef values) { - return detail::ValueHandleArray(values); -} - -} // namespace detail - -/// Helper variadic abstraction to allow extending to any MLIR op without -/// boilerplate or Tablegen. -/// Arguably a builder is not a ValueHandle but in practice it is only used as -/// an alias to a notional ValueHandle. -/// Implementing it as a subclass allows it to compose all the way to Value. -/// Without subclassing, implicit conversion to Value would fail when composing -/// in patterns such as: `select(a, b, select(c, d, e))`. -template -struct ValueBuilder : public ValueHandle { - // Builder-based - template - ValueBuilder(Args... args) - : ValueHandle(ValueHandle::create(detail::unpack(args)...)) {} - ValueBuilder(ArrayRef vs) - : ValueBuilder(ValueBuilder::create(detail::unpack(vs))) {} - template - ValueBuilder(ArrayRef vs, Args... args) - : ValueHandle(ValueHandle::create(detail::unpack(vs), - detail::unpack(args)...)) {} - template - ValueBuilder(T t, ArrayRef vs, Args... args) - : ValueHandle(ValueHandle::create( - detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} - template - ValueBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) - : ValueHandle(ValueHandle::create( - detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), - detail::unpack(args)...)) {} - - ValueBuilder() : ValueHandle(ValueHandle::create()) {} -}; template struct OperationBuilder : public OperationHandle { template OperationBuilder(Args... args) - : OperationHandle(OperationHandle::create(detail::unpack(args)...)) {} - OperationBuilder(ArrayRef vs) - : OperationHandle(OperationHandle::create(detail::unpack(vs))) {} + : OperationHandle(OperationHandle::create(args...)) {} + OperationBuilder(ArrayRef vs) + : OperationHandle(OperationHandle::create(vs)) {} template - OperationBuilder(ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create(detail::unpack(vs), - detail::unpack(args)...)) {} + OperationBuilder(ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create(vs, args...)) {} template - OperationBuilder(T t, ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create( - detail::unpack(t), detail::unpack(vs), detail::unpack(args)...)) {} + OperationBuilder(T t, ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create(t, vs, args...)) {} template - OperationBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create( - detail::unpack(t1), detail::unpack(t2), detail::unpack(vs), - detail::unpack(args)...)) {} + OperationBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) + : OperationHandle(OperationHandle::create(t1, t2, vs, args...)) {} OperationBuilder() : OperationHandle(OperationHandle::create()) {} }; diff --git a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp --- a/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp +++ b/mlir/lib/Conversion/VectorToLoops/ConvertVectorToLoops.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Affine/EDSC/Intrinsics.h" #include "mlir/Dialect/LoopOps/EDSC/Builders.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" @@ -38,9 +39,7 @@ /// `pivs` and `vectorBoundsCapture` are swapped so that the invocation of /// LoopNestBuilder captures it in the innermost loop. template -static void coalesceCopy(TransferOpTy transfer, - SmallVectorImpl *pivs, - VectorBoundsCapture *vectorBoundsCapture) { +static int computeCoalescedIndex(TransferOpTy transfer) { // rank of the remote memory access, coalescing behavior occurs on the // innermost memory dimension. auto remoteRank = transfer.getMemRefType().getRank(); @@ -62,24 +61,19 @@ coalescedIdx = en.index(); } } - if (coalescedIdx >= 0) { - std::swap(pivs->back(), (*pivs)[coalescedIdx]); - vectorBoundsCapture->swapRanges(pivs->size() - 1, coalescedIdx); - } + return coalescedIdx; } /// Emits remote memory accesses that are clipped to the boundaries of the /// MemRef. template -static SmallVector clip(TransferOpTy transfer, - MemRefBoundsCapture &bounds, - ArrayRef ivs) { +static SmallVector +clip(TransferOpTy transfer, MemRefBoundsCapture &bounds, ArrayRef ivs) { using namespace mlir::edsc; - ValueHandle zero(std_constant_index(0)), one(std_constant_index(1)); - SmallVector memRefAccess(transfer.indices()); - auto clippedScalarAccessExprs = - ValueHandle::makeIndexHandles(memRefAccess.size()); + Value zero(std_constant_index(0)), one(std_constant_index(1)); + SmallVector memRefAccess(transfer.indices()); + SmallVector clippedScalarAccessExprs(memRefAccess.size()); // Indices accessing to remote memory are clipped and their expressions are // returned in clippedScalarAccessExprs. for (unsigned memRefDim = 0; memRefDim < clippedScalarAccessExprs.size(); @@ -126,8 +120,6 @@ namespace { -using vector_type_cast = edsc::intrinsics::ValueBuilder; - /// Implements lowering of TransferReadOp and TransferWriteOp to a /// proper abstraction for the hardware. /// @@ -257,31 +249,36 @@ StdIndexedValue remote(transfer.memref()); MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); VectorBoundsCapture vectorBoundsCapture(transfer.vector()); - auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank()); - SmallVector pivs = - makeHandlePointers(MutableArrayRef(ivs)); - coalesceCopy(transfer, &pivs, &vectorBoundsCapture); + int coalescedIdx = computeCoalescedIndex(transfer); + // Swap the vectorBoundsCapture which will reorder loop bounds. + if (coalescedIdx >= 0) + vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1, + coalescedIdx); auto lbs = vectorBoundsCapture.getLbs(); auto ubs = vectorBoundsCapture.getUbs(); - SmallVector steps; + SmallVector steps; steps.reserve(vectorBoundsCapture.getSteps().size()); for (auto step : vectorBoundsCapture.getSteps()) steps.push_back(std_constant_index(step)); // 2. Emit alloc-copy-load-dealloc. - ValueHandle tmp = std_alloc(tmpMemRefType(transfer)); + Value tmp = std_alloc(tmpMemRefType(transfer)); StdIndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp); - LoopNestBuilder(pivs, lbs, ubs, steps)([&] { + Value vec = vector_type_cast(tmp); + SmallVector ivs(lbs.size()); + LoopNestBuilder(ivs, lbs, ubs, steps)([&] { + // Swap the ivs which will reorder memory accesses. + if (coalescedIdx >= 0) + std::swap(ivs.back(), ivs[coalescedIdx]); // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). local(ivs) = remote(clip(transfer, memRefBoundsCapture, ivs)); }); - ValueHandle vectorValue = std_load(vec); + Value vectorValue = std_load(vec); (std_dealloc(tmp)); // vexing parse // 3. Propagate. - rewriter.replaceOp(op, vectorValue.getValue()); + rewriter.replaceOp(op, vectorValue); return success(); } @@ -314,26 +311,31 @@ ScopedContext scope(rewriter, transfer.getLoc()); StdIndexedValue remote(transfer.memref()); MemRefBoundsCapture memRefBoundsCapture(transfer.memref()); - ValueHandle vectorValue(transfer.vector()); + Value vectorValue(transfer.vector()); VectorBoundsCapture vectorBoundsCapture(transfer.vector()); - auto ivs = ValueHandle::makeIndexHandles(vectorBoundsCapture.rank()); - SmallVector pivs = - makeHandlePointers(MutableArrayRef(ivs)); - coalesceCopy(transfer, &pivs, &vectorBoundsCapture); + int coalescedIdx = computeCoalescedIndex(transfer); + // Swap the vectorBoundsCapture which will reorder loop bounds. + if (coalescedIdx >= 0) + vectorBoundsCapture.swapRanges(vectorBoundsCapture.rank() - 1, + coalescedIdx); auto lbs = vectorBoundsCapture.getLbs(); auto ubs = vectorBoundsCapture.getUbs(); - SmallVector steps; + SmallVector steps; steps.reserve(vectorBoundsCapture.getSteps().size()); for (auto step : vectorBoundsCapture.getSteps()) steps.push_back(std_constant_index(step)); // 2. Emit alloc-store-copy-dealloc. - ValueHandle tmp = std_alloc(tmpMemRefType(transfer)); + Value tmp = std_alloc(tmpMemRefType(transfer)); StdIndexedValue local(tmp); - ValueHandle vec = vector_type_cast(tmp); + Value vec = vector_type_cast(tmp); std_store(vectorValue, vec); - LoopNestBuilder(pivs, lbs, ubs, steps)([&] { + SmallVector ivs(lbs.size()); + LoopNestBuilder(ivs, lbs, ubs, steps)([&] { + // Swap the ivs which will reorder memory accesses. + if (coalescedIdx >= 0) + std::swap(ivs.back(), ivs[coalescedIdx]); // Computes clippedScalarAccessExprs in the loop nest scope (ivs exist). remote(clip(transfer, memRefBoundsCapture, ivs)) = local(ivs); }); diff --git a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Affine/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Affine/EDSC/Builders.cpp @@ -14,65 +14,61 @@ using namespace mlir; using namespace mlir::edsc; -static Optional emitStaticFor(ArrayRef lbs, - ArrayRef ubs, - int64_t step) { +static Optional emitStaticFor(ArrayRef lbs, ArrayRef ubs, + int64_t step) { if (lbs.size() != 1 || ubs.size() != 1) - return Optional(); + return Optional(); - auto *lbDef = lbs.front().getValue().getDefiningOp(); - auto *ubDef = ubs.front().getValue().getDefiningOp(); + auto *lbDef = lbs.front().getDefiningOp(); + auto *ubDef = ubs.front().getDefiningOp(); if (!lbDef || !ubDef) - return Optional(); + return Optional(); auto lbConst = dyn_cast(lbDef); auto ubConst = dyn_cast(ubDef); if (!lbConst || !ubConst) - return Optional(); - - return ValueHandle(ScopedContext::getBuilder() - .create(ScopedContext::getLocation(), - lbConst.getValue(), - ubConst.getValue(), step) - .getInductionVar()); + return Optional(); + return ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), lbConst.getValue(), + ubConst.getValue(), step) + .getInductionVar(); } -LoopBuilder mlir::edsc::makeAffineLoopBuilder(ValueHandle *iv, - ArrayRef lbHandles, - ArrayRef ubHandles, +LoopBuilder mlir::edsc::makeAffineLoopBuilder(Value *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step) { mlir::edsc::LoopBuilder result; - if (auto staticFor = emitStaticFor(lbHandles, ubHandles, step)) { - *iv = staticFor.getValue(); + if (auto staticForIv = emitStaticFor(lbs, ubs, step)) { + *iv = staticForIv.getValue(); } else { - SmallVector lbs(lbHandles.begin(), lbHandles.end()); - SmallVector ubs(ubHandles.begin(), ubHandles.end()); auto b = ScopedContext::getBuilder(); - *iv = ValueHandle( - b.create(ScopedContext::getLocation(), lbs, - b.getMultiDimIdentityMap(lbs.size()), ubs, - b.getMultiDimIdentityMap(ubs.size()), step) - .getInductionVar()); + *iv = + Value(b.create(ScopedContext::getLocation(), lbs, + b.getMultiDimIdentityMap(lbs.size()), ubs, + b.getMultiDimIdentityMap(ubs.size()), step) + .getInductionVar()); } - auto *body = getForInductionVarOwner(iv->getValue()).getBody(); + + auto *body = getForInductionVarOwner(*iv).getBody(); result.enter(body, /*prev=*/1); return result; } -mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( - ValueHandle *iv, ArrayRef lbs, ArrayRef ubs, - int64_t step) { +mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder(Value *iv, + ArrayRef lbs, + ArrayRef ubs, + int64_t step) { loops.emplace_back(makeAffineLoopBuilder(iv, lbs, ubs, step)); } mlir::edsc::AffineLoopNestBuilder::AffineLoopNestBuilder( - ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps) { + MutableArrayRef ivs, ArrayRef lbs, ArrayRef ubs, + ArrayRef steps) { assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); for (auto it : llvm::zip(ivs, lbs, ubs, steps)) - loops.emplace_back(makeAffineLoopBuilder(std::get<0>(it), std::get<1>(it), + loops.emplace_back(makeAffineLoopBuilder(&std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it))); } @@ -89,11 +85,6 @@ (*lit)(); } -template -static ValueHandle createBinaryHandle(ValueHandle lhs, ValueHandle rhs) { - return ValueHandle::create(lhs.getValue(), rhs.getValue()); -} - static std::pair categorizeValueByAffineType(MLIRContext *context, Value val, unsigned &numDims, unsigned &numSymbols) { @@ -111,115 +102,109 @@ return std::make_pair(d, resultVal); } -static ValueHandle createBinaryIndexHandle( - ValueHandle lhs, ValueHandle rhs, +static Value createBinaryIndexHandle( + Value lhs, Value rhs, function_ref affCombiner) { MLIRContext *context = ScopedContext::getContext(); unsigned numDims = 0, numSymbols = 0; AffineExpr d0, d1; Value v0, v1; std::tie(d0, v0) = - categorizeValueByAffineType(context, lhs.getValue(), numDims, numSymbols); + categorizeValueByAffineType(context, lhs, numDims, numSymbols); std::tie(d1, v1) = - categorizeValueByAffineType(context, rhs.getValue(), numDims, numSymbols); + categorizeValueByAffineType(context, rhs, numDims, numSymbols); SmallVector operands; - if (v0) { + if (v0) operands.push_back(v0); - } - if (v1) { + if (v1) operands.push_back(v1); - } auto map = AffineMap::get(numDims, numSymbols, affCombiner(d0, d1)); + // TODO: createOrFold when available. Operation *op = makeComposedAffineApply(ScopedContext::getBuilder(), ScopedContext::getLocation(), map, operands) .getOperation(); assert(op->getNumResults() == 1 && "Expected single result AffineApply"); - return ValueHandle(op->getResult(0)); + return op->getResult(0); } template -static ValueHandle createBinaryHandle( - ValueHandle lhs, ValueHandle rhs, +static Value createBinaryHandle( + Value lhs, Value rhs, function_ref affCombiner) { - auto thisType = lhs.getValue().getType(); - auto thatType = rhs.getValue().getType(); + auto thisType = lhs.getType(); + auto thatType = rhs.getType(); assert(thisType == thatType && "cannot mix types in operators"); (void)thisType; (void)thatType; if (thisType.isIndex()) { return createBinaryIndexHandle(lhs, rhs, affCombiner); } else if (thisType.isSignlessInteger()) { - return createBinaryHandle(lhs, rhs); + return ValueBuilder(lhs, rhs); } else if (thisType.isa()) { - return createBinaryHandle(lhs, rhs); + return ValueBuilder(lhs, rhs); } else if (thisType.isa() || thisType.isa()) { auto aggregateType = thisType.cast(); if (aggregateType.getElementType().isSignlessInteger()) - return createBinaryHandle(lhs, rhs); + return ValueBuilder(lhs, rhs); else if (aggregateType.getElementType().isa()) - return createBinaryHandle(lhs, rhs); + return ValueBuilder(lhs, rhs); } - llvm_unreachable("failed to create a ValueHandle"); + llvm_unreachable("failed to create a Value"); } -ValueHandle mlir::edsc::op::operator+(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator+(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 + d1; }); } -ValueHandle mlir::edsc::op::operator-(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator-(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 - d1; }); } -ValueHandle mlir::edsc::op::operator*(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator*(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 * d1; }); } -ValueHandle mlir::edsc::op::operator/(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator/(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) -> AffineExpr { llvm_unreachable("only exprs of non-index type support operator/"); }); } -ValueHandle mlir::edsc::op::operator%(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator%(Value lhs, Value rhs) { return createBinaryHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0 % d1; }); } -ValueHandle mlir::edsc::op::floorDiv(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::floorDiv(Value lhs, Value rhs) { return createBinaryIndexHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.floorDiv(d1); }); } -ValueHandle mlir::edsc::op::ceilDiv(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::ceilDiv(Value lhs, Value rhs) { return createBinaryIndexHandle( lhs, rhs, [](AffineExpr d0, AffineExpr d1) { return d0.ceilDiv(d1); }); } -ValueHandle mlir::edsc::op::operator!(ValueHandle value) { - assert(value.getType().isInteger(1) && "expected boolean expression"); - return ValueHandle::create(1, 1) - value; -} - -ValueHandle mlir::edsc::op::operator&&(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator&&(Value lhs, Value rhs) { assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); - return ValueHandle::create(lhs, rhs); + return ValueBuilder(lhs, rhs); } -ValueHandle mlir::edsc::op::operator||(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator||(Value lhs, Value rhs) { assert(lhs.getType().isInteger(1) && "expected boolean expression on LHS"); assert(rhs.getType().isInteger(1) && "expected boolean expression on RHS"); - return ValueHandle::create(lhs, rhs); + return ValueBuilder(lhs, rhs); } -static ValueHandle createIComparisonExpr(CmpIPredicate predicate, - ValueHandle lhs, ValueHandle rhs) { +static Value createIComparisonExpr(CmpIPredicate predicate, Value lhs, + Value rhs) { auto lhsType = lhs.getType(); auto rhsType = rhs.getType(); (void)lhsType; @@ -228,13 +213,12 @@ assert((lhsType.isa() || lhsType.isSignlessInteger()) && "only integer comparisons are supported"); - auto op = ScopedContext::getBuilder().create( - ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); - return ValueHandle(op.getResult()); + return ScopedContext::getBuilder().create( + ScopedContext::getLocation(), predicate, lhs, rhs); } -static ValueHandle createFComparisonExpr(CmpFPredicate predicate, - ValueHandle lhs, ValueHandle rhs) { +static Value createFComparisonExpr(CmpFPredicate predicate, Value lhs, + Value rhs) { auto lhsType = lhs.getType(); auto rhsType = rhs.getType(); (void)lhsType; @@ -242,25 +226,24 @@ assert(lhsType == rhsType && "cannot mix types in operators"); assert(lhsType.isa() && "only float comparisons are supported"); - auto op = ScopedContext::getBuilder().create( - ScopedContext::getLocation(), predicate, lhs.getValue(), rhs.getValue()); - return ValueHandle(op.getResult()); + return ScopedContext::getBuilder().create( + ScopedContext::getLocation(), predicate, lhs, rhs); } // All floating point comparison are ordered through EDSL -ValueHandle mlir::edsc::op::operator==(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::eq(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OEQ, lhs, rhs) : createIComparisonExpr(CmpIPredicate::eq, lhs, rhs); } -ValueHandle mlir::edsc::op::operator!=(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::ne(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::ONE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::ne, lhs, rhs); } -ValueHandle mlir::edsc::op::operator<(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator<(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLT, lhs, rhs) @@ -268,19 +251,19 @@ // TODO(ntv,zinenko): signed by default, how about unsigned? createIComparisonExpr(CmpIPredicate::slt, lhs, rhs); } -ValueHandle mlir::edsc::op::operator<=(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator<=(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OLE, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sle, lhs, rhs); } -ValueHandle mlir::edsc::op::operator>(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator>(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGT, lhs, rhs) : createIComparisonExpr(CmpIPredicate::sgt, lhs, rhs); } -ValueHandle mlir::edsc::op::operator>=(ValueHandle lhs, ValueHandle rhs) { +Value mlir::edsc::op::operator>=(Value lhs, Value rhs) { auto type = lhs.getType(); return type.isa() ? createFComparisonExpr(CmpFPredicate::OGE, lhs, rhs) diff --git a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp --- a/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/MemoryPromotion.cpp @@ -44,14 +44,14 @@ MemRefBoundsCapture &bounds, Value from, Value to) { // Create EDSC handles for bounds. unsigned rank = bounds.rank(); - SmallVector lbs, ubs, steps; + SmallVector lbs, ubs, steps; // Make sure we have enough loops to use all thread dimensions, these trivial // loops should be outermost and therefore inserted first. if (rank < GPUDialect::getNumWorkgroupDimensions()) { unsigned extraLoops = GPUDialect::getNumWorkgroupDimensions() - rank; - ValueHandle zero = std_constant_index(0); - ValueHandle one = std_constant_index(1); + Value zero = std_constant_index(0); + Value one = std_constant_index(1); lbs.resize(extraLoops, zero); ubs.resize(extraLoops, one); steps.resize(extraLoops, one); @@ -78,9 +78,8 @@ } // Produce the loop nest with copies. - SmallVector ivs(lbs.size(), ValueHandle(indexType)); - auto ivPtrs = makeHandlePointers(MutableArrayRef(ivs)); - LoopNestBuilder(ivPtrs, lbs, ubs, steps)([&]() { + SmallVector ivs(lbs.size()); + LoopNestBuilder(ivs, lbs, ubs, steps)([&]() { auto activeIvs = llvm::makeArrayRef(ivs).take_back(rank); StdIndexedValue fromHandle(from), toHandle(to); toHandle(activeIvs) = fromHandle(activeIvs); @@ -90,8 +89,8 @@ for (auto en : llvm::enumerate(llvm::reverse(llvm::makeArrayRef(ivs).take_back( GPUDialect::getNumWorkgroupDimensions())))) { - auto loop = cast( - en.value().getValue().getParentRegion()->getParentOp()); + Value v = en.value(); + auto loop = cast(v.getParentRegion()->getParentOp()); mapLoopToProcessorIds(loop, {threadIds[en.index()]}, {blockDims[en.index()]}); } diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -21,69 +21,61 @@ using namespace mlir::linalg; using namespace mlir::loop; -mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, - ValueHandle range) { +mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv, Value range) { assert(range.getType() && "expected !linalg.range type"); - assert(range.getValue().getDefiningOp() && - "need operations to extract range parts"); - auto rangeOp = cast(range.getValue().getDefiningOp()); + assert(range.getDefiningOp() && "need operations to extract range parts"); + auto rangeOp = cast(range.getDefiningOp()); auto lb = rangeOp.min(); auto ub = rangeOp.max(); auto step = rangeOp.step(); auto forOp = OperationHandle::createOp(lb, ub, step); - *iv = ValueHandle(forOp.getInductionVar()); + *iv = forOp.getInductionVar(); auto *body = forOp.getBody(); enter(body, /*prev=*/1); } -mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(ValueHandle *iv, +mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv, SubViewOp::Range range) { auto forOp = OperationHandle::createOp(range.offset, range.size, range.stride); - *iv = ValueHandle(forOp.getInductionVar()); + *iv = forOp.getInductionVar(); auto *body = forOp.getBody(); enter(body, /*prev=*/1); } -ValueHandle -mlir::edsc::LoopRangeBuilder::operator()(std::function fun) { +Value mlir::edsc::LoopRangeBuilder::operator()(std::function fun) { if (fun) fun(); exit(); - return ValueHandle::null(); + return Value(); } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges) { loops.reserve(ranges.size()); for (unsigned i = 0, e = ranges.size(); i < e; ++i) { - loops.emplace_back(ivs[i], ranges[i]); + loops.emplace_back(&ivs[i], ranges[i]); } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges) { loops.reserve(ranges.size()); for (unsigned i = 0, e = ranges.size(); i < e; ++i) { - loops.emplace_back(ivs[i], ranges[i]); + loops.emplace_back(&ivs[i], ranges[i]); } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } -mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) - : LoopNestRangeBuilder( - ivs, SmallVector(ranges.begin(), ranges.end())) {} - -ValueHandle LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( +Value LoopNestRangeBuilder::LoopNestRangeBuilder::operator()( std::function fun) { if (fun) fun(); for (auto &lit : reverse(loops)) { lit({}); } - return ValueHandle::null(); + return Value(); } namespace mlir { @@ -91,15 +83,15 @@ template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) { + MutableArrayRef ivs, ArrayRef ranges) { builder = std::make_unique(ivs, ranges); } template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) { - SmallVector lbs; - SmallVector ubs; + MutableArrayRef ivs, ArrayRef ranges) { + SmallVector lbs; + SmallVector ubs; SmallVector steps; for (Value range : ranges) { assert(range.getType() && "expected linalg.range type"); @@ -114,8 +106,8 @@ template <> GenericLoopNestRangeBuilder::GenericLoopNestRangeBuilder( - ArrayRef ivs, ArrayRef ranges) { - SmallVector lbs, ubs, steps; + MutableArrayRef ivs, ArrayRef ranges) { + SmallVector lbs, ubs, steps; for (Value range : ranges) { assert(range.getType() && "expected linalg.range type"); assert(range.getDefiningOp() && "need operations to extract range parts"); @@ -197,10 +189,9 @@ OpBuilder opBuilder(op); ScopedContext scope(opBuilder, op->getLoc()); BlockHandle b; - auto handles = makeValueHandles(blockTypes); - BlockBuilder(&b, op->getRegion(0), - makeHandlePointers(MutableArrayRef(handles)))( - [&] { regionBuilder(b.getBlock()->getArguments()); }); + SmallVector handles(blockTypes.size()); + BlockBuilder(&b, op->getRegion(0), blockTypes, + handles)([&] { regionBuilder(b.getBlock()->getArguments()); }); assert(op->getRegion(0).getBlocks().size() == 1); return op; } @@ -209,16 +200,16 @@ using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 2 && "expected 2 block arguments"); - ValueHandle a(args[0]), b(args[1]); - linalg_yield((a * b).getValue()); + Value a(args[0]), b(args[1]); + linalg_yield(a * b); } void mlir::edsc::ops::macRegionBuilder(ArrayRef args) { using edsc::op::operator+; using edsc::op::operator*; assert(args.size() == 3 && "expected 3 block arguments"); - ValueHandle a(args[0]), b(args[1]), c(args[2]); - linalg_yield((c + a * b).getValue()); + Value a(args[0]), b(args[1]), c(args[2]); + linalg_yield(c + a * b); } Operation *mlir::edsc::ops::linalg_generic_pointwise( @@ -228,14 +219,14 @@ if (O.getType().isa()) { auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 1 && "expected 1 block arguments"); - ValueHandle a(args[0]); + Value a(args[0]); linalg_yield(unaryOp(a)); }; return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); } auto fun = [&unaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); - ValueHandle a(args[0]); + Value a(args[0]); linalg_yield(unaryOp(a)); }; return makeGenericLinalgOp(iterTypes, {I}, {O}, fun); @@ -243,8 +234,7 @@ Operation *mlir::edsc::ops::linalg_generic_pointwise_tanh(StructuredIndexed I, StructuredIndexed O) { - UnaryPointwiseOpBuilder unOp( - [](ValueHandle a) -> Value { return std_tanh(a); }); + UnaryPointwiseOpBuilder unOp([](Value a) -> Value { return std_tanh(a); }); return linalg_generic_pointwise(unOp, I, O); } @@ -257,14 +247,14 @@ if (O.getType().isa()) { auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 2 && "expected 2 block arguments"); - ValueHandle a(args[0]), b(args[1]); + Value a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); } auto fun = [&binaryOp](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); - ValueHandle a(args[0]), b(args[1]); + Value a(args[0]), b(args[1]); linalg_yield(binaryOp(a, b)); }; return makeGenericLinalgOp(iterTypes, {I1, I2}, {O}, fun); @@ -275,23 +265,22 @@ StructuredIndexed O) { using edsc::op::operator+; BinaryPointwiseOpBuilder binOp( - [](ValueHandle a, ValueHandle b) -> Value { return a + b; }); + [](Value a, Value b) -> Value { return a + b; }); return linalg_generic_pointwise(binOp, I1, I2, O); } Operation *mlir::edsc::ops::linalg_generic_pointwise_max(StructuredIndexed I1, StructuredIndexed I2, StructuredIndexed O) { - BinaryPointwiseOpBuilder binOp([](ValueHandle a, ValueHandle b) -> Value { + BinaryPointwiseOpBuilder binOp([](Value a, Value b) -> Value { using edsc::op::operator>; - return std_select(a > b, a, b).getValue(); + return std_select(a > b, a, b); }); return linalg_generic_pointwise(binOp, I1, I2, O); } Operation * -mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC, +mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC, MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; @@ -306,8 +295,7 @@ } Operation * -mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, - RankedTensorType tC, +mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, RankedTensorType tC, MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; @@ -322,8 +310,8 @@ } Operation * -mlir::edsc::ops::linalg_generic_matmul(ValueHandle vA, ValueHandle vB, - ValueHandle vC, RankedTensorType tD, +mlir::edsc::ops::linalg_generic_matmul(Value vA, Value vB, Value vC, + RankedTensorType tD, MatmulRegionBuilder regionBuilder) { // clang-format off AffineExpr m, n, k; @@ -337,9 +325,8 @@ // clang-format on } -Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(ValueHandle vI, - ValueHandle vW, - ValueHandle vO, +Operation *mlir::edsc::ops::linalg_generic_conv_nhwc(Value vI, Value vW, + Value vO, ArrayRef strides, ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); @@ -373,8 +360,8 @@ } Operation *mlir::edsc::ops::linalg_generic_dilated_conv_nhwc( - ValueHandle vI, ValueHandle vW, ValueHandle vO, int depth_multiplier, - ArrayRef strides, ArrayRef dilations) { + Value vI, Value vW, Value vO, int depth_multiplier, ArrayRef strides, + ArrayRef dilations) { MLIRContext *ctx = ScopedContext::getContext(); // TODO(ntv) some template magic to make everything rank-polymorphic. assert((dilations.empty() || dilations.size() == 2) && "only 2-D conv atm"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -35,7 +35,7 @@ using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -using folded_std_constant_index = folded::ValueBuilder; +using folded_std_constant_index = FoldedValueBuilder; using llvm::dbgs; diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -29,17 +29,16 @@ using namespace mlir::linalg; using edsc::op::operator+; -using edsc::op::operator==; -using mlir::edsc::intrinsics::detail::ValueHandleArray; -static SmallVector -makeCanonicalAffineApplies(OpBuilder &b, Location loc, AffineMap map, - ArrayRef vals) { +static SmallVector makeCanonicalAffineApplies(OpBuilder &b, + Location loc, + AffineMap map, + ArrayRef vals) { if (map.isEmpty()) return {}; assert(map.getNumSymbols() == 0); assert(map.getNumInputs() == vals.size()); - SmallVector res; + SmallVector res; res.reserve(map.getNumResults()); auto dims = map.getNumDims(); for (auto e : map.getResults()) { @@ -80,10 +79,10 @@ } template -static void inlineRegionAndEmitStdStore(OpType op, - ArrayRef indexedValues, - ArrayRef indexing, - ArrayRef outputBuffers) { +static void +inlineRegionAndEmitStdStore(OpType op, ArrayRef indexedValues, + ArrayRef> indexing, + ArrayRef outputBuffers) { auto &b = ScopedContext::getBuilder(); auto &block = op.region().front(); BlockAndValueMapping map; @@ -99,25 +98,27 @@ "expected an yield op in the end of the region"); for (unsigned i = 0, e = terminator.getNumOperands(); i < e; ++i) { std_store(map.lookupOrDefault(terminator.getOperand(i)), outputBuffers[i], - indexing[i]); + ArrayRef{indexing[i].begin(), indexing[i].end()}); } } // Returns a pair that contains input indices and output indices of a // SingleInputPoolingOp `op`. +struct InputAndOutputIndices { + SmallVector inputs; + SmallVector outputs; +}; template -static std::pair, SmallVector> -getInputAndOutputIndices(ArrayRef allIvs, SingleInputPoolingOp op) { +static InputAndOutputIndices getInputAndOutputIndices(ArrayRef allIvs, + SingleInputPoolingOp op) { auto &b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); auto mapsRange = op.indexing_maps().template getAsRange(); auto maps = llvm::to_vector<8>( llvm::map_range(mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - SmallVector iIdx( - makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); - SmallVector oIdx( - makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); - return {iIdx, oIdx}; + return InputAndOutputIndices{ + makeCanonicalAffineApplies(b, loc, maps[0], allIvs), + makeCanonicalAffineApplies(b, loc, maps[2], allIvs)}; } namespace { @@ -150,8 +151,8 @@ permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation()); auto outputIvs = permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation()); - SmallVector iivs(inputIvs.begin(), inputIvs.end()); - SmallVector oivs(outputIvs.begin(), outputIvs.end()); + SmallVector iivs(inputIvs.begin(), inputIvs.end()); + SmallVector oivs(outputIvs.begin(), outputIvs.end()); IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. @@ -170,13 +171,11 @@ "expected linalg op with buffer semantics"); auto nPar = fillOp.getNumParallelLoops(); assert(nPar == allIvs.size()); - auto ivs = - SmallVector(allIvs.begin(), allIvs.begin() + nPar); + auto ivs = SmallVector(allIvs.begin(), allIvs.begin() + nPar); IndexedValueType O(fillOp.getOutputBuffer(0)); // Emit the proper scalar assignment, whether we are dealing with a 0-D or // an n-D loop nest; with or without permutations. - nPar > 0 ? O(ivs) = ValueHandle(fillOp.value()) - : O() = ValueHandle(fillOp.value()); + nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value(); } }; @@ -187,7 +186,7 @@ assert(dotOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 1); - ValueHandle r_i(allIvs[0]); + Value r_i(allIvs[0]); IndexedValueType A(dotOp.getInput(0)), B(dotOp.getInput(1)), C(dotOp.getOutputBuffer(0)); // Emit scalar form. @@ -203,7 +202,7 @@ assert(matvecOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 2); - ValueHandle i(allIvs[0]), r_j(allIvs[1]); + Value i(allIvs[0]), r_j(allIvs[1]); IndexedValueType A(matvecOp.getInput(0)), B(matvecOp.getInput(1)), C(matvecOp.getOutputBuffer(0)); // Emit scalar form. @@ -219,7 +218,7 @@ assert(matmulOp.hasBufferSemantics() && "expected linalg op with buffer semantics"); assert(allIvs.size() == 3); - ValueHandle i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); + Value i(allIvs[0]), j(allIvs[1]), r_k(allIvs[2]); IndexedValueType A(matmulOp.getInput(0)), B(matmulOp.getInput(1)), C(matmulOp.getOutputBuffer(0)); // Emit scalar form. @@ -232,16 +231,16 @@ public: /// Returns the input value of convOp. If the indices in `imIdx` is out of /// boundary, returns 0 instead. - static ValueHandle getConvOpInput(ConvOp convOp, IndexedValueType im, - ArrayRef imIdx) { + static Value getConvOpInput(ConvOp convOp, IndexedValueType im, + MutableArrayRef imIdx) { // TODO(ntv): add a level of indirection to linalg.generic. if (!convOp.padding()) return im(imIdx); auto *context = ScopedContext::getContext(); - ValueHandle zeroIndex = std_constant_index(0); - SmallVector conds; - SmallVector clampedImIdx; + Value zeroIndex = std_constant_index(0); + SmallVector conds; + SmallVector clampedImIdx; for (auto iter : llvm::enumerate(imIdx)) { int idx = iter.index(); auto dim = iter.value(); @@ -254,12 +253,12 @@ using edsc::op::operator<; using edsc::op::operator>=; using edsc::op::operator||; - ValueHandle leftOutOfBound = dim < zeroIndex; + Value leftOutOfBound = dim < zeroIndex; if (conds.empty()) conds.push_back(leftOutOfBound); else conds.push_back(conds.back() || leftOutOfBound); - ValueHandle rightBound = std_dim(convOp.input(), idx); + Value rightBound = std_dim(convOp.input(), idx); conds.push_back(conds.back() || (dim >= rightBound)); // When padding is involved, the indices will only be shifted to negative, @@ -274,10 +273,10 @@ auto b = ScopedContext::getBuilder(); Type type = convOp.input().getType().cast().getElementType(); - ValueHandle zero = std_constant(type, b.getZeroAttr(type)); - ValueHandle readInput = im(clampedImIdx); + Value zero = std_constant(type, b.getZeroAttr(type)); + Value readInput = im(clampedImIdx); return conds.empty() ? readInput - : std_select(conds.back(), zero, readInput); + : (Value)std_select(conds.back(), zero, readInput); } static void emitScalarImplementation(ArrayRef allIvs, ConvOp convOp) { @@ -288,16 +287,16 @@ auto mapsRange = convOp.indexing_maps().getAsRange(); auto maps = llvm::to_vector<8>(llvm::map_range( mapsRange, [](AffineMapAttr a) { return a.getValue(); })); - SmallVector fIdx( + SmallVector fIdx( makeCanonicalAffineApplies(b, loc, maps[0], allIvs)); - SmallVector imIdx( + SmallVector imIdx( makeCanonicalAffineApplies(b, loc, maps[1], allIvs)); - SmallVector oIdx( + SmallVector oIdx( makeCanonicalAffineApplies(b, loc, maps[2], allIvs)); IndexedValueType F(convOp.filter()), I(convOp.input()), O(convOp.output()); // Emit scalar form. - ValueHandle paddedInput = getConvOpInput(convOp, I, imIdx); + Value paddedInput = getConvOpInput(convOp, I, imIdx); O(oIdx) += F(fIdx) * paddedInput; } }; @@ -308,15 +307,12 @@ static void emitScalarImplementation(ArrayRef allIvs, PoolingMaxOp op) { auto indices = getInputAndOutputIndices(allIvs, op); - ValueHandleArray iIdx(indices.first); - ValueHandleArray oIdx(indices.second); - // Emit scalar form. - ValueHandle lhs = std_load(op.output(), oIdx); - ValueHandle rhs = std_load(op.input(), iIdx); + Value lhs = std_load(op.output(), indices.outputs); + Value rhs = std_load(op.input(), indices.inputs); using edsc::op::operator>; - ValueHandle maxValue = std_select(lhs > rhs, lhs, rhs); - std_store(maxValue, op.output(), oIdx); + Value maxValue = std_select(lhs > rhs, lhs, rhs); + std_store(maxValue, op.output(), indices.outputs); } }; @@ -326,15 +322,12 @@ static void emitScalarImplementation(ArrayRef allIvs, PoolingMinOp op) { auto indices = getInputAndOutputIndices(allIvs, op); - ValueHandleArray iIdx(indices.first); - ValueHandleArray oIdx(indices.second); - // Emit scalar form. - ValueHandle lhs = std_load(op.output(), oIdx); - ValueHandle rhs = std_load(op.input(), iIdx); + Value lhs = std_load(op.output(), indices.outputs); + Value rhs = std_load(op.input(), indices.inputs); using edsc::op::operator<; - ValueHandle minValue = std_select(lhs < rhs, lhs, rhs); - std_store(minValue, op.output(), oIdx); + Value minValue = std_select(lhs < rhs, lhs, rhs); + std_store(minValue, op.output(), indices.outputs); } }; @@ -344,12 +337,10 @@ static void emitScalarImplementation(ArrayRef allIvs, PoolingSumOp op) { auto indices = getInputAndOutputIndices(allIvs, op); - SmallVector iIdx = indices.first; - SmallVector oIdx = indices.second; IndexedValueType input(op.input()), output(op.output()); // Emit scalar form. - output(oIdx) += input(iIdx); + output(indices.outputs) += input(indices.inputs); } }; @@ -392,15 +383,14 @@ "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - using edsc::intrinsics::detail::ValueHandleArray; unsigned nInputs = genericOp.getNumInputs(); unsigned nOutputs = genericOp.getNumOutputs(); SmallVector indexedValues(nInputs + nOutputs); // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getInputIndexingMap(i), allIvs)); + auto indexing = makeCanonicalAffineApplies( + b, loc, genericOp.getInputIndexingMap(i), allIvs); indexedValues[i] = std_load(genericOp.getInput(i), indexing); } @@ -409,18 +399,18 @@ // region has no uses. for (unsigned i = 0; i < nOutputs; ++i) { Value output = genericOp.getOutputBuffer(i); - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, genericOp.getOutputIndexingMap(i), allIvs)); + auto indexing = makeCanonicalAffineApplies( + b, loc, genericOp.getOutputIndexingMap(i), allIvs); indexedValues[nInputs + i] = std_load(output, indexing); } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. // 3. Emit std_store. - SmallVector indexing; + SmallVector, 8> indexing; SmallVector outputBuffers; for (unsigned i = 0; i < nOutputs; ++i) { - indexing.emplace_back(makeCanonicalAffineApplies( + indexing.push_back(makeCanonicalAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs)); outputBuffers.push_back(genericOp.getOutputBuffer(i)); } @@ -468,7 +458,6 @@ "expected linalg op with buffer semantics"); auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); - using edsc::intrinsics::detail::ValueHandleArray; unsigned nInputs = indexedGenericOp.getNumInputs(); unsigned nOutputs = indexedGenericOp.getNumOutputs(); unsigned nLoops = allIvs.size(); @@ -481,26 +470,26 @@ // 1.a. Emit std_load from input views. for (unsigned i = 0; i < nInputs; ++i) { Value input = indexedGenericOp.getInput(i); - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs)); + auto indexing = makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs); indexedValues[nLoops + i] = std_load(input, indexing); } // 1.b. Emit std_load from output views. for (unsigned i = 0; i < nOutputs; ++i) { Value output = indexedGenericOp.getOutputBuffer(i); - ValueHandleArray indexing(makeCanonicalAffineApplies( - b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); + auto indexing = makeCanonicalAffineApplies( + b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs); indexedValues[nLoops + nInputs + i] = std_load(output, indexing); } // TODO(ntv): When a region inliner exists, use it. // 2. Inline region, currently only works for a single basic block. // 3. Emit std_store. - SmallVector indexing; + SmallVector, 8> indexing; SmallVector outputBuffers; for (unsigned i = 0; i < nOutputs; ++i) { - indexing.emplace_back(makeCanonicalAffineApplies( + indexing.push_back(makeCanonicalAffineApplies( b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs)); outputBuffers.push_back(indexedGenericOp.getOutputBuffer(i)); } @@ -533,11 +522,8 @@ typename std::conditional::value, AffineIndexedValue, StdIndexedValue>::type; static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { - SmallVector allPIvs = - makeHandlePointers(MutableArrayRef(allIvs)); - - GenericLoopNestRangeBuilder(allPIvs, loopRanges)([&] { + MutableArrayRef allIvs) { + GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { SmallVector allIvValues(allIvs.begin(), allIvs.end()); LinalgScopedEmitter::emitScalarImplementation(allIvValues, @@ -555,7 +541,7 @@ using IndexedValueTy = StdIndexedValue; static void doit(ConcreteOpTy linalgOp, ArrayRef loopRanges, - MutableArrayRef allIvs) { + MutableArrayRef allIvs) { // Only generate loop.parallel for outer consecutive "parallel" // iterator_types. // TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator @@ -575,24 +561,18 @@ // If there are no outer parallel loops, then number of loop ops is same as // the number of loops, and they are all loop.for ops. auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops); - SmallVector allPIvs = - makeHandlePointers(MutableArrayRef(allIvs)); - SmallVector allLoops(nLoopOps, OperationHandle()); SmallVector allPLoops; allPLoops.reserve(allLoops.size()); for (OperationHandle &loop : allLoops) allPLoops.push_back(&loop); - - ArrayRef allPIvsRef(allPIvs); ArrayRef allPLoopsRef(allPLoops); if (nOuterPar) { GenericLoopNestRangeBuilder( - allPIvsRef.take_front(nOuterPar), - loopRanges.take_front(nOuterPar))([&] { + allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] { GenericLoopNestRangeBuilder( - allPIvsRef.drop_front(nOuterPar), + allIvs.drop_front(nOuterPar), loopRanges.drop_front(nOuterPar))([&] { SmallVector allIvValues(allIvs.begin(), allIvs.end()); LinalgScopedEmitter:: @@ -602,7 +582,7 @@ } else { // If there are no parallel loops then fallback to generating all loop.for // operations. - GenericLoopNestRangeBuilder(allPIvsRef, loopRanges)([&] { + GenericLoopNestRangeBuilder(allIvs, loopRanges)([&] { SmallVector allIvValues(allIvs.begin(), allIvs.end()); LinalgScopedEmitter::emitScalarImplementation(allIvValues, @@ -645,8 +625,7 @@ return LinalgLoops(); } - SmallVector allIvs(nLoops, - ValueHandle(rewriter.getIndexType())); + SmallVector allIvs(nLoops); auto loopRanges = emitLoopRanges(scope.getBuilder(), scope.getLocation(), invertedMap, getViewSizes(rewriter, linalgOp)); @@ -655,12 +634,12 @@ // Number of loop ops might be different from the number of ivs since some // loops like affine.parallel and loop.parallel have multiple ivs. llvm::SetVector loopSet; - for (ValueHandle &iv : allIvs) { - if (!iv.hasValue()) + for (Value iv : allIvs) { + if (!iv) return {}; // The induction variable is a block argument of the entry block of the // loop operation. - BlockArgument ivVal = iv.getValue().dyn_cast(); + BlockArgument ivVal = iv.dyn_cast(); if (!ivVal) return {}; loopSet.insert(ivVal.getOwner()->getParentOp()); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgTransforms.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/StandardOps/EDSC/Intrinsics.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/Matchers.h" @@ -219,10 +220,6 @@ SmallVector mlir::linalg::vectorizeLinalgOp(PatternRewriter &rewriter, Operation *op) { - using vector_contract = edsc::intrinsics::ValueBuilder; - using vector_broadcast = edsc::intrinsics::ValueBuilder; - using vector_type_cast = edsc::intrinsics::ValueBuilder; - assert(succeeded(vectorizeLinalgOpPrecondition(op)) && "DRR failure case must be a precondition"); auto linalgOp = cast(op); @@ -242,8 +239,8 @@ "]: Rewrite linalg.fill as vector.broadcast: " << *op << ":\n"); auto dstMemrefVec = vector_type_cast(fillOp.getOutputBuffer(0)); - auto dstVec = std_load(dstMemrefVec); - auto resVec = vector_broadcast(dstVec, fillOp.value()); + Value dstVec = std_load(dstMemrefVec); + auto resVec = vector_broadcast(dstVec.getType(), fillOp.value()); std_store(resVec, dstMemrefVec); } else { // Vectorize other ops as vector contraction (currently only matmul). diff --git a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Promotion.cpp @@ -36,11 +36,11 @@ using llvm::SetVector; -using folded_affine_min = folded::ValueBuilder; -using folded_linalg_range = folded::ValueBuilder; -using folded_std_dim = folded::ValueBuilder; -using folded_std_subview = folded::ValueBuilder; -using folded_std_view = folded::ValueBuilder; +using folded_affine_min = FoldedValueBuilder; +using folded_linalg_range = FoldedValueBuilder; +using folded_std_dim = FoldedValueBuilder; +using folded_std_subview = FoldedValueBuilder; +using folded_std_view = FoldedValueBuilder; #define DEBUG_TYPE "linalg-promotion" @@ -74,8 +74,8 @@ if (!dynamicBuffers) if (auto cst = dyn_cast_or_null(size.getDefiningOp())) return std_alloc( - MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), {}, - alignment_attr); + MemRefType::get(width * cst.getValue(), IntegerType::get(8, ctx)), + ValueRange{}, alignment_attr); Value mul = folded_std_muli(folder, folded_std_constant_index(folder, width), size); return std_alloc(MemRefType::get(-1, IntegerType::get(8, ctx)), mul, @@ -118,7 +118,7 @@ auto rangeValue = en.value(); // Try to extract a tight constant Value size = extractSmallestConstantBoundingSize(b, loc, rangeValue.size); - allocSize = folded_std_muli(folder, allocSize, size).getValue(); + allocSize = folded_std_muli(folder, allocSize, size); fullSizes.push_back(size); partialSizes.push_back(folded_std_dim(folder, subView, rank)); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -32,7 +32,7 @@ using namespace mlir::linalg; using namespace mlir::loop; -using folded_affine_min = folded::ValueBuilder; +using folded_affine_min = FoldedValueBuilder; #define DEBUG_TYPE "linalg-tiling" @@ -163,7 +163,7 @@ // TODO(pifon, ntv): Investigate whether mixing implicit and explicit indices // does not lead to losing information. static void transformIndexedGenericOpIndices( - OpBuilder &b, LinalgOp op, ArrayRef pivs, + OpBuilder &b, LinalgOp op, SmallVectorImpl &ivs, const LoopIndexToRangeIndexMap &loopIndexToRangeIndex) { assert(op.hasBufferSemantics() && "expected linalg op with buffer semantics"); auto indexedGenericOp = dyn_cast(op.getOperation()); @@ -193,7 +193,7 @@ // Offset the index argument `i` by the value of the corresponding induction // variable and replace all uses of the previous value. Value newIndex = b.create(indexedGenericOp.getLoc(), oldIndex, - pivs[rangeIndex->second]->getValue()); + ivs[rangeIndex->second]); for (auto &use : oldIndex.getUses()) { if (use.getOwner() == newIndex.getDefiningOp()) continue; @@ -376,15 +376,14 @@ // 3. Create the tiled loops. LinalgOp res = op; - auto ivs = ValueHandle::makeIndexHandles(loopRanges.size()); - auto pivs = makeHandlePointers(MutableArrayRef(ivs)); + SmallVector ivs(loopRanges.size()); // Convert SubViewOp::Range to linalg_range. SmallVector linalgRanges; for (auto &range : loopRanges) { linalgRanges.push_back( linalg_range(range.offset, range.size, range.stride)); } - GenericLoopNestRangeBuilder(pivs, linalgRanges)([&] { + GenericLoopNestRangeBuilder(ivs, linalgRanges)([&] { auto b = ScopedContext::getBuilder(); auto loc = ScopedContext::getLocation(); SmallVector ivValues(ivs.begin(), ivs.end()); @@ -405,7 +404,7 @@ }); // 4. Transforms index arguments of `linalg.generic` w.r.t. to the tiling. - transformIndexedGenericOpIndices(b, res, pivs, loopIndexToRangeIndex); + transformIndexedGenericOpIndices(b, res, ivs, loopIndexToRangeIndex); // 5. Gather the newly created loops and return them with the new op. SmallVector loops; diff --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp --- a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp @@ -14,8 +14,8 @@ using namespace mlir::edsc; mlir::edsc::ParallelLoopNestBuilder::ParallelLoopNestBuilder( - ArrayRef ivs, ArrayRef lbs, - ArrayRef ubs, ArrayRef steps) { + MutableArrayRef ivs, ArrayRef lbs, ArrayRef ubs, + ArrayRef steps) { assert(ivs.size() == lbs.size() && "Mismatch in number of arguments"); assert(ivs.size() == ubs.size() && "Mismatch in number of arguments"); assert(ivs.size() == steps.size() && "Mismatch in number of arguments"); @@ -36,29 +36,34 @@ (*lit)(); } -mlir::edsc::LoopNestBuilder::LoopNestBuilder(ArrayRef ivs, - ArrayRef lbs, - ArrayRef ubs, - ArrayRef steps) { +mlir::edsc::LoopNestBuilder::LoopNestBuilder(MutableArrayRef ivs, + ArrayRef lbs, + ArrayRef ubs, + ArrayRef steps) { assert(ivs.size() == lbs.size() && "expected size of ivs and lbs to match"); assert(ivs.size() == ubs.size() && "expected size of ivs and ubs to match"); assert(ivs.size() == steps.size() && "expected size of ivs and steps to match"); loops.reserve(ivs.size()); for (auto it : llvm::zip(ivs, lbs, ubs, steps)) - loops.emplace_back(makeLoopBuilder(std::get<0>(it), std::get<1>(it), + loops.emplace_back(makeLoopBuilder(&std::get<0>(it), std::get<1>(it), std::get<2>(it), std::get<3>(it))); assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } mlir::edsc::LoopNestBuilder::LoopNestBuilder( - ValueHandle *iv, ValueHandle lb, ValueHandle ub, ValueHandle step, - ArrayRef iter_args_handles, - ValueRange iter_args_init_values) { - assert(iter_args_init_values.size() == iter_args_handles.size() && + Value *iv, Value lb, Value ub, Value step, + MutableArrayRef iterArgsHandles, ValueRange iterArgsInitValues) { + assert(iterArgsInitValues.size() == iterArgsHandles.size() && "expected size of arguments and argument_handles to match"); - loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, iter_args_handles, - iter_args_init_values)); + loops.emplace_back( + makeLoopBuilder(iv, lb, ub, step, iterArgsHandles, iterArgsInitValues)); +} + +mlir::edsc::LoopNestBuilder::LoopNestBuilder(Value *iv, Value lb, Value ub, + Value step) { + SmallVector noArgs; + loops.emplace_back(makeLoopBuilder(iv, lb, ub, step, noArgs, {})); } Operation::result_range @@ -73,10 +78,10 @@ return loops[0].getOp()->getResults(); } -LoopBuilder mlir::edsc::makeParallelLoopBuilder(ArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, - ArrayRef steps) { +LoopBuilder mlir::edsc::makeParallelLoopBuilder(MutableArrayRef ivs, + ArrayRef lbHandles, + ArrayRef ubHandles, + ArrayRef steps) { LoopBuilder result; auto opHandle = OperationHandle::create( SmallVector(lbHandles.begin(), lbHandles.end()), @@ -86,24 +91,22 @@ loop::ParallelOp parallelOp = cast(*opHandle.getOperation()); for (size_t i = 0, e = ivs.size(); i < e; ++i) - *ivs[i] = ValueHandle(parallelOp.getBody()->getArgument(i)); + ivs[i] = parallelOp.getBody()->getArgument(i); result.enter(parallelOp.getBody(), /*prev=*/1); return result; } -mlir::edsc::LoopBuilder -mlir::edsc::makeLoopBuilder(ValueHandle *iv, ValueHandle lbHandle, - ValueHandle ubHandle, ValueHandle stepHandle, - ArrayRef iter_args_handles, - ValueRange iter_args_init_values) { +mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder( + Value *iv, Value lbHandle, Value ubHandle, Value stepHandle, + MutableArrayRef iterArgsHandles, ValueRange iterArgsInitValues) { mlir::edsc::LoopBuilder result; auto forOp = OperationHandle::createOp( - lbHandle, ubHandle, stepHandle, iter_args_init_values); - *iv = ValueHandle(forOp.getInductionVar()); - auto *body = loop::getForInductionVarOwner(iv->getValue()).getBody(); - for (size_t i = 0, e = iter_args_handles.size(); i < e; ++i) { + lbHandle, ubHandle, stepHandle, iterArgsInitValues); + *iv = forOp.getInductionVar(); + auto *body = loop::getForInductionVarOwner(*iv).getBody(); + for (size_t i = 0, e = iterArgsHandles.size(); i < e; ++i) { // Skipping the induction variable. - *(iter_args_handles[i]) = ValueHandle(body->getArgument(i + 1)); + iterArgsHandles[i] = body->getArgument(i + 1); } result.setOp(forOp); result.enter(body, /*prev=*/1); diff --git a/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp --- a/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Builders.cpp @@ -14,11 +14,11 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; -static SmallVector getMemRefSizes(Value memRef) { +static SmallVector getMemRefSizes(Value memRef) { MemRefType memRefType = memRef.getType().cast(); assert(isStrided(memRefType) && "Expected strided MemRef type"); - SmallVector res; + SmallVector res; res.reserve(memRefType.getShape().size()); const auto &shape = memRefType.getShape(); for (unsigned idx = 0, n = shape.size(); idx < n; ++idx) { diff --git a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp --- a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp @@ -13,45 +13,29 @@ using namespace mlir::edsc; OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle bh, - ArrayRef operands) { + ArrayRef operands) { assert(bh && "Expected already captured BlockHandle"); for (auto &o : operands) { (void)o; - assert(o && "Expected already captured ValueHandle"); + assert(o && "Expected already captured Value"); } SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh.getBlock(), ops); } -static void enforceEmptyCapturesMatchOperands(ArrayRef captures, - ArrayRef operands) { - assert(captures.size() == operands.size() && - "Expected same number of captures as operands"); - for (auto it : llvm::zip(captures, operands)) { - (void)it; - assert(!std::get<0>(it)->hasValue() && - "Unexpected already captured ValueHandle"); - assert(std::get<1>(it) && "Expected already captured ValueHandle"); - assert(std::get<0>(it)->getType() == std::get<1>(it).getType() && - "Expected the same type for capture and operand"); - } -} - OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle *bh, - ArrayRef captures, - ArrayRef operands) { + ArrayRef types, + MutableArrayRef captures, + ArrayRef operands) { assert(!*bh && "Unexpected already captured BlockHandle"); - enforceEmptyCapturesMatchOperands(captures, operands); - BlockBuilder(bh, captures)(/* no body */); + BlockBuilder(bh, types, captures)(/* no body */); SmallVector ops(operands.begin(), operands.end()); return OperationHandle::create(bh->getBlock(), ops); } -OperationHandle -mlir::edsc::intrinsics::std_cond_br(ValueHandle cond, BlockHandle trueBranch, - ArrayRef trueOperands, - BlockHandle falseBranch, - ArrayRef falseOperands) { +OperationHandle mlir::edsc::intrinsics::std_cond_br( + Value cond, BlockHandle trueBranch, ArrayRef trueOperands, + BlockHandle falseBranch, ArrayRef falseOperands) { SmallVector trueOps(trueOperands.begin(), trueOperands.end()); SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( @@ -59,16 +43,14 @@ } OperationHandle mlir::edsc::intrinsics::std_cond_br( - ValueHandle cond, BlockHandle *trueBranch, - ArrayRef trueCaptures, ArrayRef trueOperands, - BlockHandle *falseBranch, ArrayRef falseCaptures, - ArrayRef falseOperands) { + Value cond, BlockHandle *trueBranch, ArrayRef trueTypes, + MutableArrayRef trueCaptures, ArrayRef trueOperands, + BlockHandle *falseBranch, ArrayRef falseTypes, + MutableArrayRef falseCaptures, ArrayRef falseOperands) { assert(!*trueBranch && "Unexpected already captured BlockHandle"); assert(!*falseBranch && "Unexpected already captured BlockHandle"); - enforceEmptyCapturesMatchOperands(trueCaptures, trueOperands); - enforceEmptyCapturesMatchOperands(falseCaptures, falseOperands); - BlockBuilder(trueBranch, trueCaptures)(/* no body */); - BlockBuilder(falseBranch, falseCaptures)(/* no body */); + BlockBuilder(trueBranch, trueTypes, trueCaptures)(/* no body */); + BlockBuilder(falseBranch, falseTypes, falseCaptures)(/* no body */); SmallVector trueOps(trueOperands.begin(), trueOperands.end()); SmallVector falseOps(falseOperands.begin(), falseOperands.end()); return OperationHandle::create( diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -65,25 +65,8 @@ return getBuilder().getContext(); } -ValueHandle &mlir::edsc::ValueHandle::operator=(const ValueHandle &other) { - assert(t == other.t && "Wrong type capture"); - assert(!v && "ValueHandle has already been captured, use a new name!"); - v = other.v; - return *this; -} - -ValueHandle ValueHandle::create(StringRef name, ArrayRef operands, - ArrayRef resultTypes, - ArrayRef attributes) { - Operation *op = - OperationHandle::create(name, operands, resultTypes, attributes); - if (op->getNumResults() == 1) - return ValueHandle(op->getResult(0)); - llvm_unreachable("unsupported operation, use an OperationHandle instead"); -} - OperationHandle OperationHandle::create(StringRef name, - ArrayRef operands, + ArrayRef operands, ArrayRef resultTypes, ArrayRef attributes) { OperationState state(ScopedContext::getLocation(), name); @@ -156,37 +139,32 @@ enter(bh.getBlock()); } -mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, - ArrayRef args) { +mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, ArrayRef types, + MutableArrayRef args) { assert(!*bh && "BlockHandle already captures a block, use " "the explicit BockBuilder(bh, Append())({}) syntax instead."); - SmallVector types; - for (auto *a : args) { - assert(!a->hasValue() && - "Expected delayed ValueHandle that has not yet captured."); - types.push_back(a->getType()); - } + assert((args.empty() || args.size() == types.size()) && + "if args captures are specified, their number must match the number " + "of types"); *bh = BlockHandle::create(types); - for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { - *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); - } + if (!args.empty()) + for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) + std::get<0>(it) = Value(std::get<1>(it)); enter(bh->getBlock()); } mlir::edsc::BlockBuilder::BlockBuilder(BlockHandle *bh, Region ®ion, - ArrayRef args) { + ArrayRef types, + MutableArrayRef args) { assert(!*bh && "BlockHandle already captures a block, use " "the explicit BockBuilder(bh, Append())({}) syntax instead."); - SmallVector types; - for (auto *a : args) { - assert(!a->hasValue() && - "Expected delayed ValueHandle that has not yet captured."); - types.push_back(a->getType()); - } + assert((args.empty() || args.size() == types.size()) && + "if args captures are specified, their number must match the number " + "of types"); *bh = BlockHandle::createInRegion(region, types); - for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) { - *(std::get<0>(it)) = ValueHandle(std::get<1>(it)); - } + if (!args.empty()) + for (auto it : llvm::zip(args, bh->getBlock()->getArguments())) + std::get<0>(it) = Value(std::get<1>(it)); enter(bh->getBlock()); } diff --git a/mlir/test/EDSC/builder-api-test.cpp b/mlir/test/EDSC/builder-api-test.cpp --- a/mlir/test/EDSC/builder-api-test.cpp +++ b/mlir/test/EDSC/builder-api-test.cpp @@ -68,12 +68,11 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle i(indexType), j(indexType), lb(f.getArgument(0)), - ub(f.getArgument(1)); - ValueHandle f7(std_constant_float(llvm::APFloat(7.0f), f32Type)); - ValueHandle f13(std_constant_float(llvm::APFloat(13.0f), f32Type)); - ValueHandle i7(std_constant_int(7, 32)); - ValueHandle i13(std_constant_int(13, 32)); + Value i, j, lb(f.getArgument(0)), ub(f.getArgument(1)); + Value f7(std_constant_float(llvm::APFloat(7.0f), f32Type)); + Value f13(std_constant_float(llvm::APFloat(13.0f), f32Type)); + Value i7(std_constant_int(7, 32)); + Value i13(std_constant_int(13, 32)); AffineLoopNestBuilder(&i, lb, ub, 3)([&] { using namespace edsc::op; lb *std_constant_index(3) + ub; @@ -119,8 +118,8 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), - c(f.getArgument(2)), d(f.getArgument(3)); + Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)), + d(f.getArgument(3)); using namespace edsc::op; AffineLoopNestBuilder(&i, a - b, c + d, 2)(); @@ -141,8 +140,8 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), - c(f.getArgument(2)), d(f.getArgument(3)); + Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)), + d(f.getArgument(3)); using namespace edsc::op; LoopNestBuilder(&i, a - b, c + d, a)(); @@ -163,8 +162,8 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle i(indexType), lb1(f.getArgument(0)), lb2(f.getArgument(1)), - ub1(f.getArgument(2)), ub2(f.getArgument(3)); + Value i, lb1(f.getArgument(0)), lb2(f.getArgument(1)), ub1(f.getArgument(2)), + ub2(f.getArgument(3)); AffineLoopNestBuilder(&i, {lb1, lb2}, {ub1, ub2}, 1)(); std_ret(); @@ -183,17 +182,20 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle c1(ValueHandle::create(42, 32)), - c2(ValueHandle::create(1234, 32)); - ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), - arg4(c1.getType()), r(c1.getType()); - + Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32)); + Value r; + Value args12[2]; + Value &arg1 = args12[0], &arg2 = args12[1]; + Value args34[2]; + Value &arg3 = args34[0], &arg4 = args34[1]; BlockHandle b1, b2, functionBlock(&f.front()); - BlockBuilder(&b1, {&arg1, &arg2})( + BlockBuilder(&b1, {c1.getType(), c1.getType()}, args12)( // b2 has not yet been constructed, need to come back later. // This is a byproduct of non-structured control-flow. ); - BlockBuilder(&b2, {&arg3, &arg4})([&] { std_br(b1, {arg3, arg4}); }); + BlockBuilder(&b2, {c1.getType(), c1.getType()}, args34)([&] { + std_br(b1, {arg3, arg4}); + }); // The insertion point within the toplevel function is now past b2, we will // need to get back the entry block. // This is what happens with unstructured control-flow.. @@ -226,24 +228,25 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle c1(ValueHandle::create(42, 32)), - c2(ValueHandle::create(1234, 32)); - ValueHandle arg1(c1.getType()), arg2(c1.getType()), arg3(c1.getType()), - arg4(c1.getType()), r(c1.getType()); + Value c1(std_constant_int(42, 32)), c2(std_constant_int(1234, 32)); + Value res; + Value args1And2[2], args3And4[2]; + Value &arg1 = args1And2[0], &arg2 = args1And2[1], &arg3 = args3And4[0], + &arg4 = args3And4[1]; // clang-format off BlockHandle b1, b2; { // Toplevel function scope. // Build a new block for b1 eagerly. - std_br(&b1, {&arg1, &arg2}, {c1, c2}); + std_br(&b1, {c1.getType(), c1.getType()}, args1And2, {c1, c2}); // Construct a new block b2 explicitly with a branch into b1. - BlockBuilder(&b2, {&arg3, &arg4})([&]{ + BlockBuilder(&b2, {c1.getType(), c1.getType()}, args3And4)([&]{ std_br(b1, {arg3, arg4}); }); /// And come back to append into b1 once b2 exists. BlockBuilder(b1, Append())([&]{ - r = arg1 + arg2; - std_br(b2, {arg1, r}); + res = arg1 + arg2; + std_br(b2, {arg1, res}); }); } @@ -268,15 +271,14 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle funcArg(f.getArgument(0)); - ValueHandle c32(ValueHandle::create(32, 32)), - c64(ValueHandle::create(64, 64)), - c42(ValueHandle::create(42, 32)); - ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); - + Value funcArg(f.getArgument(0)); + Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)), + c42(std_constant_int(42, 32)); + Value arg1; + Value args23[2]; BlockHandle b1, b2, functionBlock(&f.front()); - BlockBuilder(&b1, {&arg1})([&] { std_ret(); }); - BlockBuilder(&b2, {&arg2, &arg3})([&] { std_ret(); }); + BlockBuilder(&b1, c32.getType(), arg1)([&] { std_ret(); }); + BlockBuilder(&b2, {c64.getType(), c32.getType()}, args23)([&] { std_ret(); }); // Get back to entry block and add a conditional branch BlockBuilder(functionBlock, Append())([&] { std_cond_br(funcArg, b1, {c32}, b2, {c64, c42}); @@ -304,15 +306,16 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle funcArg(f.getArgument(0)); - ValueHandle c32(ValueHandle::create(32, 32)), - c64(ValueHandle::create(64, 64)), - c42(ValueHandle::create(42, 32)); - ValueHandle arg1(c32.getType()), arg2(c64.getType()), arg3(c32.getType()); + Value arg0(f.getArgument(0)); + Value c32(std_constant_int(32, 32)), c64(std_constant_int(64, 64)), + c42(std_constant_int(42, 32)); // clang-format off BlockHandle b1, b2; - std_cond_br(funcArg, &b1, {&arg1}, {c32}, &b2, {&arg2, &arg3}, {c64, c42}); + Value arg1[1], args2And3[2]; + std_cond_br(arg0, + &b1, c32.getType(), arg1, c32, + &b2, {c64.getType(), c32.getType()}, args2And3, {c64, c42}); BlockBuilder(b1, Append())([]{ std_ret(); }); @@ -336,7 +339,6 @@ TEST_FUNC(builder_helpers) { using namespace edsc::op; - auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, @@ -348,21 +350,20 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle f7( - ValueHandle::create(llvm::APFloat(7.0f), f32Type)); + Value f7 = std_constant_float(llvm::APFloat(7.0f), f32Type); MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType), - lb0(indexType), lb1(indexType), lb2(indexType), - ub0(indexType), ub1(indexType), ub2(indexType); + Value ivs[2]; + Value &i = ivs[0], &j = ivs[1]; + Value k1, k2, lb0, lb1, lb2, ub0, ub1, ub2; int64_t step0, step1, step2; std::tie(lb0, ub0, step0) = vA.range(0); std::tie(lb1, ub1, step1) = vA.range(1); lb2 = vA.lb(2); ub2 = vA.ub(2); step2 = vA.step(2); - AffineLoopNestBuilder({&i, &j}, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{ + AffineLoopNestBuilder(ivs, {lb0, lb1}, {ub0, ub1}, {step0, step1})([&]{ AffineLoopNestBuilder(&k1, lb2, ub2, step2)([&]{ C(i, j, k1) = f7 + A(i, j, k1) + B(i, j, k1); }); @@ -393,45 +394,6 @@ f.erase(); } -TEST_FUNC(custom_ops) { - using namespace edsc::op; - auto indexType = IndexType::get(&globalContext()); - auto f = makeFunction("custom_ops", {}, {indexType, indexType}); - - OpBuilder builder(f.getBody()); - ScopedContext scope(builder, f.getLoc()); - CustomOperation MY_CUSTOM_OP("my_custom_op"); - CustomOperation MY_CUSTOM_OP_0("my_custom_op_0"); - CustomOperation MY_CUSTOM_OP_2("my_custom_op_2"); - - // clang-format off - ValueHandle vh(indexType), vh20(indexType), vh21(indexType); - OperationHandle ih0, ih2; - ValueHandle m(indexType), n(indexType); - ValueHandle M(f.getArgument(0)), N(f.getArgument(1)); - ValueHandle ten(std_constant_index(10)), twenty(std_constant_index(20)); - AffineLoopNestBuilder({&m, &n}, {M, N}, {M + ten, N + twenty}, {1, 1})([&]{ - vh = MY_CUSTOM_OP({m, m + n}, {indexType}, {}); - ih0 = MY_CUSTOM_OP_0({m, m + n}, {}); - ih2 = MY_CUSTOM_OP_2({m, m + n}, {indexType, indexType}); - // These captures are verbose for now, can improve when used in practice. - vh20 = ValueHandle(ih2.getOperation()->getResult(0)); - vh21 = ValueHandle(ih2.getOperation()->getResult(1)); - MY_CUSTOM_OP({vh20, vh21}, {indexType}, {}); - }); - - // CHECK-LABEL: @custom_ops - // CHECK: affine.for %{{.*}} {{.*}} - // CHECK: affine.for %{{.*}} {{.*}} - // CHECK: {{.*}} = "my_custom_op"{{.*}} : (index, index) -> index - // CHECK: "my_custom_op_0"{{.*}} : (index, index) -> () - // CHECK: [[TWO:%[a-z0-9]+]]:2 = "my_custom_op_2"{{.*}} : (index, index) -> (index, index) - // CHECK: {{.*}} = "my_custom_op"([[TWO]]#0, [[TWO]]#1) : (index, index) -> index - // clang-format on - f.print(llvm::outs()); - f.erase(); -} - TEST_FUNC(insertion_in_block) { using namespace edsc::op; auto indexType = IndexType::get(&globalContext()); @@ -441,11 +403,11 @@ ScopedContext scope(builder, f.getLoc()); BlockHandle b1; // clang-format off - ValueHandle::create(0, 32); - BlockBuilder(&b1, {})([]{ - ValueHandle::create(1, 32); + std_constant_int(0, 32); + (BlockBuilder(&b1))([]{ + std_constant_int(1, 32); }); - ValueHandle::create(2, 32); + std_constant_int(2, 32); // CHECK-LABEL: @insertion_in_block // CHECK: {{.*}} = constant 0 : i32 // CHECK: {{.*}} = constant 2 : i32 @@ -469,8 +431,8 @@ AffineIndexedValue A(f.getArgument(0)); AffineIndexedValue B(f.getArgument(1)); // clang-format off - edsc::intrinsics::std_zero_extendi(*A, i8Type); - edsc::intrinsics::std_sign_extendi(*B, i8Type); + edsc::intrinsics::std_zero_extendi(A, i8Type); + edsc::intrinsics::std_sign_extendi(B, i8Type); // CHECK-LABEL: @zero_and_std_sign_extendi_op // CHECK: %[[SRC1:.*]] = affine.load // CHECK: zexti %[[SRC1]] : i1 to i8 @@ -489,8 +451,8 @@ ScopedContext scope(builder, f.getLoc()); using op::operator||; - ValueHandle lhs(f.getArgument(0)); - ValueHandle rhs(f.getArgument(1)); + Value lhs(f.getArgument(0)); + Value rhs(f.getArgument(1)); lhs || rhs; // CHECK-LABEL: @operator_or @@ -508,8 +470,8 @@ ScopedContext scope(builder, f.getLoc()); using op::operator&&; - ValueHandle lhs(f.getArgument(0)); - ValueHandle rhs(f.getArgument(1)); + Value lhs(f.getArgument(0)); + Value rhs(f.getArgument(1)); lhs &&rhs; // CHECK-LABEL: @operator_and @@ -521,7 +483,6 @@ TEST_FUNC(select_op_i32) { using namespace edsc::op; - auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); @@ -530,17 +491,13 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle zero = std_constant_index(0), one = std_constant_index(1); + Value zero = std_constant_index(0), one = std_constant_index(1); MemRefBoundsCapture vA(f.getArgument(0)); AffineIndexedValue A(f.getArgument(0)); - ValueHandle i(indexType), j(indexType); - AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ - // This test exercises AffineIndexedValue::operator Value. - // Without it, one must force conversion to ValueHandle as such: - // std_select( - // i == zero, ValueHandle(A(zero, zero)), ValueHandle(ValueA(i, j))) - using edsc::op::operator==; - std_select(i == zero, *A(zero, zero), *A(i, j)); + Value ivs[2]; + Value &i = ivs[0], &j = ivs[1]; + AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{ + std_select(eq(i, zero), A(zero, zero), A(i, j)); }); // CHECK-LABEL: @select_op @@ -556,7 +513,6 @@ } TEST_FUNC(select_op_f32) { - auto indexType = IndexType::get(&globalContext()); auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); @@ -565,18 +521,19 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); // clang-format off - ValueHandle zero = std_constant_index(0), one = std_constant_index(1); + Value zero = std_constant_index(0), one = std_constant_index(1); MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)); AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)); - ValueHandle i(indexType), j(indexType); - AffineLoopNestBuilder({&i, &j}, {zero, zero}, {one, one}, {1, 1})([&]{ + Value ivs[2]; + Value &i = ivs[0], &j = ivs[1]; + AffineLoopNestBuilder(ivs, {zero, zero}, {one, one}, {1, 1})([&]{ using namespace edsc::op; - std_select(B(i, j) == B(i + one, j), *A(zero, zero), *A(i, j)); - std_select(B(i, j) != B(i + one, j), *A(zero, zero), *A(i, j)); - std_select(B(i, j) >= B(i + one, j), *A(zero, zero), *A(i, j)); - std_select(B(i, j) <= B(i + one, j), *A(zero, zero), *A(i, j)); - std_select(B(i, j) < B(i + one, j), *A(zero, zero), *A(i, j)); - std_select(B(i, j) > B(i + one, j), *A(zero, zero), *A(i, j)); + std_select(eq(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(ne(B(i, j), B(i + one, j)), A(zero, zero), A(i, j)); + std_select(B(i, j) >= B(i + one, j), A(zero, zero), A(i, j)); + std_select(B(i, j) <= B(i + one, j), A(zero, zero), A(i, j)); + std_select(B(i, j) < B(i + one, j), A(zero, zero), A(i, j)); + std_select(B(i, j) > B(i + one, j), A(zero, zero), A(i, j)); }); // CHECK-LABEL: @select_op @@ -632,7 +589,6 @@ // Inject an EDSC-constructed computation to exercise imperfectly nested 2-d // tiling. TEST_FUNC(tile_2d) { - auto indexType = IndexType::get(&globalContext()); auto memrefType = MemRefType::get({ShapedType::kDynamicSize, ShapedType::kDynamicSize, ShapedType::kDynamicSize}, @@ -641,17 +597,19 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = std_constant_index(0); + Value zero = std_constant_index(0); MemRefBoundsCapture vA(f.getArgument(0)), vB(f.getArgument(1)), vC(f.getArgument(2)); AffineIndexedValue A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); - ValueHandle i(indexType), j(indexType), k1(indexType), k2(indexType); - ValueHandle M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); + Value ivs[2]; + Value &i = ivs[0], &j = ivs[1]; + Value k1, k2; + Value M(vC.ub(0)), N(vC.ub(1)), O(vC.ub(2)); // clang-format off using namespace edsc::op; - AffineLoopNestBuilder({&i, &j}, {zero, zero}, {M, N}, {1, 1})([&]{ + AffineLoopNestBuilder(ivs, {zero, zero}, {M, N}, {1, 1})([&]{ AffineLoopNestBuilder(&k1, zero, O, 1)([&]{ C(i, j, k1) = A(i, j, k1) + B(i, j, k1); }); @@ -661,10 +619,8 @@ }); // clang-format on - auto li = getForInductionVarOwner(i.getValue()), - lj = getForInductionVarOwner(j.getValue()), - lk1 = getForInductionVarOwner(k1.getValue()), - lk2 = getForInductionVarOwner(k2.getValue()); + auto li = getForInductionVarOwner(i), lj = getForInductionVarOwner(j), + lk1 = getForInductionVarOwner(k1), lk2 = getForInductionVarOwner(k2); auto indicesL1 = mlir::tile({li, lj}, {512, 1024}, {lk1, lk2}); auto lii1 = indicesL1[0][0], ljj1 = indicesL1[1][0]; mlir::tile({ljj1, lii1}, {32, 16}, ljj1); @@ -713,15 +669,15 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = std_constant_index(0); + Value zero = std_constant_index(0); MemRefBoundsCapture vC(f.getArgument(2)); AffineIndexedValue B(f.getArgument(1)), D(f.getArgument(3)); StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2)); - ValueHandle i(builder.getIndexType()), N(vC.ub(0)); + Value i, N(vC.ub(0)); // clang-format off AffineLoopNestBuilder(&i, zero, N, 1)([&]{ - C((ValueHandle)D(i)) = A((ValueHandle)B(i)); + C((Value)D(i)) = A((Value)B(i)); }); // clang-format on @@ -747,12 +703,12 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = std_constant_index(0); - ValueHandle one = std_constant_index(1); + Value zero = std_constant_index(0); + Value one = std_constant_index(1); AffineIndexedValue input(f.getArgument(0)), res(f.getArgument(1)); - ValueHandle iv(builder.getIndexType()); // clang-format off + Value iv; AffineLoopNestBuilder(&iv, zero, one, 1)([&]{ res() = input(); }); @@ -784,7 +740,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle zero = std_constant_index(0), ten = std_constant_index(10); + Value zero = std_constant_index(0), ten = std_constant_index(10); SmallVector isEq = {false, false, false, false}; SmallVector affineExprs = { @@ -834,7 +790,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); AffineExpr i, j; bindDims(&globalContext(), i, j); StructuredIndexed SA(A), SB(B), SC(C); @@ -864,12 +820,12 @@ auto f32Type = FloatType::getF32(&globalContext()); auto memrefType = MemRefType::get( {ShapedType::kDynamicSize, ShapedType::kDynamicSize}, f32Type, {}, 0); - auto f = - makeFunction("linalg_generic_matmul", {}, {memrefType, memrefType, memrefType}); + auto f = makeFunction("linalg_generic_matmul", {}, + {memrefType, memrefType, memrefType}); OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_generic_matmul(makeValueHandles(llvm::to_vector<3>(f.getArguments()))); + linalg_generic_matmul(f.getArguments()); f.print(llvm::outs()); f.erase(); @@ -902,8 +858,8 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_generic_conv_nhwc(makeValueHandles(llvm::to_vector<3>(f.getArguments())), - /*strides=*/{3, 4}, /*dilations=*/{5, 6}); + linalg_generic_conv_nhwc(f.getArguments(), + /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); f.erase(); @@ -936,9 +892,9 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - linalg_generic_dilated_conv_nhwc(makeValueHandles(f.getArguments()), - /*depth_multiplier=*/7, - /*strides=*/{3, 4}, /*dilations=*/{5, 6}); + linalg_generic_dilated_conv_nhwc(f.getArguments(), + /*depth_multiplier=*/7, + /*strides=*/{3, 4}, /*dilations=*/{5, 6}); f.print(llvm::outs()); f.erase(); @@ -958,7 +914,7 @@ ScopedContext scope(builder, f.getLoc()); AffineExpr i, j, k; bindDims(&globalContext(), i, j, k); - ValueHandle v(f.getArgument(0)); + Value v(f.getArgument(0)); auto reshaped = linalg_reshape(v, ArrayRef>{{i, j}, k}); linalg_reshape(memrefType, reshaped, ArrayRef>{{i, j}, k}); @@ -1015,7 +971,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle A(f.getArgument(0)), B(f.getArgument(1)); + Value A(f.getArgument(0)), B(f.getArgument(1)); AffineExpr i, j; bindDims(&globalContext(), i, j); StructuredIndexed SA(A), SB(B), SC(tensorType); @@ -1023,7 +979,7 @@ linalg_generic_pointwise_max(SA({i, j}), SB({i, j}), SC({i, j})); linalg_generic_pointwise_tanh(SA({i, j}), SC({i, j})); Value o1 = linalg_generic_matmul(A, B, tensorType)->getResult(0); - linalg_generic_matmul(A, B, ValueHandle(o1), tensorType); + linalg_generic_matmul(A, B, o1, tensorType); f.print(llvm::outs()); f.erase(); @@ -1064,7 +1020,7 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); + Value A(f.getArgument(0)), B(f.getArgument(1)), C(f.getArgument(2)); auto contractionBuilder = [](ArrayRef args) { assert(args.size() == 3 && "expected 3 block arguments"); (linalg_yield(vector_contraction_matmul(args[0], args[1], args[2]))); @@ -1083,19 +1039,19 @@ OpBuilder builder(f.getBody()); ScopedContext scope(builder, f.getLoc()); - ValueHandle init0 = std_constant_float(llvm::APFloat(1.0f), f32Type); - ValueHandle init1 = std_constant_float(llvm::APFloat(2.0f), f32Type); - ValueHandle i(indexType), a(f.getArgument(0)), b(f.getArgument(1)), - c(f.getArgument(2)), d(f.getArgument(3)); - ValueHandle arg0(f32Type); - ValueHandle arg1(f32Type); + Value init0 = std_constant_float(llvm::APFloat(1.0f), f32Type); + Value init1 = std_constant_float(llvm::APFloat(2.0f), f32Type); + Value i, a(f.getArgument(0)), b(f.getArgument(1)), c(f.getArgument(2)), + d(f.getArgument(3)); + Value args01[2]; + Value &arg0 = args01[0], &arg1 = args01[1]; using namespace edsc::op; auto results = - LoopNestBuilder(&i, a - b, c + d, a, {&arg0, &arg1}, {init0, init1})([&] { + LoopNestBuilder(&i, a - b, c + d, a, args01, {init0, init1})([&] { auto sum = arg0 + arg1; - loop_yield(ArrayRef{arg1, sum}); + loop_yield(ArrayRef{arg1, sum}); }); - ValueHandle(results[0]) + ValueHandle(results[1]); + results[0] + results[1]; // clang-format off // CHECK-LABEL: func @builder_loop_for_yield(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { diff --git a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc --- a/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc +++ b/mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc @@ -16,9 +16,9 @@ // IMPL-NEXT: AffineMap::get(2, 0, {d0}, context) }; // // IMPL: Test1Op::regionBuilder(Block &block) { -// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // ods_def : @@ -41,9 +41,9 @@ // IMPL-NEXT: AffineMap::get(3, 0, {d0, d1}, context) }; // // IMPL: Test2Op::regionBuilder(Block &block) { -// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // ods_def : @@ -66,9 +66,9 @@ // IMPL-NEXT: AffineMap::get(4, 0, {d0, d1, d2}, context) }; // // IMPL: Test3Op::regionBuilder(Block &block) { -// IMPL: ValueHandle [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); -// IMPL: ValueHandle [[d:.*]] = std_mulf([[a]], [[b]]); -// IMPL: ValueHandle [[e:.*]] = std_addf([[c]], [[d]]); +// IMPL: Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]); +// IMPL: Value [[d:.*]] = std_mulf([[a]], [[b]]); +// IMPL: Value [[e:.*]] = std_addf([[c]], [[d]]); // IMPL: (linalg_yield(ValueRange{ [[e]] })); // ods_def : diff --git a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp --- a/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp +++ b/mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp @@ -1601,7 +1601,7 @@ printExpr(subExprsStringStream, *e); }); subExprsStringStream.flush(); - const char *tensorExprFmt = "\n ValueHandle _{0} = {1}({2});"; + const char *tensorExprFmt = "\n Value _{0} = {1}({2});"; os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName, subExprs); subExprsMap[pTensorExpr] = count; @@ -1613,7 +1613,7 @@ using namespace edsc; using namespace intrinsics; auto args = block.getArguments(); - ValueHandle {1}; + Value {1}; {2} (linalg_yield(ValueRange{ {3} })); })FMT";