diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -296,7 +296,8 @@ let hasFolder = 1; } -def Shape_ReduceOp : Shape_Op<"reduce", []> { +def Shape_ReduceOp : Shape_Op<"reduce", + [SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Returns an expression reduced over a shape"; let description = [{ An operation that takes as input a shape, number of initial values and has a @@ -316,25 +317,32 @@ number of elements ```mlir - func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size - %1 = "shape.reduce"(%shape, %0) ( { - ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): + func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): %acc = "shape.mul"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size shape.yield %acc : !shape.size - }) : (!shape.shape, !shape.size) -> (!shape.size) - return %1 : !shape.size + }) + return %num_elements : !shape.size } ``` If the shape is unranked, then the results of the op is also unranked. }]; - let arguments = (ins Shape_ShapeType:$shape, Variadic:$args); + let arguments = (ins Shape_ShapeType:$shape, Variadic:$initVals); let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$body); + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value shape, ValueRange initVals">, + ]; + + let verifier = [{ return ::verify(*this); }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; } def Shape_ShapeOfOp : Shape_Op<"shape_of", diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -450,6 +450,89 @@ return DenseIntElementsAttr::get(type, shape); } +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, + ValueRange initVals) { + result.addOperands(shape); + result.addOperands(initVals); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + bodyBlock.addArgument(builder.getIndexType()); + bodyBlock.addArgument(SizeType::get(builder.getContext())); + + for (Type initValType : initVals.getTypes()) { + bodyBlock.addArgument(initValType); + result.addTypes(initValType); + } +} + +static LogicalResult verify(ReduceOp op) { + // Verify block arg types. + Block &block = op.body().front(); + + auto blockArgsCount = op.initVals().size() + 2; + if (block.getNumArguments() != blockArgsCount) + return op.emitOpError() << "ReduceOp body is expected to have " + << blockArgsCount << " arguments"; + + if (block.getArgument(0).getType() != IndexType::get(op.getContext())) + return op.emitOpError( + "argument 0 of ReduceOp body is expected to be of IndexType"); + + if (block.getArgument(1).getType() != SizeType::get(op.getContext())) + return op.emitOpError( + "argument 1 of ReduceOp body is expected to be of SizeType"); + + for (auto type : llvm::enumerate(op.initVals())) + if (block.getArgument(type.index() + 2).getType() != type.value().getType()) + return op.emitOpError() + << "type mismatch between argument " << type.index() + 2 + << " of ReduceOp body and initial value " << type.index(); + return success(); +} + +static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { + auto *ctx = parser.getBuilder().getContext(); + // Parse operands. + SmallVector operands; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren) || + parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Resolve operands. + auto initVals = llvm::makeArrayRef(operands).drop_front(); + if (parser.resolveOperand(operands.front(), ShapeType::get(ctx), + result.operands) || + parser.resolveOperands(initVals, result.types, parser.getNameLoc(), + result.operands)) + return failure(); + + // Parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) + return failure(); + + // Parse attributes. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceOp op) { + p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() + << ") "; + p.printOptionalArrowTypeList(op.getResultTypes()); + p.printRegion(op.body()); + p.printOptionalAttrDict(op.getAttrs()); +} + namespace mlir { namespace shape { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{ReduceOp body is expected to have 3 arguments}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size): + "shape.yield"(%dim) : (!shape.size) -> () + } +} + +// ----- + +func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size): + %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size + "shape.yield"(%acc) : (!shape.size) -> () + } +} + +// ----- + +func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: f32, %lci: !shape.size): + "shape.yield"() : () -> () + } +} + +// ----- + +func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) { + // expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}} + %num_elements = shape.reduce(%shape, %init) -> f32 { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): + "shape.yield"() : () -> () + } +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -1,14 +1,18 @@ // RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s --dump-input-on-failure +// Verify the printed output can be parsed. +// RUN: mlir-opt %s | mlir-opt | FileCheck %s --dump-input-on-failure +// Verify the generic form can be parsed. +// RUN: mlir-opt -mlir-print-op-generic %s | mlir-opt | FileCheck %s --dump-input-on-failure // CHECK-LABEL: shape_num_elements func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %0 = shape.const_size 0 - %1 = "shape.reduce"(%shape, %0) ( { - ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): + %init = shape.const_size 0 + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size "shape.yield"(%acc) : (!shape.size) -> () - }) : (!shape.shape, !shape.size) -> (!shape.size) - return %1 : !shape.size + } + return %num_elements : !shape.size } func @test_shape_num_elements_unknown() {