diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -182,7 +182,10 @@ let builders = [ OpBuilder<"Builder *builder, OperationState &result, " "ValueRange lowerBounds, ValueRange upperBounds, " - "ValueRange steps"> + "ValueRange steps">, + OpBuilder<"Builder *builder, OperationState &result, " + "ValueRange lowerBounds, ValueRange upperBounds, " + "ValueRange steps, ArrayRef resultTypes"> ]; let extraClassDeclaration = [{ @@ -231,7 +234,13 @@ }]; - let arguments = (ins AnyType:$operand); + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "Value initValue"> + ]; + + let arguments = (ins AnyType:$initValue); let regions = (region SizedRegion<1>:$reductionOperator); } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -236,6 +236,19 @@ bodyRegion->front().addArgument(builder->getIndexType()); } +void ParallelOp::build(Builder *builder, OperationState &result, ValueRange lbs, + ValueRange ubs, ValueRange steps, + ArrayRef resultTypes) { + result.addTypes(resultTypes); + result.addOperands(lbs); + result.addOperands(ubs); + result.addOperands(steps); + Region *bodyRegion = result.addRegion(); + ParallelOp::ensureTerminator(*bodyRegion, *builder, result.location); + for (size_t i = 0; i < steps.size(); ++i) + bodyRegion->front().addArgument(builder->getIndexType()); +} + static LogicalResult verify(ParallelOp op) { // Check that there is at least one value in lowerBound, upperBound and step. // It is sufficient to test only step, because it is ensured already that the @@ -273,7 +286,7 @@ for (auto resultAndReduce : llvm::zip(op.results(), reductions)) { auto resultType = std::get<0>(resultAndReduce).getType(); auto reduceOp = std::get<1>(resultAndReduce); - auto reduceType = reduceOp.operand().getType(); + auto reduceType = reduceOp.initValue().getType(); if (resultType != reduceType) return reduceOp.emitOpError() << "expects type of reduce to be the same as result type: " @@ -354,9 +367,20 @@ // ReduceOp //===----------------------------------------------------------------------===// +void ReduceOp::build(Builder *builder, OperationState &result, + Value initValue) { + auto type = initValue.getType(); + result.addOperands(initValue); + Region *bodyRegion = result.addRegion(); + + Block *b = new Block(); + b->addArguments(SmallVector({type, type})); + bodyRegion->getBlocks().insert(bodyRegion->end(), b); +} + static LogicalResult verify(ReduceOp op) { // The region of a ReduceOp has two arguments of the same type as its operand. - auto type = op.operand().getType(); + auto type = op.initValue().getType(); Block &block = op.reductionOperator().front(); if (block.empty()) return op.emitOpError("the block inside reduce should not be empty"); @@ -397,9 +421,9 @@ } static void print(OpAsmPrinter &p, ReduceOp op) { - p << op.getOperationName() << "(" << op.operand() << ") "; + p << op.getOperationName() << "(" << op.initValue() << ") "; p.printRegion(op.reductionOperator()); - p << " : " << op.operand().getType(); + p << " : " << op.initValue().getType(); } //===----------------------------------------------------------------------===// @@ -410,7 +434,7 @@ // The type of the return value should be the same type as the type of the // operand of the enclosing ReduceOp. auto reduceOp = cast(op.getParentOp()); - Type reduceType = reduceOp.operand().getType(); + Type reduceType = reduceOp.initValue().getType(); if (reduceType != op.result().getType()) return op.emitOpError() << "needs to have type " << reduceType << " (the type of the enclosing ReduceOp)";