diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -6,5 +6,6 @@ add_subdirectory(LoopOps) add_subdirectory(QuantOps) add_subdirectory(SPIRV) +add_subdirectory(Shape) add_subdirectory(StandardOps) add_subdirectory(VectorOps) diff --git a/mlir/include/mlir/Dialect/Shape/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) diff --git a/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/IR/CMakeLists.txt @@ -0,0 +1 @@ +add_mlir_dialect(ShapeOps ShapeOps) diff --git a/mlir/include/mlir/Dialect/Shape/IR/Shape.h b/mlir/include/mlir/Dialect/Shape/IR/Shape.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/IR/Shape.h @@ -0,0 +1,119 @@ +//===- Shape.h - MLIR Shape dialect -----------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the shape dialect that is used to describe and solve shape +// relations of MLIR operations using ShapedType. +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_SHAPE_IR_SHAPE_H +#define MLIR_SHAPE_IR_SHAPE_H + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Function.h" +#include "mlir/IR/FunctionSupport.h" +#include "mlir/IR/OpDefinition.h" + +namespace mlir { +namespace shape { + +/// This dialect contains shape inference related operations and facilities. +class ShapeDialect : public Dialect { +public: + /// Create the dialect in the given `context`. + explicit ShapeDialect(MLIRContext *context); +}; + +namespace ShapeTypes { +enum Kind { + Dim = Type::FIRST_SHAPE_TYPE, + Shape, + Element, + Component, + ValueShape, +}; +} // namespace ShapeTypes + +/// The type of a single dimension. +class DimType : public Type::TypeBase { +public: + using Base::Base; + + static DimType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Dim); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Dim; } +}; + +/// The shape descriptor type represents rank and dimension sizes. +class ShapeType : public Type::TypeBase { +public: + using Base::Base; + + static ShapeType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Shape); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { return kind == ShapeTypes::Kind::Shape; } +}; + +/// The element type of the shaped type. +class ElementType : public Type::TypeBase { +public: + using Base::Base; + + static ElementType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Element); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == ShapeTypes::Kind::Element; + } +}; + +/// The component type corresponding to shape, element type and attribute. +class ComponentType : public Type::TypeBase { +public: + using Base::Base; + + static ComponentType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::Component); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == ShapeTypes::Kind::Component; + } +}; + +/// The ValueShape represents a (potentially unknown) runtime value and shape. +class ValueShapeType : public Type::TypeBase { +public: + using Base::Base; + + static ValueShapeType get(MLIRContext *context) { + return Base::get(context, ShapeTypes::Kind::ValueShape); + } + + /// Support method to enable LLVM-style type casting. + static bool kindof(unsigned kind) { + return kind == ShapeTypes::Kind::ValueShape; + } +}; + +#define GET_OP_CLASSES +#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc" + +} // namespace shape +} // namespace mlir + +#endif // MLIR_SHAPE_IR_SHAPE_H diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -0,0 +1,294 @@ +//===- Shape.td - Shape operations definition --------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This is the operation definition file for Shape dialect operations. +// +//===----------------------------------------------------------------------===// + +#ifndef SHAPE_OPS +#define SHAPE_OPS + +include "mlir/IR/OpBase.td" + +// TODO(jpienaar): Move to base. +def AnyShaped: ShapedContainerType<[AnyType], IsShapedTypePred, "shaped">; + +//===----------------------------------------------------------------------===// +// Shape Inference dialect definitions +//===----------------------------------------------------------------------===// + +def ShapeDialect : Dialect { + let name = "shape"; + + let summary = "Types and operations for shape dialect"; + let description = [{ + This dialect contains operations for shape inference. + + Note: Unless explicitly stated, all functions that return a shape and take + shapes as input, return the invalid shape if one of its operands is an + invalid shape. This avoids flagging multiple errors for one verification + failure. The dialect itself does not specify how errors should be combined + (there are multiple different options, from always chosing first operand, + concatting etc. on how to combine them). + }]; + + let cppNamespace = "shape"; +} + +def Shape_Dim : DialectType()">, "dim"> { + let typeDescription = [{ + `shape.dim` represents a dimension type which could be either a non-negative + integer, unknown or error. + + Operations on `shape.dim` types are specialized to handle unknown/dynamic + value. So, for example, ` + x == ` for all non-error `x : + !shape.dim` (e.g., an unknown value does not become known due to addition). + }]; +} + +def Shape_Type : DialectType()">, "shape"> { + let typeDescription = [{ + `shape.type` represents either an unranked shape, a ranked shape with + possibly unknown dimensions or an invalid shape. The rank is of type + `shape.dim` and, if rank is known, the extent is a 1D tensor of type + `shape.dim`. + + Shape is printed: + * `[*]` if it is an unranked shape + * `[?, 2]` if a rank 2 tensor with one unknown dimension + * `[3, 4]` is a rank 2 static tensor + * `[]` is a scalar + * `[1]` is a rank 1 tensor with 1 element + * `[invalid]` for an invalid shape + }]; +} + +def Shape_ElementType : DialectType()">, "element type"> { + let typeDescription = [{ + `shape.element_type` represents the element type of the ShapedType. It may + be unknown, error or regular element type supported by ShapedType. + }]; +} + +def Shape_ComponentType : DialectType()">, "component type"> { + let typeDescription = [{ + `shape.element_type` represents the element type of the ShapedType. It may + be unknown, error or regular element type supported by ShapedType. + }]; +} + +def Shape_ValueShape : DialectType()">, "value shape"> { + let typeDescription = [{ + `shape.value_shape` represents the value produced by an operation (this + corresponds to `Value` in the compiler) and a shape. Conceptually this is a + tuple of a value (potentially unknown) and `shape.type`. The value and shape + can either or both be unknown. If both the `value` and `shape` are known, + then the shape of `value` is conformant with `shape`. + }]; +} + +def Shape_DimOrType: AnyTypeOf<[Shape_Dim, Shape_Type], "dim or type">; + +//===----------------------------------------------------------------------===// +// Shape op definitions +//===----------------------------------------------------------------------===// + +// Base class for the operation in this dialect +class Shape_Op traits = []> : + Op; + +// TODO: consider just making a constant op instead. +def Shape_ConstantDimOp : Shape_Op<"constant_dim", []> { + let summary = "Creates a constant dim"; + let description = [{ + An operation that builds a dim from integer attribute. -1 is treated as + unknown, no other negative values are supported. + + ```mlir + %x = shape.constant_dim 10 : !shape.dim + ``` + }]; + + let arguments = (ins I32Attr:$value); + let results = (outs Shape_Dim:$result); +} + +def Shape_ConstantShapeOp : Shape_Op<"constant_shape", []> { + let summary = "Create a constant shape descripton"; + let description = [{ + An operation that builds a Shape of a given rank and extents. The value is + treated like an array of dimensions (converted as if fed to + `shape.constant_dim`) with the rank of the shape equal to the number of + elements in the array. + + Note: it is not possible to create an unranked shape using the op. + }]; + + // TODO: ArrayAttr is not specific enough. + let arguments = (ins ArrayAttr:$value); + let results = (outs Shape_Type:$output); +} + +def Shape_CreateUnrankedShapeOp : Shape_Op<"unranked_shape", []> { + let summary = "Creates an unranked Shape descriptor"; + let description = [{ + An operation that builds a shape without known rank. + }]; + + let results = (outs Shape_Type:$output); +} + +def Shape_CreateShapeOp : Shape_Op<"create_shape", []> { + let summary = "Creates a shape descriptor from a tensor"; + let description = [{ + Creates a shape from a 1D integral tensor. The rank equals the number of + elements in the tensor, and extent matches the values of the elements. + }] + + let arguments = (ins I32Tensor:$input); + let results = (outs Shape_Type:$result); +} + +def Shape_ShapeOfOp : Shape_Op<"shape_of", []> { + let summary = "Returns shape of a value or shaped type operand"; + + let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueType]>:$arg); + let results = (outs Shape_Type:$result); +} + +def Shape_AddOp : Shape_Op<"add", [SameOperandsAndResultType]> { + let summary = "Adds two dimensions"; + let description = [{ + Adds two valid dimensions as follows: + * lhs + rhs = unknown if either lhs or rhs unknown; + * lhs + rhs = (int)lhs + (int)rhs if known; + }]; + + let arguments = (ins Shape_DimOrType:$lhs, Shape_DimOrType:$rhs); + let results = (outs Shape_DimOrType:$result); +} + +def Shape_BroadcastableOp : Shape_Op<"broadcastable", []> { + let summary = "Returns the broadcastable output shape"; + let description = [{ + Computes the broadcastable output shape following: + 1. If any inputs are unranked, output is unranked; + 2. Else the input array with number of dimensions smaller than the max + input dimension, has 1’s prepended to its shapes and the output shape is + calculated as follows: + + output[i] = lhs[i] if lhs[i] == rhs[i] or rhs[i] is unknown/undefined + = rhs[i] if lhs[i] is unknown/undefined + = lhs[i] if rhs[i] == 1 + = rhs[i] if lhs[i] == 1 + = error if lhs[i] != rhs[i] + + Op has an optional string attribute for the error case where there is no + broadcastable output possible for the given inputs. + }]; + + let arguments = (ins Shape_Type:$lhs, Shape_Shape:$rhs, + OptionalAttr:$error); + let results = (outs Shape_Type:$result); +} + +// Previously called any_of and I do like that name, not sure which one is +// easier for folks. +def Shape_JoinOp : Shape_Op<"join", []> { + let summary = "Returns the least general shape/dim of its operands"; + let description = [{ + An operation that computes the least general shape of input operands. This + effectively asserts that corresponding static dimensions are equal. The + behavior is to match each element of the `shape.type` and propagate the most + restrictive information, returning an invalid shape if there are + contradictory requirements. E.g., using pseudo code + + ``` + shape.join([*], [*]) -> [*] + shape.join([*], [1, ?]) -> [1, ?] + shape.join([1, 2], [1, ?]) -> [1, 2] + shape.join([*], [1, 2]) -> [1, 2] + shape.join([], X) -> [] + shape.join([1, U], [2, U, U]) -> [] + ``` + + `shape.join` also allows specifying an optional error string, that may be + used to return an error to the user upon mismatch of dimensions. + + ```mlir + %c = shape.join %a, %b, error="" : !shape.type + ``` + }]; + + let arguments = (ins Shape_DimOrType:$arg0, Shape_DimOrType:$arg1, + OptionalAttr:$error); + let results = (outs Shape_DimOrType:$result); +} + +def Shape_ReduceOp : Shape_Op<"reduce", []> { + let summary = "Returns an expression reduced over shape"; + let description = [{ + An operation that takes as input a shape, number of initial values and has a + region/function that is applied repeatedly for each every dimension of the + shape. + + Conceptually this op performs the following reduction: + + ``` + res[] = init; + for (int i = 0, e = shape.rank(); i != e; ++i) { + res = fn(i, shape[i], res[0], ..., res[n]); + } + ``` + + Where fn is provided by the user and the result of the reduce op is the + last computed output of the reduce function. As an example, computing the + number of elements + + ```mlir + func @shape_num_elements(%shape : !shape.type) -> !shape.dim { + %0 = "shape.constant_dim"() {value = 0 : i32} : () -> !shape.dim + %1 = "shape.reduce"(%shape, %0) ( { + ^bb0(%index: i32, %dim: !si.dim, %lci: !shape.dim): + %acc = "shape.add"(%lci, %dim) : + (!shape.dim, !shape.dim) -> !shape.dim + "shape.return"(%acc) : (!shape.dim) -> () + }) : (!shape.type, !shape.dim) -> (!shape.dim) + return %1 : !shape.dim + } + ``` + + If the shape is unknown, then the results of the op is also unknown. + }]; + + let arguments = (ins Shape_Type:$shape, Variadic:$args); + let results = (outs Variadic:$result); + + let regions = (region SizedRegion<1>:$body); +} + +def Shape_PrintOp : Shape_Op<"print", []> { + let summary = "Prints the input dim or shape"; + let description = [{ + Prints the input dim or shape and passes through input. + + Note: This is intended for testing and debugging only. + }]; + + let arguments = (ins Shape_DimOrType:$input); + let results = (outs Shape_DimOrType:$output); +} + +// TODO: Add Ops: if_static, if_ranked + +#endif // SHAPE_OPS diff --git a/mlir/include/mlir/IR/DialectSymbolRegistry.def b/mlir/include/mlir/IR/DialectSymbolRegistry.def --- a/mlir/include/mlir/IR/DialectSymbolRegistry.def +++ b/mlir/include/mlir/IR/DialectSymbolRegistry.def @@ -24,6 +24,7 @@ DEFINE_SYM_KIND_RANGE(TOY) // Toy language (tutorial) Dialect DEFINE_SYM_KIND_RANGE(SPIRV) // SPIR-V dialect DEFINE_SYM_KIND_RANGE(XLA_HLO) // XLA HLO dialect +DEFINE_SYM_KIND_RANGE(SHAPE) // Shape dialect // The following ranges are reserved for experimenting with MLIR dialects in a // private context without having to register them here. diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(QuantOps) add_subdirectory(SDBM) add_subdirectory(SPIRV) +add_subdirectory(Shape) add_subdirectory(StandardOps) add_subdirectory(VectorOps) diff --git a/mlir/lib/Dialect/Shape/CMakeLists.txt b/mlir/lib/Dialect/Shape/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/CMakeLists.txt @@ -0,0 +1,9 @@ +file(GLOB globbed *.c *.cpp) +add_llvm_library(MLIRShape + ${globbed} + + ADDITIONAL_HEADER_DIRS + ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/Shape + ) +add_dependencies(MLIRShape MLIRShapeOpsIncGen LLVMSupport) +target_link_libraries(MLIRShape LLVMSupport) diff --git a/mlir/lib/Dialect/Shape/DialectRegistration.cpp b/mlir/lib/Dialect/Shape/DialectRegistration.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/Shape/DialectRegistration.cpp @@ -0,0 +1,13 @@ +//===- DialectRegistration.cpp - Register shape dialect -------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Shape/Shape.h" +using namespace mlir; + +// Static initialization for shape dialect registration. +static DialectRegistration ShapeOps;