diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -338,23 +338,26 @@ def Shape_ReduceOp : Shape_Op<"reduce", [SingleBlockImplicitTerminator<"YieldOp">]> { - let summary = "Returns an expression reduced over a shape"; + let summary = "Returns an expression reduced over a shape or extent tensor"; let description = [{ - An operation that takes as input a shape, number of initial values and has a - region/function that is applied repeatedly for every dimension of the shape. + An operation that takes as input a shape or extent tensor, and a number of + initial values. This operation has a region/function that is applied + repeatedly for every extent of the input. Starting with the initial values, + the individual extents are then aggregated as defined by the associated + region. Conceptually this op performs the following reduction: ``` res[] = init; - for (int i = 0, e = shape.rank(); i != e; ++i) { + for (int i = 0, i < shape.rank(); i++) { res = fn(i, shape[i], res[0], ..., res[n]); } ``` - Where fn is provided by the user and the result of the reduce op is the + Where `fn` is provided by the user and the result of the reduce op is the last computed output of the reduce function. As an example, computing the - number of elements + number of elements can be defined as follows: ```mlir func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size { @@ -367,11 +370,10 @@ return %num_elements : !shape.size } ``` - - If the shape is unranked, then the results of the op is also unranked. }]; - let arguments = (ins Shape_ShapeType:$shape, Variadic:$initVals); + let arguments = (ins Shape_ShapeOrExtentTensorType:$shape, + Variadic:$initVals); let results = (outs Variadic:$result); let regions = (region SizedRegion<1>:$region); diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -721,18 +721,31 @@ // Verify block arg types. Block &block = op.region().front(); + // The block takes index, extent, and aggregated values as arguments. auto blockArgsCount = op.initVals().size() + 2; if (block.getNumArguments() != blockArgsCount) return op.emitOpError() << "ReduceOp body is expected to have " << blockArgsCount << " arguments"; - if (block.getArgument(0).getType() != IndexType::get(op.getContext())) + // The first block argument is the index and must always be of type `index`. + if (!block.getArgument(0).getType().isa()) return op.emitOpError( "argument 0 of ReduceOp body is expected to be of IndexType"); - if (block.getArgument(1).getType() != SizeType::get(op.getContext())) - return op.emitOpError( - "argument 1 of ReduceOp body is expected to be of SizeType"); + // The second block argument is the extent and must be of type `size` or + // `index`, depending on whether the reduce operation is applied to a shape or + // to an extent tensor. + Type extentTy = block.getArgument(1).getType(); + if (op.shape().getType().isa()) { + if (!extentTy.isa()) + return op.emitOpError("argument 1 of ReduceOp body is expected to be of " + "SizeType if the ReduceOp operates on a ShapeType"); + } else { + if (!extentTy.isa()) + return op.emitOpError( + "argument 1 of ReduceOp body is expected to be of IndexType if the " + "ReduceOp operates on an extent tensor"); + } for (auto type : llvm::enumerate(op.initVals())) if (block.getArgument(type.index() + 2).getType() != type.value().getType()) @@ -743,17 +756,18 @@ } static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { - auto *ctx = parser.getBuilder().getContext(); // Parse operands. SmallVector operands; + Type shapeOrExtentTensorType; if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || + parser.parseColonType(shapeOrExtentTensorType) || parser.parseOptionalArrowTypeList(result.types)) return failure(); // Resolve operands. auto initVals = llvm::makeArrayRef(operands).drop_front(); - if (parser.resolveOperand(operands.front(), ShapeType::get(ctx), + if (parser.resolveOperand(operands.front(), shapeOrExtentTensorType, result.operands) || parser.resolveOperands(initVals, result.types, parser.getNameLoc(), result.operands)) @@ -773,7 +787,7 @@ static void print(OpAsmPrinter &p, ReduceOp op) { p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() - << ") "; + << ") : " << op.shape().getType(); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.region()); p.printOptionalAttrDict(op.getAttrs()); diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -1,10 +1,10 @@ // RUN: mlir-opt -convert-shape-to-scf -split-input-file %s | FileCheck %s -// CHECK-LABEL: shape_reduce -// CHECK-SAME: [[SHAPE:%.*]]: !shape.shape) -> !shape.size { +// CHECK-LABEL: @shape_reduce +// CHECK-SAME: ([[SHAPE:%.*]]: !shape.shape) -> !shape.size func @shape_reduce(%shape : !shape.shape) -> !shape.size { %init = shape.const_size 1 - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): %new_acc = shape.mul %acc, %dim shape.yield %new_acc : !shape.size diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -2,7 +2,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { // expected-error@+1 {{ReduceOp body is expected to have 3 arguments}} - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: index, %dim: !shape.size): shape.yield %dim : !shape.size } @@ -12,7 +12,7 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) { // expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}} - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: f32, %dim: !shape.size, %acc: !shape.size): %new_acc = "shape.add"(%acc, %dim) : (!shape.size, !shape.size) -> !shape.size @@ -24,7 +24,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) { // expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}} - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: index, %dim: f32, %lci: !shape.size): shape.yield } @@ -34,7 +34,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) { // expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}} - %num_elements = shape.reduce(%shape, %init) -> f32 { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> f32 { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): shape.yield } @@ -44,7 +44,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { // expected-error@+3 {{number of operands does not match number of results of its parent}} - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): shape.yield %dim, %dim : !shape.size, !shape.size } @@ -54,7 +54,7 @@ func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) { // expected-error@+4 {{types mismatch between yield op and its parent}} - %num_elements = shape.reduce(%shape, %init) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): %c0 = constant 1 : index shape.yield %c0 : index diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -6,15 +6,26 @@ // CHECK-LABEL: shape_num_elements func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %init = shape.const_size 0 - %num_elements = shape.reduce(%shape, %init) -> !shape.size { - ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): - %acc = shape.add %lci, %dim - shape.yield %acc : !shape.size + %init = shape.const_size 1 + %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { + ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size): + %acc_next = shape.mul %acc, %extent + shape.yield %acc_next : !shape.size } return %num_elements : !shape.size } +// CHECK-LABEL: extent_tensor_num_elements +func @extent_tensor_num_elements(%shape : tensor) -> index { + %init = constant 1 : index + %num_elements = shape.reduce(%shape, %init) : tensor -> index { + ^bb0(%index : index, %extent : index, %acc : index): + %acc_next = muli %acc, %extent : index + shape.yield %acc_next : index + } + return %num_elements : index +} + func @test_shape_num_elements_unknown() { %0 = "shape.unknown_shape"() : () -> !shape.shape %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) diff --git a/mlir/test/Dialect/Shape/shape-to-shape.mlir b/mlir/test/Dialect/Shape/shape-to-shape.mlir --- a/mlir/test/Dialect/Shape/shape-to-shape.mlir +++ b/mlir/test/Dialect/Shape/shape-to-shape.mlir @@ -1,16 +1,16 @@ // RUN: mlir-opt -shape-to-shape-lowering -split-input-file %s | FileCheck %s // CHECK-LABEL: func @num_elements_to_reduce( -// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> [[SIZE_TY:!.*]] { +// CHECK-SAME: [[ARG:%.*]]: !shape.shape) -> !shape.size { func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size { %num_elements = shape.num_elements %shape return %num_elements : !shape.size } // CHECK: [[C1:%.*]] = shape.const_size 1 -// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) -> [[SIZE_TY]] -// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: [[SIZE_TY]], [[ACC:%.*]]: [[SIZE_TY]] +// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : !shape.shape -> !shape.size +// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: !shape.size, [[ACC:%.*]]: !shape.size // CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]] -// CHECK: shape.yield [[NEW_ACC]] : [[SIZE_TY]] +// CHECK: shape.yield [[NEW_ACC]] : !shape.size // CHECK: } -// CHECK: return [[NUM_ELEMENTS]] : [[SIZE_TY]] +// CHECK: return [[NUM_ELEMENTS]] : !shape.size