diff --git a/mlir/include/mlir/Dialect/CMakeLists.txt b/mlir/include/mlir/Dialect/CMakeLists.txt --- a/mlir/include/mlir/Dialect/CMakeLists.txt +++ b/mlir/include/mlir/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(MemRef) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(PDL) diff --git a/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt b/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/CMakeLists.txt @@ -0,0 +1,2 @@ +add_mlir_dialect(MemRefOps memref) +add_mlir_doc(MemRefOps -gen-dialect-doc MemRefOps Dialects/) \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRef.h @@ -0,0 +1,82 @@ +//===- MemRef.h - MemRef dialect --------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_DIALECT_MEMREF_IR_MEMREF_H_ +#define MLIR_DIALECT_MEMREF_IR_MEMREF_H_ + +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Dialect.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/CallInterfaces.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Interfaces/ViewLikeInterface.h" + +//===----------------------------------------------------------------------===// +// MemRef Dialect +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRefOpsDialect.h.inc" + +//===----------------------------------------------------------------------===// +// MemRef Dialect Operations +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/MemRef/IR/MemRefOps.h.inc" + +//===----------------------------------------------------------------------===// +// MemRef Dialect Helpers +//===----------------------------------------------------------------------===// + +namespace mlir { +namespace memref { + +/// Determines whether MemRef_CastOp casts to a more dynamic version of the +/// source memref. This is useful to to fold a memref.cast into a consuming op +/// and implement canonicalization patterns for ops in different dialects that +/// may consume the results of memref.cast operations. Such foldable memref.cast +/// operations are typically inserted as `view` and `subview` ops and are +/// canonicalized, to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked memrefs with strided semantics and same +/// element type and rank. +/// 2. each of the source's size, offset or stride has more static information +/// than the corresponding result's size, offset or stride. +/// +/// Example 1: +/// ```mlir +/// %1 = memref.cast %0 : memref<8x16xf32> to memref +/// %2 = consumer %1 ... : memref ... +/// ``` +/// +/// may fold into: +/// +/// ```mlir +/// %2 = consumer %0 ... : memref<8x16xf32> ... +/// ``` +/// +/// Example 2: +/// ``` +/// %1 = memref.cast %0 : memref(16 * i + j)>> +/// to memref +/// consumer %1 : memref ... +/// ``` +/// +/// may fold into: +/// +/// ``` +/// consumer %0 ... : memref(16 * i + j)>> +/// ``` +bool canFoldIntoConsumerOp(CastOp castOp); + +} // namespace memref +} // namespace mlir + +#endif // MLIR_DIALECT_MEMREF_IR_MEMREF_H_ \ No newline at end of file diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefBase.td @@ -0,0 +1,26 @@ +//===- MemRefBase.td - Base definitions for memref dialect -*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_BASE +#define MEMREF_BASE + +include "mlir/IR/OpBase.td" + +def MemRef_Dialect : Dialect { + let name = "memref"; + let cppNamespace = "::mlir::memref"; + let description = [{ + + The `memref` dialect is intended to hold core memref creation and + manipulation ops, which are not strongly associated with any particular + other dialect or domain abstraction. + + }]; +} + +#endif // MEMREF_BASE diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -0,0 +1,655 @@ +//===- MemRefOps.td - MemRef op definitions ----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MEMREF_OPS +#define MEMREF_OPS + +include "mlir/Dialect/MemRef/IR/MemRefBase.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/ViewLikeInterface.td" +include "mlir/IR/SymbolInterfaces.td" + +class MemRef_Op traits = []> + : Op { + let printer = [{ return ::print(p, *this); }]; + let verifier = [{ return ::verify(*this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; +} + +//===----------------------------------------------------------------------===// +// AllocLikeOp +//===----------------------------------------------------------------------===// + +// Base class for memref allocating ops: alloca and alloc. +// +// %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> +// +class AllocLikeOp traits = []> : + MemRef_Op { + + let arguments = (ins Variadic:$dynamicSizes, + // The symbolic operands (the ones in square brackets) bind + // to the symbols of the memref's layout map. + Variadic:$symbolOperands, + Confined, [IntMinValue<0>]>:$alignment); + let results = (outs Res]>:$memref); + + let builders = [ + OpBuilderDAG<(ins "MemRefType":$memrefType, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, {}, alignment); + }]>, + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, dynamicSizes, {}, alignment); + }]>, + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + "ValueRange":$symbolOperands, + CArg<"IntegerAttr", "{}">:$alignment), [{ + $_state.types.push_back(memrefType); + $_state.addOperands(dynamicSizes); + $_state.addOperands(symbolOperands); + $_state.addAttribute(getOperandSegmentSizeAttr(), + $_builder.getI32VectorAttr({ + static_cast(dynamicSizes.size()), + static_cast(symbolOperands.size())})); + if (alignment) + $_state.addAttribute(getAlignmentAttrName(), alignment); + }]>]; + + let extraClassDeclaration = [{ + static StringRef getAlignmentAttrName() { return "alignment"; } + + MemRefType getType() { return getResult().getType().cast(); } + + /// Returns the dynamic sizes for this alloc operation if specified. + operand_range getDynamicSizes() { return dynamicSizes(); } + }]; + + let assemblyFormat = [{ + `(`$dynamicSizes`)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref) + }]; + + let hasCanonicalizer = 1; +} + +//===----------------------------------------------------------------------===// +// AllocOp +//===----------------------------------------------------------------------===// + +def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource> { + let summary = "memory allocation operation"; + let description = [{ + The `alloc` operation allocates a region of memory, as specified by its + memref type. + + Example: + + ```mlir + %0 = memref.alloc() : memref<8x64xf32, 1> + ``` + + The optional list of dimension operands are bound to the dynamic dimensions + specified in its memref type. In the example below, the ssa value '%d' is + bound to the second dimension of the memref (which is dynamic). + + ```mlir + %0 = memref.alloc(%d) : memref<8x?xf32, 1> + ``` + + The optional list of symbol operands are bound to the symbols of the + memrefs affine map. In the example below, the ssa value '%s' is bound to + the symbol 's0' in the affine map specified in the allocs memref type. + + ```mlir + %0 = memref.alloc()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + ``` + + This operation returns a single ssa value of memref type, which can be used + by subsequent load and store operations. + + The optional `alignment` attribute may be specified to ensure that the + region of memory that will be indexed is aligned at the specified byte + boundary. + + ```mlir + %0 = memref.alloc()[%s] {alignment = 8} : + memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> + ``` + }]; +} + +//===----------------------------------------------------------------------===// +// AllocaOp +//===----------------------------------------------------------------------===// + +def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> { + let summary = "stack memory allocation operation"; + let description = [{ + The `alloca` operation allocates memory on the stack, to be automatically + released when control transfers back from the region of its closest + surrounding operation with an + [`AutomaticAllocationScope`](../Traits.md#automaticallocationscope) trait. + The amount of memory allocated is specified by its memref and additional + operands. For example: + + ```mlir + %0 = memref.alloca() : memref<8x64xf32> + ``` + + The optional list of dimension operands are bound to the dynamic dimensions + specified in its memref type. In the example below, the SSA value '%d' is + bound to the second dimension of the memref (which is dynamic). + + ```mlir + %0 = memref.alloca(%d) : memref<8x?xf32> + ``` + + The optional list of symbol operands are bound to the symbols of the + memref's affine map. In the example below, the SSA value '%s' is bound to + the symbol 's0' in the affine map specified in the allocs memref type. + + ```mlir + %0 = memref.alloca()[%s] : memref<8x64xf32, + affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>> + ``` + + This operation returns a single SSA value of memref type, which can be used + by subsequent load and store operations. An optional alignment attribute, if + specified, guarantees alignment at least to that boundary. If not specified, + an alignment on any convenient boundary compatible with the type will be + chosen. + }]; +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +def MemRef_CastOp : MemRef_Op<"cast", [NoSideEffect]> { + let summary = "memref cast operation"; + let description = [{ + Syntax: + + ``` + operation ::= ssa-id `=` `memref.cast` ssa-use `:` type `to` type + ``` + + The `memref.cast` operation converts a memref from one type to an equivalent + type with a compatible shape. The source and destination types are + compatible if: + + a. Both are ranked memref types with the same element type, address space, + and rank and: + 1. Both have the same layout or both have compatible strided layouts. + 2. The individual sizes (resp. offset and strides in the case of strided + memrefs) may convert constant dimensions to dynamic dimensions and + vice-versa. + + If the cast converts any dimensions from an unknown to a known size, then it + acts as an assertion that fails at runtime if the dynamic dimensions + disagree with resultant destination size. + + Example: + + ```mlir + // Assert that the input dynamic shape matches the destination static shape. + %2 = memref.cast %1 : memref to memref<4x4xf32> + // Erase static shape information, replacing it with dynamic information. + %3 = memref.cast %1 : memref<4xf32> to memref + + // The same holds true for offsets and strides. + + // Assert that the input dynamic shape matches the destination static stride. + %4 = memref.cast %1 : memref<12x4xf32, offset:?, strides: [?, ?]> to + memref<12x4xf32, offset:5, strides: [4, 1]> + // Erase static offset and stride information, replacing it with + // dynamic information. + %5 = memref.cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to + memref<12x4xf32, offset:?, strides: [?, ?]> + ``` + + b. Either or both memref types are unranked with the same element type, and + address space. + + Example: + + ```mlir + Cast to concrete shape. + %4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32> + + Erase rank information. + %5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32> + ``` + }]; + + let arguments = (ins AnyRankedOrUnrankedMemRef:$source); + let results = (outs AnyRankedOrUnrankedMemRef:$dest); + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; + let builders = [ + OpBuilderDAG<(ins "Value":$source, "Type":$destType), [{ + impl::buildCastOp($_builder, $_state, source, destType); + }]> + ]; + + let extraClassDeclaration = [{ + /// Return true if `a` and `b` are valid operand and result pairs for + /// the operation. + static bool areCastCompatible(Type a, Type b); + + /// The result of a memref_cast is always a memref. + Type getType() { return getResult().getType(); } + + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// + +def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> { + let summary = "memory deallocation operation"; + let description = [{ + The `dealloc` operation frees the region of memory referenced by a memref + which was originally created by the `alloc` operation. + The `dealloc` operation should not be called on memrefs which alias an + alloc'd memref (e.g. memrefs returned by `view` operations). + + Example: + + ```mlir + %0 = memref.alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + memref.dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> + ``` + }]; + + let arguments = (ins Arg:$memref); + + let hasCanonicalizer = 1; + let hasFolder = 1; + let assemblyFormat = "$memref attr-dict `:` type($memref)"; +} + +//===----------------------------------------------------------------------===// +// GetGlobalOp +//===----------------------------------------------------------------------===// + +def MemRef_GetGlobalOp : MemRef_Op<"get_global", + [NoSideEffect, DeclareOpInterfaceMethods]> { + let summary = "get the memref pointing to a global variable"; + let description = [{ + The `memref.get_global` operation retrieves the memref pointing to a + named global variable. If the global variable is marked constant, writing + to the result memref (such as through a `memref.store` operation) is + undefined. + + Example: + + ```mlir + %x = memref.get_global @foo : memref<2xf32> + ``` + }]; + + let arguments = (ins FlatSymbolRefAttr:$name); + let results = (outs AnyStaticShapeMemRef:$result); + let assemblyFormat = "$name `:` type($result) attr-dict"; + + // `GetGlobalOp` is fully verified by its traits. + let verifier = ?; +} + +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> { + let summary = "declare or define a global memref variable"; + let description = [{ + The `memref.global` operation declares or defines a named global variable. + The backing memory for the variable is allocated statically and is described + by the type of the variable (which should be a statically shaped memref + type). The operation is a declaration if no `inital_value` is specified, + else it is a definition. The `initial_value` can either be a unit attribute + to represent a definition of an uninitialized global variable, or an + elements attribute to represent the definition of a global variable with an + initial value. The global variable can also be marked constant using the + `constant` unit attribute. Writing to such constant global variables is + undefined. + + The global variable can be accessed by using the `memref.get_global` to + retrieve the memref for the global variable. Note that the memref + for such global variable itself is immutable (i.e., memref.get_global for a + given global variable will always return the same memref descriptor). + + Example: + + ```mlir + // Private variable with an initial value. + memref.global "private" @x : memref<2xf32> = dense<0.0,2.0> + + // Declaration of an external variable. + memref.global "private" @y : memref<4xi32> + + // Uninitialized externally visible variable. + memref.global @z : memref<3xf16> = uninitialized + + // Externally visible constant variable. + memref.global constant @c : memref<2xi32> = dense<1, 4> + ``` + }]; + + let arguments = (ins + SymbolNameAttr:$sym_name, + OptionalAttr:$sym_visibility, + TypeAttr:$type, + OptionalAttr:$initial_value, + UnitAttr:$constant + ); + + let assemblyFormat = [{ + ($sym_visibility^)? + (`constant` $constant^)? + $sym_name `:` + custom($type, $initial_value) + attr-dict + }]; + + let extraClassDeclaration = [{ + bool isExternal() { return !initial_value(); } + bool isUninitialized() { + return !isExternal() && initial_value().getValue().isa(); + } + }]; +} + +//===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + +def MemRef_PrefetchOp : MemRef_Op<"prefetch"> { + let summary = "prefetch operation"; + let description = [{ + The "prefetch" op prefetches data from a memref location described with + subscript indices similar to std.load, and with three attributes: a + read/write specifier, a locality hint, and a cache type specifier as shown + below: + + ```mlir + memref.prefetch %0[%i, %j], read, locality<3>, data : memref<400x400xi32> + ``` + + The read/write specifier is either 'read' or 'write', the locality hint + ranges from locality<0> (no locality) to locality<3> (extremely local keep + in cache). The cache type specifier is either 'data' or 'instr' + and specifies whether the prefetch is performed on data cache or on + instruction cache. + }]; + + let arguments = (ins AnyMemRef:$memref, Variadic:$indices, + BoolAttr:$isWrite, + Confined, + IntMaxValue<3>]>:$localityHint, + BoolAttr:$isDataCache); + + let extraClassDeclaration = [{ + MemRefType getMemRefType() { + return memref().getType().cast(); + } + static StringRef getLocalityHintAttrName() { return "localityHint"; } + static StringRef getIsWriteAttrName() { return "isWrite"; } + static StringRef getIsDataCacheAttrName() { return "isDataCache"; } + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +def MemRef_ReshapeOp: MemRef_Op<"reshape", [ + ViewLikeOpInterface, NoSideEffect]> { + let summary = "memref reshape operation"; + let description = [{ + The `reshape` operation converts a memref from one type to an + equivalent type with a provided shape. The data is never copied or + modified. The source and destination types are compatible if both have the + same element type, same number of elements, address space and identity + layout map. The following combinations are possible: + + a. Source type is ranked or unranked. Shape argument has static size. + Result type is ranked. + + ```mlir + // Reshape statically-shaped memref. + %dst = memref.reshape %src(%shape) + : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> + %dst0 = memref.reshape %src(%shape0) + : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> + // Flatten unranked memref. + %dst = memref.reshape %src(%shape) + : (memref<*xf32>, memref<1xi32>) to memref + ``` + + b. Source type is ranked or unranked. Shape argument has dynamic size. + Result type is unranked. + + ```mlir + // Reshape dynamically-shaped 1D memref. + %dst = memref.reshape %src(%shape) + : (memref, memref) to memref<*xf32> + // Reshape unranked memref. + %dst = memref.reshape %src(%shape) + : (memref<*xf32>, memref) to memref<*xf32> + ``` + }]; + + let arguments = (ins + AnyRankedOrUnrankedMemRef:$source, + MemRefRankOf<[AnySignlessInteger, Index], [1]>:$shape + ); + let results = (outs AnyRankedOrUnrankedMemRef:$result); + + let builders = [OpBuilderDAG< + (ins "MemRefType":$resultType, "Value":$operand, "Value":$shape), [{ + $_state.addOperands(operand); + $_state.addOperands(shape); + $_state.addTypes(resultType); + }]>]; + + let extraClassDeclaration = [{ + MemRefType getType() { return getResult().getType().cast(); } + Value getViewSource() { return source(); } + }]; + + let assemblyFormat = [{ + $source `(` $shape `)` attr-dict `:` functional-type(operands, results) + }]; +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +def MemRef_StoreOp : MemRef_Op<"store", + [TypesMatchWith<"type of 'value' matches element type of 'memref'", + "memref", "value", + "$_self.cast().getElementType()">, + MemRefsNormalizable]> { + let summary = "store operation"; + let description = [{ + Store a value to a memref location given by indices. The value stored should + have the same type as the elemental type of the memref. The number of + arguments provided within brackets need to match the rank of the memref. + + In an affine context, the indices of a store are restricted to SSA values + bound to surrounding loop induction variables, + [symbols](Affine.md#restrictions-on-dimensions-and-symbols), results of a + [`constant` operation](#stdconstant-constantop), or the result of an + [`affine.apply`](Affine.md#affineapply-affineapplyop) operation that can in + turn take as arguments all of the aforementioned SSA values or the + recursively result of such an `affine.apply` operation. + + Example: + + ```mlir + memref.store %100, %A[%1, 1023] : memref<4x?xf32, #layout, memspace0> + ``` + + **Context:** The `load` and `store` operations are specifically crafted to + fully resolve a reference to an element of a memref, and (in polyhedral + `affine.if` and `affine.for` operations) the compiler can follow use-def + chains (e.g. through [`affine.apply`](Affine.md#affineapply-affineapplyop) + operations) to precisely analyze references at compile-time using polyhedral + techniques. This is possible because of the + [restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) + in these contexts. + }]; + + let arguments = (ins AnyType:$value, + Arg:$memref, + Variadic:$indices); + + let builders = [ + OpBuilderDAG<(ins "Value":$valueToStore, "Value":$memref), [{ + $_state.addOperands(valueToStore); + $_state.addOperands(memref); + }]>]; + + let extraClassDeclaration = [{ + Value getValueToStore() { return getOperand(0); } + + Value getMemRef() { return getOperand(1); } + void setMemRef(Value value) { setOperand(1, value); } + MemRefType getMemRefType() { + return getMemRef().getType().cast(); + } + + operand_range getIndices() { + return {operand_begin() + 2, operand_end()}; + } + }]; + + let hasFolder = 1; + + let assemblyFormat = [{ + $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) + }]; +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>, + Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "`transpose` produces a new strided memref (metadata-only)"; + let description = [{ + The `transpose` op produces a strided memref whose sizes and strides + are a permutation of the original `in` memref. This is purely a metadata + transformation. + + Example: + + ```mlir + %1 = memref.transpose %0 (i, j) -> (j, i) : memref to memref (d1 * s0 + d0)>> + ``` + }]; + + let builders = [ + OpBuilderDAG<(ins "Value":$in, "AffineMapAttr":$permutation, + CArg<"ArrayRef", "{}">:$attrs)>]; + + let extraClassDeclaration = [{ + static StringRef getPermutationAttrName() { return "permutation"; } + ShapedType getShapedType() { return in().getType().cast(); } + }]; + + let hasFolder = 1; +} + +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// + +def MemRef_ViewOp : MemRef_Op<"view", [ + DeclareOpInterfaceMethods, NoSideEffect]> { + let summary = "memref view operation"; + let description = [{ + The "view" operation extracts an N-D contiguous memref with empty layout map + with arbitrary element type from a 1-D contiguous memref with empty layout + map of i8 element type. The ViewOp supports the following arguments: + + * A single dynamic byte-shift operand must be specified which represents a + a shift of the base 1-D memref pointer from which to create the resulting + contiguous memref view with identity layout. + * A dynamic size operand that must be specified for each dynamic dimension + in the resulting view memref type. + + The "view" operation gives a structured indexing form to a flat 1-D buffer. + Unlike "subview" it can perform a type change. The type change behavior + requires the op to have special semantics because, e.g. a byte shift of 3 + cannot be represented as an offset on f64. + For now, a "view" op: + + 1. Only takes a contiguous source memref with 0 offset and empty layout. + 2. Must specify a byte_shift operand (in the future, a special integer + attribute may be added to support the folded case). + 3. Returns a contiguous memref with 0 offset and empty layout. + + Example: + + ```mlir + // Allocate a flat 1D/i8 memref. + %0 = memref.alloc() : memref<2048xi8> + + // ViewOp with dynamic offset and static sizes. + %1 = memref.view %0[%offset_1024][] : memref<2048xi8> to memref<64x4xf32> + + // ViewOp with dynamic offset and two dynamic size. + %2 = memref.view %0[%offset_1024][%size0, %size1] : + memref<2048xi8> to memref + ``` + }]; + + let arguments = (ins MemRefRankOf<[I8], [1]>:$source, + Index:$byte_shift, + Variadic:$sizes); + let results = (outs AnyMemRef); + + let extraClassDeclaration = [{ + /// The result of a view is always a memref. + MemRefType getType() { return getResult().getType().cast(); } + + /// Returns the dynamic sizes for this view operation. This is redundant + /// with `sizes` but needed in template implementations. More specifically: + /// ``` + /// template + /// bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, + /// Region *region) + /// ``` + operand_range getDynamicSizes() { + return {sizes().begin(), sizes().end()}; + } + }]; + + let hasCanonicalizer = 1; +} + +#endif // MEMREF_OPS diff --git a/mlir/include/mlir/InitAllDialects.h b/mlir/include/mlir/InitAllDialects.h --- a/mlir/include/mlir/InitAllDialects.h +++ b/mlir/include/mlir/InitAllDialects.h @@ -28,6 +28,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Dialect/Linalg/IR/LinalgOps.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/OpenACC/OpenACC.h" #include "mlir/Dialect/OpenMP/OpenMPDialect.h" #include "mlir/Dialect/PDL/IR/PDL.h" @@ -60,6 +61,7 @@ LLVM::LLVMArmNeonDialect, LLVM::LLVMArmSVEDialect, linalg::LinalgDialect, + memref::MemRefDialect, scf::SCFDialect, omp::OpenMPDialect, pdl::PDLDialect, diff --git a/mlir/lib/Dialect/CMakeLists.txt b/mlir/lib/Dialect/CMakeLists.txt --- a/mlir/lib/Dialect/CMakeLists.txt +++ b/mlir/lib/Dialect/CMakeLists.txt @@ -7,6 +7,7 @@ add_subdirectory(GPU) add_subdirectory(Linalg) add_subdirectory(LLVMIR) +add_subdirectory(MemRef) add_subdirectory(OpenACC) add_subdirectory(OpenMP) add_subdirectory(PDL) diff --git a/mlir/lib/Dialect/MemRef/CMakeLists.txt b/mlir/lib/Dialect/MemRef/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/CMakeLists.txt @@ -0,0 +1 @@ +add_subdirectory(IR) \ No newline at end of file diff --git a/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/CMakeLists.txt @@ -0,0 +1,17 @@ +add_mlir_dialect_library(MLIRMemRef + MemRefDialect.cpp + MemRefOps.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/inlude/mlir/Dialect/MemRefDialect + + DEPENDS + MLIRMemRefOpsIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRSupport +) \ No newline at end of file diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/MemRefDialect.cpp @@ -0,0 +1,39 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Transforms/InliningUtils.h" + +using namespace mlir; +using namespace mlir::memref; + +//===----------------------------------------------------------------------===// +// MemRefDialect Dialect Interfaces +//===----------------------------------------------------------------------===// + +namespace { +struct MemRefInlinerInterface : public DialectInlinerInterface { + using DialectInlinerInterface::DialectInlinerInterface; + bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, + BlockAndValueMapping &valueMapping) const final { + return true; + } + bool isLegalToInline(Operation *, Region *, bool wouldBeCloned, + BlockAndValueMapping &) const final { + return true; + } +}; +} // end anonymous namespace + +void MemRefDialect::initialize() { + addOperations< +#define GET_OP_LIST +#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" + >(); + addInterfaces(); +} diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp new file mode 100644 --- /dev/null +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -0,0 +1,884 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/StandardOps/IR/Ops.h" +#include "mlir/IR/AffineMap.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "llvm/ADT/STLExtras.h" + +using namespace mlir; +using namespace mlir::memref; + +/// Matches a ConstantIndexOp. +/// TODO: This should probably just be a general matcher that uses m_Constant +/// and checks the operation for an index type. +static detail::op_matcher m_ConstantIndex() { + return detail::op_matcher(); +} + +//===----------------------------------------------------------------------===// +// Common canonicalization pattern support logic +//===----------------------------------------------------------------------===// + +/// This is a common class used for patterns of the form +/// "someop(memrefcast) -> someop". It folds the source of any memref_cast +/// into the root operation directly. +static LogicalResult foldMemRefCast(Operation *op) { + bool folded = false; + for (OpOperand &operand : op->getOpOperands()) { + auto cast = operand.get().getDefiningOp(); + if (cast && !cast.getOperand().getType().isa()) { + operand.set(cast.getOperand()); + folded = true; + } + } + return success(folded); +} + +//===----------------------------------------------------------------------===// +// Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp +//===----------------------------------------------------------------------===// + +static Type getTensorTypeFromMemRefType(Type type) { + if (auto memref = type.dyn_cast()) + return RankedTensorType::get(memref.getShape(), memref.getElementType()); + if (auto memref = type.dyn_cast()) + return UnrankedTensorType::get(memref.getElementType()); + return NoneType::get(type.getContext()); +} + +//===----------------------------------------------------------------------===// +// AllocOp / AllocaOp +//===----------------------------------------------------------------------===// + +template +static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { + static_assert( + llvm::is_one_of::value, + "applies to only alloc or alloca"); + auto memRefType = op.getResult().getType().template dyn_cast(); + if (!memRefType) + return op.emitOpError("result must be a memref"); + + if (static_cast(op.dynamicSizes().size()) != + memRefType.getNumDynamicDims()) + return op.emitOpError("dimension operand count does not equal memref " + "dynamic dimension count"); + + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) + numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (op.symbolOperands().size() != numSymbols) + return op.emitOpError( + "symbol operand count does not equal memref symbol count"); + + return success(); +} + +static LogicalResult verify(memref::AllocOp op) { + return verifyAllocLikeOp(op); +} + +static LogicalResult verify(memref::AllocaOp op) { + // An alloca op needs to have an ancestor with an allocation scope trait. + if (!op->getParentWithTrait()) + return op.emitOpError( + "requires an ancestor op with AutomaticAllocationScope trait"); + + return verifyAllocLikeOp(op); +} + +namespace { +/// Fold constant dimensions into an alloc like operation. +template +struct SimplifyAllocConst : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AllocLikeOp alloc, + PatternRewriter &rewriter) const override { + // Check to see if any dimensions operands are constants. If so, we can + // substitute and drop them. + if (llvm::none_of(alloc.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + auto memrefType = alloc.getType(); + + // Ok, we have one or more constant operands. Collect the non-constant ones + // and keep track of the resultant memref type to build. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + SmallVector newOperands; + + unsigned dynamicDimPos = 0; + for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (dimSize != -1) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(-1); + newOperands.push_back(alloc.getOperand(dynamicDimPos)); + } + dynamicDimPos++; + } + + // Create new memref type (which will have fewer dynamic dimensions). + MemRefType newMemRefType = + MemRefType::Builder(memrefType).setShape(newShapeConstants); + assert(static_cast(newOperands.size()) == + newMemRefType.getNumDynamicDims()); + + // Create and insert the alloc op for the new memref. + auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, + newOperands, IntegerAttr()); + // Insert a cast so we have the same type as the old alloc. + auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, + alloc.getType()); + + rewriter.replaceOp(alloc, {resultCast}); + return success(); + } +}; + +/// Fold alloc operations with no uses. Alloc has side effects on the heap, +/// but can still be deleted if it has zero uses. +struct SimplifyDeadAlloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::AllocOp alloc, + PatternRewriter &rewriter) const override { + if (alloc.use_empty()) { + rewriter.eraseOp(alloc); + return success(); + } + return failure(); + } +}; +} // end anonymous namespace. + +void memref::AllocOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert, SimplifyDeadAlloc>( + context); +} + +void memref::AllocaOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert>(context); +} + +//===----------------------------------------------------------------------===// +// CastOp +//===----------------------------------------------------------------------===// + +/// Determines whether MemRef_CastOp casts to a more dynamic version of the +/// source memref. This is useful to to fold a memref.cast into a consuming op +/// and implement canonicalization patterns for ops in different dialects that +/// may consume the results of memref.cast operations. Such foldable memref.cast +/// operations are typically inserted as `view` and `subview` ops are +/// canonicalized, to preserve the type compatibility of their uses. +/// +/// Returns true when all conditions are met: +/// 1. source and result are ranked memrefs with strided semantics and same +/// element type and rank. +/// 2. each of the source's size, offset or stride has more static information +/// than the corresponding result's size, offset or stride. +/// +/// Example 1: +/// ```mlir +/// %1 = memref.cast %0 : memref<8x16xf32> to memref +/// %2 = consumer %1 ... : memref ... +/// ``` +/// +/// may fold into: +/// +/// ```mlir +/// %2 = consumer %0 ... : memref<8x16xf32> ... +/// ``` +/// +/// Example 2: +/// ``` +/// %1 = memref.cast %0 : memref(16 * i + j)>> +/// to memref +/// consumer %1 : memref ... +/// ``` +/// +/// may fold into: +/// +/// ``` +/// consumer %0 ... : memref(16 * i + j)>> +/// ``` +bool mlir::memref::canFoldIntoConsumerOp(CastOp castOp) { + MemRefType sourceType = castOp.source().getType().dyn_cast(); + MemRefType resultType = castOp.getType().dyn_cast(); + + // Requires ranked MemRefType. + if (!sourceType || !resultType) + return false; + + // Requires same elemental type. + if (sourceType.getElementType() != resultType.getElementType()) + return false; + + // Requires same rank. + if (sourceType.getRank() != resultType.getRank()) + return false; + + // Only fold casts between strided memref forms. + int64_t sourceOffset, resultOffset; + SmallVector sourceStrides, resultStrides; + if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || + failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) + return false; + + // If cast is towards more static sizes along any dimension, don't fold. + for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { + auto ss = std::get<0>(it), st = std::get<1>(it); + if (ss != st) + if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) + return false; + } + + // If cast is towards more static offset along any dimension, don't fold. + if (sourceOffset != resultOffset) + if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && + !MemRefType::isDynamicStrideOrOffset(resultOffset)) + return false; + + // If cast is towards more static strides along any dimension, don't fold. + for (auto it : llvm::zip(sourceStrides, resultStrides)) { + auto ss = std::get<0>(it), st = std::get<1>(it); + if (ss != st) + if (MemRefType::isDynamicStrideOrOffset(ss) && + !MemRefType::isDynamicStrideOrOffset(st)) + return false; + } + + return true; +} + +bool CastOp::areCastCompatible(Type a, Type b) { + auto aT = a.dyn_cast(); + auto bT = b.dyn_cast(); + + auto uaT = a.dyn_cast(); + auto ubT = b.dyn_cast(); + + if (aT && bT) { + if (aT.getElementType() != bT.getElementType()) + return false; + if (aT.getAffineMaps() != bT.getAffineMaps()) { + int64_t aOffset, bOffset; + SmallVector aStrides, bStrides; + if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || + failed(getStridesAndOffset(bT, bStrides, bOffset)) || + aStrides.size() != bStrides.size()) + return false; + + // Strides along a dimension/offset are compatible if the value in the + // source memref is static and the value in the target memref is the + // same. They are also compatible if either one is dynamic (see + // description of MemRefCastOp for details). + auto checkCompatible = [](int64_t a, int64_t b) { + return (a == MemRefType::getDynamicStrideOrOffset() || + b == MemRefType::getDynamicStrideOrOffset() || a == b); + }; + if (!checkCompatible(aOffset, bOffset)) + return false; + for (auto aStride : enumerate(aStrides)) + if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) + return false; + } + if (aT.getMemorySpace() != bT.getMemorySpace()) + return false; + + // They must have the same rank, and any specified dimensions must match. + if (aT.getRank() != bT.getRank()) + return false; + + for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { + int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); + if (aDim != -1 && bDim != -1 && aDim != bDim) + return false; + } + return true; + } else { + if (!aT && !uaT) + return false; + if (!bT && !ubT) + return false; + // Unranked to unranked casting is unsupported + if (uaT && ubT) + return false; + + auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); + auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); + if (aEltType != bEltType) + return false; + + auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); + auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); + if (aMemSpace != bMemSpace) + return false; + + return true; + } + + return false; +} + +OpFoldResult CastOp::fold(ArrayRef operands) { + if (Value folded = impl::foldCastOp(*this)) + return folded; + return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); +} + +//===----------------------------------------------------------------------===// +// DeallocOp +//===----------------------------------------------------------------------===// +namespace { +/// Fold Dealloc operations that are deallocating an AllocOp that is only used +/// by other Dealloc operations. +struct SimplifyDeadDealloc : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::DeallocOp dealloc, + PatternRewriter &rewriter) const override { + // Check that the memref operand's defining operation is an AllocOp. + Value memref = dealloc.memref(); + if (!isa_and_nonnull(memref.getDefiningOp())) + return failure(); + + // Check that all of the uses of the AllocOp are other DeallocOps. + for (auto *user : memref.getUsers()) + if (!isa(user)) + return failure(); + + // Erase the dealloc operation. + rewriter.eraseOp(dealloc); + return success(); + } +}; +} // end anonymous namespace. + +static LogicalResult verify(memref::DeallocOp op) { + if (!op.memref().getType().isa()) + return op.emitOpError("operand must be a memref"); + return success(); +} + +void memref::DeallocOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +LogicalResult memref::DeallocOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + /// dealloc(memrefcast) -> dealloc + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// GlobalOp +//===----------------------------------------------------------------------===// + +static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, + memref::GlobalOp op, + TypeAttr type, + Attribute initialValue) { + p << type; + if (!op.isExternal()) { + p << " = "; + if (op.isUninitialized()) + p << "uninitialized"; + else + p.printAttributeWithoutType(initialValue); + } +} + +static ParseResult +parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, + Attribute &initialValue) { + Type type; + if (parser.parseType(type)) + return failure(); + + auto memrefType = type.dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return parser.emitError(parser.getNameLoc()) + << "type should be static shaped memref, but got " << type; + typeAttr = TypeAttr::get(type); + + if (parser.parseOptionalEqual()) + return success(); + + if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { + initialValue = UnitAttr::get(parser.getBuilder().getContext()); + return success(); + } + + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (parser.parseAttribute(initialValue, tensorType)) + return failure(); + if (!initialValue.isa()) + return parser.emitError(parser.getNameLoc()) + << "initial value should be a unit or elements attribute"; + return success(); +} + +static LogicalResult verify(memref::GlobalOp op) { + auto memrefType = op.type().dyn_cast(); + if (!memrefType || !memrefType.hasStaticShape()) + return op.emitOpError("type should be static shaped memref, but got ") + << op.type(); + + // Verify that the initial value, if present, is either a unit attribute or + // an elements attribute. + if (op.initial_value().hasValue()) { + Attribute initValue = op.initial_value().getValue(); + if (!initValue.isa() && !initValue.isa()) + return op.emitOpError("initial value should be a unit or elements " + "attribute, but got ") + << initValue; + + // Check that the type of the initial value is compatible with the type of + // the global variable. + if (initValue.isa()) { + Type initType = initValue.getType(); + Type tensorType = getTensorTypeFromMemRefType(memrefType); + if (initType != tensorType) + return op.emitOpError("initial value expected to be of type ") + << tensorType << ", but was of type " << initType; + } + } + + // TODO: verify visibility for declarations. + return success(); +} + +//===----------------------------------------------------------------------===// +// GetGlobalOp +//===----------------------------------------------------------------------===// + +LogicalResult +memref::GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { + // Verify that the result type is same as the type of the referenced + // global_memref op. + auto global = + symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); + if (!global) + return emitOpError("'") + << name() << "' does not reference a valid global memref"; + + Type resultType = result().getType(); + if (global.type() != resultType) + return emitOpError("result type ") + << resultType << " does not match type " << global.type() + << " of the global memref @" << name(); + return success(); +} + +//===----------------------------------------------------------------------===// +// PrefetchOp +//===----------------------------------------------------------------------===// + +static void print(OpAsmPrinter &p, memref::PrefetchOp op) { + p << memref::PrefetchOp::getOperationName() << " " << op.memref() << '['; + p.printOperands(op.indices()); + p << ']' << ", " << (op.isWrite() ? "write" : "read"); + p << ", locality<" << op.localityHint(); + p << ">, " << (op.isDataCache() ? "data" : "instr"); + p.printOptionalAttrDict( + op.getAttrs(), + /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); + p << " : " << op.getMemRefType(); +} + +static ParseResult parsePrefetchOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType memrefInfo; + SmallVector indexInfo; + IntegerAttr localityHint; + MemRefType type; + StringRef readOrWrite, cacheType; + + auto indexTy = parser.getBuilder().getIndexType(); + auto i32Type = parser.getBuilder().getIntegerType(32); + if (parser.parseOperand(memrefInfo) || + parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || + parser.parseComma() || parser.parseKeyword(&readOrWrite) || + parser.parseComma() || parser.parseKeyword("locality") || + parser.parseLess() || + parser.parseAttribute(localityHint, i32Type, "localityHint", + result.attributes) || + parser.parseGreater() || parser.parseComma() || + parser.parseKeyword(&cacheType) || parser.parseColonType(type) || + parser.resolveOperand(memrefInfo, type, result.operands) || + parser.resolveOperands(indexInfo, indexTy, result.operands)) + return failure(); + + if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) + return parser.emitError(parser.getNameLoc(), + "rw specifier has to be 'read' or 'write'"); + result.addAttribute( + memref::PrefetchOp::getIsWriteAttrName(), + parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); + + if (!cacheType.equals("data") && !cacheType.equals("instr")) + return parser.emitError(parser.getNameLoc(), + "cache type has to be 'data' or 'instr'"); + + result.addAttribute( + memref::PrefetchOp::getIsDataCacheAttrName(), + parser.getBuilder().getBoolAttr(cacheType.equals("data"))); + + return success(); +} + +static LogicalResult verify(memref::PrefetchOp op) { + if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) + return op.emitOpError("too few indices"); + + return success(); +} + +LogicalResult memref::PrefetchOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + // prefetch(memrefcast) -> prefetch + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// ReshapeOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(memref::ReshapeOp op) { + Type operandType = op.source().getType(); + Type resultType = op.result().getType(); + + Type operandElementType = operandType.cast().getElementType(); + Type resultElementType = resultType.cast().getElementType(); + if (operandElementType != resultElementType) + return op.emitOpError("element types of source and destination memref " + "types should be the same"); + + if (auto operandMemRefType = operandType.dyn_cast()) + if (!operandMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "source memref type should have identity affine map"); + + int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + auto resultMemRefType = resultType.dyn_cast(); + if (resultMemRefType) { + if (!resultMemRefType.getAffineMaps().empty()) + return op.emitOpError( + "result memref type should have identity affine map"); + if (shapeSize == ShapedType::kDynamicSize) + return op.emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked memref type"); + if (shapeSize != resultMemRefType.getRank()) + return op.emitOpError( + "length of shape operand differs from the result's memref rank"); + } + return success(); +} + +//===----------------------------------------------------------------------===// +// StoreOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(memref::StoreOp op) { + if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) + return op.emitOpError("store index operand count not equal to memref rank"); + + return success(); +} + +LogicalResult memref::StoreOp::fold(ArrayRef cstOperands, + SmallVectorImpl &results) { + /// store(memrefcast) -> store + return foldMemRefCast(*this); +} + +//===----------------------------------------------------------------------===// +// TransposeOp +//===----------------------------------------------------------------------===// + +/// Build a strided memref type by applying `permutationMap` tp `memRefType`. +static MemRefType inferTransposeResultType(MemRefType memRefType, + AffineMap permutationMap) { + auto rank = memRefType.getRank(); + auto originalSizes = memRefType.getShape(); + // Compute permuted sizes. + SmallVector sizes(rank, 0); + for (auto en : llvm::enumerate(permutationMap.getResults())) + sizes[en.index()] = + originalSizes[en.value().cast().getPosition()]; + + // Compute permuted strides. + int64_t offset; + SmallVector strides; + auto res = getStridesAndOffset(memRefType, strides, offset); + assert(succeeded(res) && strides.size() == static_cast(rank)); + (void)res; + auto map = + makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); + map = permutationMap ? map.compose(permutationMap) : map; + return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); +} + +void memref::TransposeOp::build(OpBuilder &b, OperationState &result, Value in, + AffineMapAttr permutation, + ArrayRef attrs) { + auto permutationMap = permutation.getValue(); + assert(permutationMap); + + auto memRefType = in.getType().cast(); + // Compute result type. + MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); + + build(b, result, resultType, in, attrs); + result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); +} + +// transpose $in $permutation attr-dict : type($in) `to` type(results) +static void print(OpAsmPrinter &p, memref::TransposeOp op) { + p << "transpose " << op.in() << " " << op.permutation(); + p.printOptionalAttrDict(op.getAttrs(), + {memref::TransposeOp::getPermutationAttrName()}); + p << " : " << op.in().getType() << " to " << op.getType(); +} + +static ParseResult parseTransposeOp(OpAsmParser &parser, + OperationState &result) { + OpAsmParser::OperandType in; + AffineMap permutation; + MemRefType srcType, dstType; + if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(in, srcType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)) + return failure(); + + result.addAttribute(memref::TransposeOp::getPermutationAttrName(), + AffineMapAttr::get(permutation)); + return success(); +} + +static LogicalResult verify(memref::TransposeOp op) { + if (!op.permutation().isPermutation()) + return op.emitOpError("expected a permutation map"); + if (op.permutation().getNumDims() != op.getShapedType().getRank()) + return op.emitOpError( + "expected a permutation map of same rank as the input"); + + auto srcType = op.in().getType().cast(); + auto dstType = op.getType().cast(); + auto transposedType = inferTransposeResultType(srcType, op.permutation()); + if (dstType != transposedType) + return op.emitOpError("output type ") + << dstType << " does not match transposed input type " << srcType + << ", " << transposedType; + return success(); +} + +OpFoldResult memref::TransposeOp::fold(ArrayRef) { + if (succeeded(foldMemRefCast(*this))) + return getResult(); + return {}; +} + +//===----------------------------------------------------------------------===// +// ViewOp +//===----------------------------------------------------------------------===// + +static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector offsetInfo; + SmallVector sizesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + llvm::SMLoc offsetLoc; + if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || + parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) + return failure(); + + if (offsetInfo.size() != 1) + return parser.emitError(offsetLoc) << "expects 1 offset operand"; + + return failure( + parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); +} + +static void print(OpAsmPrinter &p, memref::ViewOp op) { + p << op.getOperationName() << ' ' << op.getOperand(0) << '['; + p.printOperand(op.byte_shift()); + p << "][" << op.sizes() << ']'; + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +} + +static LogicalResult verify(memref::ViewOp op) { + auto baseType = op.getOperand(0).getType().cast(); + auto viewType = op.getType(); + + // The base memref should have identity layout map (or none). + if (baseType.getAffineMaps().size() > 1 || + (baseType.getAffineMaps().size() == 1 && + !baseType.getAffineMaps()[0].isIdentity())) + return op.emitError("unsupported map for base memref type ") << baseType; + + // The result memref should have identity layout map (or none). + if (viewType.getAffineMaps().size() > 1 || + (viewType.getAffineMaps().size() == 1 && + !viewType.getAffineMaps()[0].isIdentity())) + return op.emitError("unsupported map for result memref type ") << viewType; + + // The base memref and the view memref should be in the same memory space. + if (baseType.getMemorySpace() != viewType.getMemorySpace()) + return op.emitError("different memory spaces specified for base memref " + "type ") + << baseType << " and view memref type " << viewType; + + // Verify that we have the correct number of sizes for the result type. + unsigned numDynamicDims = viewType.getNumDynamicDims(); + if (op.sizes().size() != numDynamicDims) + return op.emitError("incorrect number of size operands for type ") + << viewType; + + return success(); +} + +Value memref::ViewOp::getViewSource() { return source(); } + +namespace { + +struct ViewOpShapeFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ViewOp viewOp, + PatternRewriter &rewriter) const override { + // Return if none of the operands are constants. + if (llvm::none_of(viewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) + return failure(); + + // Get result memref type. + auto memrefType = viewOp.getType(); + + // Get offset from old memref view type 'memRefType'. + int64_t oldOffset; + SmallVector oldStrides; + if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) + return failure(); + assert(oldOffset == 0 && "Expected 0 offset"); + + SmallVector newOperands; + + // Offset cannot be folded into result type. + + // Fold any dynamic dim operands which are produced by a constant. + SmallVector newShapeConstants; + newShapeConstants.reserve(memrefType.getRank()); + + unsigned dynamicDimPos = 0; + unsigned rank = memrefType.getRank(); + for (unsigned dim = 0, e = rank; dim < e; ++dim) { + int64_t dimSize = memrefType.getDimSize(dim); + // If this is already static dimension, keep it. + if (!ShapedType::isDynamic(dimSize)) { + newShapeConstants.push_back(dimSize); + continue; + } + auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); + if (auto constantIndexOp = dyn_cast_or_null(defOp)) { + // Dynamic shape dimension will be folded. + newShapeConstants.push_back(constantIndexOp.getValue()); + } else { + // Dynamic shape dimension not folded; copy operand from old memref. + newShapeConstants.push_back(dimSize); + newOperands.push_back(viewOp.sizes()[dynamicDimPos]); + } + dynamicDimPos++; + } + + // Create new memref type with constant folded dims. + MemRefType newMemRefType = + MemRefType::Builder(memrefType).setShape(newShapeConstants); + // Nothing new, don't fold. + if (newMemRefType == memrefType) + return failure(); + + // Create new ViewOp. + auto newViewOp = rewriter.create( + viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), + viewOp.byte_shift(), newOperands); + // Insert a cast so we have the same type as the old memref type. + rewriter.replaceOpWithNewOp(viewOp, newViewOp, + viewOp.getType()); + return success(); + } +}; + +struct ViewOpMemrefCastFolder : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ViewOp viewOp, + PatternRewriter &rewriter) const override { + Value memrefOperand = viewOp.getOperand(0); + memref::CastOp memrefCastOp = memrefOperand.getDefiningOp(); + if (!memrefCastOp) + return failure(); + Value allocOperand = memrefCastOp.getOperand(); + memref::AllocOp allocOp = allocOperand.getDefiningOp(); + if (!allocOp) + return failure(); + rewriter.replaceOpWithNewOp( + viewOp, viewOp.getType(), allocOperand, viewOp.byte_shift(), + viewOp.sizes()); + return success(); + } +}; + +} // end anonymous namespace + +void memref::ViewOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// TableGen'd op method definitions +//===----------------------------------------------------------------------===// + +#define GET_OP_CLASSES +#include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc" diff --git a/mlir/test/mlir-opt/commandline.mlir b/mlir/test/mlir-opt/commandline.mlir --- a/mlir/test/mlir-opt/commandline.mlir +++ b/mlir/test/mlir-opt/commandline.mlir @@ -13,6 +13,7 @@ // CHECK-NEXT: llvm_arm_neon // CHECK-NEXT: llvm_arm_sve // CHECK-NEXT: llvm_avx512 +// CHECK-NEXT: memref // CHECK-NEXT: nvvm // CHECK-NEXT: omp // CHECK-NEXT: pdl