diff --git a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td --- a/mlir/include/mlir/Dialect/LoopOps/LoopOps.td +++ b/mlir/include/mlir/Dialect/LoopOps/LoopOps.td @@ -182,7 +182,10 @@ let builders = [ OpBuilder<"Builder *builder, OperationState &result, " "ValueRange lowerBounds, ValueRange upperBounds, " - "ValueRange steps"> + "ValueRange steps">, + OpBuilder<"Builder *builder, OperationState &result, " + "ValueRange lowerBounds, ValueRange upperBounds, " + "ValueRange steps, ArrayRef resultTypes"> ]; let extraClassDeclaration = [{ @@ -221,8 +224,8 @@ Example: ```mlir - %zero = constant 0.0 : f32 - loop.reduce(%zero) { + %operand = alloc() : memref<1024xf32> + loop.reduce(%operand) { ^bb0(%lhs : f32, %rhs: f32): %res = addf %lhs, %rhs : f32 loop.reduce.return %res : f32 @@ -231,6 +234,12 @@ }]; + let skipDefaultBuilders = 1; + let builders = [ + OpBuilder<"Builder *builder, OperationState &result, " + "Value operand"> + ]; + let arguments = (ins AnyType:$operand); let regions = (region SizedRegion<1>:$reductionOperator); } diff --git a/mlir/lib/Dialect/LoopOps/LoopOps.cpp b/mlir/lib/Dialect/LoopOps/LoopOps.cpp --- a/mlir/lib/Dialect/LoopOps/LoopOps.cpp +++ b/mlir/lib/Dialect/LoopOps/LoopOps.cpp @@ -232,10 +232,17 @@ result.addOperands(steps); Region *bodyRegion = result.addRegion(); ParallelOp::ensureTerminator(*bodyRegion, *builder, result.location); - for (size_t i = 0; i < steps.size(); ++i) + for (size_t i = 0, e = steps.size(); i < e; ++i) bodyRegion->front().addArgument(builder->getIndexType()); } +void ParallelOp::build(Builder *builder, OperationState &result, ValueRange lbs, + ValueRange ubs, ValueRange steps, + ArrayRef resultTypes) { + result.addTypes(resultTypes); + build(builder, result, lbs, ubs, steps); +} + static LogicalResult verify(ParallelOp op) { // Check that there is at least one value in lowerBound, upperBound and step. // It is sufficient to test only step, because it is ensured already that the @@ -354,6 +361,16 @@ // ReduceOp //===----------------------------------------------------------------------===// +void ReduceOp::build(Builder *builder, OperationState &result, Value operand) { + auto type = operand.getType(); + result.addOperands(operand); + Region *bodyRegion = result.addRegion(); + + Block *b = new Block(); + b->addArguments(ArrayRef{type, type}); + bodyRegion->getBlocks().insert(bodyRegion->end(), b); +} + static LogicalResult verify(ReduceOp op) { // The region of a ReduceOp has two arguments of the same type as its operand. auto type = op.operand().getType();