diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -40,10 +40,13 @@ }]; let cppNamespace = "shape"; + + let hasConstantMaterializer = 1; } def Shape_ComponentType : DialectType()">, "component type"> { + CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type">, + BuildableType<"$_builder.getType<::mlir::shape::ComponentType>()"> { let typeDescription = [{ `shape.element_type` represents the element type of the ShapedType. It may be unknown, error or regular element type supported by ShapedType. @@ -51,7 +54,8 @@ } def Shape_ElementType : DialectType()">, "element type"> { + CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type">, + BuildableType<"$_builder.getType<::mlir::shape::ElementType>()"> { let typeDescription = [{ `shape.element_type` represents the element type of the ShapedType. It may be unknown, error or regular element type supported by ShapedType. @@ -59,7 +63,8 @@ } def Shape_ShapeType : DialectType()">, "shape"> { + CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape">, + BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> { let typeDescription = [{ `shape.type` represents either an unranked shape, a ranked shape with possibly unknown dimensions or an invalid shape. The rank is of type @@ -77,7 +82,8 @@ } def Shape_SizeType : DialectType()">, "size"> { + CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size">, + BuildableType<"$_builder.getType<::mlir::shape::SizeType>()"> { let typeDescription = [{ `shape.size` represents a non-negative integer with support for being unknown and invalid. @@ -89,7 +95,9 @@ } def Shape_ValueShapeType : DialectType()">, "value shape"> { + CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape">, + BuildableType<"::mlir::shape::ValueShapeType::get($_builder.getContext())"> +{ let typeDescription = [{ `shape.value_shape` represents the value produced by an operation (this corresponds to `Value` in the compiler) and a shape. Conceptually this is a @@ -146,27 +154,46 @@ let results = (outs Shape_ShapeType:$result); } -def Shape_ConstantOp : Shape_Op<"constant", []> { - let summary = "Creates a shape constant"; +def Shape_ConstShapeOp : Shape_Op<"const_shape", + [ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Creates a constant of !shape.shape type."; let description = [{ - An operation that builds a size or shape from integer or array attribute. - It allows for creating dynamically valued shapes by using `?` for unknown - values. A constant shape specified with `*` will return an unranked shape. + Creates a !shape.shape with rank given by the length of `shape` and with + dimension sizes given by the values of `shape`. ```mlir - %x = shape.constant 10 : !shape.size + %0 = shape.const_shape [] + %1 = shape.const_shape [1, 2, 3] ``` }]; - - // TODO(jpienaar): Change to a more specialized attribute that would - // encapsulate the unknown parsing while using denser packing. - let arguments = (ins AnyAttr:$value); - let results = (outs Shape_ShapeOrSizeType:$result); + let arguments = (ins I64ElementsAttr:$shape); + let results = (outs Shape_ShapeType:$result); // TODO: Move this to main so that all shape ops implement these. let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasFolder = 1; +} + +def Shape_ConstSizeOp : Shape_Op<"const_size", + [ConstantLike, + NoSideEffect, + DeclareOpInterfaceMethods]> { + let summary = "Creates a constant of !shape.size type."; + let description = [{ + Creates a !shape.size type representing the constant size given by `value`. + + ```mlir + %x = shape.const_size 10 + ``` + }]; + + let arguments = (ins IndexAttr:$value); + let results = (outs Shape_SizeType:$result); + + let assemblyFormat = "attr-dict $value"; } def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> { @@ -291,6 +318,8 @@ let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg); let results = (outs Shape_ShapeType:$result); + + let hasFolder = 1; } def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -10,6 +10,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "llvm/Support/raw_ostream.h" @@ -29,6 +30,19 @@ allowUnknownOperations(); } +Operation *ShapeDialect::materializeConstant(OpBuilder &builder, + Attribute value, Type type, + Location loc) { + if (auto shapeType = type.dyn_cast()) { + return builder.create(loc, type, + value.cast()); + } + if (auto sizeType = type.dyn_cast()) { + return builder.create(loc, type, value.cast()); + } + return nullptr; +} + /// Parse a type registered to this dialect. Type ShapeDialect::parseType(DialectAsmParser &parser) const { StringRef keyword; @@ -74,37 +88,79 @@ } //===----------------------------------------------------------------------===// -// Constant*Op +// ConstShapeOp //===----------------------------------------------------------------------===// -static void print(OpAsmPrinter &p, ConstantOp &op) { - p << "shape.constant "; - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); - - if (op.getAttrs().size() > 1) - p << ' '; - p.printAttributeWithoutType(op.value()); - p << " : " << op.getType(); +static void print(OpAsmPrinter &p, ConstShapeOp &op) { + p << "shape.const_shape "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"}); + p << "["; + interleaveComma(op.shape().getValues(), p, + [&](int64_t i) { p << i; }); + p << "]"; } -static ParseResult parseConstantOp(OpAsmParser &parser, - OperationState &result) { - Attribute valueAttr; +static ParseResult parseConstShapeOp(OpAsmParser &parser, + OperationState &result) { if (parser.parseOptionalAttrDict(result.attributes)) return failure(); - Type i64Type = parser.getBuilder().getIntegerType(64); - if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes)) + // We piggy-back on ArrayAttr parsing, though we don't internally store the + // shape as an ArrayAttr. + // TODO: Implement custom parser and maybe make syntax a bit more concise. + Attribute extentsRaw; + SmallVector dummy; + if (parser.parseAttribute(extentsRaw, "dummy", dummy)) return failure(); - - Type type; - if (parser.parseColonType(type)) + auto extentsArray = extentsRaw.dyn_cast(); + if (!extentsArray) return failure(); + SmallVector ints; + for (Attribute extent : extentsArray) { + IntegerAttr attr = extent.dyn_cast(); + if (!attr) + return failure(); + ints.push_back(attr.getInt()); + } + Builder &builder = parser.getBuilder(); + result.addAttribute("shape", builder.getI64TensorAttr(ints)); - // Add the attribute type to the list. - return parser.addTypeToList(type, result.types); + result.types.push_back(ShapeType::get(builder.getContext())); + return success(); } -static LogicalResult verify(ConstantOp &op) { return success(); } +OpFoldResult ConstShapeOp::fold(ArrayRef) { return shape(); } + +LogicalResult ConstShapeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(ShapeType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// +// ConstSizeOp +//===----------------------------------------------------------------------===// + +LogicalResult ConstSizeOp::inferReturnTypes( + MLIRContext *context, Optional location, ValueRange operands, + ArrayRef attributes, RegionRange regions, + SmallVectorImpl &inferredReturnTypes) { + inferredReturnTypes.push_back(SizeType::get(context)); + return success(); +} + +//===----------------------------------------------------------------------===// +// ShapeOfOp +//===----------------------------------------------------------------------===// + +OpFoldResult ShapeOfOp::fold(ArrayRef) { + auto type = getOperand().getType().dyn_cast(); + if (!type || !type.hasStaticShape()) + return nullptr; + Builder builder(getContext()); + return builder.getI64TensorAttr(type.getShape()); +} //===----------------------------------------------------------------------===// // SplitAtOp diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -0,0 +1,9 @@ +// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail + +// ----- +// CHECK-LABEL: func @f +func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape { + // CHECK: shape.const_shape [2, 3, 4] + %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape + return %0 : !shape.shape +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -1,8 +1,8 @@ -// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure +// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s --dump-input-on-failure // CHECK-LABEL: shape_num_elements func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %0 = shape.constant 0 : !shape.size + %0 = shape.const_size 0 %1 = "shape.reduce"(%shape, %0) ( { ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size @@ -19,40 +19,46 @@ } func @test_shape_num_elements_fixed() { - %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape + %0 = shape.const_shape [1, 57, 92] %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) %3 = "shape.print"(%1) : (!shape.size) -> !shape.size return } func @test_broadcastable_fixed() { - %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape - %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %0 = shape.const_shape [10, 1, 57, 92] + %1 = shape.const_shape [4, 57, 92] %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_fixed() { - %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape - %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %0 = shape.const_shape [4, 57, 92] + %1 = shape.const_shape [4, 57, 92] %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_unknown() { - %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape - %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape + %0 = shape.const_shape [4, -1, 92] + %1 = shape.const_shape [-1, 57, 92] %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } func @test_shape_any_fixed_mismatch() { - %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape - %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape + %0 = shape.const_shape [4, 57, 92] + %1 = shape.const_shape [2, 57, 92] %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape return } + +func @test_parse_const_shape() { + %0 = shape.const_shape [] + %1 = shape.const_shape [1, 2, 3] + return +}