diff --git a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h --- a/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h +++ b/mlir/include/mlir/Dialect/Affine/EDSC/Builders.h @@ -129,34 +129,34 @@ /// Assignment-arithmetic operator overloadings. template -OperationHandle TemplatedIndexedValue::operator+=(Value e) { +Store TemplatedIndexedValue::operator+=(Value e) { using op::operator+; - return Store(*this + e, getBase(), {indices.begin(), indices.end()}); + return Store(*this + e, getBase(), indices); } template -OperationHandle TemplatedIndexedValue::operator-=(Value e) { +Store TemplatedIndexedValue::operator-=(Value e) { using op::operator-; - return Store(*this - e, getBase(), {indices.begin(), indices.end()}); + return Store(*this - e, getBase(), indices); } template -OperationHandle TemplatedIndexedValue::operator*=(Value e) { +Store TemplatedIndexedValue::operator*=(Value e) { using op::operator*; - return Store(*this * e, getBase(), {indices.begin(), indices.end()}); + return Store(*this * e, getBase(), indices); } template -OperationHandle TemplatedIndexedValue::operator/=(Value e) { +Store TemplatedIndexedValue::operator/=(Value e) { using op::operator/; - return Store(*this / e, getBase(), {indices.begin(), indices.end()}); + return Store(*this / e, getBase(), indices); } template -OperationHandle TemplatedIndexedValue::operator%=(Value e) { +Store TemplatedIndexedValue::operator%=(Value e) { using op::operator%; - return Store(*this % e, getBase(), {indices.begin(), indices.end()}); + return Store(*this % e, getBase(), indices); } template -OperationHandle TemplatedIndexedValue::operator^=(Value e) { +Store TemplatedIndexedValue::operator^=(Value e) { using op::operator^; - return Store(*this ^ e, getBase(), {indices.begin(), indices.end()}); + return Store(*this ^ e, getBase(), indices); } /// Logical operator overloadings. diff --git a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h --- a/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h +++ b/mlir/include/mlir/Dialect/StandardOps/EDSC/Intrinsics.h @@ -46,7 +46,7 @@ /// /// Prerequisites: /// All Handles have already captured previously constructed IR objects. -OperationHandle std_br(BlockHandle bh, ArrayRef operands); +BranchOp std_br(BlockHandle bh, ArrayRef operands); /// Creates a new mlir::Block* and branches to it from the current block. /// Argument types are specified by `operands`. @@ -61,8 +61,8 @@ /// All `operands` have already captured an mlir::Value /// captures.size() == operands.size() /// captures and operands are pairwise of the same type. -OperationHandle std_br(BlockHandle *bh, ArrayRef types, - ArrayRef captures, ArrayRef operands); +BranchOp std_br(BlockHandle *bh, ArrayRef types, + ArrayRef captures, ArrayRef operands); /// Branches into the mlir::Block* captured by BlockHandle `trueBranch` with /// `trueOperands` if `cond` evaluates to `true` (resp. `falseBranch` and @@ -70,10 +70,9 @@ /// /// Prerequisites: /// All Handles have captured previously constructed IR objects. -OperationHandle std_cond_br(Value cond, BlockHandle trueBranch, - ArrayRef trueOperands, - BlockHandle falseBranch, - ArrayRef falseOperands); +CondBranchOp std_cond_br(Value cond, BlockHandle trueBranch, + ArrayRef trueOperands, BlockHandle falseBranch, + ArrayRef falseOperands); /// Eagerly creates new mlir::Block* with argument types specified by /// `trueOperands`/`falseOperands`. @@ -91,7 +90,7 @@ /// `falseCaptures`.size() == `falseOperands`.size() /// `trueCaptures` and `trueOperands` are pairwise of the same type /// `falseCaptures` and `falseOperands` are pairwise of the same type. -OperationHandle +CondBranchOp std_cond_br(Value cond, BlockHandle *trueBranch, ArrayRef trueTypes, ArrayRef trueCaptures, ArrayRef trueOperands, BlockHandle *falseBranch, ArrayRef falseTypes, diff --git a/mlir/include/mlir/EDSC/Builders.h b/mlir/include/mlir/EDSC/Builders.h --- a/mlir/include/mlir/EDSC/Builders.h +++ b/mlir/include/mlir/EDSC/Builders.h @@ -73,18 +73,27 @@ template struct ValueBuilder { // Builder-based template ValueBuilder(Args... args) { - Operation *op = ScopedContext::getBuilder() - .create(ScopedContext::getLocation(), args...) - .getOperation(); - if (op->getNumResults() != 1) - llvm_unreachable("unsupported operation, use OperationBuilder instead"); - value = op->getResult(0); + value = ScopedContext::getBuilder() + .create(ScopedContext::getLocation(), args...) + .getResult(); } - operator Value() { return value; } Value value; }; +template +struct OperationBuilder { + // Builder-based + template + OperationBuilder(Args... args) { + op = ScopedContext::getBuilder().create(ScopedContext::getLocation(), + args...); + } + operator Op() { return op; } + operator Operation *() { return op.getOperation(); } + Op op; +}; + /// A NestedBuilder is a scoping abstraction to create an idiomatic syntax /// embedded in C++ that serves the purpose of building nested MLIR. /// Nesting and compositionality is obtained by using the strict ordering that @@ -166,15 +175,13 @@ private: LoopBuilder() = default; - friend LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef lbHandles, - ArrayRef ubHandles, - int64_t step); + friend LoopBuilder makeAffineLoopBuilder(Value *iv, ArrayRef lbs, + ArrayRef ubs, int64_t step); friend LoopBuilder makeParallelLoopBuilder(MutableArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, + ArrayRef lbs, + ArrayRef ubs, ArrayRef steps); - friend LoopBuilder makeLoopBuilder(Value *iv, Value lbHandle, Value ubHandle, - Value stepHandle, + friend LoopBuilder makeLoopBuilder(Value *iv, Value lb, Value ub, Value step, MutableArrayRef iterArgsHandles, ValueRange iterArgsInitValues); Operation *op; @@ -230,51 +237,12 @@ BlockBuilder &operator=(BlockBuilder &other) = delete; }; -/// Base class for Value, OperationHandle and BlockHandle. -/// Not meant to be used outside of these classes. -class CapturableHandle { -protected: - CapturableHandle() = default; -}; - -/// An OperationHandle can be used in lieu of Value to capture the -/// operation in cases when one does not care about, or cannot extract, a -/// unique Value from the operation. -/// This can be used for capturing zero result operations as well as -/// multi-result operations that are not supported by Value. -/// We do not distinguish further between zero and multi-result operations at -/// this time. -struct OperationHandle : public CapturableHandle { - OperationHandle() : op(nullptr) {} - OperationHandle(Operation *op) : op(op) {} - - OperationHandle(const OperationHandle &) = default; - OperationHandle &operator=(const OperationHandle &) = default; - - /// Generic mlir::Op create. This is the key to being extensible to the whole - /// of MLIR without duplicating the type system or the op definitions. - template - static OperationHandle create(Args... args); - template static Op createOp(Args... args); - - /// Generic create for a named operation. - static OperationHandle create(StringRef name, ArrayRef operands, - ArrayRef resultTypes, - ArrayRef attributes = {}); - - operator Operation *() { return op; } - Operation *getOperation() const { return op; } - -private: - Operation *op; -}; - /// A BlockHandle represents a (potentially "delayed") Block abstraction. /// This extra abstraction is necessary because an mlir::Block is not an /// mlir::Value. /// A BlockHandle should be captured by pointer but otherwise passed by Value /// everywhere. -class BlockHandle : public CapturableHandle { +class BlockHandle { public: /// A BlockHandle constructed without an mlir::Block* represents a "delayed" /// Block. A delayed Block represents the declaration (in the PL sense) of a @@ -361,22 +329,6 @@ SmallVector exprs; }; -template -OperationHandle OperationHandle::create(Args... args) { - return OperationHandle(ScopedContext::getBuilder() - .create(ScopedContext::getLocation(), args...) - .getOperation()); -} - -template -Op OperationHandle::createOp(Args... args) { - return cast( - OperationHandle(ScopedContext::getBuilder() - .create(ScopedContext::getLocation(), args...) - .getOperation()) - .getOperation()); -} - /// A TemplatedIndexedValue brings an index notation over the template Load and /// Store parameters. Assigning to an IndexedValue emits an actual `Store` /// operation, while converting an IndexedValue to a Value emits an actual @@ -403,10 +355,10 @@ } /// Emits a `store`. - OperationHandle operator=(const TemplatedIndexedValue &rhs) { + Store operator=(const TemplatedIndexedValue &rhs) { return Store(rhs, value, indices); } - OperationHandle operator=(Value rhs) { return Store(rhs, value, indices); } + Store operator=(Value rhs) { return Store(rhs, value, indices); } /// Emits a `load` when converting to a Value. operator Value() const { return Load(value, indices); } @@ -440,28 +392,28 @@ } /// Assignment-arithmetic operator overloadings. - OperationHandle operator+=(Value e); - OperationHandle operator-=(Value e); - OperationHandle operator*=(Value e); - OperationHandle operator/=(Value e); - OperationHandle operator%=(Value e); - OperationHandle operator^=(Value e); - OperationHandle operator+=(TemplatedIndexedValue e) { + Store operator+=(Value e); + Store operator-=(Value e); + Store operator*=(Value e); + Store operator/=(Value e); + Store operator%=(Value e); + Store operator^=(Value e); + Store operator+=(TemplatedIndexedValue e) { return this->operator+=(static_cast(e)); } - OperationHandle operator-=(TemplatedIndexedValue e) { + Store operator-=(TemplatedIndexedValue e) { return this->operator-=(static_cast(e)); } - OperationHandle operator*=(TemplatedIndexedValue e) { + Store operator*=(TemplatedIndexedValue e) { return this->operator*=(static_cast(e)); } - OperationHandle operator/=(TemplatedIndexedValue e) { + Store operator/=(TemplatedIndexedValue e) { return this->operator/=(static_cast(e)); } - OperationHandle operator%=(TemplatedIndexedValue e) { + Store operator%=(TemplatedIndexedValue e) { return this->operator%=(static_cast(e)); } - OperationHandle operator^=(TemplatedIndexedValue e) { + Store operator^=(TemplatedIndexedValue e) { return this->operator^=(static_cast(e)); } diff --git a/mlir/include/mlir/EDSC/Intrinsics.h b/mlir/include/mlir/EDSC/Intrinsics.h --- a/mlir/include/mlir/EDSC/Intrinsics.h +++ b/mlir/include/mlir/EDSC/Intrinsics.h @@ -29,25 +29,6 @@ /// In the future, most of intrinsics related to Operation that don't contain /// other operations should be Tablegen'd. namespace intrinsics { - -template struct OperationBuilder : public OperationHandle { - template - OperationBuilder(Args... args) - : OperationHandle(OperationHandle::create(args...)) {} - OperationBuilder(ArrayRef vs) - : OperationHandle(OperationHandle::create(vs)) {} - template - OperationBuilder(ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create(vs, args...)) {} - template - OperationBuilder(T t, ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create(t, vs, args...)) {} - template - OperationBuilder(T1 t1, T2 t2, ArrayRef vs, Args... args) - : OperationHandle(OperationHandle::create(t1, t2, vs, args...)) {} - OperationBuilder() : OperationHandle(OperationHandle::create()) {} -}; - } // namespace intrinsics } // namespace edsc } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp --- a/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/Linalg/EDSC/Builders.cpp @@ -28,16 +28,15 @@ auto lb = rangeOp.min(); auto ub = rangeOp.max(); auto step = rangeOp.step(); - auto forOp = OperationHandle::createOp(lb, ub, step); - *iv = Value(forOp.getInductionVar()); + ForOp forOp = OperationBuilder(lb, ub, step); + *iv = forOp.getInductionVar(); auto *body = forOp.getBody(); enter(body, /*prev=*/1); } mlir::edsc::LoopRangeBuilder::LoopRangeBuilder(Value *iv, SubViewOp::Range range) { - auto forOp = - OperationHandle::createOp(range.offset, range.size, range.stride); + ForOp forOp = OperationBuilder(range.offset, range.size, range.stride); *iv = forOp.getInductionVar(); auto *body = forOp.getBody(); enter(body, /*prev=*/1); @@ -53,18 +52,16 @@ mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( MutableArrayRef ivs, ArrayRef ranges) { loops.reserve(ranges.size()); - for (unsigned i = 0, e = ranges.size(); i < e; ++i) { + for (unsigned i = 0, e = ranges.size(); i < e; ++i) loops.emplace_back(&ivs[i], ranges[i]); - } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } mlir::edsc::LoopNestRangeBuilder::LoopNestRangeBuilder( MutableArrayRef ivs, ArrayRef ranges) { loops.reserve(ranges.size()); - for (unsigned i = 0, e = ranges.size(); i < e; ++i) { + for (unsigned i = 0, e = ranges.size(); i < e; ++i) loops.emplace_back(&ivs[i], ranges[i]); - } assert(loops.size() == ivs.size() && "Mismatch loops vs ivs size"); } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LinalgToLoops.cpp @@ -537,10 +537,6 @@ // TODO(ravishankarm): Generate loop.parallel for all "parallel" iterator // types, not just the outer most ones. Also handle "reduction" iterator // types. - auto nPar = linalgOp.getNumParallelLoops(); - auto nRed = linalgOp.getNumReductionLoops(); - auto nWin = linalgOp.getNumWindowLoops(); - auto nLoops = nPar + nRed + nWin; auto nOuterPar = linalgOp.iterator_types() .getValue() .take_while([](Attribute attr) { @@ -550,14 +546,6 @@ .size(); // If there are no outer parallel loops, then number of loop ops is same as // the number of loops, and they are all loop.for ops. - auto nLoopOps = (nOuterPar ? nLoops - nOuterPar + 1 : nLoops); - SmallVector allLoops(nLoopOps, OperationHandle()); - SmallVector allPLoops; - allPLoops.reserve(allLoops.size()); - for (OperationHandle &loop : allLoops) - allPLoops.push_back(&loop); - ArrayRef allPLoopsRef(allPLoops); - if (nOuterPar) { GenericLoopNestRangeBuilder( allIvs.take_front(nOuterPar), loopRanges.take_front(nOuterPar))([&] { diff --git a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp --- a/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp +++ b/mlir/lib/Dialect/LoopOps/EDSC/Builders.cpp @@ -80,41 +80,38 @@ } LoopBuilder mlir::edsc::makeParallelLoopBuilder(MutableArrayRef ivs, - ArrayRef lbHandles, - ArrayRef ubHandles, + ArrayRef lbs, + ArrayRef ubs, ArrayRef steps) { - LoopBuilder result; - auto opHandle = OperationHandle::create( - SmallVector(lbHandles.begin(), lbHandles.end()), - SmallVector(ubHandles.begin(), ubHandles.end()), + loop::ParallelOp parallelOp = OperationBuilder( + SmallVector(lbs.begin(), lbs.end()), + SmallVector(ubs.begin(), ubs.end()), SmallVector(steps.begin(), steps.end())); - - loop::ParallelOp parallelOp = - cast(*opHandle.getOperation()); for (size_t i = 0, e = ivs.size(); i < e; ++i) ivs[i] = Value(parallelOp.getBody()->getArgument(i)); + LoopBuilder result; result.enter(parallelOp.getBody(), /*prev=*/1); return result; } mlir::edsc::LoopBuilder -mlir::edsc::makeLoopBuilder(Value *iv, Value lbHandle, Value ubHandle, - Value stepHandle, ArrayRef iterArgsHandles, +mlir::edsc::makeLoopBuilder(Value *iv, Value lb, Value ub, Value step, + ArrayRef iterArgsHandles, ValueRange iterArgsInitValues) { SmallVector args(iterArgsHandles.size()); - auto res = makeLoopBuilder(iv, lbHandle, ubHandle, stepHandle, args, - iterArgsInitValues); + auto res = makeLoopBuilder(iv, lb, ub, step, args, iterArgsInitValues); for (auto it : llvm::zip(iterArgsHandles, args)) *(std::get<0>(it)) = std::get<1>(it); return res; } -mlir::edsc::LoopBuilder mlir::edsc::makeLoopBuilder( - Value *iv, Value lbHandle, Value ubHandle, Value stepHandle, - MutableArrayRef iterArgsHandles, ValueRange iterArgsInitValues) { +mlir::edsc::LoopBuilder +mlir::edsc::makeLoopBuilder(Value *iv, Value lb, Value ub, Value step, + MutableArrayRef iterArgsHandles, + ValueRange iterArgsInitValues) { mlir::edsc::LoopBuilder result; - auto forOp = OperationHandle::createOp( - lbHandle, ubHandle, stepHandle, iterArgsInitValues); + loop::ForOp forOp = + OperationBuilder(lb, ub, step, iterArgsInitValues); *iv = Value(forOp.getInductionVar()); auto *body = loop::getForInductionVarOwner(*iv).getBody(); for (size_t i = 0, e = iterArgsHandles.size(); i < e; ++i) { diff --git a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp --- a/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp +++ b/mlir/lib/Dialect/StandardOps/EDSC/Intrinsics.cpp @@ -12,37 +12,36 @@ using namespace mlir; using namespace mlir::edsc; -OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle bh, - ArrayRef operands) { +BranchOp mlir::edsc::intrinsics::std_br(BlockHandle bh, + ArrayRef operands) { assert(bh && "Expected already captured BlockHandle"); for (auto &o : operands) { (void)o; assert(o && "Expected already captured Value"); } SmallVector ops(operands.begin(), operands.end()); - return OperationHandle::create(bh.getBlock(), ops); + return OperationBuilder(bh.getBlock(), ops); } -OperationHandle mlir::edsc::intrinsics::std_br(BlockHandle *bh, - ArrayRef types, - ArrayRef captures, - ArrayRef operands) { +BranchOp mlir::edsc::intrinsics::std_br(BlockHandle *bh, ArrayRef types, + ArrayRef captures, + ArrayRef operands) { assert(!*bh && "Unexpected already captured BlockHandle"); BlockBuilder(bh, types, captures)(/* no body */); SmallVector ops(operands.begin(), operands.end()); - return OperationHandle::create(bh->getBlock(), ops); + return OperationBuilder(bh->getBlock(), ops); } -OperationHandle mlir::edsc::intrinsics::std_cond_br( +CondBranchOp mlir::edsc::intrinsics::std_cond_br( Value cond, BlockHandle trueBranch, ArrayRef trueOperands, BlockHandle falseBranch, ArrayRef falseOperands) { SmallVector trueOps(trueOperands.begin(), trueOperands.end()); SmallVector falseOps(falseOperands.begin(), falseOperands.end()); - return OperationHandle::create( - cond, trueBranch.getBlock(), trueOps, falseBranch.getBlock(), falseOps); + return OperationBuilder(cond, trueBranch.getBlock(), trueOps, + falseBranch.getBlock(), falseOps); } -OperationHandle mlir::edsc::intrinsics::std_cond_br( +CondBranchOp mlir::edsc::intrinsics::std_cond_br( Value cond, BlockHandle *trueBranch, ArrayRef trueTypes, ArrayRef trueCaptures, ArrayRef trueOperands, BlockHandle *falseBranch, ArrayRef falseTypes, @@ -53,6 +52,6 @@ BlockBuilder(falseBranch, falseTypes, falseCaptures)(/* no body */); SmallVector trueOps(trueOperands.begin(), trueOperands.end()); SmallVector falseOps(falseOperands.begin(), falseOperands.end()); - return OperationHandle::create( - cond, trueBranch->getBlock(), trueOps, falseBranch->getBlock(), falseOps); + return OperationBuilder(cond, trueBranch->getBlock(), trueOps, + falseBranch->getBlock(), falseOps); } diff --git a/mlir/lib/EDSC/Builders.cpp b/mlir/lib/EDSC/Builders.cpp --- a/mlir/lib/EDSC/Builders.cpp +++ b/mlir/lib/EDSC/Builders.cpp @@ -65,20 +65,6 @@ return getBuilder().getContext(); } -OperationHandle OperationHandle::create(StringRef name, - ArrayRef operands, - ArrayRef resultTypes, - ArrayRef attributes) { - OperationState state(ScopedContext::getLocation(), name); - SmallVector ops(operands.begin(), operands.end()); - state.addOperands(ops); - state.addTypes(resultTypes); - for (const auto &attr : attributes) { - state.addAttribute(attr.first, attr.second); - } - return OperationHandle(ScopedContext::getBuilder().createOperation(state)); -} - BlockHandle mlir::edsc::BlockHandle::create(ArrayRef argTypes) { auto ¤tB = ScopedContext::getBuilder(); auto *ib = currentB.getInsertionBlock();