diff --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt --- a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt @@ -1 +1,5 @@ -add_mlir_dialect(ShapeOps shape ShapeOps) +set(LLVM_TARGET_DEFINITIONS ShapeOps.td) +mlir_tablegen(ShapeOps.h.inc -gen-op-decls) +mlir_tablegen(ShapeOps.cpp.inc -gen-op-defs) +mlir_tablegen(ShapeOpsDialect.h.inc -gen-dialect-decls) +add_public_tablegen_target(MLIRShapeOpsIncGen) diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h --- a/mlir/include/mlir/Dialect/Shape/IR/Shape.h +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -27,7 +27,8 @@ Element, Shape, Size, - ValueShape + ValueShape, + LAST_SHAPE_TYPE = ValueShape }; } // namespace ShapeTypes diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -14,6 +14,7 @@ #define SHAPE_OPS include "mlir/IR/OpBase.td" +include "mlir/Interfaces/SideEffects.td" // TODO(jpienaar): Move to base. def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; @@ -40,20 +41,24 @@ let cppNamespace = "shape"; } -def Shape_SizeType : DialectType()">, "dim"> { +def Shape_ComponentType : DialectType()">, "component type"> { let typeDescription = [{ - `shape.size` represents a non-negative integer with support for being - unknown and invalid. + `shape.element_type` represents the element type of the ShapedType. It may + be unknown, error or regular element type supported by ShapedType. + }]; +} - Operations on `shape.size` types are specialized to handle unknown/dynamic - value. So, for example, ` + x == ` for all non-error `x : - !shape.size` (e.g., an unknown value does not become known due to addition). +def Shape_ElementType : DialectType()">, "element type"> { + let typeDescription = [{ + `shape.element_type` represents the element type of the ShapedType. It may + be unknown, error or regular element type supported by ShapedType. }]; } def Shape_ShapeType : DialectType()">, "shape"> { + CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> { let typeDescription = [{ `shape.type` represents either an unranked shape, a ranked shape with possibly unknown dimensions or an invalid shape. The rank is of type @@ -70,24 +75,20 @@ }]; } -def Shape_ElementType : DialectType()">, "element type"> { +def Shape_SizeType : DialectType()">, "size"> { let typeDescription = [{ - `shape.element_type` represents the element type of the ShapedType. It may - be unknown, error or regular element type supported by ShapedType. - }]; -} + `shape.size` represents a non-negative integer with support for being + unknown and invalid. -def Shape_ComponentType : DialectType()">, "component type"> { - let typeDescription = [{ - `shape.element_type` represents the element type of the ShapedType. It may - be unknown, error or regular element type supported by ShapedType. + Operations on `shape.size` types are specialized to handle unknown/dynamic + value. So, for example, ` + x == ` for all non-error `x : + !shape.size` (e.g., an unknown value does not become known due to addition). }]; } def Shape_ValueShapeType : DialectType()">, "value shape"> { + CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> { let typeDescription = [{ `shape.value_shape` represents the value produced by an operation (this corresponds to `Value` in the compiler) and a shape. Conceptually this is a @@ -116,8 +117,8 @@ * lhs + rhs = (int)lhs + (int)rhs if known; }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); - let results = (outs Shape_ShapeType:$result); + let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs); + let results = (outs Shape_SizeType:$result); } def Shape_BroadcastOp : Shape_Op<"broadcast", []> { @@ -158,8 +159,13 @@ // TODO(jpienaar): Change to a more specialized attribute that would // encapsulate the unknown parsing while using denser packing. - let arguments = (ins ArrayAttr:$value); + let arguments = (ins AnyAttr:$value); let results = (outs Shape_ShapeOrSizeType:$result); + + // TODO: Move this to main so that all shape ops implement these. + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; } def Shape_CreateShapeOp : Shape_Op<"create_shape", []> { @@ -214,8 +220,8 @@ - lhs * rhs = (int)lhs * (int)rhs if both known; }]; - let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs); - let results = (outs Shape_ShapeType:$result); + let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs); + let results = (outs Shape_SizeType:$result); } def Shape_ReduceOp : Shape_Op<"reduce", []> { @@ -244,7 +250,7 @@ ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): %acc = "shape.mul"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size - "shape.return"(%acc) : (!shape.size) -> () + shape.yield %acc : !shape.size }) : (!shape.type, !shape.size) -> (!shape.size) return %1 : !shape.size } @@ -266,6 +272,18 @@ let results = (outs Shape_ShapeType:$result); } +def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> { + let summary = "Returns the value to parent op"; + + let arguments = (ins Variadic:$operands); + + let builders = [OpBuilder< + "Builder *b, OperationState &result", [{ build(b, result, llvm::None); }] + >]; + + let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; +} + // TODO: Add Ops: if_static, if_ranked // For testing usage. diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -26,6 +26,7 @@ #include "mlir/Dialect/QuantOps/QuantOps.h" #include "mlir/Dialect/SDBM/SDBMDialect.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/Dialect/Shape/IR/Shape.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/VectorOps/VectorOps.h" #include "mlir/IR/Dialect.h" @@ -50,6 +51,7 @@ registerDialect(); registerDialect(); registerDialect(); + registerDialect(); return true; }(); (void)init_once; diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -7,7 +7,7 @@ add_subdirectory(OpenMP) add_subdirectory(QuantOps) add_subdirectory(SDBM) -#add_subdirectory(Shape) +add_subdirectory(Shape) add_subdirectory(SPIRV) add_subdirectory(StandardOps) add_subdirectory(VectorOps) diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt --- a/mlir/lib/Dialect/Shape/CMakeLists.txt +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -1,12 +1,15 @@ -file(GLOB globbed *.c *.cpp) add_mlir_dialect_library(MLIRShape - ${globbed} + IR/Shape.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape + + DEPENDS + MLIRShapeOpsIncGen ) -add_dependencies(MLIRShape MLIRShapeOpsIncGen LLVMSupport) target_link_libraries(MLIRShape PUBLIC + MLIRIR MLIRSideEffects - LLVMSupport) + LLVMSupport + ) diff --git a/mlir/lib/Dialect/Shape/DialectRegistration.cpp b/mlir/lib/Dialect/Shape/DialectRegistration.cpp deleted file mode 100644 --- a/mlir/lib/Dialect/Shape/DialectRegistration.cpp +++ /dev/null @@ -1,13 +0,0 @@ -//===- DialectRegistration.cpp - Register shape dialect -------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "mlir/Dialect/Shape/IR/Shape.h" -using namespace mlir; - -// Static initialization for shape dialect registration. -static DialectRegistration ShapeOps; diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -0,0 +1,116 @@ +//===- Shape.cpp - MLIR Shape Operations ----------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/IR/Shape.h" + +#include "mlir/IR/Builders.h" +#include "mlir/IR/DialectImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/Support/raw_ostream.h" + +using namespace mlir; +using namespace mlir::shape; + +ShapeDialect::ShapeDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context) { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" + >(); + addTypes(); + // Allow unknown operations during prototyping and testing. As the dialect is + // still evolving it makes it simple to start with an unregistered ops and + // try different variants before actually defining the op. + allowUnknownOperations(); +} + +/// Parse a type registered to this dialect. +Type ShapeDialect::parseType(DialectAsmParser &parser) const { + StringRef keyword; + if (parser.parseKeyword(&keyword)) + return Type(); + + if (keyword == "component") + return ComponentType::get(getContext()); + if (keyword == "element") + return ElementType::get(getContext()); + if (keyword == "shape") + return ShapeType::get(getContext()); + if (keyword == "size") + return SizeType::get(getContext()); + if (keyword == "value_shape") + return ValueShapeType::get(getContext()); + + parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword; + return Type(); +} + +/// Print a type registered to this dialect. +void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const { + switch (type.getKind()) { + case ShapeTypes::Component: + os << "component"; + return; + case ShapeTypes::Element: + os << "element"; + return; + case ShapeTypes::Size: + os << "size"; + return; + case ShapeTypes::Shape: + os << "shape"; + return; + case ShapeTypes::ValueShape: + os << "value_shape"; + return; + default: + llvm_unreachable("unexpected 'shape' type kind"); + } +} + +//===----------------------------------------------------------------------===// +// Constant*Op +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, ConstantOp &op) { + p << "shape.constant "; + p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); + + if (op.getAttrs().size() > 1) + p << ' '; + p.printAttributeWithoutType(op.value()); + p << " : " << op.getType(); +} + +static ParseResult parseConstantOp(OpAsmParser &parser, + OperationState &result) { + Attribute valueAttr; + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + Type i64Type = parser.getBuilder().getIntegerType(64); + if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes)) + return failure(); + + Type type; + if (parser.parseColonType(type)) + return failure(); + + // Add the attribute type to the list. + return parser.addTypeToList(type, result.types); +} + +static LogicalResult verify(ConstantOp &op) { return success(); } + +namespace mlir { +namespace shape { + +#define GET_OP_CLASSES +#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc" + +} // namespace shape +} // namespace mlir diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -0,0 +1,58 @@ +// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure + +// CHECK-LABEL: shape_num_elements +func @shape_num_elements(%shape : !shape.shape) -> !shape.size { + %0 = shape.constant 0 : !shape.size + %1 = "shape.reduce"(%shape, %0) ( { + ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): + %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size + "shape.yield"(%acc) : (!shape.size) -> () + }) : (!shape.shape, !shape.size) -> (!shape.size) + return %1 : !shape.size +} + +func @test_shape_num_elements_unknown() { + %0 = "shape.unknown_shape"() : () -> !shape.shape + %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) + %2 = "shape.print"(%1) : (!shape.size) -> !shape.size + return +} + +func @test_shape_num_elements_fixed() { + %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape + %1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size) + %3 = "shape.print"(%1) : (!shape.size) -> !shape.size + return +} + +func @test_broadcastable_fixed() { + %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape + %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape + return +} + +func @test_shape_any_fixed() { + %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape + return +} + +func @test_shape_any_unknown() { + %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape + %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape + %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape + return +} + +func @test_shape_any_fixed_mismatch() { + %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape + %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape + %2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape + %3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape + return +}