diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h index d87869270c21..902220b2d9ce 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -1,401 +1,391 @@ //===- Ops.h - Standard MLIR Operations -------------------------*- C++ -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file defines convenience types for working with standard operations // in the MLIR operation set. // //===----------------------------------------------------------------------===// #ifndef MLIR_DIALECT_STANDARDOPS_IR_OPS_H #define MLIR_DIALECT_STANDARDOPS_IR_OPS_H #include "mlir/IR/Builders.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/StandardTypes.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Interfaces/VectorInterfaces.h" #include "mlir/Interfaces/ViewLikeInterface.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/IR/OpsEnums.h.inc" namespace mlir { class AffineMap; class Builder; class FuncOp; class OpBuilder; /// Auxiliary range data structure to unpack the offset, size and stride /// operands of the SubViewOp / SubTensorOp into a list of triples. /// Such a list of triple is sometimes more convenient to manipulate. struct Range { Value offset; Value size; Value stride; }; raw_ostream &operator<<(raw_ostream &os, Range &range); #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" #include "mlir/Dialect/StandardOps/IR/OpsDialect.h.inc" /// This is a refinement of the "constant" op for the case where it is /// returning a float value of FloatType. /// /// %1 = "std.constant"(){value: 42.0} : bf16 /// class ConstantFloatOp : public ConstantOp { public: using ConstantOp::ConstantOp; /// Builds a constant float op producing a float of the specified type. static void build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type); APFloat getValue() { return getAttrOfType("value").getValue(); } static bool classof(Operation *op); }; /// This is a refinement of the "constant" op for the case where it is /// returning an integer value of IntegerType. /// /// %1 = "std.constant"(){value: 42} : i32 /// class ConstantIntOp : public ConstantOp { public: using ConstantOp::ConstantOp; /// Build a constant int op producing an integer of the specified width. static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width); /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. static void build(OpBuilder &builder, OperationState &result, int64_t value, Type type); int64_t getValue() { return getAttrOfType("value").getInt(); } static bool classof(Operation *op); }; /// This is a refinement of the "constant" op for the case where it is /// returning an integer value of Index type. /// /// %1 = "std.constant"(){value: 99} : () -> index /// class ConstantIndexOp : public ConstantOp { public: using ConstantOp::ConstantOp; /// Build a constant int op producing an index. static void build(OpBuilder &builder, OperationState &result, int64_t value); int64_t getValue() { return getAttrOfType("value").getInt(); } static bool classof(Operation *op); }; // DmaStartOp starts a non-blocking DMA operation that transfers data from a // source memref to a destination memref. The source and destination memref need // not be of the same dimensionality, but need to have the same elemental type. // The operands include the source and destination memref's each followed by its // indices, size of the data transfer in terms of the number of elements (of the // elemental type of the memref), a tag memref with its indices, and optionally // at the end, a stride and a number_of_elements_per_stride arguments. The tag // location is used by a DmaWaitOp to check for completion. The indices of the // source memref, destination memref, and the tag memref have the same // restrictions as any load/store. The optional stride arguments should be of // 'index' type, and specify a stride for the slower memory space (memory space // with a lower memory space id), transferring chunks of // number_of_elements_per_stride every stride until %num_elements are // transferred. Either both or no stride arguments should be specified. // // For example, a DmaStartOp operation that transfers 256 elements of a memref // '%src' in memory space 0 at indices [%i, %j] to memref '%dst' in memory space // 1 at indices [%k, %l], would be specified as follows: // // %num_elements = constant 256 // %idx = constant 0 : index // %tag = alloc() : memref<1 x i32, (d0) -> (d0), 4> // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] : // memref<40 x 128 x f32>, (d0) -> (d0), 0>, // memref<2 x 1024 x f32>, (d0) -> (d0), 1>, // memref<1 x i32>, (d0) -> (d0), 2> // // If %stride and %num_elt_per_stride are specified, the DMA is expected to // transfer %num_elt_per_stride elements every %stride elements apart from // memory space 0 until %num_elements are transferred. // // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride, // %num_elt_per_stride : // // TODO: add additional operands to allow source and destination striding, and // multiple stride levels. // TODO: Consider replacing src/dst memref indices with view memrefs. class DmaStartOp : public Op { public: using Op::Op; static void build(OpBuilder &builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, ValueRange destIndices, Value numElements, Value tagMemRef, ValueRange tagIndices, Value stride = nullptr, Value elementsPerStride = nullptr); // Returns the source MemRefType for this DMA operation. Value getSrcMemRef() { return getOperand(0); } // Returns the rank (number of indices) of the source MemRefType. unsigned getSrcMemRefRank() { return getSrcMemRef().getType().cast().getRank(); } // Returns the source memref indices for this DMA operation. operand_range getSrcIndices() { return {getOperation()->operand_begin() + 1, getOperation()->operand_begin() + 1 + getSrcMemRefRank()}; } // Returns the destination MemRefType for this DMA operations. Value getDstMemRef() { return getOperand(1 + getSrcMemRefRank()); } // Returns the rank (number of indices) of the destination MemRefType. unsigned getDstMemRefRank() { return getDstMemRef().getType().cast().getRank(); } unsigned getSrcMemorySpace() { return getSrcMemRef().getType().cast().getMemorySpace(); } unsigned getDstMemorySpace() { return getDstMemRef().getType().cast().getMemorySpace(); } // Returns the destination memref indices for this DMA operation. operand_range getDstIndices() { return {getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1, getOperation()->operand_begin() + 1 + getSrcMemRefRank() + 1 + getDstMemRefRank()}; } // Returns the number of elements being transferred by this DMA operation. Value getNumElements() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank()); } // Returns the Tag MemRef for this DMA operation. Value getTagMemRef() { return getOperand(1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1); } // Returns the rank (number of indices) of the tag MemRefType. unsigned getTagMemRefRank() { return getTagMemRef().getType().cast().getRank(); } // Returns the tag memref index for this DMA operation. operand_range getTagIndices() { unsigned tagIndexStartPos = 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1; return {getOperation()->operand_begin() + tagIndexStartPos, getOperation()->operand_begin() + tagIndexStartPos + getTagMemRefRank()}; } /// Returns true if this is a DMA from a faster memory space to a slower one. bool isDestMemorySpaceFaster() { return (getSrcMemorySpace() < getDstMemorySpace()); } /// Returns true if this is a DMA from a slower memory space to a faster one. bool isSrcMemorySpaceFaster() { // Assumes that a lower number is for a slower memory space. return (getDstMemorySpace() < getSrcMemorySpace()); } /// Given a DMA start operation, returns the operand position of either the /// source or destination memref depending on the one that is at the higher /// level of the memory hierarchy. Asserts failure if neither is true. unsigned getFasterMemPos() { assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster()); return isSrcMemorySpaceFaster() ? 0 : getSrcMemRefRank() + 1; } static StringRef getOperationName() { return "std.dma_start"; } static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); LogicalResult verify(); LogicalResult fold(ArrayRef cstOperands, SmallVectorImpl &results); bool isStrided() { return getNumOperands() != 1 + getSrcMemRefRank() + 1 + getDstMemRefRank() + 1 + 1 + getTagMemRefRank(); } Value getStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1 - 1); } Value getNumElementsPerStride() { if (!isStrided()) return nullptr; return getOperand(getNumOperands() - 1); } }; // DmaWaitOp blocks until the completion of a DMA operation associated with the // tag element '%tag[%index]'. %tag is a memref, and %index has to be an index // with the same restrictions as any load/store index. %num_elements is the // number of elements associated with the DMA operation. For example: // // dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] : // memref<2048 x f32>, (d0) -> (d0), 0>, // memref<256 x f32>, (d0) -> (d0), 1> // memref<1 x i32>, (d0) -> (d0), 2> // ... // ... // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 2> // class DmaWaitOp : public Op { public: using Op::Op; static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, ValueRange tagIndices, Value numElements); static StringRef getOperationName() { return "std.dma_wait"; } // Returns the Tag MemRef associated with the DMA operation being waited on. Value getTagMemRef() { return getOperand(0); } // Returns the tag memref index for this DMA operation. operand_range getTagIndices() { return {getOperation()->operand_begin() + 1, getOperation()->operand_begin() + 1 + getTagMemRefRank()}; } // Returns the rank (number of indices) of the tag memref. unsigned getTagMemRefRank() { return getTagMemRef().getType().cast().getRank(); } // Returns the number of elements transferred in the associated DMA operation. Value getNumElements() { return getOperand(1 + getTagMemRefRank()); } static ParseResult parse(OpAsmParser &parser, OperationState &result); void print(OpAsmPrinter &p); LogicalResult fold(ArrayRef cstOperands, SmallVectorImpl &results); LogicalResult verify(); }; /// Given an `originalShape` and a `reducedShape` assumed to be a subset of /// `originalShape` with some `1` entries erased, return the vector of booleans /// that specifies which of the entries of `originalShape` are keep to obtain /// `reducedShape`. The returned mask can be applied as a projection to /// `originalShape` to obtain the `reducedShape`. This mask is useful to track /// which dimensions must be kept when e.g. compute MemRef strides under /// rank-reducing operations. Return None if reducedShape cannot be obtained /// by dropping only `1` entries in `originalShape`. llvm::Optional> computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape); -/// Prints dimension and symbol list. -void printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, unsigned numDims, - OpAsmPrinter &p); - -/// Parses dimension and symbol list and returns true if parsing failed. -ParseResult parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, - unsigned &numDims); - /// Determines whether MemRefCastOp casts to a more dynamic version of the /// source memref. This is useful to to fold a memref_cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that /// may consume the results of memref_cast operations. Such foldable memref_cast /// operations are typically inserted as `view` and `subview` ops and are /// canonicalized, to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked memrefs with strided semantics and same /// element type and rank. /// 2. each of the source's size, offset or stride has more static information /// than the corresponding result's size, offset or stride. /// /// Example 1: /// ```mlir /// %1 = memref_cast %0 : memref<8x16xf32> to memref /// %2 = consumer %1 ... : memref ... /// ``` /// /// may fold into: /// /// ```mlir /// %2 = consumer %0 ... : memref<8x16xf32> ... /// ``` /// /// Example 2: /// ``` /// %1 = memref_cast %0 : memref(16 * i + j)>> /// to memref /// consumer %1 : memref ... /// ``` /// /// may fold into: /// /// ``` /// consumer %0 ... : memref(16 * i + j)>> /// ``` bool canFoldIntoConsumerOp(MemRefCastOp castOp); /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. /// Determines whether TensorCastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor_cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may /// consume the results of tensor_cast operations. Such foldable tensor_cast /// operations are typically inserted as `subtensor` ops and are canonicalized, /// to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked tensors with same element type and rank. /// 2. the tensor type has more static information than the result /// /// Example: /// ```mlir /// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor /// %2 = consumer %1 ... : tensor ... /// ``` /// /// folds into: /// /// ```mlir /// %2 = consumer %0 ... : tensor<8x16xf32> ... /// ``` bool canFoldIntoConsumerOp(TensorCastOp castOp); /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer /// comparison predicates. bool applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs); /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. bool applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs); } // end namespace mlir #endif // MLIR_DIALECT_IR_STANDARDOPS_IR_OPS_H diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td index d15f06b37fa5..652efa70fe06 100644 --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1,4106 +1,4114 @@ //===- Ops.td - Standard operation definitions -------------*- tablegen -*-===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // Defines some MLIR standard operations. // //===----------------------------------------------------------------------===// #ifndef STANDARD_OPS #define STANDARD_OPS include "mlir/Dialect/StandardOps/IR/StandardOpsBase.td" include "mlir/IR/OpAsmInterface.td" include "mlir/IR/SymbolInterfaces.td" include "mlir/Interfaces/CallInterfaces.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" def StandardOps_Dialect : Dialect { let name = "std"; let cppNamespace = ""; let hasConstantMaterializer = 1; } // Base class for Standard dialect ops. class Std_Op traits = []> : Op { // For every standard op, there needs to be a: // * void print(OpAsmPrinter &p, ${C++ class of Op} op) // * LogicalResult verify(${C++ class of Op} op) // * ParseResult parse${C++ class of Op}(OpAsmParser &parser, // OperationState &result) // functions. let printer = [{ return ::print(p, *this); }]; let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } // Base class for standard cast operations. Requires single operand and result, // but does not constrain them to specific types. class CastOp traits = []> : Std_Op { let results = (outs AnyType); let builders = [ OpBuilderDAG<(ins "Value":$source, "Type":$destType), [{ impl::buildCastOp($_builder, $_state, source, destType); }]> ]; let parser = [{ return impl::parseCastOp(parser, result); }]; let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; let verifier = [{ return ::verifyCastOp(*this); }]; let hasFolder = 1; } // Base class for arithmetic cast operations. class ArithmeticCastOp traits = []> : CastOp { } // Base class for unary ops. Requires single operand and result. Individual // classes will have `operand` accessor. class UnaryOp traits = []> : Op { let results = (outs AnyType); let printer = [{ return printStandardUnaryOp(this->getOperation(), p); }]; } class UnaryOpSameOperandAndResultType traits = []> : UnaryOp { let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; } class FloatUnaryOp traits = []> : UnaryOpSameOperandAndResultType, ElementwiseMappable])>, Arguments<(ins FloatLike:$operand)>; // Base class for standard arithmetic operations. Requires operands and // results to be of the same type, but does not constrain them to specific // types. Individual classes will have `lhs` and `rhs` accessor to operands. class ArithmeticOp traits = []> : Op { let results = (outs AnyType); let parser = [{ return impl::parseOneResultSameOperandTypeOp(parser, result); }]; let printer = [{ return printStandardBinaryOp(this->getOperation(), p); }]; } // Base class for standard arithmetic operations on integers, vectors and // tensors thereof. This operation takes two operands and returns one result, // each of these is required to be of the same type. This type may be an // integer scalar type, a vector whose element type is an integer type, or an // integer tensor. The custom assembly form of the operation is as follows // // i %0, %1 : i32 // class IntArithmeticOp traits = []> : ArithmeticOp])>, Arguments<(ins SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs)>; // Base class for standard arithmetic binary operations on floats, vectors and // tensors thereof. This operation has two operands and returns one result, // each of these is required to be of the same type. This type may be a // floating point scalar type, a vector whose element type is a floating point // type, or a floating point tensor. The custom assembly form of the operation // is as follows // // f %0, %1 : f32 // class FloatArithmeticOp traits = []> : ArithmeticOp])>, Arguments<(ins FloatLike:$lhs, FloatLike:$rhs)>; // Base class for standard arithmetic operations on complex numbers with a // floating-point element type. // These operations take two operands and return one result, all of which must // be complex numbers of the same type. // The assembly format is as follows // // cf %0, %1 : complex // class ComplexFloatArithmeticOp traits = []> : ArithmeticOp, Arguments<(ins Complex:$lhs, Complex:$rhs)>; // Base class for memref allocating ops: alloca and alloc. // // %0 = alloclike(%m)[%s] : memref<8x?xf32, (d0, d1)[s0] -> ((d0 + s0), d1)> // class AllocLikeOp traits = []> : - Std_Op]>], traits)> { - - let arguments = (ins Variadic:$value, + Std_Op]>, + AttrSizedOperandSegments + ], traits)> { + + let arguments = (ins Variadic:$dynamicSizes, + // The symbolic operands (the ones in square brackets) bind + // to the symbols of the memref's layout map. + Variadic:$symbolOperands, Confined, [IntMinValue<0>]>:$alignment); - let results = (outs Res]>); + let results = (outs Res]>:$memref); let builders = [ - OpBuilderDAG<(ins "MemRefType":$memrefType), [{ - $_state.types.push_back(memrefType); + OpBuilderDAG<(ins "MemRefType":$memrefType, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, {}, alignment); }]>, - OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$operands, - CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ - $_state.addOperands(operands); + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + CArg<"IntegerAttr", "IntegerAttr()">:$alignment), [{ + return build($_builder, $_state, memrefType, dynamicSizes, {}, alignment); + }]>, + OpBuilderDAG<(ins "MemRefType":$memrefType, "ValueRange":$dynamicSizes, + "ValueRange":$symbolOperands, + CArg<"IntegerAttr", "{}">:$alignment), [{ $_state.types.push_back(memrefType); + $_state.addOperands(dynamicSizes); + $_state.addOperands(symbolOperands); + $_state.addAttribute(getOperandSegmentSizeAttr(), + $_builder.getI32VectorAttr({ + static_cast(dynamicSizes.size()), + static_cast(symbolOperands.size())})); if (alignment) $_state.addAttribute(getAlignmentAttrName(), alignment); }]>]; let extraClassDeclaration = [{ static StringRef getAlignmentAttrName() { return "alignment"; } MemRefType getType() { return getResult().getType().cast(); } - /// Returns the number of symbolic operands (the ones in square brackets), - /// which bind to the symbols of the memref's layout map. - unsigned getNumSymbolicOperands() { - return getNumOperands() - getType().getNumDynamicDims(); - } - - /// Returns the symbolic operands (the ones in square brackets), which bind - /// to the symbols of the memref's layout map. - operand_range getSymbolicOperands() { - return {operand_begin() + getType().getNumDynamicDims(), operand_end()}; - } - /// Returns the dynamic sizes for this alloc operation if specified. - operand_range getDynamicSizes() { return getOperands(); } + operand_range getDynamicSizes() { return dynamicSizes(); } }]; - let parser = [{ return ::parseAllocLikeOp(parser, result); }]; + let assemblyFormat = [{ + `(`$dynamicSizes`)` (`` `[` $symbolOperands^ `]`)? attr-dict `:` type($memref) + }]; let hasCanonicalizer = 1; } // Base class for ops with static/dynamic offset, sizes and strides // attributes/arguments. class BaseOpWithOffsetSizesAndStrides traits = []> : Std_Op { code extraBaseClassDeclaration = [{ /// Returns the number of dynamic offset operands. int64_t getNumOffsets() { return llvm::size(offsets()); } /// Returns the number of dynamic size operands. int64_t getNumSizes() { return llvm::size(sizes()); } /// Returns the number of dynamic stride operands. int64_t getNumStrides() { return llvm::size(strides()); } /// Returns the dynamic sizes for this subview operation if specified. operand_range getDynamicSizes() { return sizes(); } /// Returns in `staticStrides` the static value of the stride /// operands. Returns failure() if the static value of the stride /// operands could not be retrieved. LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { if (!strides().empty()) return failure(); staticStrides.reserve(static_strides().size()); for (auto s : static_strides().getAsValueRange()) staticStrides.push_back(s.getZExtValue()); return success(); } /// Return the list of Range (i.e. offset, size, stride). Each /// Range entry contains either the dynamic value or a ConstantIndexOp /// constructed with `b` at location `loc`. SmallVector getOrCreateRanges(OpBuilder &b, Location loc); /// Return the offsets as Values. Each Value is either the dynamic /// value specified in the op or a ConstantIndexOp constructed /// with `b` at location `loc` SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { unsigned dynamicIdx = 1; return llvm::to_vector<4>(llvm::map_range( static_offsets().cast(), [&](Attribute a) -> Value { int64_t staticOffset = a.cast().getInt(); if (ShapedType::isDynamicStrideOrOffset(staticOffset)) return getOperand(dynamicIdx++); else return b.create( loc, b.getIndexType(), b.getIndexAttr(staticOffset)); })); } /// Return the sizes as Values. Each Value is either the dynamic /// value specified in the op or a ConstantIndexOp constructed /// with `b` at location `loc` SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { unsigned dynamicIdx = 1 + offsets().size(); return llvm::to_vector<4>(llvm::map_range( static_sizes().cast(), [&](Attribute a) -> Value { int64_t staticSize = a.cast().getInt(); if (ShapedType::isDynamic(staticSize)) return getOperand(dynamicIdx++); else return b.create( loc, b.getIndexType(), b.getIndexAttr(staticSize)); })); } /// Return the strides as Values. Each Value is either the dynamic /// value specified in the op or a ConstantIndexOp constructed with /// `b` at location `loc` SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); return llvm::to_vector<4>(llvm::map_range( static_strides().cast(), [&](Attribute a) -> Value { int64_t staticStride = a.cast().getInt(); if (ShapedType::isDynamicStrideOrOffset(staticStride)) return getOperand(dynamicIdx++); else return b.create( loc, b.getIndexType(), b.getIndexAttr(staticStride)); })); } /// Return the rank of the source ShapedType. unsigned getSourceRank() { return source().getType().cast().getRank(); } /// Return the rank of the result ShapedType. unsigned getResultRank() { return getType().getRank(); } /// Return true if the offset `idx` is a static constant. bool isDynamicOffset(unsigned idx) { APInt v = *(static_offsets().getAsValueRange().begin() + idx); return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); } /// Return true if the size `idx` is a static constant. bool isDynamicSize(unsigned idx) { APInt v = *(static_sizes().getAsValueRange().begin() + idx); return ShapedType::isDynamic(v.getSExtValue()); } /// Return true if the stride `idx` is a static constant. bool isDynamicStride(unsigned idx) { APInt v = *(static_strides().getAsValueRange().begin() + idx); return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); } /// Assert the offset `idx` is a static constant and return its value. int64_t getStaticOffset(unsigned idx) { assert(!isDynamicOffset(idx) && "expected static offset"); APInt v = *(static_offsets().getAsValueRange().begin() + idx); return v.getSExtValue(); } /// Assert the size `idx` is a static constant and return its value. int64_t getStaticSize(unsigned idx) { assert(!isDynamicSize(idx) && "expected static size"); APInt v = *(static_sizes().getAsValueRange().begin() + idx); return v.getSExtValue(); } /// Assert the stride `idx` is a static constant and return its value. int64_t getStaticStride(unsigned idx) { assert(!isDynamicStride(idx) && "expected static stride"); APInt v = *(static_strides().getAsValueRange().begin() + idx); return v.getSExtValue(); } unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, llvm::function_ref isDynamic, unsigned idx) { return std::count_if( attr.getValue().begin(), attr.getValue().begin() + idx, [&](Attribute attr) { return isDynamic(attr.cast().getInt()); }); } /// Assert the offset `idx` is dynamic and return the position of the /// corresponding operand. unsigned getIndexOfDynamicOffset(unsigned idx) { assert(isDynamicOffset(idx) && "expected static offset"); auto numDynamic = getNumDynamicEntriesUpToIdx(static_offsets().cast(), ShapedType::isDynamicStrideOrOffset, idx); return 1 + numDynamic; } /// Assert the size `idx` is dynamic and return the position of the /// corresponding operand. unsigned getIndexOfDynamicSize(unsigned idx) { assert(isDynamicSize(idx) && "expected static size"); auto numDynamic = getNumDynamicEntriesUpToIdx( static_sizes().cast(), ShapedType::isDynamic, idx); return 1 + offsets().size() + numDynamic; } /// Assert the stride `idx` is dynamic and return the position of the /// corresponding operand. unsigned getIndexOfDynamicStride(unsigned idx) { assert(isDynamicStride(idx) && "expected static stride"); auto numDynamic = getNumDynamicEntriesUpToIdx(static_strides().cast(), ShapedType::isDynamicStrideOrOffset, idx); return 1 + offsets().size() + sizes().size() + numDynamic; } /// Assert the offset `idx` is dynamic and return its value. Value getDynamicOffset(unsigned idx) { return getOperand(getIndexOfDynamicOffset(idx)); } /// Assert the size `idx` is dynamic and return its value. Value getDynamicSize(unsigned idx) { return getOperand(getIndexOfDynamicSize(idx)); } /// Assert the stride `idx` is dynamic and return its value. Value getDynamicStride(unsigned idx) { return getOperand(getIndexOfDynamicStride(idx)); } static StringRef getStaticOffsetsAttrName() { return "static_offsets"; } static StringRef getStaticSizesAttrName() { return "static_sizes"; } static StringRef getStaticStridesAttrName() { return "static_strides"; } static ArrayRef getSpecialAttrNames() { static SmallVector names{ getStaticOffsetsAttrName(), getStaticSizesAttrName(), getStaticStridesAttrName(), getOperandSegmentSizeAttr()}; return names; } }]; } //===----------------------------------------------------------------------===// // AbsFOp //===----------------------------------------------------------------------===// def AbsFOp : FloatUnaryOp<"absf"> { let summary = "floating point absolute-value operation"; let description = [{ The `absf` operation computes the absolute value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. Example: ```mlir // Scalar absolute value. %a = absf %b : f64 // SIMD vector element-wise absolute value. %f = absf %g : vector<4xf32> // Tensor element-wise absolute value. %x = absf %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // AddCFOp //===----------------------------------------------------------------------===// def AddCFOp : ComplexFloatArithmeticOp<"addcf"> { let summary = "complex number addition"; let description = [{ The `addcf` operation takes two complex number operands and returns their sum, a single complex number. All operands and result must be of the same type, a complex number with a floating-point element type. Example: ```mlir %a = addcf %b, %c : complex ``` }]; } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// def AddFOp : FloatArithmeticOp<"addf"> { let summary = "floating point addition operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.addf` ssa-use `,` ssa-use `:` type ``` The `addf` operation takes two operands and returns one result, each of these is required to be the same type. This type may be a floating point scalar type, a vector whose element type is a floating point type, or a floating point tensor. Example: ```mlir // Scalar addition. %a = addf %b, %c : f64 // SIMD vector addition, e.g. for Intel SSE. %f = addf %g, %h : vector<4xf32> // Tensor addition. %x = addf %y, %z : tensor<4x?xbf16> ``` TODO: In the distant future, this will accept optional attributes for fast math, contraction, rounding mode, and other controls. }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// def AddIOp : IntArithmeticOp<"addi", [Commutative]> { let summary = "integer addition operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.addi` ssa-use `,` ssa-use `:` type ``` The `addi` operation takes two operands and returns one result, each of these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. It has no standard attributes. Example: ```mlir // Scalar addition. %a = addi %b, %c : i64 // SIMD vector element-wise addition, e.g. for Intel SSE. %f = addi %g, %h : vector<4xi32> // Tensor element-wise addition. %x = addi %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // AllocOp //===----------------------------------------------------------------------===// def AllocOp : AllocLikeOp<"alloc", DefaultResource> { let summary = "memory allocation operation"; let description = [{ The `alloc` operation allocates a region of memory, as specified by its memref type. Example: ```mlir %0 = alloc() : memref<8x64xf32, 1> ``` The optional list of dimension operands are bound to the dynamic dimensions specified in its memref type. In the example below, the ssa value '%d' is bound to the second dimension of the memref (which is dynamic). ```mlir %0 = alloc(%d) : memref<8x?xf32, 1> ``` The optional list of symbol operands are bound to the symbols of the memrefs affine map. In the example below, the ssa value '%s' is bound to the symbol 's0' in the affine map specified in the allocs memref type. ```mlir %0 = alloc()[%s] : memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> ``` This operation returns a single ssa value of memref type, which can be used by subsequent load and store operations. The optional `alignment` attribute may be specified to ensure that the region of memory that will be indexed is aligned at the specified byte boundary. ```mlir %0 = alloc()[%s] {alignment = 8} : memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> ``` }]; } //===----------------------------------------------------------------------===// // AllocaOp //===----------------------------------------------------------------------===// def AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> { let summary = "stack memory allocation operation"; let description = [{ The `alloca` operation allocates memory on the stack, to be automatically released when control transfers back from the region of its closest surrounding operation with an [`AutomaticAllocationScope`](../Traits.md#automaticallocationscope) trait. The amount of memory allocated is specified by its memref and additional operands. For example: ```mlir %0 = alloca() : memref<8x64xf32> ``` The optional list of dimension operands are bound to the dynamic dimensions specified in its memref type. In the example below, the SSA value '%d' is bound to the second dimension of the memref (which is dynamic). ```mlir %0 = alloca(%d) : memref<8x?xf32> ``` The optional list of symbol operands are bound to the symbols of the memref's affine map. In the example below, the SSA value '%s' is bound to the symbol 's0' in the affine map specified in the allocs memref type. ```mlir %0 = alloca()[%s] : memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>> ``` This operation returns a single SSA value of memref type, which can be used by subsequent load and store operations. An optional alignment attribute, if specified, guarantees alignment at least to that boundary. If not specified, an alignment on any convenient boundary compatible with the type will be chosen. }]; } //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// def AndOp : IntArithmeticOp<"and", [Commutative]> { let summary = "integer binary and"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.and` ssa-use `,` ssa-use `:` type ``` The `and` operation takes two operands and returns one result, each of these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. It has no standard attributes. Example: ```mlir // Scalar integer bitwise and. %a = and %b, %c : i64 // SIMD vector element-wise bitwise integer and. %f = and %g, %h : vector<4xi32> // Tensor element-wise bitwise integer and. %x = and %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// def AssertOp : Std_Op<"assert"> { let summary = "Assert operation with message attribute"; let description = [{ Assert operation with single boolean operand and an error message attribute. If the argument is `true` this operation has no effect. Otherwise, the program execution will abort. The provided error message may be used by a runtime to propagate the error to the user. Example: ```mlir assert %b, "Expected ... to be true" ``` }]; let arguments = (ins I1:$arg, StrAttr:$msg); let assemblyFormat = "$arg `,` $msg attr-dict"; // AssertOp is fully verified by its traits. let verifier = ?; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// def AssumeAlignmentOp : Std_Op<"assume_alignment"> { let summary = "assertion that gives alignment information to the input memref"; let description = [{ The `assume_alignment` operation takes a memref and an integer of alignment value, and internally annotates the buffer with the given alignment. If the buffer isn't aligned to the given alignment, the behavior is undefined. This operation doesn't affect the semantics of a correct program. It's for optimization only, and the optimization is best-effort. }]; let arguments = (ins AnyMemRef:$memref, Confined:$alignment); let results = (outs); let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; } //===----------------------------------------------------------------------===// // AtanOp //===----------------------------------------------------------------------===// def AtanOp : FloatUnaryOp<"atan", []>{ let summary = "arcus tangent of the given value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.atan` ssa-use `:` type ``` The `atan` operation computes the arcus tangent of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Arcus tangent of scalar value. %a = atan %b : f64 // SIMD vector element-wise arcus tangent. %f = atan %g : vector<4xf32> // Tensor element-wise arcus tangent. %x = atan %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // Atan2Op //===----------------------------------------------------------------------===// def Atan2Op : FloatArithmeticOp<"atan2">{ let summary = "2-argument arcus tangent of the given values"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.atan2` ssa-use `,` ssa-use `:` type ``` The `atan2` operation takes two operands and returns one result, all of which must be of the same type. This type may be a floating point scalar type, a vector whose element type is a floating point type, or a floating point tensor. The 2-argument arcus tangent `atan2(y, x)` returns the angle in the Euclidian plane between the positive x-axis and the ray through the point (x, y). It is a generalization of the 1-argument arcus tangent which returns the angle on the basis of the ratio y/x. See also https://en.wikipedia.org/wiki/Atan2 Example: ```mlir // Scalar variant. %a = atan2 %b, %c : f32 // SIMD vector variant. %f = atan2 %g, %h : vector<4xf32> // Tensor variant. %x = atan2 %y, %z : tensor<4x?xf32> ``` }]; } //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// def AtomicRMWOp : Std_Op<"atomic_rmw", [ AllTypesMatch<["value", "result"]>, TypesMatchWith<"value type matches element type of memref", "memref", "value", "$_self.cast().getElementType()"> ]> { let summary = "atomic read-modify-write operation"; let description = [{ The `atomic_rmw` operation provides a way to perform a read-modify-write sequence that is free from data races. The kind enumeration specifies the modification to perform. The value operand represents the new value to be applied during the modification. The memref operand represents the buffer that the read and write will be performed against, as accessed by the specified indices. The arity of the indices is the rank of the memref. The result represents the latest value that was stored. Example: ```mlir %x = atomic_rmw "addf" %value, %I[%i] : (f32, memref<10xf32>) -> f32 ``` }]; let arguments = (ins AtomicRMWKindAttr:$kind, AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$value, MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, Variadic:$indices); let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); let assemblyFormat = [{ $kind $value `,` $memref `[` $indices `]` attr-dict `:` `(` type($value) `,` type($memref) `)` `->` type($result) }]; let extraClassDeclaration = [{ MemRefType getMemRefType() { return memref().getType().cast(); } }]; } def GenericAtomicRMWOp : Std_Op<"generic_atomic_rmw", [ SingleBlockImplicitTerminator<"AtomicYieldOp">, TypesMatchWith<"result type matches element type of memref", "memref", "result", "$_self.cast().getElementType()"> ]> { let summary = "atomic read-modify-write operation with a region"; let description = [{ The `generic_atomic_rmw` operation provides a way to perform a read-modify-write sequence that is free from data races. The memref operand represents the buffer that the read and write will be performed against, as accessed by the specified indices. The arity of the indices is the rank of the memref. The result represents the latest value that was stored. The region contains the code for the modification itself. The entry block has a single argument that represents the value stored in `memref[indices]` before the write is performed. No side-effecting ops are allowed in the body of `GenericAtomicRMWOp`. Example: ```mlir %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%current_value : f32): %c1 = constant 1.0 : f32 %inc = addf %c1, %current_value : f32 atomic_yield %inc : f32 } ``` }]; let arguments = (ins MemRefOf<[AnySignlessInteger, AnyFloat]>:$memref, Variadic:$indices); let results = (outs AnyTypeOf<[AnySignlessInteger, AnyFloat]>:$result); let regions = (region AnyRegion:$body); let skipDefaultBuilders = 1; let builders = [OpBuilderDAG<(ins "Value":$memref, "ValueRange":$ivs)>]; let extraClassDeclaration = [{ // The value stored in memref[ivs]. Value getCurrentValue() { return body().getArgument(0); } MemRefType getMemRefType() { return memref().getType().cast(); } }]; } def AtomicYieldOp : Std_Op<"atomic_yield", [ HasParent<"GenericAtomicRMWOp">, NoSideEffect, Terminator ]> { let summary = "yield operation for GenericAtomicRMWOp"; let description = [{ "atomic_yield" yields an SSA value from a GenericAtomicRMWOp region. }]; let arguments = (ins AnyType:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; } //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// def BranchOp : Std_Op<"br", [DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { let summary = "branch operation"; let description = [{ The `br` operation represents a branch operation in a function. The operation takes variable number of operands and produces no results. The operand number and types for each successor must match the arguments of the block successor. Example: ```mlir ^bb2: %2 = call @someFn() br ^bb3(%2 : tensor<*xf32>) ^bb3(%3: tensor<*xf32>): ``` }]; let arguments = (ins Variadic:$destOperands); let successors = (successor AnySuccessor:$dest); let builders = [ OpBuilderDAG<(ins "Block *":$dest, CArg<"ValueRange", "{}">:$destOperands), [{ $_state.addSuccessors(dest); $_state.addOperands(destOperands); }]>]; // BranchOp is fully verified by traits. let verifier = ?; let extraClassDeclaration = [{ Block *getDest(); void setDest(Block *block); /// Erase the operand at 'index' from the operand list. void eraseOperand(unsigned index); }]; let hasCanonicalizer = 1; let assemblyFormat = [{ $dest (`(` $destOperands^ `:` type($destOperands) `)`)? attr-dict }]; } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// def CallOp : Std_Op<"call", [CallOpInterface, MemRefsNormalizable, DeclareOpInterfaceMethods]> { let summary = "call operation"; let description = [{ The `call` operation represents a direct call to a function that is within the same symbol scope as the call. The operands and result types of the call must match the specified function type. The callee is encoded as a symbol reference attribute named "callee". Example: ```mlir %2 = call @my_add(%0, %1) : (f32, f32) -> f32 ``` }]; let arguments = (ins FlatSymbolRefAttr:$callee, Variadic:$operands); let results = (outs Variadic); let builders = [ OpBuilderDAG<(ins "FuncOp":$callee, CArg<"ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); $_state.addAttribute("callee",$_builder.getSymbolRefAttr(callee)); $_state.addTypes(callee.getType().getResults()); }]>, OpBuilderDAG<(ins "SymbolRefAttr":$callee, "TypeRange":$results, CArg<"ValueRange", "{}">:$operands), [{ $_state.addOperands(operands); $_state.addAttribute("callee", callee); $_state.addTypes(results); }]>, OpBuilderDAG<(ins "StringRef":$callee, "TypeRange":$results, CArg<"ValueRange", "{}">:$operands), [{ build($_builder, $_state, $_builder.getSymbolRefAttr(callee), results, operands); }]>]; let extraClassDeclaration = [{ StringRef getCallee() { return callee(); } FunctionType getCalleeType(); /// Get the argument operands to the called function. operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } operand_iterator arg_operand_begin() { return operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getAttrOfType("callee"); } }]; let assemblyFormat = [{ $callee `(` $operands `)` attr-dict `:` functional-type($operands, results) }]; let verifier = ?; } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// def CallIndirectOp : Std_Op<"call_indirect", [ CallOpInterface, TypesMatchWith<"callee input types match argument types", "callee", "operands", "$_self.cast().getInputs()">, TypesMatchWith<"callee result types match result types", "callee", "results", "$_self.cast().getResults()"> ]> { let summary = "indirect call operation"; let description = [{ The `call_indirect` operation represents an indirect call to a value of function type. Functions are first class types in MLIR, and may be passed as arguments and merged together with block arguments. The operands and result types of the call must match the specified function type. Function values can be created with the [`constant` operation](#stdconstant-constantop). Example: ```mlir %31 = call_indirect %15(%0, %1) : (tensor<16xf32>, tensor<16xf32>) -> tensor<16xf32> ``` }]; let arguments = (ins FunctionType:$callee, Variadic:$operands); let results = (outs Variadic:$results); let builders = [ OpBuilderDAG<(ins "Value":$callee, CArg<"ValueRange", "{}">:$operands), [{ $_state.operands.push_back(callee); $_state.addOperands(operands); $_state.addTypes(callee.getType().cast().getResults()); }]>]; let extraClassDeclaration = [{ Value getCallee() { return getOperand(0); } /// Get the argument operands to the called function. operand_range getArgOperands() { return {arg_operand_begin(), arg_operand_end()}; } operand_iterator arg_operand_begin() { return ++operand_begin(); } operand_iterator arg_operand_end() { return operand_end(); } /// Return the callee of this operation. CallInterfaceCallable getCallableForCallee() { return getCallee(); } }]; let verifier = ?; let hasCanonicalizer = 1; let assemblyFormat = "$callee `(` $operands `)` attr-dict `:` type($callee)"; } //===----------------------------------------------------------------------===// // CeilFOp //===----------------------------------------------------------------------===// def CeilFOp : FloatUnaryOp<"ceilf"> { let summary = "ceiling of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.ceilf` ssa-use `:` type ``` The `ceilf` operation computes the ceiling of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar ceiling value. %a = ceilf %b : f64 // SIMD vector element-wise ceiling value. %f = ceilf %g : vector<4xf32> // Tensor element-wise ceiling value. %x = ceilf %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // FloorFOp //===----------------------------------------------------------------------===// def FloorFOp : FloatUnaryOp<"floorf"> { let summary = "floor of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.floorf` ssa-use `:` type ``` The `floorf` operation computes the floor of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar floor value. %a = floorf %b : f64 // SIMD vector element-wise floor value. %f = floorf %g : vector<4xf32> // Tensor element-wise floor value. %x = floorf %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// // The predicate indicates the type of the comparison to perform: // (un)orderedness, (in)equality and less/greater than (or equal to) as // well as predicates that are always true or false. def CMPF_P_FALSE : I64EnumAttrCase<"AlwaysFalse", 0, "false">; def CMPF_P_OEQ : I64EnumAttrCase<"OEQ", 1, "oeq">; def CMPF_P_OGT : I64EnumAttrCase<"OGT", 2, "ogt">; def CMPF_P_OGE : I64EnumAttrCase<"OGE", 3, "oge">; def CMPF_P_OLT : I64EnumAttrCase<"OLT", 4, "olt">; def CMPF_P_OLE : I64EnumAttrCase<"OLE", 5, "ole">; def CMPF_P_ONE : I64EnumAttrCase<"ONE", 6, "one">; def CMPF_P_ORD : I64EnumAttrCase<"ORD", 7, "ord">; def CMPF_P_UEQ : I64EnumAttrCase<"UEQ", 8, "ueq">; def CMPF_P_UGT : I64EnumAttrCase<"UGT", 9, "ugt">; def CMPF_P_UGE : I64EnumAttrCase<"UGE", 10, "uge">; def CMPF_P_ULT : I64EnumAttrCase<"ULT", 11, "ult">; def CMPF_P_ULE : I64EnumAttrCase<"ULE", 12, "ule">; def CMPF_P_UNE : I64EnumAttrCase<"UNE", 13, "une">; def CMPF_P_UNO : I64EnumAttrCase<"UNO", 14, "uno">; def CMPF_P_TRUE : I64EnumAttrCase<"AlwaysTrue", 15, "true">; def CmpFPredicateAttr : I64EnumAttr< "CmpFPredicate", "", [CMPF_P_FALSE, CMPF_P_OEQ, CMPF_P_OGT, CMPF_P_OGE, CMPF_P_OLT, CMPF_P_OLE, CMPF_P_ONE, CMPF_P_ORD, CMPF_P_UEQ, CMPF_P_UGT, CMPF_P_UGE, CMPF_P_ULT, CMPF_P_ULE, CMPF_P_UNE, CMPF_P_UNO, CMPF_P_TRUE]> { let cppNamespace = "::mlir"; } def CmpFOp : Std_Op<"cmpf", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "floating-point comparison operation"; let description = [{ The `cmpf` operation compares its two operands according to the float comparison rules and the predicate specified by the respective attribute. The predicate defines the type of comparison: (un)orderedness, (in)equality and signed less/greater than (or equal to) as well as predicates that are always true or false. The operands must have the same type, and this type must be a float type, or a vector or tensor thereof. The result is an i1, or a vector/tensor thereof having the same shape as the inputs. Unlike cmpi, the operands are always treated as signed. The u prefix indicates *unordered* comparison, not unsigned comparison, so "une" means unordered or not equal. For the sake of readability by humans, custom assembly form for the operation uses a string-typed attribute for the predicate. The value of this attribute corresponds to lower-cased name of the predicate constant, e.g., "one" means "ordered not equal". The string representation of the attribute is merely a syntactic sugar and is converted to an integer attribute by the parser. Example: ```mlir %r1 = cmpf "oeq" %0, %1 : f32 %r2 = cmpf "ult" %0, %1 : tensor<42x42xf64> %r3 = "std.cmpf"(%0, %1) {predicate: 0} : (f8, f8) -> i1 ``` }]; let arguments = (ins CmpFPredicateAttr:$predicate, FloatLike:$lhs, FloatLike:$rhs ); let results = (outs BoolLike:$result); let builders = [ OpBuilderDAG<(ins "CmpFPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ ::buildCmpFOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } static CmpFPredicate getPredicateByName(StringRef name); CmpFPredicate getPredicate() { return (CmpFPredicate)getAttrOfType(getPredicateAttrName()) .getInt(); } }]; let verifier = [{ return success(); }]; let hasFolder = 1; let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// def CMPI_P_EQ : I64EnumAttrCase<"eq", 0>; def CMPI_P_NE : I64EnumAttrCase<"ne", 1>; def CMPI_P_SLT : I64EnumAttrCase<"slt", 2>; def CMPI_P_SLE : I64EnumAttrCase<"sle", 3>; def CMPI_P_SGT : I64EnumAttrCase<"sgt", 4>; def CMPI_P_SGE : I64EnumAttrCase<"sge", 5>; def CMPI_P_ULT : I64EnumAttrCase<"ult", 6>; def CMPI_P_ULE : I64EnumAttrCase<"ule", 7>; def CMPI_P_UGT : I64EnumAttrCase<"ugt", 8>; def CMPI_P_UGE : I64EnumAttrCase<"uge", 9>; def CmpIPredicateAttr : I64EnumAttr< "CmpIPredicate", "", [CMPI_P_EQ, CMPI_P_NE, CMPI_P_SLT, CMPI_P_SLE, CMPI_P_SGT, CMPI_P_SGE, CMPI_P_ULT, CMPI_P_ULE, CMPI_P_UGT, CMPI_P_UGE]> { let cppNamespace = "::mlir"; } def CmpIOp : Std_Op<"cmpi", [NoSideEffect, SameTypeOperands, SameOperandsAndResultShape, TypesMatchWith< "result type has i1 element type and same shape as operands", "lhs", "result", "getI1SameShape($_self)">, ElementwiseMappable]> { let summary = "integer comparison operation"; let description = [{ The `cmpi` operation is a generic comparison for integer-like types. Its two arguments can be integers, vectors or tensors thereof as long as their types match. The operation produces an i1 for the former case, a vector or a tensor of i1 with the same shape as inputs in the other cases. Its first argument is an attribute that defines which type of comparison is performed. The following comparisons are supported: - equal (mnemonic: `"eq"`; integer value: `0`) - not equal (mnemonic: `"ne"`; integer value: `1`) - signed less than (mnemonic: `"slt"`; integer value: `2`) - signed less than or equal (mnemonic: `"sle"`; integer value: `3`) - signed greater than (mnemonic: `"sgt"`; integer value: `4`) - signed greater than or equal (mnemonic: `"sge"`; integer value: `5`) - unsigned less than (mnemonic: `"ult"`; integer value: `6`) - unsigned less than or equal (mnemonic: `"ule"`; integer value: `7`) - unsigned greater than (mnemonic: `"ugt"`; integer value: `8`) - unsigned greater than or equal (mnemonic: `"uge"`; integer value: `9`) The result is `1` if the comparison is true and `0` otherwise. For vector or tensor operands, the comparison is performed elementwise and the element of the result indicates whether the comparison is true for the operand elements with the same indices as those of the result. Note: while the custom assembly form uses strings, the actual underlying attribute has integer type (or rather enum class in C++ code) as seen from the generic assembly form. String literals are used to improve readability of the IR by humans. This operation only applies to integer-like operands, but not floats. The main reason being that comparison operations have diverging sets of attributes: integers require sign specification while floats require various floating point-related particularities, e.g., `-ffast-math` behavior, IEEE754 compliance, etc ([rationale](../Rationale/Rationale.md#splitting-floating-point-vs-integer-operations)). The type of comparison is specified as attribute to avoid introducing ten similar operations, taking into account that they are often implemented using the same operation downstream ([rationale](../Rationale/Rationale.md#specifying-comparison-kind-as-attribute)). The separation between signed and unsigned order comparisons is necessary because of integers being signless. The comparison operation must know how to interpret values with the foremost bit being set: negatives in two's complement or large positives ([rationale](../Rationale/Rationale.md#specifying-sign-in-integer-comparison-operations)). Example: ```mlir // Custom form of scalar "signed less than" comparison. %x = cmpi "slt", %lhs, %rhs : i32 // Generic form of the same operation. %x = "std.cmpi"(%lhs, %rhs) {predicate = 2 : i64} : (i32, i32) -> i1 // Custom form of vector equality comparison. %x = cmpi "eq", %lhs, %rhs : vector<4xi64> // Generic form of the same operation. %x = "std.cmpi"(%lhs, %rhs) {predicate = 0 : i64} : (vector<4xi64>, vector<4xi64>) -> vector<4xi1> ``` }]; let arguments = (ins CmpIPredicateAttr:$predicate, SignlessIntegerLike:$lhs, SignlessIntegerLike:$rhs ); let results = (outs BoolLike:$result); let builders = [ OpBuilderDAG<(ins "CmpIPredicate":$predicate, "Value":$lhs, "Value":$rhs), [{ ::buildCmpIOp($_builder, $_state, predicate, lhs, rhs); }]>]; let extraClassDeclaration = [{ static StringRef getPredicateAttrName() { return "predicate"; } static CmpIPredicate getPredicateByName(StringRef name); CmpIPredicate getPredicate() { return (CmpIPredicate)getAttrOfType(getPredicateAttrName()) .getInt(); } }]; let verifier = [{ return success(); }]; let hasFolder = 1; let assemblyFormat = "$predicate `,` $lhs `,` $rhs attr-dict `:` type($lhs)"; } //===----------------------------------------------------------------------===// // CreateComplexOp //===----------------------------------------------------------------------===// def CreateComplexOp : Std_Op<"create_complex", [NoSideEffect, AllTypesMatch<["real", "imaginary"]>, TypesMatchWith<"complex element type matches real operand type", "complex", "real", "$_self.cast().getElementType()">, TypesMatchWith<"complex element type matches imaginary operand type", "complex", "imaginary", "$_self.cast().getElementType()">]> { let summary = "creates a complex number"; let description = [{ The `create_complex` operation creates a complex number from two floating-point operands, the real and the imaginary part. Example: ```mlir %a = create_complex %b, %c : complex ``` }]; let arguments = (ins AnyFloat:$real, AnyFloat:$imaginary); let results = (outs Complex:$complex); let assemblyFormat = "$real `,` $imaginary attr-dict `:` type($complex)"; // `CreateComplexOp` is fully verified by its traits. let verifier = ?; } //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// def CondBranchOp : Std_Op<"cond_br", [AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect, Terminator]> { let summary = "conditional branch operation"; let description = [{ The `cond_br` terminator operation represents a conditional branch on a boolean (1-bit integer) value. If the bit is set, then the first destination is jumped to; if it is false, the second destination is chosen. The count and types of operands must align with the arguments in the corresponding target blocks. The MLIR conditional branch operation is not allowed to target the entry block for a region. The two destinations of the conditional branch operation are allowed to be the same. The following example illustrates a function with a conditional branch operation that targets the same block. Example: ```mlir func @select(%a: i32, %b: i32, %flag: i1) -> i32 { // Both targets are the same, operands differ cond_br %flag, ^bb1(%a : i32), ^bb1(%b : i32) ^bb1(%x : i32) : return %x : i32 } ``` }]; let arguments = (ins I1:$condition, Variadic:$trueDestOperands, Variadic:$falseDestOperands); let successors = (successor AnySuccessor:$trueDest, AnySuccessor:$falseDest); let builders = [ OpBuilderDAG<(ins "Value":$condition, "Block *":$trueDest, "ValueRange":$trueOperands, "Block *":$falseDest, "ValueRange":$falseOperands), [{ build($_builder, $_state, condition, trueOperands, falseOperands, trueDest, falseDest); }]>, OpBuilderDAG<(ins "Value":$condition, "Block *":$trueDest, "Block *":$falseDest, CArg<"ValueRange", "{}">:$falseOperands), [{ build($_builder, $_state, condition, trueDest, ValueRange(), falseDest, falseOperands); }]>]; // CondBranchOp is fully verified by traits. let verifier = ?; let extraClassDeclaration = [{ // These are the indices into the dests list. enum { trueIndex = 0, falseIndex = 1 }; // The condition operand is the first operand in the list. Value getCondition() { return getOperand(0); } /// Return the destination if the condition is true. Block *getTrueDest() { return getSuccessor(trueIndex); } /// Return the destination if the condition is false. Block *getFalseDest() { return getSuccessor(falseIndex); } // Accessors for operands to the 'true' destination. Value getTrueOperand(unsigned idx) { assert(idx < getNumTrueOperands()); return getOperand(getTrueDestOperandIndex() + idx); } void setTrueOperand(unsigned idx, Value value) { assert(idx < getNumTrueOperands()); setOperand(getTrueDestOperandIndex() + idx, value); } operand_range getTrueOperands() { return trueDestOperands(); } unsigned getNumTrueOperands() { return getTrueOperands().size(); } /// Erase the operand at 'index' from the true operand list. void eraseTrueOperand(unsigned index) { trueDestOperandsMutable().erase(index); } // Accessors for operands to the 'false' destination. Value getFalseOperand(unsigned idx) { assert(idx < getNumFalseOperands()); return getOperand(getFalseDestOperandIndex() + idx); } void setFalseOperand(unsigned idx, Value value) { assert(idx < getNumFalseOperands()); setOperand(getFalseDestOperandIndex() + idx, value); } operand_range getFalseOperands() { return falseDestOperands(); } unsigned getNumFalseOperands() { return getFalseOperands().size(); } /// Erase the operand at 'index' from the false operand list. void eraseFalseOperand(unsigned index) { falseDestOperandsMutable().erase(index); } private: /// Get the index of the first true destination operand. unsigned getTrueDestOperandIndex() { return 1; } /// Get the index of the first false destination operand. unsigned getFalseDestOperandIndex() { return getTrueDestOperandIndex() + getNumTrueOperands(); } }]; let hasCanonicalizer = 1; let assemblyFormat = [{ $condition `,` $trueDest (`(` $trueDestOperands^ `:` type($trueDestOperands) `)`)? `,` $falseDest (`(` $falseDestOperands^ `:` type($falseDestOperands) `)`)? attr-dict }]; } //===----------------------------------------------------------------------===// // ConstantOp //===----------------------------------------------------------------------===// def ConstantOp : Std_Op<"constant", [ConstantLike, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "constant"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.constant` attribute-value `:` type ``` The `constant` operation produces an SSA value equal to some constant specified by an attribute. This is the way that MLIR uses to form simple integer and floating point constants, as well as more exotic things like references to functions and tensor/vector constants. Example: ```mlir // Integer constant %1 = constant 42 : i32 // Reference to function @myfn. %3 = constant @myfn : (tensor<16xf32>, f32) -> tensor<16xf32> // Equivalent generic forms %1 = "std.constant"() {value = 42 : i32} : () -> i32 %3 = "std.constant"() {value = @myfn} : () -> ((tensor<16xf32>, f32) -> tensor<16xf32>) ``` MLIR does not allow direct references to functions in SSA operands because the compiler is multithreaded, and disallowing SSA values to directly reference a function simplifies this ([rationale](../Rationale/Rationale.md#multithreading-the-compiler)). }]; let arguments = (ins AnyAttr:$value); let results = (outs AnyType); let builders = [ OpBuilderDAG<(ins "Attribute":$value), [{ build($_builder, $_state, value.getType(), value); }]>]; let extraClassDeclaration = [{ Attribute getValue() { return getAttr("value"); } /// Returns true if a constant operation can be built with the given value /// and result type. static bool isBuildableWith(Attribute value, Type type); }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // CopySignOp //===----------------------------------------------------------------------===// def CopySignOp : FloatArithmeticOp<"copysign"> { let summary = "A copysign operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.copysign` ssa-use `:` type ``` The `copysign` returns a value with the magnitude of the first operand and the sign of the second operand. It takes two operands and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar copysign value. %a = copysign %b %c : f64 // SIMD vector element-wise copysign value. %f = copysign %g %h : vector<4xf32> // Tensor element-wise copysign value. %x = copysign %y %z : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // CosOp //===----------------------------------------------------------------------===// def CosOp : FloatUnaryOp<"cos"> { let summary = "cosine of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.cos` ssa-use `:` type ``` The `cos` operation computes the cosine of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar cosine value. %a = cos %b : f64 // SIMD vector element-wise cosine value. %f = cos %g : vector<4xf32> // Tensor element-wise cosine value. %x = cos %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // SinOp //===----------------------------------------------------------------------===// def SinOp : FloatUnaryOp<"sin"> { let summary = "sine of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.sin` ssa-use `:` type ``` The `sin` operation computes the sine of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar sine value. %a = sin %b : f64 // SIMD vector element-wise sine value. %f = sin %g : vector<4xf32> // Tensor element-wise sine value. %x = sin %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// def DeallocOp : Std_Op<"dealloc", [MemoryEffects<[MemFree]>, MemRefsNormalizable]> { let summary = "memory deallocation operation"; let description = [{ The `dealloc` operation frees the region of memory referenced by a memref which was originally created by the `alloc` operation. The `dealloc` operation should not be called on memrefs which alias an alloc'd memref (e.g. memrefs returned by `view` operations). Example: ```mlir %0 = alloc() : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> dealloc %0 : memref<8x64xf32, (d0, d1) -> (d0, d1), 1> ``` }]; let arguments = (ins Arg:$memref); let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// def DimOp : Std_Op<"dim", [NoSideEffect]> { let summary = "dimension index operation"; let description = [{ The `dim` operation takes a memref/tensor and a dimension operand of type `index`. It returns the size of the requested dimension of the given memref/tensor. If the dimension index is out of bounds the behavior is undefined. The specified memref or tensor type is that of the first operand. Example: ```mlir // Always returns 4, can be constant folded: %c0 = constant 0 : index %x = = dim %A, %c0 : tensor<4 x ? x f32> // Returns the dynamic dimension of %A. %c1 = constant 1 : index %y = dim %A, %c1 : tensor<4 x ? x f32> // Equivalent generic form: %x = "std.dim"(%A, %c0) : (tensor<4 x ? x f32>, index) -> index %y = "std.dim"(%A, %c1) : (tensor<4 x ? x f32>, index) -> index ``` }]; let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], "any tensor or memref type">:$memrefOrTensor, Index:$index); let results = (outs Index:$result); let assemblyFormat = [{ attr-dict $memrefOrTensor `,` $index `:` type($memrefOrTensor) }]; let builders = [ OpBuilderDAG<(ins "Value":$memrefOrTensor, "int64_t":$index)>, OpBuilderDAG<(ins "Value":$memrefOrTensor, "Value":$index)> ]; let extraClassDeclaration = [{ /// Helper function to get the index as a simple integer if it is constant. Optional getConstantIndex(); }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // DivFOp //===----------------------------------------------------------------------===// def DivFOp : FloatArithmeticOp<"divf"> { let summary = "floating point division operation"; } //===----------------------------------------------------------------------===// // DynamicTensorFromElementsOp //===----------------------------------------------------------------------===// def DynamicTensorFromElementsOp : Std_Op<"dynamic_tensor_from_elements", [RecursiveSideEffects, SingleBlockImplicitTerminator<"YieldOp">]> { string summary = "Creates a dynamically sized tensor from elements"; string description = [{ This operation creates a dynamically sized tensor with elements of any type. It expects one index operand per dynamic extent of the result tensor. The body region defines the tensor's elements. It takes index operands as its region arguments that span the index space. The element at the given position is yielded with the `yield` operation (see `YieldOp`). There is no defined ordering to the invocations of the body. It is conceptually a "parallel map" operation. Example: ```mlir %tnsr = dynamic_tensor_from_elements %m, %n { ^bb0(%i : index, %j : index, %k : index): ... yield %elem : f32 } : tensor ``` }]; let arguments = (ins Variadic:$dynamicExtents); let results = (outs AnyRankedTensor:$result); let regions = (region SizedRegion<1>:$body); let builders = [ // Build op and populate its body per callback function. OpBuilderDAG<(ins "Type":$resultTy, "ValueRange":$dynamicExtents, "function_ref")>, ]; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// def ExpOp : FloatUnaryOp<"exp"> { let summary = "base-e exponential of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.exp` ssa-use `:` type ``` The `exp` operation takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar natural exponential. %a = exp %b : f64 // SIMD vector element-wise natural exponential. %f = exp %g : vector<4xf32> // Tensor element-wise natural exponential. %x = exp %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // ExpOp //===----------------------------------------------------------------------===// def Exp2Op : FloatUnaryOp<"exp2"> { let summary = "base-2 exponential of the specified value"; } //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect, TypesMatchWith<"result type matches element type of aggregate", "aggregate", "result", "$_self.cast().getElementType()">]> { let summary = "element extract operation"; let description = [{ The `extract_element` op reads a tensor or vector and returns one element from it specified by an index list. The output of the 'extract_element' is a new value with the same type as the elements of the tensor or vector. The arity of indices matches the rank of the accessed value (i.e., if a tensor is of rank 3, then 3 indices are required for the extract. The indices should all be of `index` type. Example: ```mlir %3 = extract_element %v[%1, %2] : vector<4x4xi32> %4 = extract_element %t[%1, %2] : tensor<4x4xi32> %5 = extract_element %ut[%1, %2] : tensor<*xi32> ``` }]; let arguments = (ins AnyTypeOf<[AnyVector, AnyTensor]>:$aggregate, Variadic:$indices); let results = (outs AnyType:$result); let builders = [ OpBuilderDAG<(ins "Value":$aggregate, CArg<"ValueRange", "{}">:$indices), [{ auto resType = aggregate.getType().cast() .getElementType(); build($_builder, $_state, resType, aggregate, indices); }]>]; let extraClassDeclaration = [{ Value getAggregate() { return getOperand(0); } operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; let hasFolder = 1; let assemblyFormat = [{ $aggregate `[` $indices `]` attr-dict `:` type($aggregate) }]; } //===----------------------------------------------------------------------===// // TensorFromElementsOp //===----------------------------------------------------------------------===// def TensorFromElementsOp : Std_Op<"tensor_from_elements", [ NoSideEffect, TypesMatchWith<"operand types match result element type", "result", "elements", "SmallVector(" "$_self.cast().getDimSize(0), " "$_self.cast().getElementType())"> ]> { string summary = "tensor from elements operation."; string description = [{ Create a 1D tensor from a range of same-type arguments. Example: ```mlir tensor_from_elements(i_1, ..., i_N) : tensor ``` }]; let arguments = (ins Variadic:$elements); let results = (outs 1DTensorOf<[AnyType]>:$result); let assemblyFormat = "$elements attr-dict `:` type($result)"; // This op is fully verified by its traits. let verifier = ?; let skipDefaultBuilders = 1; let builders = [ OpBuilderDAG<(ins "Type":$elementType, "ValueRange":$elements)>, // Special case builder for when `elements` has size >=1. OpBuilderDAG<(ins "ValueRange":$elements)> ]; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// def FPExtOp : ArithmeticCastOp<"fpext">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to wider floating-point"; let description = [{ Cast a floating-point value to a larger floating-point-typed value. The destination type must to be strictly wider than the source type. Only scalars are currently supported. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // FPToSIOp //===----------------------------------------------------------------------===// def FPToSIOp : ArithmeticCastOp<"fptosi">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) signed integer value. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // FPToUIOp //===----------------------------------------------------------------------===// def FPToUIOp : ArithmeticCastOp<"fptoui">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point type to integer type"; let description = [{ Cast from a value interpreted as floating-point to the nearest (rounding towards zero) unsigned integer value. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // FPTruncOp //===----------------------------------------------------------------------===// def FPTruncOp : ArithmeticCastOp<"fptrunc">, Arguments<(ins AnyType:$in)> { let summary = "cast from floating-point to narrower floating-point"; let description = [{ Truncate a floating-point value to a smaller floating-point-typed value. The destination type must be strictly narrower than the source type. If the value cannot be exactly represented, it is rounded using the default rounding mode. Only scalars are currently supported. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // GlobalMemrefOp //===----------------------------------------------------------------------===// def GlobalMemrefOp : Std_Op<"global_memref", [Symbol]> { let summary = "declare or define a global memref variable"; let description = [{ The `global_memref` operation declares or defines a named global variable. The backing memory for the variable is allocated statically and is described by the type of the variable (which should be a statically shaped memref type). The operation is a declaration if no `inital_value` is specified, else it is a definition. The `initial_value` can either be a unit attribute to represent a definition of an uninitialized global variable, or an elements attribute to represent the definition of a global variable with an initial value. The global variable can also be marked constant using the `constant` unit attribute. Writing to such constant global variables is undefined. The global variable can be accessed by using the `get_global_memref` to retrieve the memref for the global variable. Note that the memref for such global variable itself is immutable (i.e., get_global_memref for a given global variable will always return the same memref descriptor). Example: ```mlir // Private variable with an initial value. global_memref "private" @x : memref<2xf32> = dense<0.0,2.0> // Declaration of an external variable. global_memref "private" @y : memref<4xi32> // Uninitialized externally visible variable. global_memref @z : memref<3xf16> = uninitialized // Externally visibile constant variable. global_memref constant @c : memref<2xi32> = dense<1, 4> ``` }]; let arguments = (ins SymbolNameAttr:$sym_name, OptionalAttr:$sym_visibility, TypeAttr:$type, OptionalAttr:$initial_value, UnitAttr:$constant ); let assemblyFormat = [{ ($sym_visibility^)? (`constant` $constant^)? $sym_name `:` custom($type, $initial_value) attr-dict }]; let extraClassDeclaration = [{ bool isExternal() { return !initial_value(); } bool isUninitialized() { return !isExternal() && initial_value().getValue().isa(); } }]; } //===----------------------------------------------------------------------===// // GetGlobalMemrefOp //===----------------------------------------------------------------------===// def GetGlobalMemrefOp : Std_Op<"get_global_memref", [NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "get the memref pointing to a global variable"; let description = [{ The `get_global_memref` operation retrieves the memref pointing to a named global variable. If the global variable is marked constant, writing to the result memref (such as through a `std.store` operation) is undefined. Example: ```mlir %x = get_global_memref @foo : memref<2xf32> ``` }]; let arguments = (ins FlatSymbolRefAttr:$name); let results = (outs AnyStaticShapeMemRef:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; // `GetGlobalMemrefOp` is fully verified by its traits. let verifier = ?; } //===----------------------------------------------------------------------===// // ImOp //===----------------------------------------------------------------------===// def ImOp : Std_Op<"im", [NoSideEffect, TypesMatchWith<"complex element type matches result type", "complex", "imaginary", "$_self.cast().getElementType()">]> { let summary = "extracts the imaginary part of a complex number"; let description = [{ The `im` operation takes a single complex number as its operand and extracts the imaginary part as a floating-point value. Example: ```mlir %a = im %b : complex ``` }]; let arguments = (ins Complex:$complex); let results = (outs AnyFloat:$imaginary); let assemblyFormat = "$complex attr-dict `:` type($complex)"; // `ImOp` is fully verified by its traits. let verifier = ?; } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// def IndexCastOp : ArithmeticCastOp<"index_cast">, Arguments<(ins AnyType:$in)> { let summary = "cast between index and integer types"; let description = [{ Casts between integer scalars and 'index' scalars. Index is an integer of platform-specific bit width. If casting to a wider integer, the value is sign-extended. If casting to a narrower integer, the value is truncated. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// def LoadOp : Std_Op<"load", [TypesMatchWith<"result type matches element type of 'memref'", "memref", "result", "$_self.cast().getElementType()">, MemRefsNormalizable]> { let summary = "load operation"; let description = [{ The `load` op reads an element from a memref specified by an index list. The output of load is a new value with the same type as the elements of the memref. The arity of indices is the rank of the memref (i.e., if the memref loaded from is of rank 3, then 3 indices are required for the load following the memref identifier). In an `affine.if` or `affine.for` body, the indices of a load are restricted to SSA values bound to surrounding loop induction variables, [symbols](Affine.md#dimensions-and-symbols), results of a [`constant` operation](#stdconstant-constantop), or the result of an `affine.apply` operation that can in turn take as arguments all of the aforementioned SSA values or the recursively result of such an `affine.apply` operation. Example: ```mlir %1 = affine.apply affine_map<(d0, d1) -> (3*d0)> (%i, %j) %2 = affine.apply affine_map<(d0, d1) -> (d1+1)> (%i, %j) %12 = load %A[%1, %2] : memref<8x?xi32, #layout, memspace0> // Example of an indirect load (treated as non-affine) %3 = affine.apply affine_map<(d0) -> (2*d0 + 1)>(%12) %13 = load %A[%3, %2] : memref<4x?xi32, #layout, memspace0> ``` **Context:** The `load` and `store` operations are specifically crafted to fully resolve a reference to an element of a memref, and (in affine `affine.if` and `affine.for` operations) the compiler can follow use-def chains (e.g. through [`affine.apply`](Affine.md#affineapply-affineapplyop) operations) to precisely analyze references at compile-time using polyhedral techniques. This is possible because of the [restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) in these contexts. }]; let arguments = (ins Arg:$memref, Variadic:$indices); let results = (outs AnyType:$result); let builders = [ OpBuilderDAG<(ins "Value":$memref, CArg<"ValueRange", "{}">:$indices), [{ auto memrefType = memref.getType().cast(); $_state.addOperands(memref); $_state.addOperands(indices); $_state.types.push_back(memrefType.getElementType()); }]>]; let extraClassDeclaration = [{ Value getMemRef() { return getOperand(0); } void setMemRef(Value value) { setOperand(0, value); } MemRefType getMemRefType() { return getMemRef().getType().cast(); } operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; let hasFolder = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } //===----------------------------------------------------------------------===// // LogOp //===----------------------------------------------------------------------===// def LogOp : FloatUnaryOp<"log"> { let summary = "base-e logarithm of the specified value"; } def Log10Op : FloatUnaryOp<"log10"> { let summary = "base-10 logarithm of the specified value"; } def Log2Op : FloatUnaryOp<"log2"> { let summary = "base-2 logarithm of the specified value"; } //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// def MemRefCastOp : CastOp<"memref_cast", [ DeclareOpInterfaceMethods ]> { let summary = "memref cast operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.memref_cast` ssa-use `:` type `to` type ``` The `memref_cast` operation converts a memref from one type to an equivalent type with a compatible shape. The source and destination types are compatible if: a. Both are ranked memref types with the same element type, address space, and rank and: 1. Both have the same layout or both have compatible strided layouts. 2. The individual sizes (resp. offset and strides in the case of strided memrefs) may convert constant dimensions to dynamic dimensions and vice-versa. If the cast converts any dimensions from an unknown to a known size, then it acts as an assertion that fails at runtime if the dynamic dimensions disagree with resultant destination size. Example: ```mlir // Assert that the input dynamic shape matches the destination static shape. %2 = memref_cast %1 : memref to memref<4x4xf32> // Erase static shape information, replacing it with dynamic information. %3 = memref_cast %1 : memref<4xf32> to memref // The same holds true for offsets and strides. // Assert that the input dynamic shape matches the destination static stride. %4 = memref_cast %1 : memref<12x4xf32, offset:?, strides: [?, ?]> to memref<12x4xf32, offset:5, strides: [4, 1]> // Erase static offset and stride information, replacing it with // dynamic information. %5 = memref_cast %1 : memref<12x4xf32, offset:5, strides: [4, 1]> to memref<12x4xf32, offset:?, strides: [?, ?]> ``` b. Either or both memref types are unranked with the same element type, and address space. Example: ```mlir Cast to concrete shape. %4 = memref_cast %1 : memref<*xf32> to memref<4x?xf32> Erase rank information. %5 = memref_cast %1 : memref<4x?xf32> to memref<*xf32> ``` }]; let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef); let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); /// The result of a memref_cast is always a memref. Type getType() { return getResult().getType(); } }]; } //===----------------------------------------------------------------------===// // MemRefReinterpretCastOp //===----------------------------------------------------------------------===// def MemRefReinterpretCastOp: BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ NoSideEffect, ViewLikeOpInterface ]> { let summary = "memref reinterpret cast operation"; let description = [{ Modify offset, sizes and strides of an unranked/ranked memref. Example: ```mlir memref_reinterpret_cast %ranked to offset: [0], sizes: [%size0, 10], strides: [1, %stride1] : memref to memref memref_reinterpret_cast %unranked to offset: [%offset], sizes: [%size0, %size1], strides: [%stride0, %stride1] : memref<*xf32> to memref ``` }]; let arguments = (ins Arg:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, I64ArrayAttr:$static_offsets, I64ArrayAttr:$static_sizes, I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); let builders = [ // Build a ReinterpretCastOp with mixed static and dynamic entries. OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source, "int64_t":$staticOffset, "ArrayRef":$staticSizes, "ArrayRef":$staticStrides, "ValueRange":$offset, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a ReinterpretCastOp with all dynamic entries. OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source, "Value":$offset, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ // The result of the op is always a ranked memref. MemRefType getType() { return getResult().getType().cast(); } Value getViewSource() { return source(); } }]; } //===----------------------------------------------------------------------===// // MemRefReshapeOp //===----------------------------------------------------------------------===// def MemRefReshapeOp: Std_Op<"memref_reshape", [ ViewLikeOpInterface, NoSideEffect]> { let summary = "memref reshape operation"; let description = [{ The `memref_reshape` operation converts a memref from one type to an equivalent type with a provided shape. The data is never copied or modified. The source and destination types are compatible if both have the same element type, same number of elements, address space and identity layout map. The following combinations are possible: a. Source type is ranked or unranked. Shape argument has static size. Result type is ranked. ```mlir // Reshape statically-shaped memref. %dst = memref_reshape %src(%shape) : (memref<4x1xf32>, memref<1xi32>) to memref<4xf32> %dst0 = memref_reshape %src(%shape0) : (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32> // Flatten unranked memref. %dst = memref_reshape %src(%shape) : (memref<*xf32>, memref<1xi32>) to memref ``` a. Source type is ranked or unranked. Shape argument has dynamic size. Result type is unranked. ```mlir // Reshape dynamically-shaped 1D memref. %dst = memref_reshape %src(%shape) : (memref, memref) to memref<*xf32> // Reshape unranked memref. %dst = memref_reshape %src(%shape) : (memref<*xf32>, memref) to memref<*xf32> ``` }]; let arguments = (ins AnyRankedOrUnrankedMemRef:$source, MemRefRankOf<[AnySignlessInteger, Index], [1]>:$shape ); let results = (outs AnyRankedOrUnrankedMemRef:$result); let builders = [OpBuilderDAG< (ins "MemRefType":$resultType, "Value":$operand, "Value":$shape), [{ $_state.addOperands(operand); $_state.addOperands(shape); $_state.addTypes(resultType); }]>]; let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } Value getViewSource() { return source(); } }]; let assemblyFormat = [{ $source `(` $shape `)` attr-dict `:` functional-type(operands, results) }]; } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// def MulFOp : FloatArithmeticOp<"mulf"> { let summary = "floating point multiplication operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.mulf` ssa-use `,` ssa-use `:` type ``` The `mulf` operation takes two operands and returns one result, each of these is required to be the same type. This type may be a floating point scalar type, a vector whose element type is a floating point type, or a floating point tensor. Example: ```mlir // Scalar multiplication. %a = mulf %b, %c : f64 // SIMD pointwise vector multiplication, e.g. for Intel SSE. %f = mulf %g, %h : vector<4xf32> // Tensor pointwise multiplication. %x = mulf %y, %z : tensor<4x?xbf16> ``` TODO: In the distant future, this will accept optional attributes for fast math, contraction, rounding mode, and other controls. }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// def MulIOp : IntArithmeticOp<"muli", [Commutative]> { let summary = "integer multiplication operation"; let hasFolder = 1; } //===----------------------------------------------------------------------===// // NegFOp //===----------------------------------------------------------------------===// def NegFOp : FloatUnaryOp<"negf"> { let summary = "floating point negation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `negf` ssa-use `:` type ``` The `negf` operation computes the negation of a given value. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar negation value. %a = negf %b : f64 // SIMD vector element-wise negation value. %f = negf %g : vector<4xf32> // Tensor element-wise negation value. %x = negf %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// def OrOp : IntArithmeticOp<"or", [Commutative]> { let summary = "integer binary or"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `or` ssa-use `,` ssa-use `:` type ``` The `or` operation takes two operands and returns one result, each of these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. It has no standard attributes. Example: ```mlir // Scalar integer bitwise or. %a = or %b, %c : i64 // SIMD vector element-wise bitwise integer or. %f = or %g, %h : vector<4xi32> // Tensor element-wise bitwise integer or. %x = or %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// def PrefetchOp : Std_Op<"prefetch"> { let summary = "prefetch operation"; let description = [{ The "prefetch" op prefetches data from a memref location described with subscript indices similar to std.load, and with three attributes: a read/write specifier, a locality hint, and a cache type specifier as shown below: ```mlir prefetch %0[%i, %j], read, locality<3>, data : memref<400x400xi32> ``` The read/write specifier is either 'read' or 'write', the locality hint ranges from locality<0> (no locality) to locality<3> (extremely local keep in cache). The cache type specifier is either 'data' or 'instr' and specifies whether the prefetch is performed on data cache or on instruction cache. }]; let arguments = (ins AnyMemRef:$memref, Variadic:$indices, BoolAttr:$isWrite, Confined, IntMaxValue<3>]>:$localityHint, BoolAttr:$isDataCache); let extraClassDeclaration = [{ MemRefType getMemRefType() { return memref().getType().cast(); } static StringRef getLocalityHintAttrName() { return "localityHint"; } static StringRef getIsWriteAttrName() { return "isWrite"; } static StringRef getIsDataCacheAttrName() { return "isDataCache"; } }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// def RankOp : Std_Op<"rank", [NoSideEffect]> { let summary = "rank operation"; let description = [{ The `rank` operation takes a memref/tensor operand and returns its rank. Example: ```mlir %1 = rank %arg0 : tensor<*xf32> %2 = rank %arg1 : memref<*xf32> ``` }]; let arguments = (ins AnyTypeOf<[AnyRankedOrUnrankedMemRef, AnyTensor], "any tensor or memref type">:$memrefOrTensor); let results = (outs Index); let verifier = ?; let builders = [ OpBuilderDAG<(ins "Value":$tensor), [{ auto indexType = $_builder.getIndexType(); build($_builder, $_state, indexType, tensor); }]>]; let hasFolder = 1; let assemblyFormat = "$memrefOrTensor attr-dict `:` type($memrefOrTensor)"; } //===----------------------------------------------------------------------===// // ReOp //===----------------------------------------------------------------------===// def ReOp : Std_Op<"re", [NoSideEffect, TypesMatchWith<"complex element type matches result type", "complex", "real", "$_self.cast().getElementType()">]> { let summary = "extracts the real part of a complex number"; let description = [{ The `re` operation takes a single complex number as its operand and extracts the real part as a floating-point value. Example: ```mlir %a = re %b : complex ``` }]; let arguments = (ins Complex:$complex); let results = (outs AnyFloat:$real); let assemblyFormat = "$complex attr-dict `:` type($complex)"; // `ReOp` is fully verified by its traits. let verifier = ?; } //===----------------------------------------------------------------------===// // RemFOp //===----------------------------------------------------------------------===// def RemFOp : FloatArithmeticOp<"remf"> { let summary = "floating point division remainder operation"; } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// def ReturnOp : Std_Op<"return", [NoSideEffect, HasParent<"FuncOp">, MemRefsNormalizable, ReturnLike, Terminator]> { let summary = "return operation"; let description = [{ The `return` operation represents a return operation within a function. The operation takes variable number of operands and produces no results. The operand number and types must match the signature of the function that contains the operation. Example: ```mlir func @foo() : (i32, f8) { ... return %0, %1 : i32, f8 } ``` }]; let arguments = (ins Variadic:$operands); let builders = [ OpBuilderDAG<(ins), [{ build($_builder, $_state, llvm::None); }]>]; let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?"; } //===----------------------------------------------------------------------===// // RsqrtOp //===----------------------------------------------------------------------===// def RsqrtOp : FloatUnaryOp<"rsqrt"> { let summary = "reciprocal of sqrt (1 / sqrt of the specified value)"; let description = [{ The `rsqrt` operation computes the reciprocal of the square root. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. }]; } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// def SelectOp : Std_Op<"select", [NoSideEffect, AllTypesMatch<["true_value", "false_value", "result"]>, ElementwiseMappable]> { let summary = "select operation"; let description = [{ The `select` operation chooses one value based on a binary condition supplied as its first operand. If the value of the first operand is `1`, the second operand is chosen, otherwise the third operand is chosen. The second and the third operand must have the same type. The operation applies to vectors and tensors elementwise given the _shape_ of all operands is identical. The choice is made for each element individually based on the value at the same position as the element in the condition operand. If an i1 is provided as the condition, the entire vector or tensor is chosen. The `select` operation combined with [`cmpi`](#stdcmpi-cmpiop) can be used to implement `min` and `max` with signed or unsigned comparison semantics. Example: ```mlir // Custom form of scalar selection. %x = select %cond, %true, %false : i32 // Generic form of the same operation. %x = "std.select"(%cond, %true, %false) : (i1, i32, i32) -> i32 // Element-wise vector selection. %vx = std.select %vcond, %vtrue, %vfalse : vector<42xi1>, vector<42xf32> // Full vector selection. %vx = std.select %cond, %vtrue, %vfalse : vector<42xf32> ``` }]; let arguments = (ins BoolLike:$condition, AnyType:$true_value, AnyType:$false_value); let results = (outs AnyType:$result); let builders = [ OpBuilderDAG<(ins "Value":$condition, "Value":$trueValue, "Value":$falseValue), [{ $_state.addOperands({condition, trueValue, falseValue}); $_state.addTypes(trueValue.getType()); }]>]; let extraClassDeclaration = [{ Value getCondition() { return condition(); } Value getTrueValue() { return true_value(); } Value getFalseValue() { return false_value(); } }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // ShiftLeftOp //===----------------------------------------------------------------------===// def ShiftLeftOp : IntArithmeticOp<"shift_left"> { let summary = "integer left-shift"; let description = [{ The shift_left operation shifts an integer value to the left by a variable amount. The low order bits are filled with zeros. Example: ```mlir %1 = constant 5 : i8 // %1 is 0b00000101 %2 = constant 3 : i8 %3 = shift_left %1, %2 : (i8, i8) -> i8 // %3 is 0b00101000 ``` }]; } //===----------------------------------------------------------------------===// // SignedDivIOp //===----------------------------------------------------------------------===// def SignedDivIOp : IntArithmeticOp<"divi_signed"> { let summary = "signed integer division operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `divi_signed` ssa-use `,` ssa-use `:` type ``` Signed integer division. Rounds towards zero. Treats the leading bit as sign, i.e. `6 / -2 = -3`. Note: the semantics of division by zero or signed division overflow (minimum value divided by -1) is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar signed integer division. %a = divis %b, %c : i64 // SIMD vector element-wise division. %f = divis %g, %h : vector<4xi32> // Tensor element-wise integer division. %x = divis %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SignedFloorDivIOp //===----------------------------------------------------------------------===// def SignedFloorDivIOp : IntArithmeticOp<"floordivi_signed"> { let summary = "signed floor integer division operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `floordivi_signed` ssa-use `,` ssa-use `:` type ``` Signed integer division. Rounds towards negative infinity, i.e. `5 / -2 = -3`. Note: the semantics of division by zero or signed division overflow (minimum value divided by -1) is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar signed integer division. %a = floordivi_signed %b, %c : i64 ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SignedCeilDivIOp //===----------------------------------------------------------------------===// def SignedCeilDivIOp : IntArithmeticOp<"ceildivi_signed"> { let summary = "signed ceil integer division operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `ceildivi_signed` ssa-use `,` ssa-use `:` type ``` Signed integer division. Rounds towards positive infinity, i.e. `7 / -2 = -3`. Note: the semantics of division by zero or signed division overflow (minimum value divided by -1) is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar signed integer division. %a = ceildivi_signed %b, %c : i64 ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// def SignedRemIOp : IntArithmeticOp<"remi_signed"> { let summary = "signed integer division remainder operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.remi_signed` ssa-use `,` ssa-use `:` type ``` Signed integer division remainder. Treats the leading bit as sign, i.e. `6 % -2 = 0`. Note: the semantics of division by zero is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar signed integer division remainder. %a = remis %b, %c : i64 // SIMD vector element-wise division remainder. %f = remis %g, %h : vector<4xi32> // Tensor element-wise integer division remainder. %x = remis %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SignedShiftRightOp //===----------------------------------------------------------------------===// def SignedShiftRightOp : IntArithmeticOp<"shift_right_signed"> { let summary = "signed integer right-shift"; let description = [{ The shift_right_signed operation shifts an integer value to the right by a variable amount. The integer is interpreted as signed. The high order bits in the output are filled with copies of the most-significant bit of the shifted value (which means that the sign of the value is preserved). Example: ```mlir %1 = constant 160 : i8 // %1 is 0b10100000 %2 = constant 3 : i8 %3 = shift_right_signed %1, %2 : (i8, i8) -> i8 // %3 is 0b11110100 %4 = constant 96 : i8 // %4 is 0b01100000 %5 = shift_right_signed %4, %2 : (i8, i8) -> i8 // %5 is 0b00001100 ``` }]; } //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// def SignExtendIOp : Std_Op<"sexti", [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { let summary = "integer sign extension operation"; let description = [{ The integer sign extension operation takes an integer input of width M and an integer destination type of width N. The destination bit-width must be larger than the input bit-width (N > M). The top-most (N - M) bits of the output are filled with copies of the most-significant bit of the input. Example: ```mlir %1 = constant 5 : i3 // %1 is 0b101 %2 = sexti %1 : i3 to i6 // %2 is 0b111101 %3 = constant 2 : i3 // %3 is 0b010 %4 = sexti %3 : i3 to i6 // %4 is 0b000010 %5 = sexti %0 : vector<2 x i32> to vector<2 x i64> ``` }]; let arguments = (ins SignlessIntegerLike:$value); let results = (outs SignlessIntegerLike); let builders = [ OpBuilderDAG<(ins "Value":$value, "Type":$destType), [{ $_state.addOperands(value); $_state.addTypes(destType); }]>]; let parser = [{ return impl::parseCastOp(parser, result); }]; let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; } //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// def SIToFPOp : ArithmeticCastOp<"sitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from integer type to floating-point"; let description = [{ Cast from a value interpreted as signed or vector of signed integers to the corresponding floating-point scalar or vector value. If the value cannot be exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// def SplatOp : Std_Op<"splat", [NoSideEffect, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", "$_self.cast().getElementType()">]> { let summary = "splat or broadcast operation"; let description = [{ Broadcast the operand to all elements of the result vector or tensor. The operand has to be of either integer or float type. When the result is a tensor, it has to be statically shaped. Example: ```mlir %s = load %A[%i] : memref<128xf32> %v = splat %s : vector<4xf32> %t = splat %s : tensor<8x16xi32> ``` TODO: This operation is easy to extend to broadcast to dynamically shaped tensors in the same way dynamically shaped memrefs are handled. ```mlir // Broadcasts %s to a 2-d dynamically shaped tensor, with %m, %n binding // to the sizes of the two dynamic dimensions. %m = "foo"() : () -> (index) %n = "bar"() : () -> (index) %t = splat %s [%m, %n] : tensor ``` }]; let arguments = (ins AnyTypeOf<[AnySignlessInteger, AnyFloat], "integer or float type">:$input); let results = (outs AnyTypeOf<[AnyVector, AnyStaticShapeTensor]>:$aggregate); let builders = [ OpBuilderDAG<(ins "Value":$element, "Type":$aggregateType), [{ build($_builder, $_state, aggregateType, element); }]>]; let hasFolder = 1; let assemblyFormat = "$input attr-dict `:` type($aggregate)"; } //===----------------------------------------------------------------------===// // SqrtOp //===----------------------------------------------------------------------===// def SqrtOp : FloatUnaryOp<"sqrt"> { let summary = "sqrt of the specified value"; let description = [{ The `sqrt` operation computes the square root. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar square root value. %a = sqrt %b : f64 // SIMD vector element-wise square root value. %f = sqrt %g : vector<4xf32> // Tensor element-wise square root value. %x = sqrt %y : tensor<4x?xf32> ``` }]; } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// def StoreOp : Std_Op<"store", [TypesMatchWith<"type of 'value' matches element type of 'memref'", "memref", "value", "$_self.cast().getElementType()">, MemRefsNormalizable]> { let summary = "store operation"; let description = [{ Store a value to a memref location given by indices. The value stored should have the same type as the elemental type of the memref. The number of arguments provided within brackets need to match the rank of the memref. In an affine context, the indices of a store are restricted to SSA values bound to surrounding loop induction variables, [symbols](Affine.md#restrictions-on-dimensions-and-symbols), results of a [`constant` operation](#stdconstant-constantop), or the result of an [`affine.apply`](Affine.md#affineapply-affineapplyop) operation that can in turn take as arguments all of the aforementioned SSA values or the recursively result of such an `affine.apply` operation. Example: ```mlir store %100, %A[%1, 1023] : memref<4x?xf32, #layout, memspace0> ``` **Context:** The `load` and `store` operations are specifically crafted to fully resolve a reference to an element of a memref, and (in polyhedral `affine.if` and `affine.for` operations) the compiler can follow use-def chains (e.g. through [`affine.apply`](Affine.md#affineapply-affineapplyop) operations) to precisely analyze references at compile-time using polyhedral techniques. This is possible because of the [restrictions on dimensions and symbols](Affine.md#restrictions-on-dimensions-and-symbols) in these contexts. }]; let arguments = (ins AnyType:$value, Arg:$memref, Variadic:$indices); let builders = [ OpBuilderDAG<(ins "Value":$valueToStore, "Value":$memref), [{ $_state.addOperands(valueToStore); $_state.addOperands(memref); }]>]; let extraClassDeclaration = [{ Value getValueToStore() { return getOperand(0); } Value getMemRef() { return getOperand(1); } void setMemRef(Value value) { setOperand(1, value); } MemRefType getMemRefType() { return getMemRef().getType().cast(); } operand_range getIndices() { return {operand_begin() + 2, operand_end()}; } }]; let hasFolder = 1; let assemblyFormat = [{ $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) }]; } //===----------------------------------------------------------------------===// // SubCFOp //===----------------------------------------------------------------------===// def SubCFOp : ComplexFloatArithmeticOp<"subcf"> { let summary = "complex number subtraction"; let description = [{ The `subcf` operation takes two complex number operands and returns their difference, a single complex number. All operands and result must be of the same type, a complex number with a floating-point element type. Example: ```mlir %a = subcf %b, %c : complex ``` }]; } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// def SubFOp : FloatArithmeticOp<"subf"> { let summary = "floating point subtraction operation"; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// def SubIOp : IntArithmeticOp<"subi"> { let summary = "integer subtraction operation"; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// def SubViewOp : BaseOpWithOffsetSizesAndStrides< "subview", [DeclareOpInterfaceMethods] > { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type which represents a reduced-size view of the original memref as specified by the operation's offsets, sizes and strides arguments. The SubView operation supports the following arguments: * semref: the "base" memref on which to create a "view" memref. * offsets: memref-rank number of offsets into the "base" memref at which to create the "view" memref. * sizes: memref-rank number of sizes which specify the sizes of the result "view" memref type. * strides: memref-rank number of strides that compose multiplicatively with the base memref strides in each dimension. The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special sentinel value ShapedType::kDynamicSize and ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. A subview operation may additionally reduce the rank of the resulting view by removing dimensions that are statically known to be of size 1. Example 1: ```mlir %0 = alloc() : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1)> // Create a sub-view of "base" memref '%0' with offset arguments '%c0', // dynamic sizes for each dimension, and stride arguments '%c1'. %1 = subview %0[%c0, %c0][%size0, %size1][%c1, %c1] : memref<64x4xf32, (d0, d1) -> (d0 * 4 + d1) > to memref (d0 * s1 + d1 + s0)> ``` Example 2: ```mlir %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> // Create a sub-view of "base" memref '%0' with dynamic offsets, sizes, // and strides. // Note that dynamic offsets are represented by the linearized dynamic // offset symbol 's0' in the subview memref layout map, and that the // dynamic strides operands, after being applied to the base memref // strides in each dimension, are represented in the view memref layout // map as symbols 's1', 's2' and 's3'. %1 = subview %0[%i, %j, %k][%size0, %size1, %size2][%x, %y, %z] : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)> ``` Example 3: ```mlir %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> // Subview with constant offsets, sizes and strides. %1 = subview %0[0, 2, 0][4, 4, 4][64, 4, 1] : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> ``` Example 4: ```mlir %0 = alloc(%arg0, %arg1) : memref // Subview with constant size, but dynamic offsets and // strides. The resulting memref has a static shape, but if the // base memref has an affine map to describe the layout, the result // memref also uses an affine map to describe the layout. The // strides of the result memref is computed as follows: // // Let #map1 represents the layout of the base memref, and #map2 // represents the layout of the result memref. A #mapsubview can be // constructed to map an index from the result memref to the base // memref (note that the description below uses more convenient // naming for symbols, while in affine maps, symbols are // represented as unsigned numbers that identify that symbol in the // given affine map. // // #mapsubview = (d0, d1)[o0, o1, t0, t1] -> (d0 * t0 + o0, d1 * t1 + o1) // // where, o0, o1, ... are offsets, and t0, t1, ... are strides. Then, // // #map2 = #map1.compose(#mapsubview) // // If the layout map is represented as // // #map1 = (d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0) // // then, // // #map2 = (d0, d1)[s0, s1, s2, o0, o1, t0, t1] -> // (d0 * s1 * t0 + d1 * s2 * t1 + o0 * s1 + o1 * s2 + s0) // // Representing this canonically // // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) // // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. %1 = subview %0[%i, %j][4, 4][%x, %y] : : memref (d0 * s1 + d1 * s2 + s0)> to memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)> // Note that the subview op does not guarantee that the result // memref is "inbounds" w.r.t to base memref. It is upto the client // to ensure that the subview is accessed in a manner that is // in-bounds. ``` Example 5: ```mlir // Rank-reducing subview. %1 = subview %0[0, 0, 0][1, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> %3 = subview %2[3, 4, 2][1, 6, 3][1, 1, 1] : memref<8x16x4xf32> to memref<6x3xf32, offset: 210, strides: [4, 1]> ``` } }]; let arguments = (ins AnyMemRef:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, I64ArrayAttr:$static_offsets, I64ArrayAttr:$static_sizes, I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); let builders = [ // Build a SubViewOp with mixed static and dynamic entries. OpBuilderDAG<(ins "Value":$source, "ArrayRef":$staticOffsets, "ArrayRef":$staticSizes, "ArrayRef":$staticStrides, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a SubViewOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a SubViewOp with mixed static and dynamic entries // and custom result type. OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source, "ArrayRef":$staticOffsets, "ArrayRef":$staticSizes, "ArrayRef":$staticStrides, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a SubViewOp with all dynamic entries and custom result type. OpBuilderDAG<(ins "MemRefType":$resultType, "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base memref operand. MemRefType getSourceType() { return source().getType().cast(); } /// The result of a subview is always a memref. MemRefType getType() { return getResult().getType().cast(); } /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. static Type inferResultType(MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); }]; let hasCanonicalizer = 1; let hasFolder = 1; } //===----------------------------------------------------------------------===// // SubTensorOp //===----------------------------------------------------------------------===// def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> { let summary = "subtensor operation"; let description = [{ The "subtensor" operation extract a tensor from another tensor as specified by the operation's offsets, sizes and strides arguments. The subtensor operation supports the following arguments: * tensor: the "base" tensor from which to extract a subtensor. * offsets: tensor-rank number of offsets into the "base" tensor from which to extract the subtensor. * sizes: tensor-rank number of sizes which specify the sizes of the result tensor type. * strides: tensor-rank number of strides specifying subsampling in each dimension. The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special sentinel value ShapedType::kDynamicSize and ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. After buffer-allocation, the "subtensor" op is expected to lower into a "subview" op. A subtensor operation may additionally reduce the rank of the resulting tensor by removing dimensions that are statically known to be of size 1. Example: ``` // Rank-reducing subtensor. %1 = subtensor %0[0, 0, 0][1, 16, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor<16x4xf32> %3 = subtensor %2[3, 4, 2][1, 6, 3][1, 1, 1] : tensor<8x16x4xf32> to tensor<6x3xf32> ``` }]; let arguments = (ins AnyRankedTensor:$source, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, I64ArrayAttr:$static_offsets, I64ArrayAttr:$static_sizes, I64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); let builders = [ // Build a SubViewOp with mixed static and dynamic entries. OpBuilderDAG<(ins "Value":$source, "ArrayRef":$staticOffsets, "ArrayRef":$staticSizes, "ArrayRef":$staticStrides, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a SubViewOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base tensor operand. RankedTensorType getSourceType() { return source().getType().cast(); } /// The result of a subtensor is always a tensor. RankedTensorType getType() { return getResult().getType().cast(); } /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. static Type inferResultType(RankedTensorType sourceRankedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); }]; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // SubTensorInsertOp //===----------------------------------------------------------------------===// def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> { let summary = "subtensor_insert operation"; let description = [{ The "subtensor_insert" operation insert a tensor `source` into another tensor `dest` as specified by the operation's offsets, sizes and strides arguments. It returns a copy of `dest` with the proper subtensor updated with the value of `source`. The subtensor_insert operation has the encodes the following information: * source: the tensor that is inserted. * dest: the tensor into which the source tensor is inserted. * offsets: tensor-rank number of offsets into the "base" tensor from which to extract the subtensor. * sizes: tensor-rank number of sizes which specify the sizes of the result tensor type. * strides: tensor-rank number of strides that specify subsampling in each dimension. The representation based on offsets, sizes and strides support a partially-static specification via attributes specified through the `static_offsets`, `static_sizes` and `static_strides` arguments. A special sentinel value ShapedType::kDynamicSize and ShapedType::kDynamicStrideOrOffset encodes that the corresponding entry has a dynamic value. After buffer-allocation, the "subtensor_insert" op is expected to become an in-place buffer update. }]; let arguments = (ins AnyRankedTensor:$source, AnyRankedTensor:$dest, Variadic:$offsets, Variadic:$sizes, Variadic:$strides, I64ArrayAttr:$static_offsets, I64ArrayAttr:$static_sizes, I64ArrayAttr:$static_strides ); let results = (outs AnyRankedTensor:$result); let builders = [ // Build a SubViewOp with mixed static and dynamic entries. OpBuilderDAG<(ins "Value":$source, "Value":$dest, "ArrayRef":$staticOffsets, "ArrayRef":$staticSizes, "ArrayRef":$staticStrides, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)>, // Build a SubViewOp with all dynamic entries. OpBuilderDAG<(ins "Value":$source, "Value":$dest, "ValueRange":$offsets, "ValueRange":$sizes, "ValueRange":$strides, CArg<"ArrayRef", "{}">:$attrs)> ]; let extraClassDeclaration = extraBaseClassDeclaration # [{ /// Returns the type of the base tensor operand. RankedTensorType getSourceType() { return source().getType().cast(); } /// The result of a subtensor is always a tensor. RankedTensorType getType() { return getResult().getType().cast(); } }]; } //===----------------------------------------------------------------------===// // TanhOp //===----------------------------------------------------------------------===// def TanhOp : FloatUnaryOp<"tanh"> { let summary = "hyperbolic tangent of the specified value"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.tanh` ssa-use `:` type ``` The `tanh` operation computes the hyperbolic tangent. It takes one operand and returns one result of the same type. This type may be a float scalar type, a vector whose element type is float, or a tensor of floats. It has no standard attributes. Example: ```mlir // Scalar hyperbolic tangent value. %a = tanh %b : f64 // SIMD vector element-wise hyperbolic tangent value. %f = tanh %g : vector<4xf32> // Tensor element-wise hyperbolic tangent value. %x = tanh %y : tensor<4x?xf8> ``` }]; } //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// def TensorCastOp : CastOp<"tensor_cast"> { let summary = "tensor cast operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.tensor_cast` ssa-use `:` type `to` type ``` Convert a tensor from one type to an equivalent type without changing any data elements. The source and destination types must both be tensor types with the same element type. If both are ranked, then the rank should be the same and static dimensions should match. The operation is invalid if converting to a mismatching constant dimension. Example: ```mlir // Convert from unknown rank to rank 2 with unknown dimension sizes. %2 = "std.tensor_cast"(%1) : (tensor<*xf32>) -> tensor %2 = tensor_cast %1 : tensor<*xf32> to tensor // Convert to a type with more known dimensions. %3 = "std.tensor_cast"(%2) : (tensor) -> tensor<4x?xf32> // Discard static dimension and rank information. %4 = "std.tensor_cast"(%3) : (tensor<4x?xf32>) -> tensor %5 = "std.tensor_cast"(%4) : (tensor) -> tensor<*xf32> ``` }]; let arguments = (ins AnyTensor:$source); let results = (outs AnyTensor); let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); /// The result of a tensor_cast is always a tensor. TensorType getType() { return getResult().getType().cast(); } }]; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// def TensorLoadOp : Std_Op<"tensor_load", [SameOperandsAndResultShape, SameOperandsAndResultElementType, TypesMatchWith<"result type matches tensor equivalent of 'memref'", "memref", "result", "getTensorTypeFromMemRefType($_self)">]> { let summary = "tensor load operation"; let description = [{ Create a tensor from a memref, making an independent copy of the element data. The result value is a tensor whose shape and element type match the memref operand. The opposite of this op is tensor_to_memref. Together, these two ops are useful for source/target materializations when doing type conversions involving tensors and memrefs. Example: ```mlir // Produces a value of tensor<4x?xf32> type. %12 = tensor_load %10 : memref<4x?xf32, #layout, memspace0> ``` }]; let arguments = (ins Arg:$memref); let results = (outs AnyTensor:$result); // TensorLoadOp is fully verified by traits. let verifier = ?; let builders = [ OpBuilderDAG<(ins "Value":$memref), [{ $_state.addOperands(memref); $_state.addTypes(getTensorTypeFromMemRefType(memref.getType())); }]>]; let extraClassDeclaration = [{ /// The result of a tensor_load is always a tensor. TensorType getType() { Type resultType = getResult().getType(); if (resultType.isa()) return resultType.cast(); return {}; } }]; let assemblyFormat = "$memref attr-dict `:` type($memref)"; let hasFolder = 1; } //===----------------------------------------------------------------------===// // TensorStoreOp //===----------------------------------------------------------------------===// def TensorStoreOp : Std_Op<"tensor_store", [SameOperandsShape, SameOperandsElementType, TypesMatchWith<"type of 'value' matches tensor equivalent of 'memref'", "memref", "tensor", "getTensorTypeFromMemRefType($_self)">]> { let summary = "tensor store operation"; let description = [{ Stores the contents of a tensor into a memref. The first operand is a value of tensor type, the second operand is a value of memref type. The shapes and element types of these must match, and are specified by the memref type. Example: ```mlir %9 = dim %8, 1 : tensor<4x?xf32> %10 = alloc(%9) : memref<4x?xf32, #layout, memspace0> tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0> ``` }]; let arguments = (ins AnyTensor:$tensor, Arg:$memref); // TensorStoreOp is fully verified by traits. let verifier = ?; let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } //===----------------------------------------------------------------------===// // TensorToMemrefOp //===----------------------------------------------------------------------===// def TensorToMemrefOp : Std_Op<"tensor_to_memref", [SameOperandsAndResultShape, SameOperandsAndResultElementType, TypesMatchWith<"type of 'tensor' is the tensor equivalent of 'memref'", "memref", "tensor", "getTensorTypeFromMemRefType($_self)">]> { let summary = "tensor to memref operation"; let description = [{ Create a memref from a tensor. This is equivalent to allocating a new memref of the appropriate (possibly dynamic) shape, and then copying the elements (as if by a tensor_store op) into the newly allocated memref. The opposite of this op is tensor_load. Together, these two ops are useful for source/target materializations when doing type conversions involving tensors and memrefs. Note: This op takes the memref type in its pretty form because the tensor type can always be inferred from the memref type, but the reverse is not true. For example, the memref might have a layout map or memory space which cannot be inferred from the tensor type. ```mlir // Result type is tensor<4x?xf32> %12 = tensor_to_memref %10 : memref<4x?xf32, #map0, 42> ``` }]; let arguments = (ins AnyTensor:$tensor); let results = (outs Res:$memref); // This op is fully verified by traits. let verifier = ?; let assemblyFormat = "$tensor attr-dict `:` type($memref)"; let hasFolder = 1; } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// def TransposeOp : Std_Op<"transpose", [NoSideEffect]>, Arguments<(ins AnyStridedMemRef:$in, AffineMapAttr:$permutation)>, Results<(outs AnyStridedMemRef)> { let summary = "`transpose` produces a new strided memref (metadata-only)"; let description = [{ The `transpose` op produces a strided memref whose sizes and strides are a permutation of the original `in` memref. This is purely a metadata transformation. Example: ```mlir %1 = transpose %0 (i, j) -> (j, i) : memref to memref (d1 * s0 + d0)>> ``` }]; let builders = [ OpBuilderDAG<(ins "Value":$in, "AffineMapAttr":$permutation, CArg<"ArrayRef", "{}">:$attrs)>]; let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } ShapedType getShapedType() { return in().getType().cast(); } }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// def TruncateIOp : Std_Op<"trunci", [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { let summary = "integer truncation operation"; let description = [{ The integer truncation operation takes an integer input of width M and an integer destination type of width N. The destination bit-width must be smaller than the input bit-width (N < M). The top-most (N - M) bits of the input are discarded. Example: ```mlir %1 = constant 21 : i5 // %1 is 0b10101 %2 = trunci %1 : i5 to i4 // %2 is 0b0101 %3 = trunci %1 : i5 to i3 // %3 is 0b101 %5 = trunci %0 : vector<2 x i32> to vector<2 x i16> ``` }]; let arguments = (ins SignlessIntegerLike:$value); let results = (outs SignlessIntegerLike); let builders = [ OpBuilderDAG<(ins "Value":$value, "Type":$destType), [{ $_state.addOperands(value); $_state.addTypes(destType); }]>]; let parser = [{ return impl::parseCastOp(parser, result); }]; let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; } //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// def UIToFPOp : ArithmeticCastOp<"uitofp">, Arguments<(ins AnyType:$in)> { let summary = "cast from unsigned integer type to floating-point"; let description = [{ Cast from a value interpreted as unsigned integer or vector of unsigned integers to the corresponding scalar or vector floating-point value. If the value cannot be exactly represented, it is rounded using the default rounding mode. Scalars and vector types are currently supported. }]; let extraClassDeclaration = [{ /// Return true if `a` and `b` are valid operand and result pairs for /// the operation. static bool areCastCompatible(Type a, Type b); }]; let hasFolder = 0; } //===----------------------------------------------------------------------===// // UnsignedDivIOp //===----------------------------------------------------------------------===// def UnsignedDivIOp : IntArithmeticOp<"divi_unsigned"> { let summary = "unsigned integer division operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.divi_unsigned` ssa-use `,` ssa-use `:` type ``` Unsigned integer division. Rounds towards zero. Treats the leading bit as the most significant, i.e. for `i16` given two's complement representation, `6 / -2 = 6 / (2^16 - 2) = 0`. Note: the semantics of division by zero is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar unsigned integer division. %a = diviu %b, %c : i64 // SIMD vector element-wise division. %f = diviu %g, %h : vector<4xi32> // Tensor element-wise integer division. %x = diviu %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // UnsignedRemIOp //===----------------------------------------------------------------------===// def UnsignedRemIOp : IntArithmeticOp<"remi_unsigned"> { let summary = "unsigned integer division remainder operation"; let description = [{ Syntax: ``` operation ::= ssa-id `=` `std.remi_unsigned` ssa-use `,` ssa-use `:` type ``` Unsigned integer division remainder. Treats the leading bit as the most significant, i.e. for `i16`, `6 % -2 = 6 % (2^16 - 2) = 6`. Note: the semantics of division by zero is TBD; do NOT assume any specific behavior. Example: ```mlir // Scalar unsigned integer division remainder. %a = remiu %b, %c : i64 // SIMD vector element-wise division remainder. %f = remiu %g, %h : vector<4xi32> // Tensor element-wise integer division remainder. %x = remiu %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // UnsignedShiftRightOp //===----------------------------------------------------------------------===// def UnsignedShiftRightOp : IntArithmeticOp<"shift_right_unsigned"> { let summary = "unsigned integer right-shift"; let description = [{ The shift_right_unsigned operation shifts an integer value to the right by a variable amount. The integer is interpreted as unsigned. The high order bits are always filled with zeros. Example: ```mlir %1 = constant 160 : i8 // %1 is 0b10100000 %2 = constant 3 : i8 %3 = shift_right_unsigned %1, %2 : (i8, i8) -> i8 // %3 is 0b00010100 ``` }]; } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// def ViewOp : Std_Op<"view", [ DeclareOpInterfaceMethods, NoSideEffect]> { let summary = "memref view operation"; let description = [{ The "view" operation extracts an N-D contiguous memref with empty layout map with arbitrary element type from a 1-D contiguous memref with empty layout map of i8 element type. The ViewOp supports the following arguments: * A single dynamic byte-shift operand must be specified which represents a a shift of the base 1-D memref pointer from which to create the resulting contiguous memref view with identity layout. * A dynamic size operand that must be specified for each dynamic dimension in the resulting view memref type. The "view" operation gives a structured indexing form to a flat 1-D buffer. Unlike "subview" it can perform a type change. The type change behavior requires the op to have special semantics because, e.g. a byte shift of 3 cannot be represented as an offset on f64. For now, a "view" op: 1. Only takes a contiguous source memref with 0 offset and empty layout. 2. Must specify a byte_shift operand (in the future, a special integer attribute may be added to support the folded case). 3. Returns a contiguous memref with 0 offset and empty layout. Example: ```mlir // Allocate a flat 1D/i8 memref. %0 = alloc() : memref<2048xi8> // ViewOp with dynamic offset and static sizes. %1 = view %0[%offset_1024][] : memref<2048xi8> to memref<64x4xf32> // ViewOp with dynamic offset and two dynamic size. %2 = view %0[%offset_1024][%size0, %size1] : memref<2048xi8> to memref ``` }]; let arguments = (ins MemRefRankOf<[I8], [1]>:$source, Index:$byte_shift, Variadic:$sizes); let results = (outs AnyMemRef); let extraClassDeclaration = [{ /// The result of a view is always a memref. MemRefType getType() { return getResult().getType().cast(); } /// Returns the dynamic sizes for this view operation. This is redundant /// with `sizes` but needed in template implementations. More specifically: /// ``` /// template /// bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, /// Region *region) /// ``` operand_range getDynamicSizes() { return {sizes().begin(), sizes().end()}; } }]; let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// // YieldOp //===----------------------------------------------------------------------===// def YieldOp : Std_Op<"yield", [NoSideEffect, ReturnLike, Terminator, HasParent<"DynamicTensorFromElementsOp">]> { let summary = "Yield a value from a region"; let description = [{ This operation is used to yield a single value from a within a region. It is used to create dynamically sized tensors (see `DynamicTensorFromElementsOp`). }]; let arguments = (ins AnyType:$value); let assemblyFormat = "$value attr-dict `:` type($value)"; let verifier = ?; } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// def XOrOp : IntArithmeticOp<"xor", [Commutative]> { let summary = "integer binary xor"; let description = [{ The `xor` operation takes two operands and returns one result, each of these is required to be the same type. This type may be an integer scalar type, a vector whose element type is integer, or a tensor of integers. It has no standard attributes. Example: ```mlir // Scalar integer bitwise xor. %a = xor %b, %c : i64 // SIMD vector element-wise bitwise integer xor. %f = xor %g, %h : vector<4xi32> // Tensor element-wise bitwise integer xor. %x = xor %y, %z : tensor<4x?xi8> ``` }]; let hasFolder = 1; } //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// def ZeroExtendIOp : Std_Op<"zexti", [NoSideEffect, SameOperandsAndResultShape, ElementwiseMappable]> { let summary = "integer zero extension operation"; let description = [{ The integer zero extension operation takes an integer input of width M and an integer destination type of width N. The destination bit-width must be larger than the input bit-width (N > M). The top-most (N - M) bits of the output are filled with zeros. Example: ```mlir %1 = constant 5 : i3 // %1 is 0b101 %2 = zexti %1 : i3 to i6 // %2 is 0b000101 %3 = constant 2 : i3 // %3 is 0b010 %4 = zexti %3 : i3 to i6 // %4 is 0b000010 %5 = zexti %0 : vector<2 x i32> to vector<2 x i64> ``` }]; let arguments = (ins SignlessIntegerLike:$value); let results = (outs SignlessIntegerLike); let builders = [ OpBuilderDAG<(ins "Value":$value, "Type":$destType), [{ $_state.addOperands(value); $_state.addTypes(destType); }]>]; let parser = [{ return impl::parseCastOp(parser, result); }]; let printer = [{ return printStandardCastOp(this->getOperation(), p); }]; } #endif // STANDARD_OPS diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 4cb6821c9c15..4c2196ff176f 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -1,2997 +1,3024 @@ //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Affine/IR/AffineValueMap.h" #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/IR/Function.h" #include "mlir/IR/IntegerSet.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallBitVector.h" #include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Debug.h" using namespace mlir; using llvm::dbgs; #define DEBUG_TYPE "affine-analysis" //===----------------------------------------------------------------------===// // AffineDialect Interfaces //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for handling inlining with affine /// operations. struct AffineInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// Returns true if the given region 'src' can be inlined into the region /// 'dest' that is attached to an operation registered to the current dialect. bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Conservatively don't allow inlining into affine structures. return false; } /// Returns true if the given operation 'op', that is registered to this /// dialect, can be inlined into the given region, false otherwise. bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned, BlockAndValueMapping &valueMapping) const final { // Always allow inlining affine operations into the top-level region of a // function. There are some edge cases when inlining *into* affine // structures, but that is handled in the other 'isLegalToInline' hook // above. // TODO: We should be able to inline into other regions than functions. return isa(region->getParentOp()); } /// Affine regions should be analyzed recursively. bool shouldAnalyzeRecursively(Operation *op) const final { return true; } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // AffineDialect //===----------------------------------------------------------------------===// void AffineDialect::initialize() { addOperations(); addInterfaces(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } /// A utility function to check if a value is defined at the top level of an /// op with trait `AffineScope`. If the value is defined in an unlinked region, /// conservatively assume it is not top-level. A value of index type defined at /// the top level is always a valid symbol. bool mlir::isTopLevelValue(Value value) { if (auto arg = value.dyn_cast()) { // The block owning the argument may be unlinked, e.g. when the surrounding // region has not yet been attached to an Op, at which point the parent Op // is null. Operation *parentOp = arg.getOwner()->getParentOp(); return parentOp && parentOp->hasTrait(); } // The defining Op may live in an unlinked block so its parent Op may be null. Operation *parentOp = value.getDefiningOp()->getParentOp(); return parentOp && parentOp->hasTrait(); } /// A utility function to check if a value is defined at the top level of /// `region` or is an argument of `region`. A value of index type defined at the /// top level of a `AffineScope` region is always a valid symbol for all /// uses in that region. static bool isTopLevelValue(Value value, Region *region) { if (auto arg = value.dyn_cast()) return arg.getParentRegion() == region; return value.getDefiningOp()->getParentRegion() == region; } /// Returns the closest region enclosing `op` that is held by an operation with /// trait `AffineScope`; `nullptr` if there is no such region. // TODO: getAffineScope should be publicly exposed for affine passes/utilities. static Region *getAffineScope(Operation *op) { auto *curOp = op; while (auto *parentOp = curOp->getParentOp()) { if (parentOp->hasTrait()) return curOp->getParentRegion(); curOp = parentOp; } return nullptr; } // A Value can be used as a dimension id iff it meets one of the following // conditions: // *) It is valid as a symbol. // *) It is an induction variable. // *) It is the result of affine apply operation with dimension id arguments. bool mlir::isValidDim(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; if (auto *defOp = value.getDefiningOp()) return isValidDim(value, getAffineScope(defOp)); // This value has to be a block argument for an op that has the // `AffineScope` trait or for an affine.for or affine.parallel. auto *parentOp = value.cast().getOwner()->getParentOp(); return parentOp && (parentOp->hasTrait() || isa(parentOp)); } // Value can be used as a dimension id iff it meets one of the following // conditions: // *) It is valid as a symbol. // *) It is an induction variable. // *) It is the result of an affine apply operation with dimension id operands. bool mlir::isValidDim(Value value, Region *region) { // The value must be an index type. if (!value.getType().isIndex()) return false; // All valid symbols are okay. if (isValidSymbol(value, region)) return true; auto *op = value.getDefiningOp(); if (!op) { // This value has to be a block argument for an affine.for or an // affine.parallel. auto *parentOp = value.cast().getOwner()->getParentOp(); return isa(parentOp); } // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(op)) return applyOp.isValidDim(region); // The dim op is okay if its operand memref/tensor is defined at the top // level. if (auto dimOp = dyn_cast(op)) return isTopLevelValue(dimOp.memrefOrTensor()); return false; } /// Returns true if the 'index' dimension of the `memref` defined by /// `memrefDefOp` is a statically shaped one or defined using a valid symbol /// for `region`. template static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index, Region *region) { auto memRefType = memrefDefOp.getType(); // Statically shaped. if (!memRefType.isDynamicDim(index)) return true; // Get the position of the dimension among dynamic dimensions; unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index); return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos), region); } /// Returns true if the result of the dim op is a valid symbol for `region`. static bool isDimOpValidSymbol(DimOp dimOp, Region *region) { // The dim op is okay if its operand memref/tensor is defined at the top // level. if (isTopLevelValue(dimOp.memrefOrTensor())) return true; // Conservatively handle remaining BlockArguments as non-valid symbols. // E.g. scf.for iterArgs. if (dimOp.memrefOrTensor().isa()) return false; // The dim op is also okay if its operand memref/tensor is a view/subview // whose corresponding size is a valid symbol. Optional index = dimOp.getConstantIndex(); assert(index.hasValue() && "expect only `dim` operations with a constant index"); int64_t i = index.getValue(); return TypeSwitch(dimOp.memrefOrTensor().getDefiningOp()) .Case( [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); }) .Default([](Operation *) { return false; }); } // A value can be used as a symbol (at all its use sites) iff it meets one of // the following conditions: // *) It is a constant. // *) Its defining op or block arg appearance is immediately enclosed by an op // with `AffineScope` trait. // *) It is the result of an affine.apply operation with symbol operands. // *) It is a result of the dim op on a memref whose corresponding size is a // valid symbol. bool mlir::isValidSymbol(Value value) { // The value must be an index type. if (!value.getType().isIndex()) return false; // Check that the value is a top level value. if (isTopLevelValue(value)) return true; if (auto *defOp = value.getDefiningOp()) return isValidSymbol(value, getAffineScope(defOp)); return false; } /// A value can be used as a symbol for `region` iff it meets onf of the the /// following conditions: /// *) It is a constant. /// *) It is the result of an affine apply operation with symbol arguments. /// *) It is a result of the dim op on a memref whose corresponding size is /// a valid symbol. /// *) It is defined at the top level of 'region' or is its argument. /// *) It dominates `region`'s parent op. /// If `region` is null, conservatively assume the symbol definition scope does /// not exist and only accept the values that would be symbols regardless of /// the surrounding region structure, i.e. the first three cases above. bool mlir::isValidSymbol(Value value, Region *region) { // The value must be an index type. if (!value.getType().isIndex()) return false; // A top-level value is a valid symbol. if (region && ::isTopLevelValue(value, region)) return true; auto *defOp = value.getDefiningOp(); if (!defOp) { // A block argument that is not a top-level value is a valid symbol if it // dominates region's parent op. if (region && !region->getParentOp()->isKnownIsolatedFromAbove()) if (auto *parentOpRegion = region->getParentOp()->getParentRegion()) return isValidSymbol(value, parentOpRegion); return false; } // Constant operation is ok. Attribute operandCst; if (matchPattern(defOp, m_Constant(&operandCst))) return true; // Affine apply operation is ok if all of its operands are ok. if (auto applyOp = dyn_cast(defOp)) return applyOp.isValidSymbol(region); // Dim op results could be valid symbols at any level. if (auto dimOp = dyn_cast(defOp)) return isDimOpValidSymbol(dimOp, region); // Check for values dominating `region`'s parent op. if (region && !region->getParentOp()->isKnownIsolatedFromAbove()) if (auto *parentRegion = region->getParentOp()->getParentRegion()) return isValidSymbol(value, parentRegion); return false; } // Returns true if 'value' is a valid index to an affine operation (e.g. // affine.load, affine.store, affine.dma_start, affine.dma_wait) where // `region` provides the polyhedral symbol scope. Returns false otherwise. static bool isValidAffineIndexOperand(Value value, Region *region) { return isValidDim(value, region) || isValidSymbol(value, region); } +/// Prints dimension and symbol list. +static void printDimAndSymbolList(Operation::operand_iterator begin, + Operation::operand_iterator end, + unsigned numDims, OpAsmPrinter &printer) { + OperandRange operands(begin, end); + printer << '(' << operands.take_front(numDims) << ')'; + if (operands.size() > numDims) + printer << '[' << operands.drop_front(numDims) << ']'; +} + +/// Parses dimension and symbol list and returns true if parsing failed. +static ParseResult parseDimAndSymbolList(OpAsmParser &parser, + SmallVectorImpl &operands, + unsigned &numDims) { + SmallVector opInfos; + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) + return failure(); + // Store number of dimensions for validation by caller. + numDims = opInfos.size(); + + // Parse the optional symbol operands. + auto indexTy = parser.getBuilder().getIndexType(); + return failure(parser.parseOperandList( + opInfos, OpAsmParser::Delimiter::OptionalSquare) || + parser.resolveOperands(opInfos, indexTy, operands)); +} + /// Utility function to verify that a set of operands are valid dimension and /// symbol identifiers. The operands should be laid out such that the dimension /// operands are before the symbol operands. This function returns failure if /// there was an invalid operand. An operation is provided to emit any necessary /// errors. template static LogicalResult verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands, unsigned numDims) { unsigned opIt = 0; for (auto operand : operands) { if (opIt++ < numDims) { if (!isValidDim(operand, getAffineScope(op))) return op.emitOpError("operand cannot be used as a dimension id"); } else if (!isValidSymbol(operand, getAffineScope(op))) { return op.emitOpError("operand cannot be used as a symbol"); } } return success(); } //===----------------------------------------------------------------------===// // AffineApplyOp //===----------------------------------------------------------------------===// AffineValueMap AffineApplyOp::getAffineValueMap() { return AffineValueMap(getAffineMap(), getOperands(), getResult()); } static ParseResult parseAffineApplyOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; if (parser.parseAttribute(mapAttr, "map", result.attributes) || parseDimAndSymbolList(parser, result.operands, numDims) || parser.parseOptionalAttrDict(result.attributes)) return failure(); auto map = mapAttr.getValue(); if (map.getNumDims() != numDims || numDims + map.getNumSymbols() != result.operands.size()) { return parser.emitError(parser.getNameLoc(), "dimension or symbol index mismatch"); } result.types.append(map.getNumResults(), indexTy); return success(); } static void print(OpAsmPrinter &p, AffineApplyOp op) { p << AffineApplyOp::getOperationName() << " " << op.mapAttr(); printDimAndSymbolList(op.operand_begin(), op.operand_end(), op.getAffineMap().getNumDims(), p); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); } static LogicalResult verify(AffineApplyOp op) { // Check input and output dimensions match. auto map = op.map(); // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols()) return op.emitOpError( "operand count and affine map dimension and symbol count must match"); // Verify that the map only produces one result. if (map.getNumResults() != 1) return op.emitOpError("mapping must produce one value"); return success(); } // The result of the affine apply operation can be used as a dimension id if all // its operands are valid dimension ids. bool AffineApplyOp::isValidDim() { return llvm::all_of(getOperands(), [](Value op) { return mlir::isValidDim(op); }); } // The result of the affine apply operation can be used as a dimension id if all // its operands are valid dimension ids with the parent operation of `region` // defining the polyhedral scope for symbols. bool AffineApplyOp::isValidDim(Region *region) { return llvm::all_of(getOperands(), [&](Value op) { return ::isValidDim(op, region); }); } // The result of the affine apply operation can be used as a symbol if all its // operands are symbols. bool AffineApplyOp::isValidSymbol() { return llvm::all_of(getOperands(), [](Value op) { return mlir::isValidSymbol(op); }); } // The result of the affine apply operation can be used as a symbol in `region` // if all its operands are symbols in `region`. bool AffineApplyOp::isValidSymbol(Region *region) { return llvm::all_of(getOperands(), [&](Value operand) { return mlir::isValidSymbol(operand, region); }); } OpFoldResult AffineApplyOp::fold(ArrayRef operands) { auto map = getAffineMap(); // Fold dims and symbols to existing values. auto expr = map.getResult(0); if (auto dim = expr.dyn_cast()) return getOperand(dim.getPosition()); if (auto sym = expr.dyn_cast()) return getOperand(map.getNumDims() + sym.getPosition()); // Otherwise, default to folding the map. SmallVector result; if (failed(map.constantFold(operands, result))) return {}; return result[0]; } AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) { DenseMap::iterator iterPos; bool inserted = false; std::tie(iterPos, inserted) = dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size())); if (inserted) { reorderedDims.push_back(v); } return getAffineDimExpr(iterPos->second, v.getContext()) .cast(); } AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) { SmallVector dimRemapping; for (auto v : other.reorderedDims) { auto kvp = other.dimValueToPosition.find(v); if (dimRemapping.size() <= kvp->second) dimRemapping.resize(kvp->second + 1); dimRemapping[kvp->second] = renumberOneDim(kvp->first); } unsigned numSymbols = concatenatedSymbols.size(); unsigned numOtherSymbols = other.concatenatedSymbols.size(); SmallVector symRemapping(numOtherSymbols); for (unsigned idx = 0; idx < numOtherSymbols; ++idx) { symRemapping[idx] = getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext()); } concatenatedSymbols.insert(concatenatedSymbols.end(), other.concatenatedSymbols.begin(), other.concatenatedSymbols.end()); auto map = other.affineMap; return map.replaceDimsAndSymbols(dimRemapping, symRemapping, reorderedDims.size(), concatenatedSymbols.size()); } // Gather the positions of the operands that are produced by an AffineApplyOp. static llvm::SetVector indicesFromAffineApplyOp(ArrayRef operands) { llvm::SetVector res; for (auto en : llvm::enumerate(operands)) if (isa_and_nonnull(en.value().getDefiningOp())) res.insert(en.index()); return res; } // Support the special case of a symbol coming from an AffineApplyOp that needs // to be composed into the current AffineApplyOp. // This case is handled by rewriting all such symbols into dims for the purpose // of allowing mathematical AffineMap composition. // Returns an AffineMap where symbols that come from an AffineApplyOp have been // rewritten as dims and are ordered after the original dims. // TODO: This promotion makes AffineMap lose track of which // symbols are represented as dims. This loss is static but can still be // recovered dynamically (with `isValidSymbol`). Still this is annoying for the // semi-affine map case. A dynamic canonicalization of all dims that are valid // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even // results in better simplifications and foldings. But we should evaluate // whether this behavior is what we really want after using more. static AffineMap promoteComposedSymbolsAsDims(AffineMap map, ArrayRef symbols) { if (symbols.empty()) { return map; } // Sanity check on symbols. for (auto sym : symbols) { assert(isValidSymbol(sym) && "Expected only valid symbols"); (void)sym; } // Extract the symbol positions that come from an AffineApplyOp and // needs to be rewritten as dims. auto symPositions = indicesFromAffineApplyOp(symbols); if (symPositions.empty()) { return map; } // Create the new map by replacing each symbol at pos by the next new dim. unsigned numDims = map.getNumDims(); unsigned numSymbols = map.getNumSymbols(); unsigned numNewDims = 0; unsigned numNewSymbols = 0; SmallVector symReplacements(numSymbols); for (unsigned i = 0; i < numSymbols; ++i) { symReplacements[i] = symPositions.count(i) > 0 ? getAffineDimExpr(numDims + numNewDims++, map.getContext()) : getAffineSymbolExpr(numNewSymbols++, map.getContext()); } assert(numSymbols >= numNewDims); AffineMap newMap = map.replaceDimsAndSymbols( {}, symReplacements, numDims + numNewDims, numNewSymbols); return newMap; } /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to /// keep a correspondence between the mathematical `map` and the `operands` of /// a given AffineApplyOp. This correspondence is maintained by iterating over /// the operands and forming an `auxiliaryMap` that can be composed /// mathematically with `map`. To keep this correspondence in cases where /// symbols are produced by affine.apply operations, we perform a local rewrite /// of symbols as dims. /// /// Rationale for locally rewriting symbols as dims: /// ================================================ /// The mathematical composition of AffineMap must always concatenate symbols /// because it does not have enough information to do otherwise. For example, /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce /// `(d0)[s0, s1] -> (d0 + s0 + s1)`. /// /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when /// applied to the same mlir::Value for both s0 and s1. /// As a consequence mathematical composition of AffineMap always concatenates /// symbols. /// /// When AffineMaps are used in AffineApplyOp however, they may specify /// composition via symbols, which is ambiguous mathematically. This corner case /// is handled by locally rewriting such symbols that come from AffineApplyOp /// into dims and composing through dims. /// TODO: Composition via symbols comes at a significant code /// complexity. Alternatively we should investigate whether we want to /// explicitly disallow symbols coming from affine.apply and instead force the /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2 /// extra API calls for such uses, which haven't popped up until now) and the /// benefit potentially big: simpler and more maintainable code for a /// non-trivial, recursive, procedure. AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map, ArrayRef operands) : AffineApplyNormalizer() { static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0"); assert(map.getNumInputs() == operands.size() && "number of operands does not match the number of map inputs"); LLVM_DEBUG(map.print(dbgs() << "\nInput map: ")); // Promote symbols that come from an AffineApplyOp to dims by rewriting the // map to always refer to: // (dims, symbols coming from AffineApplyOp, other symbols). // The order of operands can remain unchanged. // This is a simplification that relies on 2 ordering properties: // 1. rewritten symbols always appear after the original dims in the map; // 2. operands are traversed in order and either dispatched to: // a. auxiliaryExprs (dims and symbols rewritten as dims); // b. concatenatedSymbols (all other symbols) // This allows operand order to remain unchanged. unsigned numDimsBeforeRewrite = map.getNumDims(); map = promoteComposedSymbolsAsDims(map, operands.take_back(map.getNumSymbols())); LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: ")); SmallVector auxiliaryExprs; bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth); // We fully spell out the 2 cases below. In this particular instance a little // code duplication greatly improves readability. // Note that the first branch would disappear if we only supported full // composition (i.e. infinite kMaxAffineApplyDepth). if (!furtherCompose) { // 1. Only dispatch dims or symbols. for (auto en : llvm::enumerate(operands)) { auto t = en.value(); assert(t.getType().isIndex()); bool isDim = (en.index() < map.getNumDims()); if (isDim) { // a. The mathematical composition of AffineMap composes dims. auxiliaryExprs.push_back(renumberOneDim(t)); } else { // b. The mathematical composition of AffineMap concatenates symbols. // We do the same for symbol operands. concatenatedSymbols.push_back(t); } } } else { assert(numDimsBeforeRewrite <= operands.size()); // 2. Compose AffineApplyOps and dispatch dims or symbols. for (unsigned i = 0, e = operands.size(); i < e; ++i) { auto t = operands[i]; auto affineApply = t.getDefiningOp(); if (affineApply) { // a. Compose affine.apply operations. LLVM_DEBUG(affineApply.getOperation()->print( dbgs() << "\nCompose AffineApplyOp recursively: ")); AffineMap affineApplyMap = affineApply.getAffineMap(); SmallVector affineApplyOperands( affineApply.getOperands().begin(), affineApply.getOperands().end()); AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands); LLVM_DEBUG(normalizer.affineMap.print( dbgs() << "\nRenumber into current normalizer: ")); auto renumberedMap = renumber(normalizer); LLVM_DEBUG( renumberedMap.print(dbgs() << "\nRecursive composition yields: ")); auxiliaryExprs.push_back(renumberedMap.getResult(0)); } else { if (i < numDimsBeforeRewrite) { // b. The mathematical composition of AffineMap composes dims. auxiliaryExprs.push_back(renumberOneDim(t)); } else { // c. The mathematical composition of AffineMap concatenates symbols. // Note that the map composition will put symbols already present // in the map before any symbols coming from the auxiliary map, so // we insert them before any symbols that are due to renumbering, // and after the proper symbols we have seen already. concatenatedSymbols.insert( std::next(concatenatedSymbols.begin(), numProperSymbols++), t); } } } } // Early exit if `map` is already composed. if (auxiliaryExprs.empty()) { affineMap = map; return; } assert(concatenatedSymbols.size() >= map.getNumSymbols() && "Unexpected number of concatenated symbols"); auto numDims = dimValueToPosition.size(); auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols(); auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext()); LLVM_DEBUG(map.print(dbgs() << "\nCompose map: ")); LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: ")); LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: ")); // TODO: Disabling simplification results in major speed gains. // Another option is to cache the results as it is expected a lot of redundant // work is performed in practice. affineMap = simplifyAffineMap(map.compose(auxiliaryMap)); LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: ")); LLVM_DEBUG(dbgs() << "\n"); } void AffineApplyNormalizer::normalize(AffineMap *otherMap, SmallVectorImpl *otherOperands) { AffineApplyNormalizer other(*otherMap, *otherOperands); *otherMap = renumber(other); otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size()); otherOperands->assign(reorderedDims.begin(), reorderedDims.end()); otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end()); } /// Implements `map` and `operands` composition and simplification to support /// `makeComposedAffineApply`. This can be called to achieve the same effects /// on `map` and `operands` without creating an AffineApplyOp that needs to be /// immediately deleted. static void composeAffineMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { AffineApplyNormalizer normalizer(*map, *operands); auto normalizedMap = normalizer.getAffineMap(); auto normalizedOperands = normalizer.getOperands(); canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands); *map = normalizedMap; *operands = normalizedOperands; assert(*map); } void mlir::fullyComposeAffineMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { while (llvm::any_of(*operands, [](Value v) { return isa_and_nonnull(v.getDefiningOp()); })) { composeAffineMapAndOperands(map, operands); } } AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef operands) { AffineMap normalizedMap = map; SmallVector normalizedOperands(operands.begin(), operands.end()); composeAffineMapAndOperands(&normalizedMap, &normalizedOperands); assert(normalizedMap); return b.create(loc, normalizedMap, normalizedOperands); } // A symbol may appear as a dim in affine.apply operations. This function // canonicalizes dims that are valid symbols into actual symbols. template static void canonicalizePromotedSymbols(MapOrSet *mapOrSet, SmallVectorImpl *operands) { if (!mapOrSet || operands->empty()) return; assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); SmallVector remappedSymbols; remappedSymbols.reserve(operands->size()); unsigned nextDim = 0; unsigned nextSym = 0; unsigned oldNumSyms = mapOrSet->getNumSymbols(); SmallVector dimRemapping(mapOrSet->getNumDims()); for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) { if (i < mapOrSet->getNumDims()) { if (isValidSymbol((*operands)[i])) { // This is a valid symbol that appears as a dim, canonicalize it. dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context); remappedSymbols.push_back((*operands)[i]); } else { dimRemapping[i] = getAffineDimExpr(nextDim++, context); resultOperands.push_back((*operands)[i]); } } else { resultOperands.push_back((*operands)[i]); } } resultOperands.append(remappedSymbols.begin(), remappedSymbols.end()); *operands = resultOperands; *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim, oldNumSyms + nextSym); assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); } // Works for either an affine map or an integer set. template static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet, SmallVectorImpl *operands) { static_assert(llvm::is_one_of::value, "Argument must be either of AffineMap or IntegerSet type"); if (!mapOrSet || operands->empty()) return; assert(mapOrSet->getNumInputs() == operands->size() && "map/set inputs must match number of operands"); canonicalizePromotedSymbols(mapOrSet, operands); // Check to see what dims are used. llvm::SmallBitVector usedDims(mapOrSet->getNumDims()); llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols()); mapOrSet->walkExprs([&](AffineExpr expr) { if (auto dimExpr = expr.dyn_cast()) usedDims[dimExpr.getPosition()] = true; else if (auto symExpr = expr.dyn_cast()) usedSyms[symExpr.getPosition()] = true; }); auto *context = mapOrSet->getContext(); SmallVector resultOperands; resultOperands.reserve(operands->size()); llvm::SmallDenseMap seenDims; SmallVector dimRemapping(mapOrSet->getNumDims()); unsigned nextDim = 0; for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) { if (usedDims[i]) { // Remap dim positions for duplicate operands. auto it = seenDims.find((*operands)[i]); if (it == seenDims.end()) { dimRemapping[i] = getAffineDimExpr(nextDim++, context); resultOperands.push_back((*operands)[i]); seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i])); } else { dimRemapping[i] = it->second; } } } llvm::SmallDenseMap seenSymbols; SmallVector symRemapping(mapOrSet->getNumSymbols()); unsigned nextSym = 0; for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) { if (!usedSyms[i]) continue; // Handle constant operands (only needed for symbolic operands since // constant operands in dimensional positions would have already been // promoted to symbolic positions above). IntegerAttr operandCst; if (matchPattern((*operands)[i + mapOrSet->getNumDims()], m_Constant(&operandCst))) { symRemapping[i] = getAffineConstantExpr(operandCst.getValue().getSExtValue(), context); continue; } // Remap symbol positions for duplicate operands. auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]); if (it == seenSymbols.end()) { symRemapping[i] = getAffineSymbolExpr(nextSym++, context); resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]); seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()], symRemapping[i])); } else { symRemapping[i] = it->second; } } *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping, nextDim, nextSym); *operands = resultOperands; } void mlir::canonicalizeMapAndOperands(AffineMap *map, SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(map, operands); } void mlir::canonicalizeSetAndOperands(IntegerSet *set, SmallVectorImpl *operands) { canonicalizeMapOrSetAndOperands(set, operands); } namespace { /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing /// maps that supply results into them. /// template struct SimplifyAffineOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; /// Replace the affine op with another instance of it with the supplied /// map and mapOperands. void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp, AffineMap map, ArrayRef mapOperands) const; LogicalResult matchAndRewrite(AffineOpTy affineOp, PatternRewriter &rewriter) const override { static_assert(llvm::is_one_of::value, "affine load/store/apply/prefetch/min/max op expected"); auto map = affineOp.getAffineMap(); AffineMap oldMap = map; auto oldOperands = affineOp.getMapOperands(); SmallVector resultOperands(oldOperands); composeAffineMapAndOperands(&map, &resultOperands); if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(), resultOperands.begin())) return failure(); replaceAffineOp(rewriter, affineOp, map, resultOperands); return success(); } }; // Specialize the template to account for the different build signatures for // affine load, store, and apply ops. template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineLoadOp load, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(load, load.getMemRef(), map, mapOperands); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(), prefetch.isWrite(), prefetch.isDataCache()); } template <> void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineStoreOp store, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp( store, store.getValueToStore(), store.getMemRef(), map, mapOperands); } // Generic version for ops that don't have extra operands. template void SimplifyAffineOp::replaceAffineOp( PatternRewriter &rewriter, AffineOpTy op, AffineMap map, ArrayRef mapOperands) const { rewriter.replaceOpWithNewOp(op, map, mapOperands); } } // end anonymous namespace. void AffineApplyOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // AffineDmaStartOp //===----------------------------------------------------------------------===// // TODO: Check that map operands are loop IVs or symbols. void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, Value srcMemRef, AffineMap srcMap, ValueRange srcIndices, Value destMemRef, AffineMap dstMap, ValueRange destIndices, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap)); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap)); result.addOperands(destIndices); result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); if (stride) { result.addOperands({stride, elementsPerStride}); } } void AffineDmaStartOp::print(OpAsmPrinter &p) { p << "affine.dma_start " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); p << "], " << getDstMemRef() << '['; p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices()); p << "], " << getTagMemRef() << '['; p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices()); p << "], " << getNumElements(); if (isStrided()) { p << ", " << getStride(); p << ", " << getNumElementsPerStride(); } p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", " << getTagMemRefType(); } // Parse AffineDmaStartOp. // Ex: // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size, // %stride, %num_elt_per_stride // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32> // ParseResult AffineDmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; AffineMapAttr srcMapAttr; SmallVector srcMapOperands; OpAsmParser::OperandType dstMemRefInfo; AffineMapAttr dstMapAttr; SmallVector dstMapOperands; OpAsmParser::OperandType tagMemRefInfo; AffineMapAttr tagMapAttr; SmallVector tagMapOperands; OpAsmParser::OperandType numElementsInfo; SmallVector strideInfo; SmallVector types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) dst memref followed by its affine maps operands (in square brackets). // *) src memref followed by its affine map operands (in square brackets). // *) tag memref followed by its affine map operands (in square brackets). // *) number of elements transferred by DMA operation. if (parser.parseOperand(srcMemRefInfo) || parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr, getSrcMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr, getDstMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(tagMemRefInfo) || parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(numElementsInfo)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) { return failure(); } if (!strideInfo.empty() && strideInfo.size() != 2) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } bool isStrided = strideInfo.size() == 2; if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "expected three types"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcMapOperands, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstMapOperands, indexType, result.operands) || parser.resolveOperand(tagMemRefInfo, types[2], result.operands) || parser.resolveOperands(tagMapOperands, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } // Check that src/dst/tag operand counts match their map.numInputs. if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() || dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() || tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) return parser.emitError(parser.getNameLoc(), "memref operand count not equal to map.numInputs"); return success(); } LogicalResult AffineDmaStartOp::verify() { if (!getOperand(getSrcMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA source to be of memref type"); if (!getOperand(getDstMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA destination to be of memref type"); if (!getOperand(getTagMemRefOperandIndex()).getType().isa()) return emitOpError("expected DMA tag to be of memref type"); // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) { return emitOpError("DMA should be between different memory spaces"); } unsigned numInputsAllMaps = getSrcMap().getNumInputs() + getDstMap().getNumInputs() + getTagMap().getNumInputs(); if (getNumOperands() != numInputsAllMaps + 3 + 1 && getNumOperands() != numInputsAllMaps + 3 + 1 + 2) { return emitOpError("incorrect number of operands"); } Region *scope = getAffineScope(*this); for (auto idx : getSrcIndices()) { if (!idx.getType().isIndex()) return emitOpError("src index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("src index must be a dimension or symbol identifier"); } for (auto idx : getDstIndices()) { if (!idx.getType().isIndex()) return emitOpError("dst index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("dst index must be a dimension or symbol identifier"); } for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("tag index to dma_start must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("tag index must be a dimension or symbol identifier"); } return success(); } LogicalResult AffineDmaStartOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineDmaWaitOp //===----------------------------------------------------------------------===// // TODO: Check that map operands are loop IVs or symbols. void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap)); result.addOperands(tagIndices); result.addOperands(numElements); } void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << "affine.dma_wait " << getTagMemRef() << '['; SmallVector operands(getTagIndices()); p.printAffineMapOfSSAIds(getTagMapAttr(), operands); p << "], "; p.printOperand(getNumElements()); p << " : " << getTagMemRef().getType(); } // Parse AffineDmaWaitOp. // Eg: // affine.dma_wait %tag[%index], %num_elements // : memref<1 x i32, (d0) -> (d0), 4> // ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType tagMemRefInfo; AffineMapAttr tagMapAttr; SmallVector tagMapOperands; Type type; auto indexType = parser.getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its map operands, and dma size. if (parser.parseOperand(tagMemRefInfo) || parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr, getTagMapAttrName(), result.attributes) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseColonType(type) || parser.resolveOperand(tagMemRefInfo, type, result.operands) || parser.resolveOperands(tagMapOperands, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); if (!type.isa()) return parser.emitError(parser.getNameLoc(), "expected tag to be of memref type"); if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs()) return parser.emitError(parser.getNameLoc(), "tag memref operand count != to map.numInputs"); return success(); } LogicalResult AffineDmaWaitOp::verify() { if (!getOperand(0).getType().isa()) return emitOpError("expected DMA tag to be of memref type"); Region *scope = getAffineScope(*this); for (auto idx : getTagIndices()) { if (!idx.getType().isIndex()) return emitOpError("index to dma_wait must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return emitOpError("index must be a dimension or symbol identifier"); } return success(); } LogicalResult AffineDmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineForOp //===----------------------------------------------------------------------===// /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and /// bodyBuilder are empty/null, we include default terminator op. void AffineForOp::build(OpBuilder &builder, OperationState &result, ValueRange lbOperands, AffineMap lbMap, ValueRange ubOperands, AffineMap ubMap, int64_t step, ValueRange iterArgs, BodyBuilderFn bodyBuilder) { assert(((!lbMap && lbOperands.empty()) || lbOperands.size() == lbMap.getNumInputs()) && "lower bound operand count does not match the affine map"); assert(((!ubMap && ubOperands.empty()) || ubOperands.size() == ubMap.getNumInputs()) && "upper bound operand count does not match the affine map"); assert(step > 0 && "step has to be a positive integer constant"); for (Value val : iterArgs) result.addTypes(val.getType()); // Add an attribute for the step. result.addAttribute(getStepAttrName(), builder.getIntegerAttr(builder.getIndexType(), step)); // Add the lower bound. result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap)); result.addOperands(lbOperands); // Add the upper bound. result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap)); result.addOperands(ubOperands); result.addOperands(iterArgs); // Create a region and a block for the body. The argument of the region is // the loop induction variable. Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block); Block &bodyBlock = bodyRegion->front(); Value inductionVar = bodyBlock.addArgument(builder.getIndexType()); for (Value val : iterArgs) bodyBlock.addArgument(val.getType()); // Create the default terminator if the builder is not provided and if the // iteration arguments are not provided. Otherwise, leave this to the caller // because we don't know which values to return from the loop. if (iterArgs.empty() && !bodyBuilder) { ensureTerminator(*bodyRegion, builder, result.location); } else if (bodyBuilder) { OpBuilder::InsertionGuard guard(builder); builder.setInsertionPointToStart(&bodyBlock); bodyBuilder(builder, result.location, inductionVar, bodyBlock.getArguments().drop_front()); } } void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb, int64_t ub, int64_t step, ValueRange iterArgs, BodyBuilderFn bodyBuilder) { auto lbMap = AffineMap::getConstantMap(lb, builder.getContext()); auto ubMap = AffineMap::getConstantMap(ub, builder.getContext()); return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs, bodyBuilder); } static LogicalResult verify(AffineForOp op) { // Check that the body defines as single block argument for the induction // variable. auto *body = op.getBody(); if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex()) return op.emitOpError( "expected body to have a single index argument for the " "induction variable"); // Verify that the bound operands are valid dimension/symbols. /// Lower bound. if (op.getLowerBoundMap().getNumInputs() > 0) if (failed( verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(), op.getLowerBoundMap().getNumDims()))) return failure(); /// Upper bound. if (op.getUpperBoundMap().getNumInputs() > 0) if (failed( verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(), op.getUpperBoundMap().getNumDims()))) return failure(); unsigned opNumResults = op.getNumResults(); if (opNumResults == 0) return success(); // If ForOp defines values, check that the number and types of the defined // values match ForOp initial iter operands and backedge basic block // arguments. if (op.getNumIterOperands() != opNumResults) return op.emitOpError( "mismatch between the number of loop-carried values and results"); if (op.getNumRegionIterArgs() != opNumResults) return op.emitOpError( "mismatch between the number of basic block args and results"); return success(); } /// Parse a for operation loop bounds. static ParseResult parseBound(bool isLower, OperationState &result, OpAsmParser &p) { // 'min' / 'max' prefixes are generally syntactic sugar, but are required if // the map has multiple results. bool failedToParsedMinMax = failed(p.parseOptionalKeyword(isLower ? "max" : "min")); auto &builder = p.getBuilder(); auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName() : AffineForOp::getUpperBoundAttrName(); // Parse ssa-id as identity map. SmallVector boundOpInfos; if (p.parseOperandList(boundOpInfos)) return failure(); if (!boundOpInfos.empty()) { // Check that only one operand was parsed. if (boundOpInfos.size() > 1) return p.emitError(p.getNameLoc(), "expected only one loop bound operand"); // TODO: improve error message when SSA value is not of index type. // Currently it is 'use of value ... expects different type than prior uses' if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(), result.operands)) return failure(); // Create an identity map using symbol id. This representation is optimized // for storage. Analysis passes may expand it into a multi-dimensional map // if desired. AffineMap map = builder.getSymbolIdentityMap(); result.addAttribute(boundAttrName, AffineMapAttr::get(map)); return success(); } // Get the attribute location. llvm::SMLoc attrLoc = p.getCurrentLocation(); Attribute boundAttr; if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName, result.attributes)) return failure(); // Parse full form - affine map followed by dim and symbol list. if (auto affineMapAttr = boundAttr.dyn_cast()) { unsigned currentNumOperands = result.operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result.operands, numDims)) return failure(); auto map = affineMapAttr.getValue(); if (map.getNumDims() != numDims) return p.emitError( p.getNameLoc(), "dim operand count and affine map dim count must match"); unsigned numDimAndSymbolOperands = result.operands.size() - currentNumOperands; if (numDims + map.getNumSymbols() != numDimAndSymbolOperands) return p.emitError( p.getNameLoc(), "symbol operand count and affine map symbol count must match"); // If the map has multiple results, make sure that we parsed the min/max // prefix. if (map.getNumResults() > 1 && failedToParsedMinMax) { if (isLower) { return p.emitError(attrLoc, "lower loop bound affine map with " "multiple results requires 'max' prefix"); } return p.emitError(attrLoc, "upper loop bound affine map with multiple " "results requires 'min' prefix"); } return success(); } // Parse custom assembly form. if (auto integerAttr = boundAttr.dyn_cast()) { result.attributes.pop_back(); result.addAttribute( boundAttrName, AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt()))); return success(); } return p.emitError( p.getNameLoc(), "expected valid affine map representation for loop bounds"); } static ParseResult parseAffineForOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); OpAsmParser::OperandType inductionVariable; // Parse the induction variable followed by '='. if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual()) return failure(); // Parse loop bounds. if (parseBound(/*isLower=*/true, result, parser) || parser.parseKeyword("to", " between bounds") || parseBound(/*isLower=*/false, result, parser)) return failure(); // Parse the optional loop step, we default to 1 if one is not present. if (parser.parseOptionalKeyword("step")) { result.addAttribute( AffineForOp::getStepAttrName(), builder.getIntegerAttr(builder.getIndexType(), /*value=*/1)); } else { llvm::SMLoc stepLoc = parser.getCurrentLocation(); IntegerAttr stepAttr; if (parser.parseAttribute(stepAttr, builder.getIndexType(), AffineForOp::getStepAttrName().data(), result.attributes)) return failure(); if (stepAttr.getValue().getSExtValue() < 0) return parser.emitError( stepLoc, "expected step to be representable as a positive signed integer"); } // Parse the optional initial iteration arguments. SmallVector regionArgs, operands; SmallVector argTypes; regionArgs.push_back(inductionVariable); if (succeeded(parser.parseOptionalKeyword("iter_args"))) { // Parse assignment list and results type list. if (parser.parseAssignmentList(regionArgs, operands) || parser.parseArrowTypeList(result.types)) return failure(); // Resolve input operands. for (auto operandType : llvm::zip(operands, result.types)) if (parser.resolveOperand(std::get<0>(operandType), std::get<1>(operandType), result.operands)) return failure(); } // Induction variable. Type indexType = builder.getIndexType(); argTypes.push_back(indexType); // Loop carried variables. argTypes.append(result.types.begin(), result.types.end()); // Parse the body region. Region *body = result.addRegion(); if (regionArgs.size() != argTypes.size()) return parser.emitError( parser.getNameLoc(), "mismatch between the number of loop-carried values and results"); if (parser.parseRegion(*body, regionArgs, argTypes)) return failure(); AffineForOp::ensureTerminator(*body, builder, result.location); // Parse the optional attribute list. return parser.parseOptionalAttrDict(result.attributes); } static void printBound(AffineMapAttr boundMap, Operation::operand_range boundOperands, const char *prefix, OpAsmPrinter &p) { AffineMap map = boundMap.getValue(); // Check if this bound should be printed using custom assembly form. // The decision to restrict printing custom assembly form to trivial cases // comes from the will to roundtrip MLIR binary -> text -> binary in a // lossless way. // Therefore, custom assembly form parsing and printing is only supported for // zero-operand constant maps and single symbol operand identity maps. if (map.getNumResults() == 1) { AffineExpr expr = map.getResult(0); // Print constant bound. if (map.getNumDims() == 0 && map.getNumSymbols() == 0) { if (auto constExpr = expr.dyn_cast()) { p << constExpr.getValue(); return; } } // Print bound that consists of a single SSA symbol if the map is over a // single symbol. if (map.getNumDims() == 0 && map.getNumSymbols() == 1) { if (auto symExpr = expr.dyn_cast()) { p.printOperand(*boundOperands.begin()); return; } } } else { // Map has multiple results. Print 'min' or 'max' prefix. p << prefix << ' '; } // Print the map and its operands. p << boundMap; printDimAndSymbolList(boundOperands.begin(), boundOperands.end(), map.getNumDims(), p); } unsigned AffineForOp::getNumIterOperands() { AffineMap lbMap = getLowerBoundMapAttr().getValue(); AffineMap ubMap = getUpperBoundMapAttr().getValue(); return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs(); } static void print(OpAsmPrinter &p, AffineForOp op) { p << op.getOperationName() << ' '; p.printOperand(op.getBody()->getArgument(0)); p << " = "; printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p); p << " to "; printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p); if (op.getStep() != 1) p << " step " << op.getStep(); bool printBlockTerminators = false; if (op.getNumIterOperands() > 0) { p << " iter_args("; auto regionArgs = op.getRegionIterArgs(); auto operands = op.getIterOperands(); llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) { p << std::get<0>(it) << " = " << std::get<1>(it); }); p << ") -> (" << op.getResultTypes() << ")"; printBlockTerminators = true; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false, printBlockTerminators); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getLowerBoundAttrName(), op.getUpperBoundAttrName(), op.getStepAttrName()}); } /// Fold the constant bounds of a loop. static LogicalResult foldLoopBounds(AffineForOp forOp) { auto foldLowerOrUpperBound = [&forOp](bool lower) { // Check to see if each of the operands is the result of a constant. If // so, get the value. If not, ignore it. SmallVector operandConstants; auto boundOperands = lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands(); for (auto operand : boundOperands) { Attribute operandCst; matchPattern(operand, m_Constant(&operandCst)); operandConstants.push_back(operandCst); } AffineMap boundMap = lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap(); assert(boundMap.getNumResults() >= 1 && "bound maps should have at least one result"); SmallVector foldedResults; if (failed(boundMap.constantFold(operandConstants, foldedResults))) return failure(); // Compute the max or min as applicable over the results. assert(!foldedResults.empty() && "bounds should have at least one result"); auto maxOrMin = foldedResults[0].cast().getValue(); for (unsigned i = 1, e = foldedResults.size(); i < e; i++) { auto foldedResult = foldedResults[i].cast().getValue(); maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) : llvm::APIntOps::smin(maxOrMin, foldedResult); } lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue()) : forOp.setConstantUpperBound(maxOrMin.getSExtValue()); return success(); }; // Try to fold the lower bound. bool folded = false; if (!forOp.hasConstantLowerBound()) folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true)); // Try to fold the upper bound. if (!forOp.hasConstantUpperBound()) folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false)); return success(folded); } /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) { SmallVector lbOperands(forOp.getLowerBoundOperands()); SmallVector ubOperands(forOp.getUpperBoundOperands()); auto lbMap = forOp.getLowerBoundMap(); auto ubMap = forOp.getUpperBoundMap(); auto prevLbMap = lbMap; auto prevUbMap = ubMap; canonicalizeMapAndOperands(&lbMap, &lbOperands); lbMap = removeDuplicateExprs(lbMap); canonicalizeMapAndOperands(&ubMap, &ubOperands); ubMap = removeDuplicateExprs(ubMap); // Any canonicalization change always leads to updated map(s). if (lbMap == prevLbMap && ubMap == prevUbMap) return failure(); if (lbMap != prevLbMap) forOp.setLowerBound(lbOperands, lbMap); if (ubMap != prevUbMap) forOp.setUpperBound(ubOperands, ubMap); return success(); } namespace { /// This is a pattern to fold trivially empty loops. struct AffineForEmptyLoopFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineForOp forOp, PatternRewriter &rewriter) const override { // Check that the body only contains a yield. if (!llvm::hasSingleElement(*forOp.getBody())) return failure(); rewriter.eraseOp(forOp); return success(); } }; } // end anonymous namespace void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } LogicalResult AffineForOp::fold(ArrayRef operands, SmallVectorImpl &results) { bool folded = succeeded(foldLoopBounds(*this)); folded |= succeeded(canonicalizeLoopBounds(*this)); return success(folded); } AffineBound AffineForOp::getLowerBound() { auto lbMap = getLowerBoundMap(); return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap); } AffineBound AffineForOp::getUpperBound() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap); } void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector newOperands(lbOperands.begin(), lbOperands.end()); auto ubOperands = getUpperBoundOperands(); newOperands.append(ubOperands.begin(), ubOperands.end()); auto iterOperands = getIterOperands(); newOperands.append(iterOperands.begin(), iterOperands.end()); getOperation()->setOperands(newOperands); setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); SmallVector newOperands(getLowerBoundOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); auto iterOperands = getIterOperands(); newOperands.append(iterOperands.begin(), iterOperands.end()); getOperation()->setOperands(newOperands); setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setLowerBoundMap(AffineMap map) { auto lbMap = getLowerBoundMap(); assert(lbMap.getNumDims() == map.getNumDims() && lbMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); (void)lbMap; setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map)); } void AffineForOp::setUpperBoundMap(AffineMap map) { auto ubMap = getUpperBoundMap(); assert(ubMap.getNumDims() == map.getNumDims() && ubMap.getNumSymbols() == map.getNumSymbols()); assert(map.getNumResults() >= 1 && "bound map has at least one result"); (void)ubMap; setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map)); } bool AffineForOp::hasConstantLowerBound() { return getLowerBoundMap().isSingleConstant(); } bool AffineForOp::hasConstantUpperBound() { return getUpperBoundMap().isSingleConstant(); } int64_t AffineForOp::getConstantLowerBound() { return getLowerBoundMap().getSingleConstantResult(); } int64_t AffineForOp::getConstantUpperBound() { return getUpperBoundMap().getSingleConstantResult(); } void AffineForOp::setConstantLowerBound(int64_t value) { setLowerBound({}, AffineMap::getConstantMap(value, getContext())); } void AffineForOp::setConstantUpperBound(int64_t value) { setUpperBound({}, AffineMap::getConstantMap(value, getContext())); } AffineForOp::operand_range AffineForOp::getLowerBoundOperands() { return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()}; } AffineForOp::operand_range AffineForOp::getUpperBoundOperands() { return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_begin() + getLowerBoundMap().getNumInputs() + getUpperBoundMap().getNumInputs()}; } bool AffineForOp::matchingBoundOperandList() { auto lbMap = getLowerBoundMap(); auto ubMap = getUpperBoundMap(); if (lbMap.getNumDims() != ubMap.getNumDims() || lbMap.getNumSymbols() != ubMap.getNumSymbols()) return false; unsigned numOperands = lbMap.getNumInputs(); for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) { // Compare Value 's. if (getOperand(i) != getOperand(numOperands + i)) return false; } return true; } Region &AffineForOp::getLoopBody() { return region(); } bool AffineForOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } LogicalResult AffineForOp::moveOutOfLoop(ArrayRef ops) { for (auto *op : ops) op->moveBefore(*this); return success(); } /// Returns true if the provided value is the induction variable of a /// AffineForOp. bool mlir::isForInductionVar(Value val) { return getForInductionVarOwner(val) != AffineForOp(); } /// Returns the loop parent of an induction variable. If the provided value is /// not an induction variable, then return nullptr. AffineForOp mlir::getForInductionVarOwner(Value val) { auto ivArg = val.dyn_cast(); if (!ivArg || !ivArg.getOwner()) return AffineForOp(); auto *containingInst = ivArg.getOwner()->getParent()->getParentOp(); return dyn_cast(containingInst); } /// Extracts the induction variables from a list of AffineForOps and returns /// them. void mlir::extractForInductionVars(ArrayRef forInsts, SmallVectorImpl *ivs) { ivs->reserve(forInsts.size()); for (auto forInst : forInsts) ivs->push_back(forInst.getInductionVar()); } /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop /// operations. template static void buildAffineLoopNestImpl( OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs, ArrayRef steps, function_ref bodyBuilderFn, LoopCreatorTy &&loopCreatorFn) { assert(lbs.size() == ubs.size() && "Mismatch in number of arguments"); assert(lbs.size() == steps.size() && "Mismatch in number of arguments"); // If there are no loops to be constructed, construct the body anyway. OpBuilder::InsertionGuard guard(builder); if (lbs.empty()) { if (bodyBuilderFn) bodyBuilderFn(builder, loc, ValueRange()); return; } // Create the loops iteratively and store the induction variables. SmallVector ivs; ivs.reserve(lbs.size()); for (unsigned i = 0, e = lbs.size(); i < e; ++i) { // Callback for creating the loop body, always creates the terminator. auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv, ValueRange iterArgs) { ivs.push_back(iv); // In the innermost loop, call the body builder. if (i == e - 1 && bodyBuilderFn) { OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } nestedBuilder.create(nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch // between constant- and variable-bound loops. auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody); builder.setInsertionPointToStart(loop.getBody()); } } /// Creates an affine loop from the bounds known to be constants. static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { return builder.create(loc, lb, ub, step, /*iterArgs=*/llvm::None, bodyBuilderFn); } /// Creates an affine loop from the bounds that may or may not be constants. static AffineForOp buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { auto lbConst = lb.getDefiningOp(); auto ubConst = ub.getDefiningOp(); if (lbConst && ubConst) return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(), ubConst.getValue(), step, bodyBuilderFn); return builder.create(loc, lb, builder.getDimIdentityMap(), ub, builder.getDimIdentityMap(), step, /*iterArgs=*/llvm::None, bodyBuilderFn); } void mlir::buildAffineLoopNest( OpBuilder &builder, Location loc, ArrayRef lbs, ArrayRef ubs, ArrayRef steps, function_ref bodyBuilderFn) { buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, buildAffineLoopFromConstants); } void mlir::buildAffineLoopNest( OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs, ArrayRef steps, function_ref bodyBuilderFn) { buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn, buildAffineLoopFromValues); } //===----------------------------------------------------------------------===// // AffineIfOp //===----------------------------------------------------------------------===// namespace { /// Remove else blocks that have nothing other than a zero value yield. struct SimplifyDeadElse : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AffineIfOp ifOp, PatternRewriter &rewriter) const override { if (ifOp.elseRegion().empty() || !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults()) return failure(); rewriter.startRootUpdate(ifOp); rewriter.eraseBlock(ifOp.getElseBlock()); rewriter.finalizeRootUpdate(ifOp); return success(); } }; } // end anonymous namespace. static LogicalResult verify(AffineIfOp op) { // Verify that we have a condition attribute. auto conditionAttr = op.getAttrOfType(op.getConditionAttrName()); if (!conditionAttr) return op.emitOpError( "requires an integer set attribute named 'condition'"); // Verify that there are enough operands for the condition. IntegerSet condition = conditionAttr.getValue(); if (op.getNumOperands() != condition.getNumInputs()) return op.emitOpError( "operand count and condition integer set dimension and " "symbol count must match"); // Verify that the operands are valid dimension/symbols. if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(), condition.getNumDims()))) return failure(); return success(); } static ParseResult parseAffineIfOp(OpAsmParser &parser, OperationState &result) { // Parse the condition attribute set. IntegerSetAttr conditionAttr; unsigned numDims; if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(), result.attributes) || parseDimAndSymbolList(parser, result.operands, numDims)) return failure(); // Verify the condition operands. auto set = conditionAttr.getValue(); if (set.getNumDims() != numDims) return parser.emitError( parser.getNameLoc(), "dim operand count and integer set dim count must match"); if (numDims + set.getNumSymbols() != result.operands.size()) return parser.emitError( parser.getNameLoc(), "symbol operand count and integer set symbol count must match"); if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Create the regions for 'then' and 'else'. The latter must be created even // if it remains empty for the validity of the operation. result.regions.reserve(2); Region *thenRegion = result.addRegion(); Region *elseRegion = result.addRegion(); // Parse the 'then' region. if (parser.parseRegion(*thenRegion, {}, {})) return failure(); AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location); // If we find an 'else' keyword then parse the 'else' region. if (!parser.parseOptionalKeyword("else")) { if (parser.parseRegion(*elseRegion, {}, {})) return failure(); AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location); } // Parse the optional attribute list. if (parser.parseOptionalAttrDict(result.attributes)) return failure(); return success(); } static void print(OpAsmPrinter &p, AffineIfOp op) { auto conditionAttr = op.getAttrOfType(op.getConditionAttrName()); p << "affine.if " << conditionAttr; printDimAndSymbolList(op.operand_begin(), op.operand_end(), conditionAttr.getValue().getNumDims(), p); p.printOptionalArrowTypeList(op.getResultTypes()); p.printRegion(op.thenRegion(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); // Print the 'else' regions if it has any blocks. auto &elseRegion = op.elseRegion(); if (!elseRegion.empty()) { p << " else"; p.printRegion(elseRegion, /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); } // Print the attribute list. p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/op.getConditionAttrName()); } IntegerSet AffineIfOp::getIntegerSet() { return getAttrOfType(getConditionAttrName()).getValue(); } void AffineIfOp::setIntegerSet(IntegerSet newSet) { setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet)); } void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) { setIntegerSet(set); getOperation()->setOperands(operands); } void AffineIfOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, IntegerSet set, ValueRange args, bool withElseRegion) { assert(resultTypes.empty() || withElseRegion); result.addTypes(resultTypes); result.addOperands(args); result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set)); Region *thenRegion = result.addRegion(); thenRegion->push_back(new Block()); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*thenRegion, builder, result.location); Region *elseRegion = result.addRegion(); if (withElseRegion) { elseRegion->push_back(new Block()); if (resultTypes.empty()) AffineIfOp::ensureTerminator(*elseRegion, builder, result.location); } } void AffineIfOp::build(OpBuilder &builder, OperationState &result, IntegerSet set, ValueRange args, bool withElseRegion) { AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args, withElseRegion); } /// Canonicalize an affine if op's conditional (integer set + operands). LogicalResult AffineIfOp::fold(ArrayRef, SmallVectorImpl &) { auto set = getIntegerSet(); SmallVector operands(getOperands()); canonicalizeSetAndOperands(&set, &operands); // Any canonicalization change always leads to either a reduction in the // number of operands or a change in the number of symbolic operands // (promotion of dims to symbols). if (operands.size() < getIntegerSet().getNumInputs() || set.getNumSymbols() > getIntegerSet().getNumSymbols()) { setConditional(set, operands); return success(); } return failure(); } void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // AffineLoadOp //===----------------------------------------------------------------------===// void AffineLoadOp::build(OpBuilder &builder, OperationState &result, AffineMap map, ValueRange operands) { assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands"); result.addOperands(operands); if (map) result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); auto memrefType = operands[0].getType().cast(); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(memref); result.addOperands(mapOperands); auto memrefType = memref.getType().cast(); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); result.types.push_back(memrefType.getElementType()); } void AffineLoadOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, memref, map, indices); } static ParseResult parseAffineLoadOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType type; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineLoadOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands) || parser.addTypeToList(type.getElementType(), result.types)); } static void print(OpAsmPrinter &p, AffineLoadOp op) { p << "affine.load " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType(); } /// Verify common indexing invariants of affine.load, affine.store, /// affine.vector_load and affine.vector_store. static LogicalResult verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr, Operation::operand_range mapOperands, MemRefType memrefType, unsigned numIndexOperands) { if (mapAttr) { AffineMap map = mapAttr.getValue(); if (map.getNumResults() != memrefType.getRank()) return op->emitOpError("affine map num results must equal memref rank"); if (map.getNumInputs() != numIndexOperands) return op->emitOpError("expects as many subscripts as affine map inputs"); } else { if (memrefType.getRank() != numIndexOperands) return op->emitOpError( "expects the number of subscripts to be equal to memref rank"); } Region *scope = getAffineScope(op); for (auto idx : mapOperands) { if (!idx.getType().isIndex()) return op->emitOpError("index to load must have 'index' type"); if (!isValidAffineIndexOperand(idx, scope)) return op->emitOpError("index must be a dimension or symbol identifier"); } return success(); } LogicalResult verify(AffineLoadOp op) { auto memrefType = op.getMemRefType(); if (op.getType() != memrefType.getElementType()) return op.emitOpError("result type must match element type of memref"); if (failed(verifyMemoryOpIndexing( op.getOperation(), op.getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 1))) return failure(); return success(); } void AffineLoadOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } OpFoldResult AffineLoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // AffineStoreOp //===----------------------------------------------------------------------===// void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, AffineMap map, ValueRange mapOperands) { assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info"); result.addOperands(valueToStore); result.addOperands(memref); result.addOperands(mapOperands); result.addAttribute(getMapAttrName(), AffineMapAttr::get(map)); } // Use identity map. void AffineStoreOp::build(OpBuilder &builder, OperationState &result, Value valueToStore, Value memref, ValueRange indices) { auto memrefType = memref.getType().cast(); auto rank = memrefType.getRank(); // Create identity map for memrefs with at least one dimension or () -> () // for zero-dimensional memrefs. auto map = rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap(); build(builder, result, valueToStore, memref, map, indices); } static ParseResult parseAffineStoreOp(OpAsmParser &parser, OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType type; OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineStoreOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(storeValueInfo, type.getElementType(), result.operands) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)); } static void print(OpAsmPrinter &p, AffineStoreOp op) { p << "affine.store " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType(); } LogicalResult verify(AffineStoreOp op) { // First operand must have same type as memref element type. auto memrefType = op.getMemRefType(); if (op.getValueToStore().getType() != memrefType.getElementType()) return op.emitOpError( "first operand must have same type memref element type"); if (failed(verifyMemoryOpIndexing( op.getOperation(), op.getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 2))) return failure(); return success(); } void AffineStoreOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } LogicalResult AffineStoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineMinMaxOpBase //===----------------------------------------------------------------------===// template static LogicalResult verifyAffineMinMaxOp(T op) { // Verify that operand count matches affine map dimension and symbol count. if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols()) return op.emitOpError( "operand count and affine map dimension and symbol count must match"); return success(); } template static void printAffineMinMaxOp(OpAsmPrinter &p, T op) { p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName()); auto operands = op.getOperands(); unsigned numDims = op.map().getNumDims(); p << '(' << operands.take_front(numDims) << ')'; if (operands.size() != numDims) p << '[' << operands.drop_front(numDims) << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{T::getMapAttrName()}); } template static ParseResult parseAffineMinMaxOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); SmallVector dim_infos; SmallVector sym_infos; AffineMapAttr mapAttr; return failure( parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) || parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) || parser.parseOperandList(sym_infos, OpAsmParser::Delimiter::OptionalSquare) || parser.parseOptionalAttrDict(result.attributes) || parser.resolveOperands(dim_infos, indexType, result.operands) || parser.resolveOperands(sym_infos, indexType, result.operands) || parser.addTypeToList(indexType, result.types)); } /// Fold an affine min or max operation with the given operands. The operand /// list may contain nulls, which are interpreted as the operand not being a /// constant. template static OpFoldResult foldMinMaxOp(T op, ArrayRef operands) { static_assert(llvm::is_one_of::value, "expected affine min or max op"); // Fold the affine map. // TODO: Fold more cases: // min(some_affine, some_affine + constant, ...), etc. SmallVector results; auto foldedMap = op.map().partialConstantFold(operands, &results); // If some of the map results are not constant, try changing the map in-place. if (results.empty()) { // If the map is the same, report that folding did not happen. if (foldedMap == op.map()) return {}; op.setAttr("map", AffineMapAttr::get(foldedMap)); return op.getResult(); } // Otherwise, completely fold the op into a constant. auto resultIt = std::is_same::value ? std::min_element(results.begin(), results.end()) : std::max_element(results.begin(), results.end()); if (resultIt == results.end()) return {}; return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt); } //===----------------------------------------------------------------------===// // AffineMinOp //===----------------------------------------------------------------------===// // // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0) // OpFoldResult AffineMinOp::fold(ArrayRef operands) { return foldMinMaxOp(*this, operands); } void AffineMinOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert>(context); } //===----------------------------------------------------------------------===// // AffineMaxOp //===----------------------------------------------------------------------===// // // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0) // OpFoldResult AffineMaxOp::fold(ArrayRef operands) { return foldMinMaxOp(*this, operands); } void AffineMaxOp::getCanonicalizationPatterns( OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert>(context); } //===----------------------------------------------------------------------===// // AffinePrefetchOp //===----------------------------------------------------------------------===// // // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32> // static ParseResult parseAffinePrefetchOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType type; OpAsmParser::OperandType memrefInfo; IntegerAttr hintInfo; auto i32Type = parser.getBuilder().getIntegerType(32); StringRef readOrWrite, cacheType; AffineMapAttr mapAttr; SmallVector mapOperands; if (parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffinePrefetchOp::getMapAttrName(), result.attributes) || parser.parseComma() || parser.parseKeyword(&readOrWrite) || parser.parseComma() || parser.parseKeyword("locality") || parser.parseLess() || parser.parseAttribute(hintInfo, i32Type, AffinePrefetchOp::getLocalityHintAttrName(), result.attributes) || parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)) return failure(); if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute( AffinePrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); if (!cacheType.equals("data") && !cacheType.equals("instr")) return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute( AffinePrefetchOp::getIsDataCacheAttrName(), parser.getBuilder().getBoolAttr(cacheType.equals("data"))); return success(); } static void print(OpAsmPrinter &p, AffinePrefetchOp op) { p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '['; AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName()); if (mapAttr) { SmallVector operands(op.getMapOperands()); p.printAffineMapOfSSAIds(mapAttr, operands); } p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", " << "locality<" << op.localityHint() << ">, " << (op.isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(), op.getIsDataCacheAttrName(), op.getIsWriteAttrName()}); p << " : " << op.getMemRefType(); } static LogicalResult verify(AffinePrefetchOp op) { auto mapAttr = op.getAttrOfType(op.getMapAttrName()); if (mapAttr) { AffineMap map = mapAttr.getValue(); if (map.getNumResults() != op.getMemRefType().getRank()) return op.emitOpError("affine.prefetch affine map num results must equal" " memref rank"); if (map.getNumInputs() + 1 != op.getNumOperands()) return op.emitOpError("too few operands"); } else { if (op.getNumOperands() != 1) return op.emitOpError("too few operands"); } Region *scope = getAffineScope(op); for (auto idx : op.getMapOperands()) { if (!isValidAffineIndexOperand(idx, scope)) return op.emitOpError("index must be a dimension or symbol identifier"); } return success(); } void AffinePrefetchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { // prefetch(memrefcast) -> prefetch results.insert>(context); } LogicalResult AffinePrefetchOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // AffineParallelOp //===----------------------------------------------------------------------===// void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, ArrayRef ranges) { SmallVector lbExprs(ranges.size(), builder.getAffineConstantExpr(0)); auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext()); SmallVector ubExprs; for (int64_t range : ranges) ubExprs.push_back(builder.getAffineConstantExpr(range)); auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext()); build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap, /*ubArgs=*/{}); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of both maps are the same. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); // Make default step sizes of 1. SmallVector steps(numDims, 1); build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs, steps); } void AffineParallelOp::build(OpBuilder &builder, OperationState &result, TypeRange resultTypes, ArrayRef reductions, AffineMap lbMap, ValueRange lbArgs, AffineMap ubMap, ValueRange ubArgs, ArrayRef steps) { auto numDims = lbMap.getNumResults(); // Verify that the dimensionality of the maps matches the number of steps. assert(numDims == ubMap.getNumResults() && "num dims and num results mismatch"); assert(numDims == steps.size() && "num dims and num steps mismatch"); result.addTypes(resultTypes); // Convert the reductions to integer attributes. SmallVector reductionAttrs; for (AtomicRMWKind reduction : reductions) reductionAttrs.push_back( builder.getI64IntegerAttr(static_cast(reduction))); result.addAttribute(getReductionsAttrName(), builder.getArrayAttr(reductionAttrs)); result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap)); result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap)); result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps)); result.addOperands(lbArgs); result.addOperands(ubArgs); // Create a region and a block for the body. auto bodyRegion = result.addRegion(); auto body = new Block(); // Add all the block arguments. for (unsigned i = 0; i < numDims; ++i) body->addArgument(IndexType::get(builder.getContext())); bodyRegion->push_back(body); if (resultTypes.empty()) ensureTerminator(*bodyRegion, builder, result.location); } Region &AffineParallelOp::getLoopBody() { return region(); } bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) { return !region().isAncestor(value.getParentRegion()); } LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef ops) { for (Operation *op : ops) op->moveBefore(*this); return success(); } unsigned AffineParallelOp::getNumDims() { return steps().size(); } AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() { return getOperands().take_front(lowerBoundsMap().getNumInputs()); } AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() { return getOperands().drop_front(lowerBoundsMap().getNumInputs()); } AffineValueMap AffineParallelOp::getLowerBoundsValueMap() { return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands()); } AffineValueMap AffineParallelOp::getUpperBoundsValueMap() { return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands()); } AffineValueMap AffineParallelOp::getRangesValueMap() { AffineValueMap out; AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(), &out); return out; } Optional> AffineParallelOp::getConstantRanges() { // Try to convert all the ranges to constant expressions. SmallVector out; AffineValueMap rangesValueMap = getRangesValueMap(); out.reserve(rangesValueMap.getNumResults()); for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) { auto expr = rangesValueMap.getResult(i); auto cst = expr.dyn_cast(); if (!cst) return llvm::None; out.push_back(cst.getValue()); } return out; } Block *AffineParallelOp::getBody() { return ®ion().front(); } OpBuilder AffineParallelOp::getBodyBuilder() { return OpBuilder(getBody(), std::prev(getBody()->end())); } void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) { assert(lbOperands.size() == map.getNumInputs() && "operands to map must match number of inputs"); assert(map.getNumResults() >= 1 && "bounds map has at least one result"); auto ubOperands = getUpperBoundsOperands(); SmallVector newOperands(lbOperands); newOperands.append(ubOperands.begin(), ubOperands.end()); getOperation()->setOperands(newOperands); lowerBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) { assert(ubOperands.size() == map.getNumInputs() && "operands to map must match number of inputs"); assert(map.getNumResults() >= 1 && "bounds map has at least one result"); SmallVector newOperands(getLowerBoundsOperands()); newOperands.append(ubOperands.begin(), ubOperands.end()); getOperation()->setOperands(newOperands); upperBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setLowerBoundsMap(AffineMap map) { AffineMap lbMap = lowerBoundsMap(); assert(lbMap.getNumDims() == map.getNumDims() && lbMap.getNumSymbols() == map.getNumSymbols()); (void)lbMap; lowerBoundsMapAttr(AffineMapAttr::get(map)); } void AffineParallelOp::setUpperBoundsMap(AffineMap map) { AffineMap ubMap = upperBoundsMap(); assert(ubMap.getNumDims() == map.getNumDims() && ubMap.getNumSymbols() == map.getNumSymbols()); (void)ubMap; upperBoundsMapAttr(AffineMapAttr::get(map)); } SmallVector AffineParallelOp::getSteps() { SmallVector result; for (Attribute attr : steps()) { result.push_back(attr.cast().getInt()); } return result; } void AffineParallelOp::setSteps(ArrayRef newSteps) { stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps)); } static LogicalResult verify(AffineParallelOp op) { auto numDims = op.getNumDims(); if (op.lowerBoundsMap().getNumResults() != numDims || op.upperBoundsMap().getNumResults() != numDims || op.steps().size() != numDims || op.getBody()->getNumArguments() != numDims) return op.emitOpError("region argument count and num results of upper " "bounds, lower bounds, and steps must all match"); if (op.reductions().size() != op.getNumResults()) return op.emitOpError("a reduction must be specified for each output"); // Verify reduction ops are all valid for (Attribute attr : op.reductions()) { auto intAttr = attr.dyn_cast(); if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt())) return op.emitOpError("invalid reduction attribute"); } // Verify that the bound operands are valid dimension/symbols. /// Lower bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(), op.lowerBoundsMap().getNumDims()))) return failure(); /// Upper bounds. if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(), op.upperBoundsMap().getNumDims()))) return failure(); return success(); } LogicalResult AffineValueMap::canonicalize() { SmallVector newOperands{operands}; auto newMap = getAffineMap(); composeAffineMapAndOperands(&newMap, &newOperands); if (newMap == getAffineMap() && newOperands == operands) return failure(); reset(newMap, newOperands); return success(); } /// Canonicalize the bounds of the given loop. static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) { AffineValueMap lb = op.getLowerBoundsValueMap(); bool lbCanonicalized = succeeded(lb.canonicalize()); AffineValueMap ub = op.getUpperBoundsValueMap(); bool ubCanonicalized = succeeded(ub.canonicalize()); // Any canonicalization change always leads to updated map(s). if (!lbCanonicalized && !ubCanonicalized) return failure(); if (lbCanonicalized) op.setLowerBounds(lb.getOperands(), lb.getAffineMap()); if (ubCanonicalized) op.setUpperBounds(ub.getOperands(), ub.getAffineMap()); return success(); } LogicalResult AffineParallelOp::fold(ArrayRef operands, SmallVectorImpl &results) { return canonicalizeLoopBounds(*this); } static void print(OpAsmPrinter &p, AffineParallelOp op) { p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("; p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(), op.getLowerBoundsOperands()); p << ") to ("; p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(), op.getUpperBoundsOperands()); p << ')'; SmallVector steps = op.getSteps(); bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; }); if (!elideSteps) { p << " step ("; llvm::interleaveComma(steps, p); p << ')'; } if (op.getNumResults()) { p << " reduce ("; llvm::interleaveComma(op.reductions(), p, [&](auto &attr) { AtomicRMWKind sym = *symbolizeAtomicRMWKind(attr.template cast().getInt()); p << "\"" << stringifyAtomicRMWKind(sym) << "\""; }); p << ") -> (" << op.getResultTypes() << ")"; } p.printRegion(op.region(), /*printEntryBlockArgs=*/false, /*printBlockTerminators=*/op.getNumResults()); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(), AffineParallelOp::getLowerBoundsMapAttrName(), AffineParallelOp::getUpperBoundsMapAttrName(), AffineParallelOp::getStepsAttrName()}); } // // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)` // `to` `(` map-of-ssa-ids `)` steps? region attr-dict? // steps ::= `steps` `(` integer-literals `)` // static ParseResult parseAffineParallelOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexType = builder.getIndexType(); AffineMapAttr lowerBoundsAttr, upperBoundsAttr; SmallVector ivs; SmallVector lowerBoundsMapOperands; SmallVector upperBoundsMapOperands; if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1, OpAsmParser::Delimiter::Paren) || parser.parseEqual() || parser.parseAffineMapOfSSAIds( lowerBoundsMapOperands, lowerBoundsAttr, AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(lowerBoundsMapOperands, indexType, result.operands) || parser.parseKeyword("to") || parser.parseAffineMapOfSSAIds( upperBoundsMapOperands, upperBoundsAttr, AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes, OpAsmParser::Delimiter::Paren) || parser.resolveOperands(upperBoundsMapOperands, indexType, result.operands)) return failure(); AffineMapAttr stepsMapAttr; NamedAttrList stepsAttrs; SmallVector stepsMapOperands; if (failed(parser.parseOptionalKeyword("step"))) { SmallVector steps(ivs.size(), 1); result.addAttribute(AffineParallelOp::getStepsAttrName(), builder.getI64ArrayAttr(steps)); } else { if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr, AffineParallelOp::getStepsAttrName(), stepsAttrs, OpAsmParser::Delimiter::Paren)) return failure(); // Convert steps from an AffineMap into an I64ArrayAttr. SmallVector steps; auto stepsMap = stepsMapAttr.getValue(); for (const auto &result : stepsMap.getResults()) { auto constExpr = result.dyn_cast(); if (!constExpr) return parser.emitError(parser.getNameLoc(), "steps must be constant integers"); steps.push_back(constExpr.getValue()); } result.addAttribute(AffineParallelOp::getStepsAttrName(), builder.getI64ArrayAttr(steps)); } // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the // quoted strings are a member of the enum AtomicRMWKind. SmallVector reductions; if (succeeded(parser.parseOptionalKeyword("reduce"))) { if (parser.parseLParen()) return failure(); do { // Parse a single quoted string via the attribute parsing, and then // verify it is a member of the enum and convert to it's integer // representation. StringAttr attrVal; NamedAttrList attrStorage; auto loc = parser.getCurrentLocation(); if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce", attrStorage)) return failure(); llvm::Optional reduction = symbolizeAtomicRMWKind(attrVal.getValue()); if (!reduction) return parser.emitError(loc, "invalid reduction value: ") << attrVal; reductions.push_back(builder.getI64IntegerAttr( static_cast(reduction.getValue()))); // While we keep getting commas, keep parsing. } while (succeeded(parser.parseOptionalComma())); if (parser.parseRParen()) return failure(); } result.addAttribute(AffineParallelOp::getReductionsAttrName(), builder.getArrayAttr(reductions)); // Parse return types of reductions (if any) if (parser.parseOptionalArrowTypeList(result.types)) return failure(); // Now parse the body. Region *body = result.addRegion(); SmallVector types(ivs.size(), indexType); if (parser.parseRegion(*body, ivs, types) || parser.parseOptionalAttrDict(result.attributes)) return failure(); // Add a terminator if none was parsed. AffineParallelOp::ensureTerminator(*body, builder, result.location); return success(); } //===----------------------------------------------------------------------===// // AffineYieldOp //===----------------------------------------------------------------------===// static LogicalResult verify(AffineYieldOp op) { auto *parentOp = op.getParentOp(); auto results = parentOp->getResults(); auto operands = op.getOperands(); if (!isa(parentOp)) return op.emitOpError() << "only terminates affine.if/for/parallel regions"; if (parentOp->getNumResults() != op.getNumOperands()) return op.emitOpError() << "parent of yield must have same number of " "results as the yield operands"; for (auto it : llvm::zip(results, operands)) { if (std::get<0>(it).getType() != std::get<1>(it).getType()) return op.emitOpError() << "types mismatch between yield op and its parent"; } return success(); } //===----------------------------------------------------------------------===// // AffineVectorLoadOp //===----------------------------------------------------------------------===// static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); auto indexTy = builder.getIndexType(); MemRefType memrefType; VectorType resultType; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineVectorLoadOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.parseComma() || parser.parseType(resultType) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands) || parser.addTypeToList(resultType, result.types)); } static void print(OpAsmPrinter &p, AffineVectorLoadOp op) { p << "affine.vector_load " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType() << ", " << op.getType(); } /// Verify common invariants of affine.vector_load and affine.vector_store. static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { // Check that memref and vector element types match. if (memrefType.getElementType() != vectorType.getElementType()) return op->emitOpError( "requires memref and vector types of the same elemental type"); return success(); } static LogicalResult verify(AffineVectorLoadOp op) { MemRefType memrefType = op.getMemRefType(); if (failed(verifyMemoryOpIndexing( op.getOperation(), op.getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 1))) return failure(); if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, op.getVectorType()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // AffineVectorStoreOp //===----------------------------------------------------------------------===// static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser, OperationState &result) { auto indexTy = parser.getBuilder().getIndexType(); MemRefType memrefType; VectorType resultType; OpAsmParser::OperandType storeValueInfo; OpAsmParser::OperandType memrefInfo; AffineMapAttr mapAttr; SmallVector mapOperands; return failure( parser.parseOperand(storeValueInfo) || parser.parseComma() || parser.parseOperand(memrefInfo) || parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, AffineVectorStoreOp::getMapAttrName(), result.attributes) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(memrefType) || parser.parseComma() || parser.parseType(resultType) || parser.resolveOperand(storeValueInfo, resultType, result.operands) || parser.resolveOperand(memrefInfo, memrefType, result.operands) || parser.resolveOperands(mapOperands, indexTy, result.operands)); } static void print(OpAsmPrinter &p, AffineVectorStoreOp op) { p << "affine.vector_store " << op.getValueToStore(); p << ", " << op.getMemRef() << '['; if (AffineMapAttr mapAttr = op.getAttrOfType(op.getMapAttrName())) p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands()); p << ']'; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()}); p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType(); } static LogicalResult verify(AffineVectorStoreOp op) { MemRefType memrefType = op.getMemRefType(); if (failed(verifyMemoryOpIndexing( op.getOperation(), op.getAttrOfType(op.getMapAttrName()), op.getMapOperands(), memrefType, /*numIndexOperands=*/op.getNumOperands() - 2))) return failure(); if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType, op.getVectorType()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc" diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp index d333ddc8e34c..bc584bd628fb 100644 --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1,4516 +1,4433 @@ //===- Ops.cpp - Standard MLIR Operations ---------------------------------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/BlockAndValueMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Module.h" #include "mlir/IR/OpImplementation.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/InliningUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/FormatVariadic.h" #include "llvm/Support/raw_ostream.h" // Pull in all enum type definitions and utility function declarations. #include "mlir/Dialect/StandardOps/IR/OpsEnums.cpp.inc" using namespace mlir; //===----------------------------------------------------------------------===// // StandardOpsDialect Interfaces //===----------------------------------------------------------------------===// namespace { /// This class defines the interface for handling inlining with standard /// operations. struct StdInlinerInterface : public DialectInlinerInterface { using DialectInlinerInterface::DialectInlinerInterface; //===--------------------------------------------------------------------===// // Analysis Hooks //===--------------------------------------------------------------------===// /// All call operations within standard ops can be inlined. bool isLegalToInline(Operation *call, Operation *callable, bool wouldBeCloned) const final { return true; } /// All operations within standard ops can be inlined. bool isLegalToInline(Operation *, Region *, bool, BlockAndValueMapping &) const final { return true; } //===--------------------------------------------------------------------===// // Transformation Hooks //===--------------------------------------------------------------------===// /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, Block *newDest) const final { // Only "std.return" needs to be handled here. auto returnOp = dyn_cast(op); if (!returnOp) return; // Replace the return with a branch to the dest. OpBuilder builder(op); builder.create(op->getLoc(), newDest, returnOp.getOperands()); op->erase(); } /// Handle the given inlined terminator by replacing it with a new operation /// as necessary. void handleTerminator(Operation *op, ArrayRef valuesToRepl) const final { // Only "std.return" needs to be handled here. auto returnOp = cast(op); // Replace the values directly with the return operands. assert(returnOp.getNumOperands() == valuesToRepl.size()); for (const auto &it : llvm::enumerate(returnOp.getOperands())) valuesToRepl[it.index()].replaceAllUsesWith(it.value()); } }; } // end anonymous namespace //===----------------------------------------------------------------------===// // StandardOpsDialect //===----------------------------------------------------------------------===// /// A custom unary operation printer that omits the "std." prefix from the /// operation names. static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 1 && "unary op should have one operand"); assert(op->getNumResults() == 1 && "unary op should have one result"); int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0); p.printOptionalAttrDict(op->getAttrs()); p << " : " << op->getOperand(0).getType(); } /// A custom binary operation printer that omits the "std." prefix from the /// operation names. static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) { assert(op->getNumOperands() == 2 && "binary op should have two operands"); assert(op->getNumResults() == 1 && "binary op should have one result"); // If not all the operand and result types are the same, just use the // generic assembly form to avoid omitting information in printing. auto resultType = op->getResult(0).getType(); if (op->getOperand(0).getType() != resultType || op->getOperand(1).getType() != resultType) { p.printGenericOp(op); return; } int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0) << ", " << op->getOperand(1); p.printOptionalAttrDict(op->getAttrs()); // Now we can output only one type for all operands and the result. p << " : " << op->getResult(0).getType(); } /// A custom cast operation printer that omits the "std." prefix from the /// operation names. static void printStandardCastOp(Operation *op, OpAsmPrinter &p) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op->getName().getStringRef().drop_front(stdDotLen) << ' ' << op->getOperand(0) << " : " << op->getOperand(0).getType() << " to " << op->getResult(0).getType(); } /// A custom cast operation verifier. template static LogicalResult verifyCastOp(T op) { auto opType = op.getOperand().getType(); auto resType = op.getType(); if (!T::areCastCompatible(opType, resType)) return op.emitError("operand type ") << opType << " and result type " << resType << " are cast incompatible"; return success(); } void StandardOpsDialect::initialize() { addOperations(); addInterfaces(); } /// Materialize a single constant operation from a given attribute value with /// the desired resultant type. Operation *StandardOpsDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { return builder.create(loc, type, value); } -void mlir::printDimAndSymbolList(Operation::operand_iterator begin, - Operation::operand_iterator end, - unsigned numDims, OpAsmPrinter &p) { - Operation::operand_range operands(begin, end); - p << '(' << operands.take_front(numDims) << ')'; - if (operands.size() != numDims) - p << '[' << operands.drop_front(numDims) << ']'; -} - -// Parses dimension and symbol list, and sets 'numDims' to the number of -// dimension operands parsed. -// Returns 'false' on success and 'true' on error. -ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser, - SmallVectorImpl &operands, - unsigned &numDims) { - SmallVector opInfos; - if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) - return failure(); - // Store number of dimensions for validation by caller. - numDims = opInfos.size(); - - // Parse the optional symbol operands. - auto indexTy = parser.getBuilder().getIndexType(); - if (parser.parseOperandList(opInfos, - OpAsmParser::Delimiter::OptionalSquare) || - parser.resolveOperands(opInfos, indexTy, operands)) - return failure(); - return success(); -} - /// Matches a ConstantIndexOp. /// TODO: This should probably just be a general matcher that uses m_Constant /// and checks the operation for an index type. static detail::op_matcher m_ConstantIndex() { return detail::op_matcher(); } //===----------------------------------------------------------------------===// // Common canonicalization pattern support logic //===----------------------------------------------------------------------===// /// This is a common class used for patterns of the form /// "someop(memrefcast) -> someop". It folds the source of any memref_cast /// into the root operation directly. static LogicalResult foldMemRefCast(Operation *op) { bool folded = false; for (OpOperand &operand : op->getOpOperands()) { auto cast = operand.get().getDefiningOp(); if (cast && !cast.getOperand().getType().isa()) { operand.set(cast.getOperand()); folded = true; } } return success(folded); } //===----------------------------------------------------------------------===// // Common cast compatibility check for vector types. //===----------------------------------------------------------------------===// /// This method checks for cast compatibility of vector types. /// If 'a' and 'b' are vector types, and they are cast compatible, /// it calls the 'areElementsCastCompatible' function to check for /// element cast compatibility. /// Returns 'true' if the vector types are cast compatible, and 'false' /// otherwise. static bool areVectorCastSimpleCompatible( Type a, Type b, function_ref areElementsCastCompatible) { if (auto va = a.dyn_cast()) if (auto vb = b.dyn_cast()) return va.getShape().equals(vb.getShape()) && areElementsCastCompatible(va.getElementType(), vb.getElementType()); return false; } //===----------------------------------------------------------------------===// // Helpers for Tensor[Load|Store]Op, TensorToMemrefOp, and GlobalMemrefOp //===----------------------------------------------------------------------===// static Type getTensorTypeFromMemRefType(Type type) { if (auto memref = type.dyn_cast()) return RankedTensorType::get(memref.getShape(), memref.getElementType()); if (auto memref = type.dyn_cast()) return UnrankedTensorType::get(memref.getElementType()); return NoneType::get(type.getContext()); } //===----------------------------------------------------------------------===// // AddFOp //===----------------------------------------------------------------------===// OpFoldResult AddFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a + b; }); } //===----------------------------------------------------------------------===// // AddIOp //===----------------------------------------------------------------------===// OpFoldResult AddIOp::fold(ArrayRef operands) { /// addi(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a + b; }); } //===----------------------------------------------------------------------===// // BaseOpWithOffsetSizesAndStridesOp //===----------------------------------------------------------------------===// /// Print a list with either (1) the static integer value in `arrayAttr` if /// `isDynamic` evaluates to false or (2) the next value otherwise. /// This allows idiomatic printing of mixed value and integer attributes in a /// list. E.g. `[%arg0, 7, 42, %arg42]`. static void printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, llvm::function_ref isDynamic) { p << '['; unsigned idx = 0; llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { int64_t val = a.cast().getInt(); if (isDynamic(val)) p << values[idx++]; else p << val; }); p << ']'; } /// Parse a mixed list with either (1) static integer values or (2) SSA values. /// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` /// encode the position of SSA values. Add the parsed SSA values to `ssa` /// in-order. // /// E.g. after parsing "[%arg0, 7, 42, %arg42]": /// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" /// 2. `ssa` is filled with "[%arg0, %arg1]". static ParseResult parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, StringRef attrName, int64_t dynVal, SmallVectorImpl &ssa) { if (failed(parser.parseLSquare())) return failure(); // 0-D. if (succeeded(parser.parseOptionalRSquare())) { result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); return success(); } SmallVector attrVals; while (true) { OpAsmParser::OperandType operand; auto res = parser.parseOptionalOperand(operand); if (res.hasValue() && succeeded(res.getValue())) { ssa.push_back(operand); attrVals.push_back(dynVal); } else { IntegerAttr attr; if (failed(parser.parseAttribute(attr))) return parser.emitError(parser.getNameLoc()) << "expected SSA value or integer"; attrVals.push_back(attr.getInt()); } if (succeeded(parser.parseOptionalComma())) continue; if (failed(parser.parseRSquare())) return failure(); break; } auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); result.addAttribute(attrName, arrayAttr); return success(); } /// Verify that a particular offset/size/stride static attribute is well-formed. template static LogicalResult verifyOpWithOffsetSizesAndStridesPart( OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, llvm::function_ref isDynamic, ValueRange values) { /// Check static and dynamic offsets/sizes/strides breakdown. if (attr.size() != expectedNumElements) return op.emitError("expected ") << expectedNumElements << " " << name << " values"; unsigned expectedNumDynamicEntries = llvm::count_if(attr.getValue(), [&](Attribute attr) { return isDynamic(attr.cast().getInt()); }); if (values.size() != expectedNumDynamicEntries) return op.emitError("expected ") << expectedNumDynamicEntries << " dynamic " << name << " values"; return success(); } /// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. static SmallVector extractFromI64ArrayAttr(Attribute attr) { return llvm::to_vector<4>( llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { return a.cast().getInt(); })); } /// Verify static attributes offsets/sizes/strides. template static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { unsigned srcRank = op.getSourceRank(); if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "offset", srcRank, op.getStaticOffsetsAttrName(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset, op.offsets()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "stride", srcRank, op.getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); return success(); } //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// template -static void printAllocLikeOp(OpAsmPrinter &p, AllocLikeOp op, StringRef name) { - static_assert(llvm::is_one_of::value, - "applies to only alloc or alloca"); - p << name; - - // Print dynamic dimension operands. - MemRefType type = op.getType(); - printDimAndSymbolList(op.operand_begin(), op.operand_end(), - type.getNumDynamicDims(), p); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"}); - p << " : " << type; -} - -static void print(OpAsmPrinter &p, AllocOp op) { - printAllocLikeOp(p, op, "alloc"); -} - -static void print(OpAsmPrinter &p, AllocaOp op) { - printAllocLikeOp(p, op, "alloca"); -} - -static ParseResult parseAllocLikeOp(OpAsmParser &parser, - OperationState &result) { - MemRefType type; - - // Parse the dimension operands and optional symbol operands, followed by a - // memref type. - unsigned numDimOperands; - if (parseDimAndSymbolList(parser, result.operands, numDimOperands) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(type)) - return failure(); - - // Check numDynamicDims against number of question marks in memref type. - // Note: this check remains here (instead of in verify()), because the - // partition between dim operands and symbol operands is lost after parsing. - // Verification still checks that the total number of operands matches - // the number of symbols in the affine map, plus the number of dynamic - // dimensions in the memref. - if (numDimOperands != type.getNumDynamicDims()) - return parser.emitError(parser.getNameLoc()) - << "dimension operand count does not equal memref dynamic dimension " - "count"; - result.types.push_back(type); - return success(); -} - -template -static LogicalResult verify(AllocLikeOp op) { +static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { static_assert(llvm::is_one_of::value, "applies to only alloc or alloca"); auto memRefType = op.getResult().getType().template dyn_cast(); if (!memRefType) return op.emitOpError("result must be a memref"); - unsigned numSymbols = 0; - if (!memRefType.getAffineMaps().empty()) { - // Store number of symbols used in affine map (used in subsequent check). - AffineMap affineMap = memRefType.getAffineMaps()[0]; - numSymbols = affineMap.getNumSymbols(); - } + if (static_cast(op.dynamicSizes().size()) != + memRefType.getNumDynamicDims()) + return op.emitOpError("dimension operand count does not equal memref " + "dynamic dimension count"); - // Check that the total number of operands matches the number of symbols in - // the affine map, plus the number of dynamic dimensions specified in the - // memref type. - unsigned numDynamicDims = memRefType.getNumDynamicDims(); - if (op.getNumOperands() != numDynamicDims + numSymbols) + unsigned numSymbols = 0; + if (!memRefType.getAffineMaps().empty()) + numSymbols = memRefType.getAffineMaps().front().getNumSymbols(); + if (op.symbolOperands().size() != numSymbols) return op.emitOpError( - "operand count does not equal dimension plus symbol operand count"); + "symbol operand count does not equal memref symbol count"); - // Verify that all operands are of type Index. - for (auto operandType : op.getOperandTypes()) - if (!operandType.isIndex()) - return op.emitOpError("requires operands to be of type Index"); + return success(); +} - if (std::is_same::value) - return success(); +static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } +static LogicalResult verify(AllocaOp op) { // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op.template getParentWithTrait()) + if (!op.getParentWithTrait()) return op.emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); - return success(); + return verifyAllocLikeOp(op); } namespace { /// Fold constant dimensions into an alloc like operation. template struct SimplifyAllocConst : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocLikeOp alloc, PatternRewriter &rewriter) const override { // Check to see if any dimensions operands are constants. If so, we can // substitute and drop them. if (llvm::none_of(alloc.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); auto memrefType = alloc.getType(); // Ok, we have one or more constant operands. Collect the non-constant ones // and keep track of the resultant memref type to build. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); SmallVector newOperands; unsigned dynamicDimPos = 0; for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (dimSize != -1) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = alloc.getOperand(dynamicDimPos).getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(-1); newOperands.push_back(alloc.getOperand(dynamicDimPos)); } dynamicDimPos++; } // Create new memref type (which will have fewer dynamic dimensions). MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); assert(static_cast(newOperands.size()) == newMemRefType.getNumDynamicDims()); // Create and insert the alloc op for the new memref. auto newAlloc = rewriter.create(alloc.getLoc(), newMemRefType, newOperands, IntegerAttr()); // Insert a cast so we have the same type as the old alloc. auto resultCast = rewriter.create(alloc.getLoc(), newAlloc, alloc.getType()); rewriter.replaceOp(alloc, {resultCast}); return success(); } }; /// Fold alloc operations with no uses. Alloc has side effects on the heap, /// but can still be deleted if it has zero uses. struct SimplifyDeadAlloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AllocOp alloc, PatternRewriter &rewriter) const override { if (alloc.use_empty()) { rewriter.eraseOp(alloc); return success(); } return failure(); } }; } // end anonymous namespace. void AllocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, SimplifyDeadAlloc>(context); } void AllocaOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert>(context); } //===----------------------------------------------------------------------===// // AndOp //===----------------------------------------------------------------------===// OpFoldResult AndOp::fold(ArrayRef operands) { /// and(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// and(x, allOnes) -> x APInt intValue; if (matchPattern(rhs(), m_ConstantInt(&intValue)) && intValue.isAllOnesValue()) return lhs(); /// and(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a & b; }); } //===----------------------------------------------------------------------===// // AssertOp //===----------------------------------------------------------------------===// namespace { struct EraseRedundantAssertions : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(AssertOp op, PatternRewriter &rewriter) const override { // Erase assertion if argument is constant true. if (matchPattern(op.arg(), m_One())) { rewriter.eraseOp(op); return success(); } return failure(); } }; } // namespace void AssertOp::getCanonicalizationPatterns(OwningRewritePatternList &patterns, MLIRContext *context) { patterns.insert(context); } //===----------------------------------------------------------------------===// // AssumeAlignmentOp //===----------------------------------------------------------------------===// static LogicalResult verify(AssumeAlignmentOp op) { unsigned alignment = op.alignment(); if (!llvm::isPowerOf2_32(alignment)) return op.emitOpError("alignment must be power of 2"); return success(); } //===----------------------------------------------------------------------===// // AtomicRMWOp //===----------------------------------------------------------------------===// static LogicalResult verify(AtomicRMWOp op) { if (op.getMemRefType().getRank() != op.getNumOperands() - 2) return op.emitOpError( "expects the number of subscripts to be equal to memref rank"); switch (op.kind()) { case AtomicRMWKind::addf: case AtomicRMWKind::maxf: case AtomicRMWKind::minf: case AtomicRMWKind::mulf: if (!op.value().getType().isa()) return op.emitOpError() << "with kind '" << stringifyAtomicRMWKind(op.kind()) << "' expects a floating-point type"; break; case AtomicRMWKind::addi: case AtomicRMWKind::maxs: case AtomicRMWKind::maxu: case AtomicRMWKind::mins: case AtomicRMWKind::minu: case AtomicRMWKind::muli: if (!op.value().getType().isa()) return op.emitOpError() << "with kind '" << stringifyAtomicRMWKind(op.kind()) << "' expects an integer type"; break; default: break; } return success(); } //===----------------------------------------------------------------------===// // GenericAtomicRMWOp //===----------------------------------------------------------------------===// void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, Value memref, ValueRange ivs) { result.addOperands(memref); result.addOperands(ivs); if (auto memrefType = memref.getType().dyn_cast()) { Type elementType = memrefType.getElementType(); result.addTypes(elementType); Region *bodyRegion = result.addRegion(); bodyRegion->push_back(new Block()); bodyRegion->addArgument(elementType); } } static LogicalResult verify(GenericAtomicRMWOp op) { auto &body = op.body(); if (body.getNumArguments() != 1) return op.emitOpError("expected single number of entry block arguments"); if (op.getResult().getType() != body.getArgument(0).getType()) return op.emitOpError( "expected block argument of the same type result type"); bool hasSideEffects = body.walk([&](Operation *nestedOp) { if (MemoryEffectOpInterface::hasNoEffect(nestedOp)) return WalkResult::advance(); nestedOp->emitError("body of 'generic_atomic_rmw' should contain " "only operations with no side effects"); return WalkResult::interrupt(); }) .wasInterrupted(); return hasSideEffects ? failure() : success(); } static ParseResult parseGenericAtomicRMWOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memref; Type memrefType; SmallVector ivs; Type indexType = parser.getBuilder().getIndexType(); if (parser.parseOperand(memref) || parser.parseOperandList(ivs, OpAsmParser::Delimiter::Square) || parser.parseColonType(memrefType) || parser.resolveOperand(memref, memrefType, result.operands) || parser.resolveOperands(ivs, indexType, result.operands)) return failure(); Region *body = result.addRegion(); if (parser.parseRegion(*body, llvm::None, llvm::None) || parser.parseOptionalAttrDict(result.attributes)) return failure(); result.types.push_back(memrefType.cast().getElementType()); return success(); } static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { p << op.getOperationName() << ' ' << op.memref() << "[" << op.indices() << "] : " << op.memref().getType(); p.printRegion(op.body()); p.printOptionalAttrDict(op.getAttrs()); } //===----------------------------------------------------------------------===// // AtomicYieldOp //===----------------------------------------------------------------------===// static LogicalResult verify(AtomicYieldOp op) { Type parentType = op.getParentOp()->getResultTypes().front(); Type resultType = op.result().getType(); if (parentType != resultType) return op.emitOpError() << "types mismatch between yield op: " << resultType << " and its parent: " << parentType; return success(); } //===----------------------------------------------------------------------===// // BranchOp //===----------------------------------------------------------------------===// /// Given a successor, try to collapse it to a new destination if it only /// contains a passthrough unconditional branch. If the successor is /// collapsable, `successor` and `successorOperands` are updated to reference /// the new destination and values. `argStorage` is an optional storage to use /// if operands to the collapsed successor need to be remapped. static LogicalResult collapseBranch(Block *&successor, ValueRange &successorOperands, SmallVectorImpl &argStorage) { // Check that the successor only contains a unconditional branch. if (std::next(successor->begin()) != successor->end()) return failure(); // Check that the terminator is an unconditional branch. BranchOp successorBranch = dyn_cast(successor->getTerminator()); if (!successorBranch) return failure(); // Check that the arguments are only used within the terminator. for (BlockArgument arg : successor->getArguments()) { for (Operation *user : arg.getUsers()) if (user != successorBranch) return failure(); } // Don't try to collapse branches to infinite loops. Block *successorDest = successorBranch.getDest(); if (successorDest == successor) return failure(); // Update the operands to the successor. If the branch parent has no // arguments, we can use the branch operands directly. OperandRange operands = successorBranch.getOperands(); if (successor->args_empty()) { successor = successorDest; successorOperands = operands; return success(); } // Otherwise, we need to remap any argument operands. for (Value operand : operands) { BlockArgument argOperand = operand.dyn_cast(); if (argOperand && argOperand.getOwner() == successor) argStorage.push_back(successorOperands[argOperand.getArgNumber()]); else argStorage.push_back(operand); } successor = successorDest; successorOperands = argStorage; return success(); } namespace { /// Simplify a branch to a block that has a single predecessor. This effectively /// merges the two blocks. struct SimplifyBrToBlockWithSinglePred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BranchOp op, PatternRewriter &rewriter) const override { // Check that the successor block has a single predecessor. Block *succ = op.getDest(); Block *opParent = op.getOperation()->getBlock(); if (succ == opParent || !llvm::hasSingleElement(succ->getPredecessors())) return failure(); // Merge the successor into the current block and erase the branch. rewriter.mergeBlocks(succ, opParent, op.getOperands()); rewriter.eraseOp(op); return success(); } }; /// br ^bb1 /// ^bb1 /// br ^bbN(...) /// /// -> br ^bbN(...) /// struct SimplifyPassThroughBr : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(BranchOp op, PatternRewriter &rewriter) const override { Block *dest = op.getDest(); ValueRange destOperands = op.getOperands(); SmallVector destOperandStorage; // Try to collapse the successor if it points somewhere other than this // block. if (dest == op.getOperation()->getBlock() || failed(collapseBranch(dest, destOperands, destOperandStorage))) return failure(); // Create a new branch with the collapsed successor. rewriter.replaceOpWithNewOp(op, dest, destOperands); return success(); } }; } // end anonymous namespace. Block *BranchOp::getDest() { return getSuccessor(); } void BranchOp::setDest(Block *block) { return setSuccessor(block); } void BranchOp::eraseOperand(unsigned index) { getOperation()->eraseOperand(index); } void BranchOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert( context); } Optional BranchOp::getMutableSuccessorOperands(unsigned index) { assert(index == 0 && "invalid successor index"); return destOperandsMutable(); } Block *BranchOp::getSuccessorForOperands(ArrayRef) { return dest(); } //===----------------------------------------------------------------------===// // CallOp //===----------------------------------------------------------------------===// LogicalResult CallOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Check that the callee attribute was specified. auto fnAttr = getAttrOfType("callee"); if (!fnAttr) return emitOpError("requires a 'callee' symbol reference attribute"); FuncOp fn = symbolTable.lookupNearestSymbolFrom(*this, fnAttr); if (!fn) return emitOpError() << "'" << fnAttr.getValue() << "' does not reference a valid function"; // Verify that the operand and result types match the callee. auto fnType = fn.getType(); if (fnType.getNumInputs() != getNumOperands()) return emitOpError("incorrect number of operands for callee"); for (unsigned i = 0, e = fnType.getNumInputs(); i != e; ++i) if (getOperand(i).getType() != fnType.getInput(i)) return emitOpError("operand type mismatch: expected operand type ") << fnType.getInput(i) << ", but provided " << getOperand(i).getType() << " for operand number " << i; if (fnType.getNumResults() != getNumResults()) return emitOpError("incorrect number of results for callee"); for (unsigned i = 0, e = fnType.getNumResults(); i != e; ++i) if (getResult(i).getType() != fnType.getResult(i)) return emitOpError("result type mismatch"); return success(); } FunctionType CallOp::getCalleeType() { return FunctionType::get(getOperandTypes(), getResultTypes(), getContext()); } //===----------------------------------------------------------------------===// // CallIndirectOp //===----------------------------------------------------------------------===// namespace { /// Fold indirect calls that have a constant function as the callee operand. struct SimplifyIndirectCallWithKnownCallee : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CallIndirectOp indirectCall, PatternRewriter &rewriter) const override { // Check that the callee is a constant callee. SymbolRefAttr calledFn; if (!matchPattern(indirectCall.getCallee(), m_Constant(&calledFn))) return failure(); // Replace with a direct call. rewriter.replaceOpWithNewOp(indirectCall, calledFn, indirectCall.getResultTypes(), indirectCall.getArgOperands()); return success(); } }; } // end anonymous namespace. void CallIndirectOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // General helpers for comparison ops //===----------------------------------------------------------------------===// // Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(1, type.getContext()); if (auto tensorType = type.dyn_cast()) return RankedTensorType::get(tensorType.getShape(), i1Type); if (type.isa()) return UnrankedTensorType::get(i1Type); if (auto vectorType = type.dyn_cast()) return VectorType::get(vectorType.getShape(), i1Type); return i1Type; } //===----------------------------------------------------------------------===// // CmpIOp //===----------------------------------------------------------------------===// static void buildCmpIOp(OpBuilder &build, OperationState &result, CmpIPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(lhs.getType())); result.addAttribute(CmpIOp::getPredicateAttrName(), build.getI64IntegerAttr(static_cast(predicate))); } // Compute `lhs` `pred` `rhs`, where `pred` is one of the known integer // comparison predicates. bool mlir::applyCmpPredicate(CmpIPredicate predicate, const APInt &lhs, const APInt &rhs) { switch (predicate) { case CmpIPredicate::eq: return lhs.eq(rhs); case CmpIPredicate::ne: return lhs.ne(rhs); case CmpIPredicate::slt: return lhs.slt(rhs); case CmpIPredicate::sle: return lhs.sle(rhs); case CmpIPredicate::sgt: return lhs.sgt(rhs); case CmpIPredicate::sge: return lhs.sge(rhs); case CmpIPredicate::ult: return lhs.ult(rhs); case CmpIPredicate::ule: return lhs.ule(rhs); case CmpIPredicate::ugt: return lhs.ugt(rhs); case CmpIPredicate::uge: return lhs.uge(rhs); } llvm_unreachable("unknown comparison predicate"); } // Constant folding hook for comparisons. OpFoldResult CmpIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpi takes two arguments"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CmpFOp //===----------------------------------------------------------------------===// static void buildCmpFOp(OpBuilder &build, OperationState &result, CmpFPredicate predicate, Value lhs, Value rhs) { result.addOperands({lhs, rhs}); result.types.push_back(getI1SameShape(lhs.getType())); result.addAttribute(CmpFOp::getPredicateAttrName(), build.getI64IntegerAttr(static_cast(predicate))); } /// Compute `lhs` `pred` `rhs`, where `pred` is one of the known floating point /// comparison predicates. bool mlir::applyCmpPredicate(CmpFPredicate predicate, const APFloat &lhs, const APFloat &rhs) { auto cmpResult = lhs.compare(rhs); switch (predicate) { case CmpFPredicate::AlwaysFalse: return false; case CmpFPredicate::OEQ: return cmpResult == APFloat::cmpEqual; case CmpFPredicate::OGT: return cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::OGE: return cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::OLT: return cmpResult == APFloat::cmpLessThan; case CmpFPredicate::OLE: return cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ONE: return cmpResult != APFloat::cmpUnordered && cmpResult != APFloat::cmpEqual; case CmpFPredicate::ORD: return cmpResult != APFloat::cmpUnordered; case CmpFPredicate::UEQ: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UGT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan; case CmpFPredicate::UGE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpGreaterThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::ULT: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan; case CmpFPredicate::ULE: return cmpResult == APFloat::cmpUnordered || cmpResult == APFloat::cmpLessThan || cmpResult == APFloat::cmpEqual; case CmpFPredicate::UNE: return cmpResult != APFloat::cmpEqual; case CmpFPredicate::UNO: return cmpResult == APFloat::cmpUnordered; case CmpFPredicate::AlwaysTrue: return true; } llvm_unreachable("unknown comparison predicate"); } // Constant folding hook for comparisons. OpFoldResult CmpFOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "cmpf takes two arguments"); auto lhs = operands.front().dyn_cast_or_null(); auto rhs = operands.back().dyn_cast_or_null(); // TODO: We could actually do some intelligent things if we know only one // of the operands, but it's inf or nan. if (!lhs || !rhs) return {}; auto val = applyCmpPredicate(getPredicate(), lhs.getValue(), rhs.getValue()); return IntegerAttr::get(IntegerType::get(1, getContext()), APInt(1, val)); } //===----------------------------------------------------------------------===// // CondBranchOp //===----------------------------------------------------------------------===// namespace { /// cond_br true, ^bb1, ^bb2 /// -> br ^bb1 /// cond_br false, ^bb1, ^bb2 /// -> br ^bb2 /// struct SimplifyConstCondBranchPred : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { if (matchPattern(condbr.getCondition(), m_NonZero())) { // True branch taken. rewriter.replaceOpWithNewOp(condbr, condbr.getTrueDest(), condbr.getTrueOperands()); return success(); } else if (matchPattern(condbr.getCondition(), m_Zero())) { // False branch taken. rewriter.replaceOpWithNewOp(condbr, condbr.getFalseDest(), condbr.getFalseOperands()); return success(); } return failure(); } }; /// cond_br %cond, ^bb1, ^bb2 /// ^bb1 /// br ^bbN(...) /// ^bb2 /// br ^bbK(...) /// /// -> cond_br %cond, ^bbN(...), ^bbK(...) /// struct SimplifyPassThroughCondBranch : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { Block *trueDest = condbr.trueDest(), *falseDest = condbr.falseDest(); ValueRange trueDestOperands = condbr.getTrueOperands(); ValueRange falseDestOperands = condbr.getFalseOperands(); SmallVector trueDestOperandStorage, falseDestOperandStorage; // Try to collapse one of the current successors. LogicalResult collapsedTrue = collapseBranch(trueDest, trueDestOperands, trueDestOperandStorage); LogicalResult collapsedFalse = collapseBranch(falseDest, falseDestOperands, falseDestOperandStorage); if (failed(collapsedTrue) && failed(collapsedFalse)) return failure(); // Create a new branch with the collapsed successors. rewriter.replaceOpWithNewOp(condbr, condbr.getCondition(), trueDest, trueDestOperands, falseDest, falseDestOperands); return success(); } }; /// cond_br %cond, ^bb1(A, ..., N), ^bb1(A, ..., N) /// -> br ^bb1(A, ..., N) /// /// cond_br %cond, ^bb1(A), ^bb1(B) /// -> %select = select %cond, A, B /// br ^bb1(%select) /// struct SimplifyCondBranchIdenticalSuccessors : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { // Check that the true and false destinations are the same and have the same // operands. Block *trueDest = condbr.trueDest(); if (trueDest != condbr.falseDest()) return failure(); // If all of the operands match, no selects need to be generated. OperandRange trueOperands = condbr.getTrueOperands(); OperandRange falseOperands = condbr.getFalseOperands(); if (trueOperands == falseOperands) { rewriter.replaceOpWithNewOp(condbr, trueDest, trueOperands); return success(); } // Otherwise, if the current block is the only predecessor insert selects // for any mismatched branch operands. if (trueDest->getUniquePredecessor() != condbr.getOperation()->getBlock()) return failure(); // Generate a select for any operands that differ between the two. SmallVector mergedOperands; mergedOperands.reserve(trueOperands.size()); Value condition = condbr.getCondition(); for (auto it : llvm::zip(trueOperands, falseOperands)) { if (std::get<0>(it) == std::get<1>(it)) mergedOperands.push_back(std::get<0>(it)); else mergedOperands.push_back(rewriter.create( condbr.getLoc(), condition, std::get<0>(it), std::get<1>(it))); } rewriter.replaceOpWithNewOp(condbr, trueDest, mergedOperands); return success(); } }; /// ... /// cond_br %cond, ^bb1(...), ^bb2(...) /// ... /// ^bb1: // has single predecessor /// ... /// cond_br %cond, ^bb3(...), ^bb4(...) /// /// -> /// /// ... /// cond_br %cond, ^bb1(...), ^bb2(...) /// ... /// ^bb1: // has single predecessor /// ... /// br ^bb3(...) /// struct SimplifyCondBranchFromCondBranchOnSameCondition : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(CondBranchOp condbr, PatternRewriter &rewriter) const override { // Check that we have a single distinct predecessor. Block *currentBlock = condbr.getOperation()->getBlock(); Block *predecessor = currentBlock->getSinglePredecessor(); if (!predecessor) return failure(); // Check that the predecessor terminates with a conditional branch to this // block and that it branches on the same condition. auto predBranch = dyn_cast(predecessor->getTerminator()); if (!predBranch || condbr.getCondition() != predBranch.getCondition()) return failure(); // Fold this branch to an unconditional branch. if (currentBlock == predBranch.trueDest()) rewriter.replaceOpWithNewOp(condbr, condbr.trueDest(), condbr.trueDestOperands()); else rewriter.replaceOpWithNewOp(condbr, condbr.falseDest(), condbr.falseDestOperands()); return success(); } }; } // end anonymous namespace void CondBranchOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } Optional CondBranchOp::getMutableSuccessorOperands(unsigned index) { assert(index < getNumSuccessors() && "invalid successor index"); return index == trueIndex ? trueDestOperandsMutable() : falseDestOperandsMutable(); } Block *CondBranchOp::getSuccessorForOperands(ArrayRef operands) { if (IntegerAttr condAttr = operands.front().dyn_cast_or_null()) return condAttr.getValue().isOneValue() ? trueDest() : falseDest(); return nullptr; } //===----------------------------------------------------------------------===// // Constant*Op //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, ConstantOp &op) { p << "constant "; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"}); if (op.getAttrs().size() > 1) p << ' '; p << op.getValue(); // If the value is a symbol reference, print a trailing type. if (op.getValue().isa()) p << " : " << op.getType(); } static ParseResult parseConstantOp(OpAsmParser &parser, OperationState &result) { Attribute valueAttr; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseAttribute(valueAttr, "value", result.attributes)) return failure(); // If the attribute is a symbol reference, then we expect a trailing type. Type type; if (!valueAttr.isa()) type = valueAttr.getType(); else if (parser.parseColonType(type)) return failure(); // Add the attribute type to the list. return parser.addTypeToList(type, result.types); } /// The constant op requires an attribute, and furthermore requires that it /// matches the return type. static LogicalResult verify(ConstantOp &op) { auto value = op.getValue(); if (!value) return op.emitOpError("requires a 'value' attribute"); auto type = op.getType(); if (!value.getType().isa() && type != value.getType()) return op.emitOpError() << "requires attribute's type (" << value.getType() << ") to match op's return type (" << type << ")"; if (type.isa() || value.isa()) return success(); if (auto intAttr = value.dyn_cast()) { // If the type has a known bitwidth we verify that the value can be // represented with the given bitwidth. auto bitwidth = type.cast().getWidth(); auto intVal = intAttr.getValue(); if (!intVal.isSignedIntN(bitwidth) && !intVal.isIntN(bitwidth)) return op.emitOpError("requires 'value' to be an integer within the " "range of the integer result type"); return success(); } if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a floating point constant"); return success(); } if (type.isa()) { if (!value.isa()) return op.emitOpError("requires 'value' to be a shaped constant"); return success(); } if (type.isa()) { auto fnAttr = value.dyn_cast(); if (!fnAttr) return op.emitOpError("requires 'value' to be a function reference"); // Try to find the referenced function. auto fn = op.getParentOfType().lookupSymbol(fnAttr.getValue()); if (!fn) return op.emitOpError() << "reference to undefined function '" << fnAttr.getValue() << "'"; // Check that the referenced function has the correct type. if (fn.getType() != type) return op.emitOpError("reference to function with mismatched type"); return success(); } if (type.isa() && value.isa()) return success(); return op.emitOpError("unsupported 'value' attribute: ") << value; } OpFoldResult ConstantOp::fold(ArrayRef operands) { assert(operands.empty() && "constant has no operands"); return getValue(); } void ConstantOp::getAsmResultNames( function_ref setNameFn) { Type type = getType(); if (auto intCst = getValue().dyn_cast()) { IntegerType intTy = type.dyn_cast(); // Sugar i1 constants with 'true' and 'false'. if (intTy && intTy.getWidth() == 1) return setNameFn(getResult(), (intCst.getInt() ? "true" : "false")); // Otherwise, build a complex name with the value and type. SmallString<32> specialNameBuffer; llvm::raw_svector_ostream specialName(specialNameBuffer); specialName << 'c' << intCst.getInt(); if (intTy) specialName << '_' << type; setNameFn(getResult(), specialName.str()); } else if (type.isa()) { setNameFn(getResult(), "f"); } else { setNameFn(getResult(), "cst"); } } /// Returns true if a constant operation can be built with the given value and /// result type. bool ConstantOp::isBuildableWith(Attribute value, Type type) { // SymbolRefAttr can only be used with a function type. if (value.isa()) return type.isa(); // Otherwise, the attribute must have the same type as 'type'. if (value.getType() != type) return false; // Finally, check that the attribute kind is handled. return value.isa(); } void ConstantFloatOp::build(OpBuilder &builder, OperationState &result, const APFloat &value, FloatType type) { ConstantOp::build(builder, result, type, builder.getFloatAttr(type, value)); } bool ConstantFloatOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isa(); } /// ConstantIntOp only matches values whose result type is an IntegerType. bool ConstantIntOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isSignlessInteger(); } void ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width) { Type type = builder.getIntegerType(width); ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } /// Build a constant int op producing an integer with the specified type, /// which must be an integer type. void ConstantIntOp::build(OpBuilder &builder, OperationState &result, int64_t value, Type type) { assert(type.isSignlessInteger() && "ConstantIntOp can only have signless integer type"); ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } /// ConstantIndexOp only matches values whose result type is Index. bool ConstantIndexOp::classof(Operation *op) { return ConstantOp::classof(op) && op->getResult(0).getType().isIndex(); } void ConstantIndexOp::build(OpBuilder &builder, OperationState &result, int64_t value) { Type type = builder.getIndexType(); ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } //===----------------------------------------------------------------------===// // DeallocOp //===----------------------------------------------------------------------===// namespace { /// Fold Dealloc operations that are deallocating an AllocOp that is only used /// by other Dealloc operations. struct SimplifyDeadDealloc : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DeallocOp dealloc, PatternRewriter &rewriter) const override { // Check that the memref operand's defining operation is an AllocOp. Value memref = dealloc.memref(); if (!isa_and_nonnull(memref.getDefiningOp())) return failure(); // Check that all of the uses of the AllocOp are other DeallocOps. for (auto *user : memref.getUsers()) if (!isa(user)) return failure(); // Erase the dealloc operation. rewriter.eraseOp(dealloc); return success(); } }; } // end anonymous namespace. static LogicalResult verify(DeallocOp op) { if (!op.memref().getType().isa()) return op.emitOpError("operand must be a memref"); return success(); } void DeallocOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } LogicalResult DeallocOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dealloc(memrefcast) -> dealloc return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // DimOp //===----------------------------------------------------------------------===// void DimOp::build(OpBuilder &builder, OperationState &result, Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); build(builder, result, memrefOrTensor, indexValue); } void DimOp::build(OpBuilder &builder, OperationState &result, Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { if (auto constantOp = index().getDefiningOp()) return constantOp.getValue().cast().getInt(); return {}; } static LogicalResult verify(DimOp op) { // Assume unknown index to be in range. Optional index = op.getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. auto type = op.memrefOrTensor().getType(); if (auto tensorType = type.dyn_cast()) { if (index.getValue() >= tensorType.getRank()) return op.emitOpError("index is out of range"); } else if (auto memrefType = type.dyn_cast()) { if (index.getValue() >= memrefType.getRank()) return op.emitOpError("index is out of range"); } else if (type.isa() || type.isa()) { // Assume index to be in range. } else { llvm_unreachable("expected operand with tensor or memref type"); } return success(); } OpFoldResult DimOp::fold(ArrayRef operands) { auto index = operands[1].dyn_cast_or_null(); // All forms of folding require a known index. if (!index) return {}; auto argTy = memrefOrTensor().getType(); // Fold if the shape extent along the given index is known. if (auto shapedTy = argTy.dyn_cast()) { // Folding for unranked types (UnrankedMemRefType, UnrankedTensorType) is // not supported. if (!shapedTy.hasRank()) return {}; if (!shapedTy.isDynamicDim(index.getInt())) { Builder builder(getContext()); return builder.getIndexAttr(shapedTy.getShape()[index.getInt()]); } } Operation *definingOp = memrefOrTensor().getDefiningOp(); // dim(tensor_load(memref)) -> dim(memref) if (auto tensorLoadOp = dyn_cast_or_null(definingOp)) { setOperand(0, tensorLoadOp.memref()); return getResult(); } // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`. auto memrefType = argTy.dyn_cast(); if (!memrefType) return {}; // The size at the given index is now known to be a dynamic size of a memref. unsigned unsignedIndex = index.getValue().getZExtValue(); if (auto alloc = dyn_cast_or_null(definingOp)) return *(alloc.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); if (auto view = dyn_cast_or_null(definingOp)) return *(view.getDynamicSizes().begin() + memrefType.getDynamicDimIndex(unsignedIndex)); if (auto subview = dyn_cast_or_null(definingOp)) { assert(subview.isDynamicSize(unsignedIndex) && "Expected dynamic subview size"); return subview.getDynamicSize(unsignedIndex); } // dim(memrefcast) -> dim if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } // --------------------------------------------------------------------------- // DmaStartOp // --------------------------------------------------------------------------- void DmaStartOp::build(OpBuilder &builder, OperationState &result, Value srcMemRef, ValueRange srcIndices, Value destMemRef, ValueRange destIndices, Value numElements, Value tagMemRef, ValueRange tagIndices, Value stride, Value elementsPerStride) { result.addOperands(srcMemRef); result.addOperands(srcIndices); result.addOperands(destMemRef); result.addOperands(destIndices); result.addOperands({numElements, tagMemRef}); result.addOperands(tagIndices); if (stride) result.addOperands({stride, elementsPerStride}); } void DmaStartOp::print(OpAsmPrinter &p) { p << "dma_start " << getSrcMemRef() << '[' << getSrcIndices() << "], " << getDstMemRef() << '[' << getDstIndices() << "], " << getNumElements() << ", " << getTagMemRef() << '[' << getTagIndices() << ']'; if (isStrided()) p << ", " << getStride() << ", " << getNumElementsPerStride(); p.printOptionalAttrDict(getAttrs()); p << " : " << getSrcMemRef().getType() << ", " << getDstMemRef().getType() << ", " << getTagMemRef().getType(); } // Parse DmaStartOp. // Ex: // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size, // %tag[%index], %stride, %num_elt_per_stride : // : memref<3076 x f32, 0>, // memref<1024 x f32, 2>, // memref<1 x i32> // ParseResult DmaStartOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcMemRefInfo; SmallVector srcIndexInfos; OpAsmParser::OperandType dstMemRefInfo; SmallVector dstIndexInfos; OpAsmParser::OperandType numElementsInfo; OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; SmallVector strideInfo; SmallVector types; auto indexType = parser.getBuilder().getIndexType(); // Parse and resolve the following list of operands: // *) source memref followed by its indices (in square brackets). // *) destination memref followed by its indices (in square brackets). // *) dma size in KiB. if (parser.parseOperand(srcMemRefInfo) || parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(dstMemRefInfo) || parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseComma() || parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square)) return failure(); // Parse optional stride and elements per stride. if (parser.parseTrailingOperandList(strideInfo)) return failure(); bool isStrided = strideInfo.size() == 2; if (!strideInfo.empty() && !isStrided) { return parser.emitError(parser.getNameLoc(), "expected two stride related operands"); } if (parser.parseColonTypeList(types)) return failure(); if (types.size() != 3) return parser.emitError(parser.getNameLoc(), "fewer/more types expected"); if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) || parser.resolveOperands(srcIndexInfos, indexType, result.operands) || parser.resolveOperand(dstMemRefInfo, types[1], result.operands) || parser.resolveOperands(dstIndexInfos, indexType, result.operands) || // size should be an index. parser.resolveOperand(numElementsInfo, indexType, result.operands) || parser.resolveOperand(tagMemrefInfo, types[2], result.operands) || // tag indices should be index. parser.resolveOperands(tagIndexInfos, indexType, result.operands)) return failure(); if (isStrided) { if (parser.resolveOperands(strideInfo, indexType, result.operands)) return failure(); } return success(); } LogicalResult DmaStartOp::verify() { unsigned numOperands = getNumOperands(); // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. if (numOperands < 4) return emitOpError("expected at least 4 operands"); // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. if (!getSrcMemRef().getType().isa()) return emitOpError("expected source to be of memref type"); if (numOperands < getSrcMemRefRank() + 4) return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 << " operands"; if (!getSrcIndices().empty() && !llvm::all_of(getSrcIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected source indices to be of index type"); // 2. Destination memref. if (!getDstMemRef().getType().isa()) return emitOpError("expected destination to be of memref type"); unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) return emitOpError() << "expected at least " << numExpectedOperands << " operands"; if (!getDstIndices().empty() && !llvm::all_of(getDstIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected destination indices to be of index type"); // 3. Number of elements. if (!getNumElements().getType().isIndex()) return emitOpError("expected num elements to be of index type"); // 4. Tag memref. if (!getTagMemRef().getType().isa()) return emitOpError("expected tag to be of memref type"); numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) return emitOpError() << "expected at least " << numExpectedOperands << " operands"; if (!getTagIndices().empty() && !llvm::all_of(getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError("expected tag indices to be of index type"); // DMAs from different memory spaces supported. if (getSrcMemorySpace() == getDstMemorySpace()) return emitOpError("DMA should be between different memory spaces"); // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && numOperands != numExpectedOperands + 2) return emitOpError("incorrect number of operands"); // 5. Strides. if (isStrided()) { if (!getStride().getType().isIndex() || !getNumElementsPerStride().getType().isIndex()) return emitOpError( "expected stride and num elements per stride to be of type index"); } return success(); } LogicalResult DmaStartOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_start(memrefcast) -> dma_start return foldMemRefCast(*this); } // --------------------------------------------------------------------------- // DmaWaitOp // --------------------------------------------------------------------------- void DmaWaitOp::build(OpBuilder &builder, OperationState &result, Value tagMemRef, ValueRange tagIndices, Value numElements) { result.addOperands(tagMemRef); result.addOperands(tagIndices); result.addOperands(numElements); } void DmaWaitOp::print(OpAsmPrinter &p) { p << "dma_wait " << getTagMemRef() << '[' << getTagIndices() << "], " << getNumElements(); p.printOptionalAttrDict(getAttrs()); p << " : " << getTagMemRef().getType(); } // Parse DmaWaitOp. // Eg: // dma_wait %tag[%index], %num_elements : memref<1 x i32, (d0) -> (d0), 4> // ParseResult DmaWaitOp::parse(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType tagMemrefInfo; SmallVector tagIndexInfos; Type type; auto indexType = parser.getBuilder().getIndexType(); OpAsmParser::OperandType numElementsInfo; // Parse tag memref, its indices, and dma size. if (parser.parseOperand(tagMemrefInfo) || parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseOperand(numElementsInfo) || parser.parseColonType(type) || parser.resolveOperand(tagMemrefInfo, type, result.operands) || parser.resolveOperands(tagIndexInfos, indexType, result.operands) || parser.resolveOperand(numElementsInfo, indexType, result.operands)) return failure(); return success(); } LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// dma_wait(memrefcast) -> dma_wait return foldMemRefCast(*this); } LogicalResult DmaWaitOp::verify() { // Mandatory non-variadic operands are tag and the number of elements. if (getNumOperands() < 2) return emitOpError() << "expected at least 2 operands"; // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. if (!getTagMemRef().getType().isa()) return emitOpError() << "expected tag to be of memref type"; if (getNumOperands() != 2 + getTagMemRefRank()) return emitOpError() << "expected " << 2 + getTagMemRefRank() << " operands"; if (!getTagIndices().empty() && !llvm::all_of(getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) return emitOpError() << "expected tag indices to be of index type"; if (!getNumElements().getType().isIndex()) return emitOpError() << "expected the number of elements to be of index type"; return success(); } //===----------------------------------------------------------------------===// // DynamicTensorFromElementsOp //===----------------------------------------------------------------------===// static ParseResult parseDynamicTensorFromElementsOp(OpAsmParser &parser, OperationState &result) { // Parse operands. SmallVector dynamicExtents; Type indexTy = parser.getBuilder().getIndexType(); if (parser.parseOperandList(dynamicExtents) || parser.resolveOperands(dynamicExtents, indexTy, result.operands)) return failure(); // Parse body. Region *body = result.addRegion(); if (parser.parseRegion(*body, {}, {})) return failure(); // Parse result type. Type resultType; if (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(resultType)) return failure(); result.addTypes(resultType); return success(); } static void print(OpAsmPrinter &p, DynamicTensorFromElementsOp op) { p << "dynamic_tensor_from_elements " << op.dynamicExtents(); p.printRegion(op.body()); p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getType(); } static LogicalResult verify(DynamicTensorFromElementsOp op) { // Ensure that the tensor type has as many dynamic dimensions as are specified // by the operands. RankedTensorType resultTy = op.getType().cast(); if (op.getNumOperands() != resultTy.getNumDynamicDims()) return op.emitError("must have as many index operands as dynamic extents " "in the result type"); // Ensure that region arguments span the index space. if (!llvm::all_of(op.body().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) return op.emitError("all body arguments must be index"); if (op.body().getNumArguments() != resultTy.getRank()) return op.emitError("must have one body argument per input dimension"); // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(op.body().getBlocks().front().getTerminator()); if (yieldOp.value().getType() != resultTy.getElementType()) return op.emitOpError( "body must be terminated with a `yield` operation of the tensor " "element type"); return success(); } void DynamicTensorFromElementsOp::build( OpBuilder &b, OperationState &result, Type resultTy, ValueRange dynamicExtents, function_ref bodyBuilder) { build(b, result, resultTy, dynamicExtents); // Build and populate body. OpBuilder::InsertionGuard guard(b); Region *bodyRegion = result.regions.front().get(); auto rank = resultTy.cast().getRank(); SmallVector argumentTypes(rank, b.getIndexType()); Block *bodyBlock = b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); bodyBuilder(b, result.location, bodyBlock->getArguments()); } namespace { /// Canonicalizes dynamic_tensor_from_elements operations with a constant /// operand into the equivalent operation with the operand expressed in the /// result type, instead. We also insert a type cast to make sure that the /// resulting IR is still well-typed. struct StaticDynamicTensorFromElements : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(DynamicTensorFromElementsOp tensorFromElements, PatternRewriter &rewriter) const final { auto resultType = tensorFromElements.getResult().getType().cast(); if (resultType.hasStaticShape()) return failure(); SmallVector newOperands; SmallVector newShape; auto operandsIt = tensorFromElements.dynamicExtents().begin(); for (int64_t dim : resultType.getShape()) { if (dim != RankedTensorType::kDynamicSize) { newShape.push_back(dim); continue; } APInt index; if (!matchPattern(*operandsIt, m_ConstantInt(&index))) { newShape.push_back(RankedTensorType::kDynamicSize); newOperands.push_back(*operandsIt++); continue; } newShape.push_back(index.getSExtValue()); operandsIt++; } if (newOperands.size() == tensorFromElements.dynamicExtents().size()) return failure(); auto loc = tensorFromElements.getLoc(); auto newOp = rewriter.create( loc, RankedTensorType::get(newShape, resultType.getElementType()), newOperands); rewriter.inlineRegionBefore(tensorFromElements.body(), newOp.body(), newOp.body().begin()); rewriter.replaceOpWithNewOp(tensorFromElements, resultType, newOp); return success(); } }; /// Canonicalizes the pattern of the form /// /// %tensor = dynamic_tensor_from_elements %x { /// ^bb0(%arg0: index): // no predecessors /// /// yield %1 : index /// } : tensor /// %extracted_element = extract_element %tensor[%c0] : tensor /// /// to just with %arg0 replaced by %c0. We only do this if the /// dynamic_tensor_from_elements operation has no side-effects. struct ExtractElementFromDynamicTensorFromElements : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractElementOp extract, PatternRewriter &rewriter) const final { auto tensorFromElements = extract.aggregate().getDefiningOp(); if (!tensorFromElements || !wouldOpBeTriviallyDead(tensorFromElements)) return failure(); BlockAndValueMapping mapping; Block *body = tensorFromElements.getBody(); mapping.map(body->getArguments(), extract.indices()); for (auto &op : body->without_terminator()) rewriter.clone(op, mapping); auto yield = cast(body->getTerminator()); rewriter.replaceOp(extract, mapping.lookupOrDefault(yield.value())); return success(); } }; } // namespace void DynamicTensorFromElementsOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// static LogicalResult verify(ExtractElementOp op) { // Verify the # indices match if we have a ranked type. auto aggregateType = op.getAggregate().getType().cast(); if (aggregateType.hasRank() && aggregateType.getRank() != op.getNumOperands() - 1) return op.emitOpError("incorrect number of indices for extract_element"); return success(); } OpFoldResult ExtractElementOp::fold(ArrayRef operands) { assert(!operands.empty() && "extract_element takes at least one operand"); // The aggregate operand must be a known constant. Attribute aggregate = operands.front(); if (!aggregate) return {}; // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatAggregate = aggregate.dyn_cast()) return splatAggregate.getSplatValue(); // Otherwise, collect the constant indices into the aggregate. SmallVector indices; for (Attribute indice : llvm::drop_begin(operands, 1)) { if (!indice || !indice.isa()) return {}; indices.push_back(indice.cast().getInt()); } // If this is an elements attribute, query the value at the given indices. auto elementsAttr = aggregate.dyn_cast(); if (elementsAttr && elementsAttr.isValidIndex(indices)) return elementsAttr.getValue(indices); return {}; } //===----------------------------------------------------------------------===// // TensorFromElementsOp //===----------------------------------------------------------------------===// void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, Type elementType, ValueRange elements) { Type resultTy = RankedTensorType::get({static_cast(elements.size())}, elementType); result.addOperands(elements); result.addTypes(resultTy); } void TensorFromElementsOp::build(OpBuilder &builder, OperationState &result, ValueRange elements) { assert(!elements.empty() && "expected at least one element"); build(builder, result, elements.front().getType(), elements); } namespace { // Canonicalizes the pattern of the form // // %tensor = "tensor_from_elements(%element) : (i32) -> tensor<1xi32> // %extracted_element = extract_element %tensor[%c0] : tensor<1xi32> // // to just %element. struct ExtractElementFromTensorFromElements : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ExtractElementOp extract, PatternRewriter &rewriter) const final { if (extract.indices().size() != 1) return failure(); auto tensorFromElements = dyn_cast_or_null( extract.aggregate().getDefiningOp()); if (tensorFromElements == nullptr) return failure(); APInt index; if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) return failure(); rewriter.replaceOp(extract, tensorFromElements.getOperand(index.getZExtValue())); return success(); } }; } // namespace void TensorFromElementsOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // FPExtOp //===----------------------------------------------------------------------===// bool FPExtOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() < fb.getWidth(); return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // FPToSIOp //===----------------------------------------------------------------------===// bool FPToSIOp::areCastCompatible(Type a, Type b) { if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // FPToUIOp //===----------------------------------------------------------------------===// bool FPToUIOp::areCastCompatible(Type a, Type b) { if (a.isa() && b.isSignlessInteger()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // FPTruncOp //===----------------------------------------------------------------------===// bool FPTruncOp::areCastCompatible(Type a, Type b) { if (auto fa = a.dyn_cast()) if (auto fb = b.dyn_cast()) return fa.getWidth() > fb.getWidth(); return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // GlobalMemrefOp //===----------------------------------------------------------------------===// static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalMemrefOp op, TypeAttr type, Attribute initialValue) { p << type; if (!op.isExternal()) { p << " = "; if (op.isUninitialized()) p << "uninitialized"; else p.printAttributeWithoutType(initialValue); } } static ParseResult parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, Attribute &initialValue) { Type type; if (parser.parseType(type)) return failure(); auto memrefType = type.dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) return parser.emitError(parser.getNameLoc()) << "type should be static shaped memref, but got " << type; typeAttr = TypeAttr::get(type); if (parser.parseOptionalEqual()) return success(); if (succeeded(parser.parseOptionalKeyword("uninitialized"))) { initialValue = UnitAttr::get(parser.getBuilder().getContext()); return success(); } Type tensorType = getTensorTypeFromMemRefType(memrefType); if (parser.parseAttribute(initialValue, tensorType)) return failure(); if (!initialValue.isa()) return parser.emitError(parser.getNameLoc()) << "initial value should be a unit or elements attribute"; return success(); } static LogicalResult verify(GlobalMemrefOp op) { auto memrefType = op.type().dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) return op.emitOpError("type should be static shaped memref, but got ") << op.type(); // Verify that the initial value, if present, is either a unit attribute or // an elements attribute. if (op.initial_value().hasValue()) { Attribute initValue = op.initial_value().getValue(); if (!initValue.isa() && !initValue.isa()) return op.emitOpError("initial value should be a unit or elements " "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of // the global variable. if (initValue.isa()) { Type initType = initValue.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) return op.emitOpError("initial value expected to be of type ") << tensorType << ", but was of type " << initType; } } // TODO: verify visibility for declarations. return success(); } //===----------------------------------------------------------------------===// // GetGlobalMemrefOp //===----------------------------------------------------------------------===// LogicalResult GetGlobalMemrefOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // Verify that the result type is same as the type of the referenced // global_memref op. auto global = symbolTable.lookupNearestSymbolFrom(*this, nameAttr()); if (!global) return emitOpError("'") << name() << "' does not reference a valid global memref"; Type resultType = result().getType(); if (global.type() != resultType) return emitOpError("result type ") << resultType << " does not match type " << global.type() << " of the global memref @" << name(); return success(); } //===----------------------------------------------------------------------===// // IndexCastOp //===----------------------------------------------------------------------===// // Index cast is applicable from index to integer and backwards. bool IndexCastOp::areCastCompatible(Type a, Type b) { if (a.isa() && b.isa()) { auto aShaped = a.cast(); auto bShaped = b.cast(); return (aShaped.getShape() == bShaped.getShape()) && areCastCompatible(aShaped.getElementType(), bShaped.getElementType()); } return (a.isIndex() && b.isSignlessInteger()) || (a.isSignlessInteger() && b.isIndex()); } OpFoldResult IndexCastOp::fold(ArrayRef cstOperands) { // Fold IndexCast(IndexCast(x)) -> x auto cast = getOperand().getDefiningOp(); if (cast && cast.getOperand().getType() == getType()) return cast.getOperand(); // Fold IndexCast(constant) -> constant // A little hack because we go through int. Otherwise, the size // of the constant might need to change. if (auto value = cstOperands[0].dyn_cast_or_null()) return IntegerAttr::get(getType(), value.getInt()); return {}; } //===----------------------------------------------------------------------===// // LoadOp //===----------------------------------------------------------------------===// static LogicalResult verify(LoadOp op) { if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("incorrect number of indices for load"); return success(); } OpFoldResult LoadOp::fold(ArrayRef cstOperands) { /// load(memrefcast) -> load if (succeeded(foldMemRefCast(*this))) return getResult(); return OpFoldResult(); } //===----------------------------------------------------------------------===// // MemRefCastOp //===----------------------------------------------------------------------===// Value MemRefCastOp::getViewSource() { return source(); } bool MemRefCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); auto uaT = a.dyn_cast(); auto ubT = b.dyn_cast(); if (aT && bT) { if (aT.getElementType() != bT.getElementType()) return false; if (aT.getAffineMaps() != bT.getAffineMaps()) { int64_t aOffset, bOffset; SmallVector aStrides, bStrides; if (failed(getStridesAndOffset(aT, aStrides, aOffset)) || failed(getStridesAndOffset(bT, bStrides, bOffset)) || aStrides.size() != bStrides.size()) return false; // Strides along a dimension/offset are compatible if the value in the // source memref is static and the value in the target memref is the // same. They are also compatible if either one is dynamic (see // description of MemRefCastOp for details). auto checkCompatible = [](int64_t a, int64_t b) { return (a == MemRefType::getDynamicStrideOrOffset() || b == MemRefType::getDynamicStrideOrOffset() || a == b); }; if (!checkCompatible(aOffset, bOffset)) return false; for (auto aStride : enumerate(aStrides)) if (!checkCompatible(aStride.value(), bStrides[aStride.index()])) return false; } if (aT.getMemorySpace() != bT.getMemorySpace()) return false; // They must have the same rank, and any specified dimensions must match. if (aT.getRank() != bT.getRank()) return false; for (unsigned i = 0, e = aT.getRank(); i != e; ++i) { int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i); if (aDim != -1 && bDim != -1 && aDim != bDim) return false; } return true; } else { if (!aT && !uaT) return false; if (!bT && !ubT) return false; // Unranked to unranked casting is unsupported if (uaT && ubT) return false; auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType(); auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType(); if (aEltType != bEltType) return false; auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace(); auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace(); if (aMemSpace != bMemSpace) return false; return true; } return false; } OpFoldResult MemRefCastOp::fold(ArrayRef operands) { if (Value folded = impl::foldCastOp(*this)) return folded; return succeeded(foldMemRefCast(*this)) ? getResult() : Value(); } //===----------------------------------------------------------------------===// // MemRefReinterpretCastOp //===----------------------------------------------------------------------===// void mlir::MemRefReinterpretCastOp::build( OpBuilder &b, OperationState &result, MemRefType resultType, Value source, int64_t staticOffset, ArrayRef staticSizes, ArrayRef staticStrides, ValueRange offset, ValueRange sizes, ValueRange strides, ArrayRef attrs) { build(b, result, resultType, source, offset, sizes, strides, b.getI64ArrayAttr(staticOffset), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } /// Build a MemRefReinterpretCastOp with all dynamic entries: `staticOffsets`, /// `staticSizes` and `staticStrides` are automatically filled with /// source-memref-rank sentinel values that encode dynamic entries. void mlir::MemRefReinterpretCastOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, Value offset, ValueRange sizes, ValueRange strides, ArrayRef attrs) { unsigned rank = resultType.getRank(); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector( rank, ShapedType::kDynamicStrideOrOffset); build(b, result, resultType, source, /*staticOffset=*/ShapedType::kDynamicStrideOrOffset, staticSizesVector, staticStridesVector, offset, sizes, strides, attrs); } /// Print of the form: /// ``` /// `name` ssa-name to /// offset: `[` offset `]` /// sizes: `[` size-list `]` /// strides:`[` stride-list `]` /// `:` any-memref-type to strided-memref-type /// ``` static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op.getOperationName().drop_front(stdDotLen) << " " << op.source() << " to offset: "; printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset); p << ", sizes: "; printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), ShapedType::isDynamic); p << ", strides: "; printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), MemRefReinterpretCastOp::getStaticOffsetsAttrName(), MemRefReinterpretCastOp::getStaticSizesAttrName(), MemRefReinterpretCastOp::getStaticStridesAttrName()}); p << ": " << op.source().getType() << " to " << op.getType(); } /// Parse of the form: /// ``` /// `name` ssa-name to /// offset: `[` offset `]` /// sizes: `[` size-list `]` /// strides:`[` stride-list `]` /// `:` any-memref-type to strided-memref-type /// ``` static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, OperationState &result) { // Parse `operand` and `offset`. OpAsmParser::OperandType operand; if (parser.parseOperand(operand)) return failure(); // Parse offset. SmallVector offset; if (parser.parseKeyword("to") || parser.parseKeyword("offset") || parser.parseColon() || parseListOfOperandsOrIntegers( parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offset) || parser.parseComma()) return failure(); // Parse `sizes`. SmallVector sizes; if (parser.parseKeyword("sizes") || parser.parseColon() || parseListOfOperandsOrIntegers( parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), ShapedType::kDynamicSize, sizes) || parser.parseComma()) return failure(); // Parse `strides`. SmallVector strides; if (parser.parseKeyword("strides") || parser.parseColon() || parseListOfOperandsOrIntegers( parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, strides)) return failure(); // Handle segment sizes. auto b = parser.getBuilder(); SmallVector segmentSizes = {1, static_cast(offset.size()), static_cast(sizes.size()), static_cast(strides.size())}; result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), b.getI32VectorAttr(segmentSizes)); // Parse types and resolve. Type indexType = b.getIndexType(); Type operandType, resultType; return failure( (parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(operandType) || parser.parseKeyword("to") || parser.parseType(resultType) || parser.resolveOperand(operand, operandType, result.operands) || parser.resolveOperands(offset, indexType, result.operands) || parser.resolveOperands(sizes, indexType, result.operands) || parser.resolveOperands(strides, indexType, result.operands) || parser.addTypeToList(resultType, result.types))); } static LogicalResult verify(MemRefReinterpretCastOp op) { // The source and result memrefs should be in the same memory space. auto srcType = op.source().getType().cast(); auto resultType = op.getType().cast(); if (srcType.getMemorySpace() != resultType.getMemorySpace()) return op.emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) return op.emitError("different element types specified for source type ") << srcType << " and result memref type " << resultType; // Verify that dynamic and static offset/sizes/strides arguments/attributes // are consistent. if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset, op.offsets()))) return failure(); unsigned resultRank = op.getResultRank(); if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "size", resultRank, op.getStaticSizesAttrName(), op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( op, "stride", resultRank, op.getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); // Match sizes in result memref type and in static_sizes attribute. for (auto &en : llvm::enumerate(llvm::zip(resultType.getShape(), extractFromI64ArrayAttr(op.static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (resultSize != expectedSize) return op.emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); } // Match offset and strides in static_offset and static_strides attributes if // result memref type has an affine map specified. if (!resultType.getAffineMaps().empty()) { int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) return failure(); // Match offset in result memref type and in static_offsets attribute. int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); if (resultOffset != expectedOffset) return op.emitError("expected result type with offset = ") << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. for (auto &en : llvm::enumerate(llvm::zip( resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (resultStride != expectedStride) return op.emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride << " in dim = " << en.index(); } } return success(); } //===----------------------------------------------------------------------===// // MemRefReshapeOp //===----------------------------------------------------------------------===// static LogicalResult verify(MemRefReshapeOp op) { Type operandType = op.source().getType(); Type resultType = op.result().getType(); Type operandElementType = operandType.cast().getElementType(); Type resultElementType = resultType.cast().getElementType(); if (operandElementType != resultElementType) return op.emitOpError("element types of source and destination memref " "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) if (!operandMemRefType.getAffineMaps().empty()) return op.emitOpError( "source memref type should have identity affine map"); int64_t shapeSize = op.shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { if (!resultMemRefType.getAffineMaps().empty()) return op.emitOpError( "result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) return op.emitOpError("cannot use shape operand with dynamic length to " "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) return op.emitOpError( "length of shape operand differs from the result's memref rank"); } return success(); } //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// OpFoldResult MulFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a * b; }); } //===----------------------------------------------------------------------===// // MulIOp //===----------------------------------------------------------------------===// OpFoldResult MulIOp::fold(ArrayRef operands) { /// muli(x, 0) -> 0 if (matchPattern(rhs(), m_Zero())) return rhs(); /// muli(x, 1) -> x if (matchPattern(rhs(), m_One())) return getOperand(0); // TODO: Handle the overflow case. return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a * b; }); } //===----------------------------------------------------------------------===// // OrOp //===----------------------------------------------------------------------===// OpFoldResult OrOp::fold(ArrayRef operands) { /// or(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// or(x,x) -> x if (lhs() == rhs()) return rhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a | b; }); } //===----------------------------------------------------------------------===// // PrefetchOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, PrefetchOp op) { p << PrefetchOp::getOperationName() << " " << op.memref() << '['; p.printOperands(op.indices()); p << ']' << ", " << (op.isWrite() ? "write" : "read"); p << ", locality<" << op.localityHint(); p << ">, " << (op.isDataCache() ? "data" : "instr"); p.printOptionalAttrDict( op.getAttrs(), /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"}); p << " : " << op.getMemRefType(); } static ParseResult parsePrefetchOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType memrefInfo; SmallVector indexInfo; IntegerAttr localityHint; MemRefType type; StringRef readOrWrite, cacheType; auto indexTy = parser.getBuilder().getIndexType(); auto i32Type = parser.getBuilder().getIntegerType(32); if (parser.parseOperand(memrefInfo) || parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || parser.parseComma() || parser.parseKeyword(&readOrWrite) || parser.parseComma() || parser.parseKeyword("locality") || parser.parseLess() || parser.parseAttribute(localityHint, i32Type, "localityHint", result.attributes) || parser.parseGreater() || parser.parseComma() || parser.parseKeyword(&cacheType) || parser.parseColonType(type) || parser.resolveOperand(memrefInfo, type, result.operands) || parser.resolveOperands(indexInfo, indexTy, result.operands)) return failure(); if (!readOrWrite.equals("read") && !readOrWrite.equals("write")) return parser.emitError(parser.getNameLoc(), "rw specifier has to be 'read' or 'write'"); result.addAttribute( PrefetchOp::getIsWriteAttrName(), parser.getBuilder().getBoolAttr(readOrWrite.equals("write"))); if (!cacheType.equals("data") && !cacheType.equals("instr")) return parser.emitError(parser.getNameLoc(), "cache type has to be 'data' or 'instr'"); result.addAttribute( PrefetchOp::getIsDataCacheAttrName(), parser.getBuilder().getBoolAttr(cacheType.equals("data"))); return success(); } static LogicalResult verify(PrefetchOp op) { if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) return op.emitOpError("too few indices"); return success(); } LogicalResult PrefetchOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { // prefetch(memrefcast) -> prefetch return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// OpFoldResult RankOp::fold(ArrayRef operands) { // Constant fold rank when the rank of the operand is known. auto type = getOperand().getType(); if (auto shapedType = type.dyn_cast()) if (shapedType.hasRank()) return IntegerAttr::get(IndexType::get(getContext()), shapedType.getRank()); return IntegerAttr(); } //===----------------------------------------------------------------------===// // ReturnOp //===----------------------------------------------------------------------===// static LogicalResult verify(ReturnOp op) { auto function = cast(op.getParentOp()); // The operand number and types must match the function signature. const auto &results = function.getType().getResults(); if (op.getNumOperands() != results.size()) return op.emitOpError("has ") << op.getNumOperands() << " operands, but enclosing function (@" << function.getName() << ") returns " << results.size(); for (unsigned i = 0, e = results.size(); i != e; ++i) if (op.getOperand(i).getType() != results[i]) return op.emitError() << "type of return operand " << i << " (" << op.getOperand(i).getType() << ") doesn't match function result type (" << results[i] << ")" << " in function @" << function.getName(); return success(); } //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// OpFoldResult SelectOp::fold(ArrayRef operands) { auto condition = getCondition(); // select true, %0, %1 => %0 if (matchPattern(condition, m_One())) return getTrueValue(); // select false, %0, %1 => %1 if (matchPattern(condition, m_Zero())) return getFalseValue(); return nullptr; } static void print(OpAsmPrinter &p, SelectOp op) { p << "select " << op.getOperands(); p.printOptionalAttrDict(op.getAttrs()); p << " : "; if (ShapedType condType = op.getCondition().getType().dyn_cast()) p << condType << ", "; p << op.getType(); } static ParseResult parseSelectOp(OpAsmParser &parser, OperationState &result) { Type conditionType, resultType; SmallVector operands; if (parser.parseOperandList(operands, /*requiredOperandCount=*/3) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(resultType)) return failure(); // Check for the explicit condition type if this is a masked tensor or vector. if (succeeded(parser.parseOptionalComma())) { conditionType = resultType; if (parser.parseType(resultType)) return failure(); } else { conditionType = parser.getBuilder().getI1Type(); } result.addTypes(resultType); return parser.resolveOperands(operands, {conditionType, resultType, resultType}, parser.getNameLoc(), result.operands); } static LogicalResult verify(SelectOp op) { Type conditionType = op.getCondition().getType(); if (conditionType.isSignlessInteger(1)) return success(); // If the result type is a vector or tensor, the type can be a mask with the // same elements. Type resultType = op.getType(); if (!resultType.isa()) return op.emitOpError() << "expected condition to be a signless i1, but got " << conditionType; Type shapedConditionType = getI1SameShape(resultType); if (conditionType != shapedConditionType) return op.emitOpError() << "expected condition type to have the same shape " "as the result type, expected " << shapedConditionType << ", but got " << conditionType; return success(); } //===----------------------------------------------------------------------===// // SignExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(SignExtendIOp op) { // Get the scalar type (which is either directly the type of the operand // or the vector's/tensor's element type. auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); // For now, index is forbidden for the source and the destination type. if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() >= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // SignedDivIOp //===----------------------------------------------------------------------===// OpFoldResult SignedDivIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } return a.sdiv_ov(b, overflowOrDiv0); }); // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return lhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return lhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // SignedFloorDivIOp //===----------------------------------------------------------------------===// static APInt signedCeilNonnegInputs(APInt a, APInt b, bool &overflow) { // Returns (a-1)/b + 1 APInt one(a.getBitWidth(), 1, true); // Signed value 1. APInt val = a.ssub_ov(one, overflow).sdiv_ov(b, overflow); return val.sadd_ov(one, overflow); } OpFoldResult SignedFloorDivIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } unsigned bits = a.getBitWidth(); APInt zero = APInt::getNullValue(bits); if (a.sge(zero) && b.sgt(zero)) { // Both positive (or a is zero), return a / b. return a.sdiv_ov(b, overflowOrDiv0); } else if (a.sle(zero) && b.slt(zero)) { // Both negative (or a is zero), return -a / -b. APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return posA.sdiv_ov(posB, overflowOrDiv0); } else if (a.slt(zero) && b.sgt(zero)) { // A is negative, b is positive, return - ceil(-a, b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(posA, b, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); } else { // A is positive, b is negative, return - ceil(a, -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt ceil = signedCeilNonnegInputs(a, posB, overflowOrDiv0); return zero.ssub_ov(ceil, overflowOrDiv0); } }); // Fold out floor division by one. Assumes all tensors of all ones are // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return lhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return lhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // SignedCeilDivIOp //===----------------------------------------------------------------------===// OpFoldResult SignedCeilDivIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would overflow or if it requires a division by zero. bool overflowOrDiv0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (overflowOrDiv0 || !b) { overflowOrDiv0 = true; return a; } unsigned bits = a.getBitWidth(); APInt zero = APInt::getNullValue(bits); if (a.sgt(zero) && b.sgt(zero)) { // Both positive, return ceil(a, b). return signedCeilNonnegInputs(a, b, overflowOrDiv0); } else if (a.slt(zero) && b.slt(zero)) { // Both negative, return ceil(-a, -b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt posB = zero.ssub_ov(b, overflowOrDiv0); return signedCeilNonnegInputs(posA, posB, overflowOrDiv0); } else if (a.slt(zero) && b.sgt(zero)) { // A is negative, b is positive, return - ( -a / b). APInt posA = zero.ssub_ov(a, overflowOrDiv0); APInt div = posA.sdiv_ov(b, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); } else { // A is positive (or zero), b is negative, return - (a / -b). APInt posB = zero.ssub_ov(b, overflowOrDiv0); APInt div = a.sdiv_ov(posB, overflowOrDiv0); return zero.ssub_ov(div, overflowOrDiv0); } }); // Fold out floor division by one. Assumes all tensors of all ones are // splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return lhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return lhs(); } return overflowOrDiv0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // SignedRemIOp //===----------------------------------------------------------------------===// OpFoldResult SignedRemIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "remi_signed takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().srem(rhsValue)); } //===----------------------------------------------------------------------===// // SIToFPOp //===----------------------------------------------------------------------===// // sitofp is applicable from integer types to float types. bool SIToFPOp::areCastCompatible(Type a, Type b) { if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // SplatOp //===----------------------------------------------------------------------===// static LogicalResult verify(SplatOp op) { // TODO: we could replace this by a trait. if (op.getOperand().getType() != op.getType().cast().getElementType()) return op.emitError("operand should be of elemental type of result type"); return success(); } // Constant folding hook for SplatOp. OpFoldResult SplatOp::fold(ArrayRef operands) { assert(operands.size() == 1 && "splat takes one operand"); auto constOperand = operands.front(); if (!constOperand || !constOperand.isa()) return {}; auto shapedType = getType().cast(); assert(shapedType.getElementType() == constOperand.getType() && "incorrect input attribute type for folding"); // SplatElementsAttr::get treats single value for second arg as being a splat. return SplatElementsAttr::get(shapedType, {constOperand}); } //===----------------------------------------------------------------------===// // StoreOp //===----------------------------------------------------------------------===// static LogicalResult verify(StoreOp op) { if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) return op.emitOpError("store index operand count not equal to memref rank"); return success(); } LogicalResult StoreOp::fold(ArrayRef cstOperands, SmallVectorImpl &results) { /// store(memrefcast) -> store return foldMemRefCast(*this); } //===----------------------------------------------------------------------===// // SubFOp //===----------------------------------------------------------------------===// OpFoldResult SubFOp::fold(ArrayRef operands) { return constFoldBinaryOp( operands, [](APFloat a, APFloat b) { return a - b; }); } //===----------------------------------------------------------------------===// // SubIOp //===----------------------------------------------------------------------===// OpFoldResult SubIOp::fold(ArrayRef operands) { // subi(x,x) -> 0 if (getOperand(0) == getOperand(1)) return Builder(getContext()).getZeroAttr(getType()); // subi(x,0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a - b; }); } //===----------------------------------------------------------------------===// // UIToFPOp //===----------------------------------------------------------------------===// // uitofp is applicable from integer types to float types. bool UIToFPOp::areCastCompatible(Type a, Type b) { if (a.isSignlessInteger() && b.isa()) return true; return areVectorCastSimpleCompatible(a, b, areCastCompatible); } //===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// namespace { /// Helpers to write more idiomatic operations. namespace saturated_arith { struct Wrapper { explicit Wrapper(int64_t v) : v(v) {} operator int64_t() { return v; } int64_t v; }; Wrapper operator+(Wrapper a, int64_t b) { if (ShapedType::isDynamicStrideOrOffset(a) || ShapedType::isDynamicStrideOrOffset(b)) return Wrapper(ShapedType::kDynamicStrideOrOffset); return Wrapper(a.v + b); } Wrapper operator*(Wrapper a, int64_t b) { if (ShapedType::isDynamicStrideOrOffset(a) || ShapedType::isDynamicStrideOrOffset(b)) return Wrapper(ShapedType::kDynamicStrideOrOffset); return Wrapper(a.v * b); } } // end namespace saturated_arith } // end namespace /// A subview result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubViewOp::inferResultType(MemRefType sourceMemRefType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { unsigned rank = sourceMemRefType.getRank(); (void)rank; assert(staticOffsets.size() == rank && "unexpected staticOffsets size mismatch"); assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); assert(staticStrides.size() == rank && "unexpected staticStrides size mismatch"); // Extract source offset and strides. int64_t sourceOffset; SmallVector sourceStrides; auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; // Compute target offset whose value is: // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. int64_t targetOffset = sourceOffset; for (auto it : llvm::zip(staticOffsets, sourceStrides)) { auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); using namespace saturated_arith; targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; } // Compute target stride whose value is: // `sourceStrides_i * staticStrides_i`. SmallVector targetStrides; targetStrides.reserve(staticOffsets.size()); for (auto it : llvm::zip(sourceStrides, staticStrides)) { auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); using namespace saturated_arith; targetStrides.push_back(Wrapper(sourceStride) * staticStride); } // The type is now known. return MemRefType::get( staticSizes, sourceMemRefType.getElementType(), makeStridedLinearLayoutMap(targetStrides, targetOffset, sourceMemRefType.getContext()), sourceMemRefType.getMemorySpace()); } /// Print SubViewOp in the form: /// ``` /// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` /// `:` strided-memref-type `to` strided-memref-type /// ``` template static void printOpWithOffsetsSizesAndStrides( OpAsmPrinter &p, OpType op, llvm::function_ref printExtraOperands = [](OpAsmPrinter &p, OpType op) {}, StringRef resultTypeKeyword = "to") { int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; p << op.source(); printExtraOperands(p, op); printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset); p << ' '; printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), ShapedType::isDynamic); p << ' '; printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), ShapedType::isDynamicStrideOrOffset); p << ' '; p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " << op.getType(); } static void print(OpAsmPrinter &p, SubViewOp op) { return printOpWithOffsetsSizesAndStrides(p, op); } /// Parse of the form: /// ``` /// `name` ssa-name (extra-operands)? /// `[` offset-list `]` `[` size-list `]` `[` stride-list `]` /// `:` strided-memref-type `resultTypeKeyword strided-memref-type /// ``` template static ParseResult parseOpWithOffsetsSizesAndStrides( OpAsmParser &parser, OperationState &result, std::function parseExtraOperand = nullptr, StringRef resultTypeKeyword = "to") { OpAsmParser::OperandType srcInfo, dstInfo; SmallVector offsetsInfo, sizesInfo, stridesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; if (parser.parseOperand(srcInfo)) return failure(); if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) return failure(); if (parseListOfOperandsOrIntegers( parser, result, OpType::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || parseListOfOperandsOrIntegers(parser, result, OpType::getStaticSizesAttrName(), ShapedType::kDynamicSize, sizesInfo) || parseListOfOperandsOrIntegers( parser, result, OpType::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); // Handle segment sizes. auto b = parser.getBuilder(); SmallVector segmentSizes = {1, static_cast(offsetsInfo.size()), static_cast(sizesInfo.size()), static_cast(stridesInfo.size())}; // If we parse an extra operand it needs to appear in the segmentSizes if (parseExtraOperand) segmentSizes.insert(segmentSizes.begin(), 1); result.addAttribute(OpType::getOperandSegmentSizeAttr(), b.getI32VectorAttr(segmentSizes)); return failure( parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.parseKeywordType(resultTypeKeyword.str().c_str(), dstType) || parser.resolveOperand(srcInfo, srcType, result.operands) || (parseExtraOperand && parser.resolveOperand(dstInfo, dstType, result.operands)) || parser.resolveOperands(offsetsInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.resolveOperands(stridesInfo, indexType, result.operands) || parser.addTypeToList(dstType, result.types)); } static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { return parseOpWithOffsetsSizesAndStrides(parser, result); } void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceMemRefType = source.getType().cast(); auto resultType = inferResultType(sourceMemRefType, staticOffsets, staticSizes, staticStrides); build(b, result, resultType, source, offsets, sizes, strides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` /// and `staticStrides` are automatically filled with source-memref-rank /// sentinel values that encode dynamic entries. void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceMemRefType = source.getType().cast(); unsigned rank = sourceMemRefType.getRank(); SmallVector staticOffsetsVector; staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector; staticSizesVector.assign(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector; staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); build(b, result, source, staticOffsetsVector, staticSizesVector, staticStridesVector, offsets, sizes, strides, attrs); } /// Build a SubViewOp as above but with custom result type. void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { build(b, result, resultType, source, offsets, sizes, strides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } /// Build a SubViewOp as above but with custom result type. void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, MemRefType resultType, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceMemRefType = source.getType().cast(); unsigned rank = sourceMemRefType.getRank(); SmallVector staticOffsetsVector; staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector; staticSizesVector.assign(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector; staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); build(b, result, resultType, source, staticOffsetsVector, staticSizesVector, staticStridesVector, offsets, sizes, strides, attrs); } /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { size_t originalRank = originalShape.size(), reducedRank = reducedShape.size(); SmallVector mask(originalRank); unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { // Skip matching dims greedily. mask[originalIdx] = (reducedIdx < reducedRank) && (originalShape[originalIdx] == reducedShape[reducedIdx]); if (mask[originalIdx]) reducedIdx++; // 1 is the only non-matching allowed. else if (originalShape[originalIdx] != 1) return {}; } if (reducedIdx != reducedRank) return {}; return mask; } enum SubViewVerificationResult { Success, RankTooLarge, SizeMismatch, StrideMismatch, ElemTypeMismatch, MemSpaceMismatch, AffineMapMismatch }; /// Checks if `original` Type type can be rank reduced to `reduced` type. /// This function is slight variant of `is subsequence` algorithm where /// not matching dimension must be 1. static SubViewVerificationResult isRankReducedType(Type originalType, Type reducedType) { if (originalType == reducedType) return SubViewVerificationResult::Success; if (!originalType.isa() && !originalType.isa()) return SubViewVerificationResult::Success; if (originalType.isa() && !reducedType.isa()) return SubViewVerificationResult::Success; if (originalType.isa() && !reducedType.isa()) return SubViewVerificationResult::Success; ShapedType originalShapedType = originalType.cast(); ShapedType reducedShapedType = reducedType.cast(); // Rank and size logic is valid for all ShapedTypes. ArrayRef originalShape = originalShapedType.getShape(); ArrayRef reducedShape = reducedShapedType.getShape(); unsigned originalRank = originalShape.size(), reducedRank = reducedShape.size(); if (reducedRank > originalRank) return SubViewVerificationResult::RankTooLarge; auto optionalMask = computeRankReductionMask(originalShape, reducedShape); // Sizes cannot be matched in case empty vector is returned. if (!optionalMask.hasValue()) return SubViewVerificationResult::SizeMismatch; // We are done for the tensor case. if (originalType.isa()) return SubViewVerificationResult::Success; // Strided layout logic is relevant for MemRefType only. MemRefType original = originalType.cast(); MemRefType reduced = reducedType.cast(); MLIRContext *c = original.getContext(); int64_t originalOffset, reducedOffset; SmallVector originalStrides, reducedStrides, keepStrides; SmallVector keepMask = optionalMask.getValue(); getStridesAndOffset(original, originalStrides, originalOffset); getStridesAndOffset(reduced, reducedStrides, reducedOffset); // Filter strides based on the mask and check that they are the same // as reduced ones. unsigned reducedIdx = 0; for (unsigned originalIdx = 0; originalIdx < originalRank; ++originalIdx) { if (keepMask[originalIdx]) { if (originalStrides[originalIdx] != reducedStrides[reducedIdx++]) return SubViewVerificationResult::StrideMismatch; keepStrides.push_back(originalStrides[originalIdx]); } } if (original.getElementType() != reduced.getElementType()) return SubViewVerificationResult::ElemTypeMismatch; if (original.getMemorySpace() != reduced.getMemorySpace()) return SubViewVerificationResult::MemSpaceMismatch; auto reducedMap = makeStridedLinearLayoutMap(keepStrides, originalOffset, c); if (!reduced.getAffineMaps().empty() && reducedMap != reduced.getAffineMaps().front()) return SubViewVerificationResult::AffineMapMismatch; return SubViewVerificationResult::Success; } template static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result, OpTy op, Type expectedType) { auto memrefType = expectedType.cast(); switch (result) { case SubViewVerificationResult::Success: return success(); case SubViewVerificationResult::RankTooLarge: return op.emitError("expected result rank to be smaller or equal to ") << "the source rank."; case SubViewVerificationResult::SizeMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result sizes)"; case SubViewVerificationResult::StrideMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result strides)"; case SubViewVerificationResult::ElemTypeMismatch: return op.emitError("expected result element type to be ") << memrefType.getElementType(); case SubViewVerificationResult::MemSpaceMismatch: return op.emitError("expected result and source memory spaces to match."); case SubViewVerificationResult::AffineMapMismatch: return op.emitError("expected result type to be ") << expectedType << " or a rank-reduced version. (mismatch of result affine map)"; } llvm_unreachable("unexpected subview verification result"); } /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { MemRefType baseType = op.getSourceType(); MemRefType subViewType = op.getType(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. if (!isStrided(baseType)) return op.emitError("base type ") << baseType << " is not strided"; if (failed(verifyOpWithOffsetSizesAndStrides(op))) return failure(); // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( baseType, extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); auto result = isRankReducedType(expectedType, subViewType); return produceSubViewErrorMsg(result, op, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, Range &range) { return os << "range " << range.offset << ":" << range.size << ":" << range.stride; } /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. template static SmallVector getOrCreateRangesImpl(OpType op, OpBuilder &b, Location loc) { SmallVector res; unsigned rank = op.getSourceRank(); res.reserve(rank); for (unsigned idx = 0; idx < rank; ++idx) { Value offset = op.isDynamicOffset(idx) ? op.getDynamicOffset(idx) : b.create(loc, op.getStaticOffset(idx)); Value size = op.isDynamicSize(idx) ? op.getDynamicSize(idx) : b.create(loc, op.getStaticSize(idx)); Value stride = op.isDynamicStride(idx) ? op.getDynamicStride(idx) : b.create(loc, op.getStaticStride(idx)); res.emplace_back(Range{offset, size, stride}); } return res; } SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) { return ::getOrCreateRangesImpl(*this, b, loc); } namespace { /// Take a list of `values` with potential new constant to extract and a list /// of `constantValues` with`values.size()` sentinel that evaluate to true by /// applying `isDynamic`. /// Detects the `values` produced by a ConstantIndexOp and places the new /// constant in place of the corresponding sentinel value. void canonicalizeSubViewPart(SmallVectorImpl &values, SmallVectorImpl &constantValues, llvm::function_ref isDynamic) { bool hasNewStaticValue = llvm::any_of( values, [](Value val) { return matchPattern(val, m_ConstantIndex()); }); if (hasNewStaticValue) { for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size(); cstIdx != e; ++cstIdx) { // Was already static, skip. if (!isDynamic(constantValues[cstIdx])) continue; // Newly static, move from Value to constant. if (matchPattern(values[valIdx], m_ConstantIndex())) { constantValues[cstIdx] = cast(values[valIdx].getDefiningOp()).getValue(); // Erase for impl. simplicity. Reverse iterator if we really must. values.erase(std::next(values.begin(), valIdx)); continue; } // Remains dynamic move to next value. ++valIdx; } } } static void replaceWithNewOp(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) { rewriter.replaceOpWithNewOp(op, newOp, op.getType()); } static void replaceWithNewOp(PatternRewriter &rewriter, SubTensorOp op, SubTensorOp newOp) { rewriter.replaceOpWithNewOp(op, newOp, op.getType()); } /// Pattern to rewrite a subview op with constant arguments. template class OpWithOffsetSizesAndStridesConstantArgumentFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(OpType op, PatternRewriter &rewriter) const override { // No constant operand, just return; if (llvm::none_of(op.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); // At least one of offsets/sizes/strides is a new constant. // Form the new list of operands and constant attributes from the existing. SmallVector newOffsets(op.offsets()); SmallVector newStaticOffsets = extractFromI64ArrayAttr(op.static_offsets()); assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newOffsets, newStaticOffsets, ShapedType::isDynamicStrideOrOffset); SmallVector newSizes(op.sizes()); SmallVector newStaticSizes = extractFromI64ArrayAttr(op.static_sizes()); assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); SmallVector newStrides(op.strides()); SmallVector newStaticStrides = extractFromI64ArrayAttr(op.static_strides()); assert(newStaticOffsets.size() == op.getSourceRank()); canonicalizeSubViewPart(newStrides, newStaticStrides, ShapedType::isDynamicStrideOrOffset); // Create the new op in canonical form. auto newOp = rewriter.create( op.getLoc(), op.source(), newStaticOffsets, newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides); replaceWithNewOp(rewriter, op, newOp); return success(); } }; } // end anonymous namespace /// Determines whether MemRefCastOp casts to a more dynamic version of the /// source memref. This is useful to to fold a memref_cast into a consuming op /// and implement canonicalization patterns for ops in different dialects that /// may consume the results of memref_cast operations. Such foldable memref_cast /// operations are typically inserted as `view` and `subview` ops are /// canonicalized, to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked memrefs with strided semantics and same /// element type and rank. /// 2. each of the source's size, offset or stride has more static information /// than the corresponding result's size, offset or stride. /// /// Example 1: /// ```mlir /// %1 = memref_cast %0 : memref<8x16xf32> to memref /// %2 = consumer %1 ... : memref ... /// ``` /// /// may fold into: /// /// ```mlir /// %2 = consumer %0 ... : memref<8x16xf32> ... /// ``` /// /// Example 2: /// ``` /// %1 = memref_cast %0 : memref(16 * i + j)>> /// to memref /// consumer %1 : memref ... /// ``` /// /// may fold into: /// /// ``` /// consumer %0 ... : memref(16 * i + j)>> /// ``` bool mlir::canFoldIntoConsumerOp(MemRefCastOp castOp) { MemRefType sourceType = castOp.source().getType().dyn_cast(); MemRefType resultType = castOp.getType().dyn_cast(); // Requires ranked MemRefType. if (!sourceType || !resultType) return false; // Requires same elemental type. if (sourceType.getElementType() != resultType.getElementType()) return false; // Requires same rank. if (sourceType.getRank() != resultType.getRank()) return false; // Only fold casts between strided memref forms. int64_t sourceOffset, resultOffset; SmallVector sourceStrides, resultStrides; if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) || failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) return false; // If cast is towards more static sizes along any dimension, don't fold. for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st)) return false; } // If cast is towards more static offset along any dimension, don't fold. if (sourceOffset != resultOffset) if (MemRefType::isDynamicStrideOrOffset(sourceOffset) && !MemRefType::isDynamicStrideOrOffset(resultOffset)) return false; // If cast is towards more static strides along any dimension, don't fold. for (auto it : llvm::zip(sourceStrides, resultStrides)) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) if (MemRefType::isDynamicStrideOrOffset(ss) && !MemRefType::isDynamicStrideOrOffset(st)) return false; } return true; } /// Counterpart of `canFoldIntoConsumerOp(MemRefCastOp castOp)` for tensors. /// Determines whether TensorCastOp casts to a more dynamic version of the /// source tensor. This is useful to fold a tensor_cast into a consuming op and /// implement canonicalization patterns for ops in different dialects that may /// consume the results of tensor_cast operations. Such foldable tensor_cast /// operations are typically inserted as `subtensor` ops and are canonicalized, /// to preserve the type compatibility of their uses. /// /// Returns true when all conditions are met: /// 1. source and result are ranked tensors with same element type and rank. /// 2. the tensor type has more static information than the result /// /// Example: /// ```mlir /// %1 = tensor_cast %0 : tensor<8x16xf32> to tensor /// %2 = consumer %1 ... : tensor ... /// ``` /// /// folds into: /// /// ```mlir /// %2 = consumer %0 ... : tensor<8x16xf32> ... /// ``` bool mlir::canFoldIntoConsumerOp(TensorCastOp castOp) { if (!castOp) return false; RankedTensorType sourceType = castOp.source().getType().dyn_cast(); RankedTensorType resultType = castOp.getType().dyn_cast(); // Requires RankedTensorType. if (!sourceType || !resultType) return false; // Requires same elemental type. if (sourceType.getElementType() != resultType.getElementType()) return false; // Requires same rank. if (sourceType.getRank() != resultType.getRank()) return false; // If cast is towards more static sizes along any dimension, don't fold. for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) { auto ss = std::get<0>(it), st = std::get<1>(it); if (ss != st) if (ShapedType::isDynamic(ss) && !ShapedType::isDynamic(st)) return false; } return true; } namespace { /// Pattern to rewrite a subview op with MemRefCast arguments. /// This essentially pushes memref_cast past its consuming subview when /// `canFoldIntoConsumerOp` is true. /// /// Example: /// ``` /// %0 = memref_cast %V : memref<16x16xf32> to memref /// %1 = subview %0[0, 0][3, 4][1, 1] : /// memref to memref<3x4xf32, offset:?, strides:[?, 1]> /// ``` /// is rewritten into: /// ``` /// %0 = subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]> /// %1 = memref_cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to /// memref<3x4xf32, offset:?, strides:[?, 1]> /// ``` class SubViewOpMemRefCastFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { // Any constant operand, just return to let SubViewOpConstantFolder kick in. if (llvm::any_of(subViewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); auto castOp = subViewOp.source().getDefiningOp(); if (!castOp) return failure(); if (!canFoldIntoConsumerOp(castOp)) return failure(); /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on /// the cast source operand type and the SubViewOp static information. This /// is the resulting type if the MemRefCastOp were folded. Type resultType = SubViewOp::inferResultType( castOp.source().getType().cast(), extractFromI64ArrayAttr(subViewOp.static_offsets()), extractFromI64ArrayAttr(subViewOp.static_sizes()), extractFromI64ArrayAttr(subViewOp.static_strides())); Value newSubView = rewriter.create( subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(), subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(), subViewOp.static_sizes(), subViewOp.static_strides()); rewriter.replaceOpWithNewOp(subViewOp, subViewOp.getType(), newSubView); return success(); } }; } // namespace void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert, SubViewOpMemRefCastFolder>(context); } OpFoldResult SubViewOp::fold(ArrayRef operands) { if (getResultRank() == 0 && getSourceRank() == 0) return getViewSource(); return {}; } //===----------------------------------------------------------------------===// // SubTensorOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, SubTensorOp op) { return printOpWithOffsetsSizesAndStrides(p, op); } static ParseResult parseSubTensorOp(OpAsmParser &parser, OperationState &result) { return parseOpWithOffsetsSizesAndStrides(parser, result); } /// A subtensor result type can be fully inferred from the source type and the /// static representation of offsets, sizes and strides. Special sentinels /// encode the dynamic case. Type SubTensorOp::inferResultType(RankedTensorType sourceRankedTensorType, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides) { unsigned rank = sourceRankedTensorType.getRank(); (void)rank; assert(staticOffsets.size() == rank && "unexpected staticOffsets size mismatch"); assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); assert(staticStrides.size() == rank && "unexpected staticStrides size mismatch"); return RankedTensorType::get(staticSizes, sourceRankedTensorType.getElementType()); } void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result, Value source, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceRankedTensorType = source.getType().cast(); auto resultType = inferResultType(sourceRankedTensorType, staticOffsets, staticSizes, staticStrides); build(b, result, resultType, source, offsets, sizes, strides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } /// Build a SubTensorOp with all dynamic entries: `staticOffsets`, `staticSizes` /// and `staticStrides` are automatically filled with sentinel values that /// encode dynamic entries. void mlir::SubTensorOp::build(OpBuilder &b, OperationState &result, Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceRankedTensorType = source.getType().cast(); unsigned rank = sourceRankedTensorType.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector( rank, ShapedType::kDynamicStrideOrOffset); build(b, result, source, staticOffsetsVector, staticSizesVector, staticStridesVector, offsets, sizes, strides, attrs); } SmallVector SubTensorOp::getOrCreateRanges(OpBuilder &b, Location loc) { return ::getOrCreateRangesImpl(*this, b, loc); } /// Verifier for SubTensorOp. static LogicalResult verify(SubTensorOp op) { if (failed(verifyOpWithOffsetSizesAndStrides(op))) return failure(); // Verify result type against inferred type. auto expectedType = SubTensorOp::inferResultType( op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(op.static_strides())); auto result = isRankReducedType(expectedType, op.getType()); return produceSubViewErrorMsg(result, op, expectedType); } void SubTensorOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results .insert>( context); } //===----------------------------------------------------------------------===// // SubTensorInsertOp //===----------------------------------------------------------------------===// static void print(OpAsmPrinter &p, SubTensorInsertOp op) { return printOpWithOffsetsSizesAndStrides( p, op, [](OpAsmPrinter &p, SubTensorInsertOp op) { p << " into " << op.dest(); }, /*resultTypeKeyword=*/"into"); } static ParseResult parseSubTensorInsertOp(OpAsmParser &parser, OperationState &result) { return parseOpWithOffsetsSizesAndStrides( parser, result, [](OpAsmParser &parser, OpAsmParser::OperandType &dstInfo) { return failure(parser.parseKeyword("into") || parser.parseOperand(dstInfo)); }, "into"); } void mlir::SubTensorInsertOp::build( OpBuilder &b, OperationState &result, Value source, Value dest, ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { build(b, result, dest.getType(), source, dest, offsets, sizes, strides, b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } /// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` /// and `staticStrides` are automatically filled with source-memref-rank /// sentinel values that encode dynamic entries. void mlir::SubTensorInsertOp::build(OpBuilder &b, OperationState &result, Value source, Value dest, ValueRange offsets, ValueRange sizes, ValueRange strides, ArrayRef attrs) { auto sourceRankedTensorType = source.getType().cast(); unsigned rank = sourceRankedTensorType.getRank(); SmallVector staticOffsetsVector( rank, ShapedType::kDynamicStrideOrOffset); SmallVector staticSizesVector(rank, ShapedType::kDynamicSize); SmallVector staticStridesVector( rank, ShapedType::kDynamicStrideOrOffset); build(b, result, source, dest, staticOffsetsVector, staticSizesVector, staticStridesVector, offsets, sizes, strides, attrs); } SmallVector SubTensorInsertOp::getOrCreateRanges(OpBuilder &b, Location loc) { return ::getOrCreateRangesImpl(*this, b, loc); } /// Verifier for SubViewOp. static LogicalResult verify(SubTensorInsertOp op) { if (failed(verifyOpWithOffsetSizesAndStrides(op))) return failure(); if (op.getType() != op.dest().getType()) return op.emitError("expected result type to be ") << op.dest().getType(); return success(); } //===----------------------------------------------------------------------===// // TensorCastOp //===----------------------------------------------------------------------===// bool TensorCastOp::areCastCompatible(Type a, Type b) { auto aT = a.dyn_cast(); auto bT = b.dyn_cast(); if (!aT || !bT) return false; if (aT.getElementType() != bT.getElementType()) return false; return succeeded(verifyCompatibleShape(aT, bT)); } OpFoldResult TensorCastOp::fold(ArrayRef operands) { return impl::foldCastOp(*this); } /// Compute a TensorType that has the joined shape knowledge of the two /// given TensorTypes. The element types need to match. static TensorType joinShapes(TensorType one, TensorType two) { assert(one.getElementType() == two.getElementType()); if (!one.hasRank()) return two; if (!two.hasRank()) return one; int64_t rank = one.getRank(); if (rank != two.getRank()) return {}; SmallVector join; join.reserve(rank); for (int64_t i = 0; i < rank; ++i) { if (one.isDynamicDim(i)) { join.push_back(two.getDimSize(i)); continue; } if (two.isDynamicDim(i)) { join.push_back(one.getDimSize(i)); continue; } if (one.getDimSize(i) != two.getDimSize(i)) return {}; join.push_back(one.getDimSize(i)); } return RankedTensorType::get(join, one.getElementType()); } namespace { /// Replaces chains of two tensor_cast operations by a single tensor_cast /// operation if doing so does not remove runtime constraints. struct ChainedTensorCast : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TensorCastOp tensorCast, PatternRewriter &rewriter) const final { auto tensorCastOperand = tensorCast.getOperand().getDefiningOp(); if (!tensorCastOperand) return failure(); auto sourceType = tensorCastOperand.getOperand().getType().cast(); auto intermediateType = tensorCastOperand.getType().cast(); auto resultType = tensorCast.getType().cast(); // We can remove the intermediate cast if joining all three produces the // same result as just joining the source and result shapes. auto firstJoin = joinShapes(joinShapes(sourceType, intermediateType), resultType); // The join might not exist if the cast sequence would fail at runtime. if (!firstJoin) return failure(); // The newJoin always exists if the above join exists, it might just contain // less information. If so, we cannot drop the intermediate cast, as doing // so would remove runtime checks. auto newJoin = joinShapes(sourceType, resultType); if (firstJoin != newJoin) return failure(); rewriter.replaceOpWithNewOp(tensorCast, resultType, tensorCastOperand.getOperand()); return success(); } }; } // namespace void TensorCastOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // TensorLoadOp //===----------------------------------------------------------------------===// OpFoldResult TensorLoadOp::fold(ArrayRef) { if (auto tensorToMemref = memref().getDefiningOp()) return tensorToMemref.tensor(); return {}; } //===----------------------------------------------------------------------===// // TensorToMemrefOp //===----------------------------------------------------------------------===// OpFoldResult TensorToMemrefOp::fold(ArrayRef) { if (auto tensorLoad = tensor().getDefiningOp()) if (tensorLoad.memref().getType() == getType()) return tensorLoad.memref(); return {}; } //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// /// Build a strided memref type by applying `permutationMap` tp `memRefType`. static MemRefType inferTransposeResultType(MemRefType memRefType, AffineMap permutationMap) { auto rank = memRefType.getRank(); auto originalSizes = memRefType.getShape(); // Compute permuted sizes. SmallVector sizes(rank, 0); for (auto en : llvm::enumerate(permutationMap.getResults())) sizes[en.index()] = originalSizes[en.value().cast().getPosition()]; // Compute permuted strides. int64_t offset; SmallVector strides; auto res = getStridesAndOffset(memRefType, strides, offset); assert(succeeded(res) && strides.size() == static_cast(rank)); (void)res; auto map = makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); map = permutationMap ? map.compose(permutationMap) : map; return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map); } void TransposeOp::build(OpBuilder &b, OperationState &result, Value in, AffineMapAttr permutation, ArrayRef attrs) { auto permutationMap = permutation.getValue(); assert(permutationMap); auto memRefType = in.getType().cast(); // Compute result type. MemRefType resultType = inferTransposeResultType(memRefType, permutationMap); build(b, result, resultType, in, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); } // transpose $in $permutation attr-dict : type($in) `to` type(results) static void print(OpAsmPrinter &p, TransposeOp op) { p << "transpose " << op.in() << " " << op.permutation(); p.printOptionalAttrDict(op.getAttrs(), {TransposeOp::getPermutationAttrName()}); p << " : " << op.in().getType() << " to " << op.getType(); } static ParseResult parseTransposeOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType in; AffineMap permutation; MemRefType srcType, dstType; if (parser.parseOperand(in) || parser.parseAffineMap(permutation) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(in, srcType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)) return failure(); result.addAttribute(TransposeOp::getPermutationAttrName(), AffineMapAttr::get(permutation)); return success(); } static LogicalResult verify(TransposeOp op) { if (!op.permutation().isPermutation()) return op.emitOpError("expected a permutation map"); if (op.permutation().getNumDims() != op.getShapedType().getRank()) return op.emitOpError( "expected a permutation map of same rank as the input"); auto srcType = op.in().getType().cast(); auto dstType = op.getType().cast(); auto transposedType = inferTransposeResultType(srcType, op.permutation()); if (dstType != transposedType) return op.emitOpError("output type ") << dstType << " does not match transposed input type " << srcType << ", " << transposedType; return success(); } OpFoldResult TransposeOp::fold(ArrayRef) { if (succeeded(foldMemRefCast(*this))) return getResult(); return {}; } //===----------------------------------------------------------------------===// // TruncateIOp //===----------------------------------------------------------------------===// static LogicalResult verify(TruncateIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() <= dstType.cast().getWidth()) return op.emitError("operand type ") << srcType << " must be wider than result type " << dstType; return success(); } //===----------------------------------------------------------------------===// // UnsignedDivIOp //===----------------------------------------------------------------------===// OpFoldResult UnsignedDivIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "binary operation takes two operands"); // Don't fold if it would require a division by zero. bool div0 = false; auto result = constFoldBinaryOp(operands, [&](APInt a, APInt b) { if (div0 || !b) { div0 = true; return a; } return a.udiv(b); }); // Fold out division by one. Assumes all tensors of all ones are splats. if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getValue() == 1) return lhs(); } else if (auto rhs = operands[1].dyn_cast_or_null()) { if (rhs.getSplatValue().getValue() == 1) return lhs(); } return div0 ? Attribute() : result; } //===----------------------------------------------------------------------===// // UnsignedRemIOp //===----------------------------------------------------------------------===// OpFoldResult UnsignedRemIOp::fold(ArrayRef operands) { assert(operands.size() == 2 && "remi_unsigned takes two operands"); auto rhs = operands.back().dyn_cast_or_null(); if (!rhs) return {}; auto rhsValue = rhs.getValue(); // x % 1 = 0 if (rhsValue.isOneValue()) return IntegerAttr::get(rhs.getType(), APInt(rhsValue.getBitWidth(), 0)); // Don't fold if it requires division by zero. if (rhsValue.isNullValue()) return {}; auto lhs = operands.front().dyn_cast_or_null(); if (!lhs) return {}; return IntegerAttr::get(lhs.getType(), lhs.getValue().urem(rhsValue)); } //===----------------------------------------------------------------------===// // ViewOp //===----------------------------------------------------------------------===// static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType srcInfo; SmallVector offsetInfo; SmallVector sizesInfo; auto indexType = parser.getBuilder().getIndexType(); Type srcType, dstType; llvm::SMLoc offsetLoc; if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) || parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square)) return failure(); if (offsetInfo.size() != 1) return parser.emitError(offsetLoc) << "expects 1 offset operand"; return failure( parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttrDict(result.attributes) || parser.parseColonType(srcType) || parser.resolveOperand(srcInfo, srcType, result.operands) || parser.resolveOperands(offsetInfo, indexType, result.operands) || parser.resolveOperands(sizesInfo, indexType, result.operands) || parser.parseKeywordType("to", dstType) || parser.addTypeToList(dstType, result.types)); } static void print(OpAsmPrinter &p, ViewOp op) { p << op.getOperationName() << ' ' << op.getOperand(0) << '['; p.printOperand(op.byte_shift()); p << "][" << op.sizes() << ']'; p.printOptionalAttrDict(op.getAttrs()); p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } static LogicalResult verify(ViewOp op) { auto baseType = op.getOperand(0).getType().cast(); auto viewType = op.getType(); // The base memref should have identity layout map (or none). if (baseType.getAffineMaps().size() > 1 || (baseType.getAffineMaps().size() == 1 && !baseType.getAffineMaps()[0].isIdentity())) return op.emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). if (viewType.getAffineMaps().size() > 1 || (viewType.getAffineMaps().size() == 1 && !viewType.getAffineMaps()[0].isIdentity())) return op.emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " "type ") << baseType << " and view memref type " << viewType; // Verify that we have the correct number of sizes for the result type. unsigned numDynamicDims = viewType.getNumDynamicDims(); if (op.sizes().size() != numDynamicDims) return op.emitError("incorrect number of size operands for type ") << viewType; return success(); } Value ViewOp::getViewSource() { return source(); } namespace { struct ViewOpShapeFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { // Return if none of the operands are constants. if (llvm::none_of(viewOp.getOperands(), [](Value operand) { return matchPattern(operand, m_ConstantIndex()); })) return failure(); // Get result memref type. auto memrefType = viewOp.getType(); // Get offset from old memref view type 'memRefType'. int64_t oldOffset; SmallVector oldStrides; if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset))) return failure(); assert(oldOffset == 0 && "Expected 0 offset"); SmallVector newOperands; // Offset cannot be folded into result type. // Fold any dynamic dim operands which are produced by a constant. SmallVector newShapeConstants; newShapeConstants.reserve(memrefType.getRank()); unsigned dynamicDimPos = 0; unsigned rank = memrefType.getRank(); for (unsigned dim = 0, e = rank; dim < e; ++dim) { int64_t dimSize = memrefType.getDimSize(dim); // If this is already static dimension, keep it. if (!ShapedType::isDynamic(dimSize)) { newShapeConstants.push_back(dimSize); continue; } auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp(); if (auto constantIndexOp = dyn_cast_or_null(defOp)) { // Dynamic shape dimension will be folded. newShapeConstants.push_back(constantIndexOp.getValue()); } else { // Dynamic shape dimension not folded; copy operand from old memref. newShapeConstants.push_back(dimSize); newOperands.push_back(viewOp.sizes()[dynamicDimPos]); } dynamicDimPos++; } // Create new memref type with constant folded dims. MemRefType newMemRefType = MemRefType::Builder(memrefType).setShape(newShapeConstants); // Nothing new, don't fold. if (newMemRefType == memrefType) return failure(); // Create new ViewOp. auto newViewOp = rewriter.create(viewOp.getLoc(), newMemRefType, viewOp.getOperand(0), viewOp.byte_shift(), newOperands); // Insert a cast so we have the same type as the old memref type. rewriter.replaceOpWithNewOp(viewOp, newViewOp, viewOp.getType()); return success(); } }; struct ViewOpMemrefCastFolder : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ViewOp viewOp, PatternRewriter &rewriter) const override { Value memrefOperand = viewOp.getOperand(0); MemRefCastOp memrefCastOp = memrefOperand.getDefiningOp(); if (!memrefCastOp) return failure(); Value allocOperand = memrefCastOp.getOperand(); AllocOp allocOp = allocOperand.getDefiningOp(); if (!allocOp) return failure(); rewriter.replaceOpWithNewOp(viewOp, viewOp.getType(), allocOperand, viewOp.byte_shift(), viewOp.sizes()); return success(); } }; } // end anonymous namespace void ViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { results.insert(context); } //===----------------------------------------------------------------------===// // XOrOp //===----------------------------------------------------------------------===// OpFoldResult XOrOp::fold(ArrayRef operands) { /// xor(x, 0) -> x if (matchPattern(rhs(), m_Zero())) return lhs(); /// xor(x,x) -> 0 if (lhs() == rhs()) return Builder(getContext()).getZeroAttr(getType()); return constFoldBinaryOp(operands, [](APInt a, APInt b) { return a ^ b; }); } //===----------------------------------------------------------------------===// // ZeroExtendIOp //===----------------------------------------------------------------------===// static LogicalResult verify(ZeroExtendIOp op) { auto srcType = getElementTypeOrSelf(op.getOperand().getType()); auto dstType = getElementTypeOrSelf(op.getType()); if (srcType.isa()) return op.emitError() << srcType << " is not a valid operand type"; if (dstType.isa()) return op.emitError() << dstType << " is not a valid result type"; if (srcType.cast().getWidth() >= dstType.cast().getWidth()) return op.emitError("result type ") << dstType << " must be wider than operand type " << srcType; return success(); } //===----------------------------------------------------------------------===// // TableGen'd op method definitions //===----------------------------------------------------------------------===// #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.cpp.inc" diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index 516f8c060a93..cef0a827f08d 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -1,497 +1,497 @@ //===- Utils.cpp ---- Misc utilities for code and data transformation -----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // //===----------------------------------------------------------------------===// // // This file implements miscellaneous transformation routines for non-loop IR // structures. // //===----------------------------------------------------------------------===// #include "mlir/Transforms/Utils.h" #include "mlir/Analysis/AffineAnalysis.h" #include "mlir/Analysis/AffineStructures.h" #include "mlir/Analysis/Utils.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Dominance.h" #include "mlir/IR/Function.h" #include "mlir/IR/Module.h" #include "mlir/Support/MathExtras.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; /// Return true if this operation dereferences one or more memref's. // Temporary utility: will be replaced when this is modeled through // side-effects/op traits. TODO static bool isMemRefDereferencingOp(Operation &op) { return isa(op); } /// Return the AffineMapAttr associated with memory 'op' on 'memref'. static NamedAttribute getAffineMapAttrForMemRef(Operation *op, Value memref) { return TypeSwitch(op) .Case( [=](auto op) { return op.getAffineMapAttrForMemRef(memref); }); } // Perform the replacement in `op`. LogicalResult mlir::replaceAllMemRefUsesWith(Value oldMemRef, Value newMemRef, Operation *op, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, ArrayRef symbolOperands, bool allowNonDereferencingOps) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); (void)oldMemRefRank; // unused in opt mode if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && "symbolic operand count mismatch"); assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); } // Assert same elemental type. assert(oldMemRef.getType().cast().getElementType() == newMemRef.getType().cast().getElementType()); SmallVector usePositions; for (const auto &opEntry : llvm::enumerate(op->getOperands())) { if (opEntry.value() == oldMemRef) usePositions.push_back(opEntry.index()); } // If memref doesn't appear, nothing to do. if (usePositions.empty()) return success(); if (usePositions.size() > 1) { // TODO: extend it for this case when needed (rare). assert(false && "multiple dereferencing uses in a single op not supported"); return failure(); } unsigned memRefOperandPos = usePositions.front(); OpBuilder builder(op); // The following checks if op is dereferencing memref and performs the access // index rewrites. if (!isMemRefDereferencingOp(*op)) { if (!allowNonDereferencingOps) // Failure: memref used in a non-dereferencing context (potentially // escapes); no replacement in these cases unless allowNonDereferencingOps // is set. return failure(); op->setOperand(memRefOperandPos, newMemRef); return success(); } // Perform index rewrites for the dereferencing op and then replace the op NamedAttribute oldMapAttrPair = getAffineMapAttrForMemRef(op, oldMemRef); AffineMap oldMap = oldMapAttrPair.second.cast().getValue(); unsigned oldMapNumInputs = oldMap.getNumInputs(); SmallVector oldMapOperands( op->operand_begin() + memRefOperandPos + 1, op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs); // Apply 'oldMemRefOperands = oldMap(oldMapOperands)'. SmallVector oldMemRefOperands; SmallVector affineApplyOps; oldMemRefOperands.reserve(oldMemRefRank); if (oldMap != builder.getMultiDimIdentityMap(oldMap.getNumDims())) { for (auto resultExpr : oldMap.getResults()) { auto singleResMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(), resultExpr); auto afOp = builder.create(op->getLoc(), singleResMap, oldMapOperands); oldMemRefOperands.push_back(afOp); affineApplyOps.push_back(afOp); } } else { oldMemRefOperands.assign(oldMapOperands.begin(), oldMapOperands.end()); } // Construct new indices as a remap of the old ones if a remapping has been // provided. The indices of a memref come right after it, i.e., // at position memRefOperandPos + 1. SmallVector remapOperands; remapOperands.reserve(extraOperands.size() + oldMemRefRank + symbolOperands.size()); remapOperands.append(extraOperands.begin(), extraOperands.end()); remapOperands.append(oldMemRefOperands.begin(), oldMemRefOperands.end()); remapOperands.append(symbolOperands.begin(), symbolOperands.end()); SmallVector remapOutputs; remapOutputs.reserve(oldMemRefRank); if (indexRemap && indexRemap != builder.getMultiDimIdentityMap(indexRemap.getNumDims())) { // Remapped indices. for (auto resultExpr : indexRemap.getResults()) { auto singleResMap = AffineMap::get( indexRemap.getNumDims(), indexRemap.getNumSymbols(), resultExpr); auto afOp = builder.create(op->getLoc(), singleResMap, remapOperands); remapOutputs.push_back(afOp); affineApplyOps.push_back(afOp); } } else { // No remapping specified. remapOutputs.assign(remapOperands.begin(), remapOperands.end()); } SmallVector newMapOperands; newMapOperands.reserve(newMemRefRank); // Prepend 'extraIndices' in 'newMapOperands'. for (Value extraIndex : extraIndices) { assert(extraIndex.getDefiningOp()->getNumResults() == 1 && "single result op's expected to generate these indices"); assert((isValidDim(extraIndex) || isValidSymbol(extraIndex)) && "invalid memory op index"); newMapOperands.push_back(extraIndex); } // Append 'remapOutputs' to 'newMapOperands'. newMapOperands.append(remapOutputs.begin(), remapOutputs.end()); // Create new fully composed AffineMap for new op to be created. assert(newMapOperands.size() == newMemRefRank); auto newMap = builder.getMultiDimIdentityMap(newMemRefRank); // TODO: Avoid creating/deleting temporary AffineApplyOps here. fullyComposeAffineMapAndOperands(&newMap, &newMapOperands); newMap = simplifyAffineMap(newMap); canonicalizeMapAndOperands(&newMap, &newMapOperands); // Remove any affine.apply's that became dead as a result of composition. for (Value value : affineApplyOps) if (value.use_empty()) value.getDefiningOp()->erase(); OperationState state(op->getLoc(), op->getName()); // Construct the new operation using this memref. state.operands.reserve(op->getNumOperands() + extraIndices.size()); // Insert the non-memref operands. state.operands.append(op->operand_begin(), op->operand_begin() + memRefOperandPos); // Insert the new memref value. state.operands.push_back(newMemRef); // Insert the new memref map operands. state.operands.append(newMapOperands.begin(), newMapOperands.end()); // Insert the remaining operands unmodified. state.operands.append(op->operand_begin() + memRefOperandPos + 1 + oldMapNumInputs, op->operand_end()); // Result types don't change. Both memref's are of the same elemental type. state.types.reserve(op->getNumResults()); for (auto result : op->getResults()) state.types.push_back(result.getType()); // Add attribute for 'newMap', other Attributes do not change. auto newMapAttr = AffineMapAttr::get(newMap); for (auto namedAttr : op->getAttrs()) { if (namedAttr.first == oldMapAttrPair.first) state.attributes.push_back({namedAttr.first, newMapAttr}); else state.attributes.push_back(namedAttr); } // Create the new operation. auto *repOp = builder.createOperation(state); op->replaceAllUsesWith(repOp); op->erase(); return success(); } LogicalResult mlir::replaceAllMemRefUsesWith( Value oldMemRef, Value newMemRef, ArrayRef extraIndices, AffineMap indexRemap, ArrayRef extraOperands, ArrayRef symbolOperands, Operation *domInstFilter, Operation *postDomInstFilter, bool allowNonDereferencingOps, bool replaceInDeallocOp) { unsigned newMemRefRank = newMemRef.getType().cast().getRank(); (void)newMemRefRank; // unused in opt mode unsigned oldMemRefRank = oldMemRef.getType().cast().getRank(); (void)oldMemRefRank; if (indexRemap) { assert(indexRemap.getNumSymbols() == symbolOperands.size() && "symbol operand count mismatch"); assert(indexRemap.getNumInputs() == extraOperands.size() + oldMemRefRank + symbolOperands.size()); assert(indexRemap.getNumResults() + extraIndices.size() == newMemRefRank); } else { assert(oldMemRefRank + extraIndices.size() == newMemRefRank); } // Assert same elemental type. assert(oldMemRef.getType().cast().getElementType() == newMemRef.getType().cast().getElementType()); std::unique_ptr domInfo; std::unique_ptr postDomInfo; if (domInstFilter) domInfo = std::make_unique( domInstFilter->getParentOfType()); if (postDomInstFilter) postDomInfo = std::make_unique( postDomInstFilter->getParentOfType()); // Walk all uses of old memref; collect ops to perform replacement. We use a // DenseSet since an operation could potentially have multiple uses of a // memref (although rare), and the replacement later is going to erase ops. DenseSet opsToReplace; for (auto *op : oldMemRef.getUsers()) { // Skip this use if it's not dominated by domInstFilter. if (domInstFilter && !domInfo->dominates(domInstFilter, op)) continue; // Skip this use if it's not post-dominated by postDomInstFilter. if (postDomInstFilter && !postDomInfo->postDominates(postDomInstFilter, op)) continue; // Skip dealloc's - no replacement is necessary, and a memref replacement // at other uses doesn't hurt these dealloc's. if (isa(op) && !replaceInDeallocOp) continue; // Check if the memref was used in a non-dereferencing context. It is fine // for the memref to be used in a non-dereferencing way outside of the // region where this replacement is happening. if (!isMemRefDereferencingOp(*op)) { if (!allowNonDereferencingOps) return failure(); // Currently we support the following non-dereferencing ops to be a // candidate for replacement: Dealloc, CallOp and ReturnOp. // TODO: Add support for other kinds of ops. if (!op->hasTrait()) return failure(); } // We'll first collect and then replace --- since replacement erases the op // that has the use, and that op could be postDomFilter or domFilter itself! opsToReplace.insert(op); } for (auto *op : opsToReplace) { if (failed(replaceAllMemRefUsesWith( oldMemRef, newMemRef, op, extraIndices, indexRemap, extraOperands, symbolOperands, allowNonDereferencingOps))) llvm_unreachable("memref replacement guaranteed to succeed here"); } return success(); } /// Given an operation, inserts one or more single result affine /// apply operations, results of which are exclusively used by this operation /// operation. The operands of these newly created affine apply ops are /// guaranteed to be loop iterators or terminal symbols of a function. /// /// Before /// /// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// "compute"(%idx) /// /// After /// /// affine.for %i = 0 to #map(%N) /// %idx = affine.apply (d0) -> (d0 mod 2) (%i) /// "send"(%idx, %A, ...) /// %idx_ = affine.apply (d0) -> (d0 mod 2) (%i) /// "compute"(%idx_) /// /// This allows applying different transformations on send and compute (for eg. /// different shifts/delays). /// /// Returns nullptr either if none of opInst's operands were the result of an /// affine.apply and thus there was no affine computation slice to create, or if /// all the affine.apply op's supplying operands to this opInst did not have any /// uses besides this opInst; otherwise returns the list of affine.apply /// operations created in output argument `sliceOps`. void mlir::createAffineComputationSlice( Operation *opInst, SmallVectorImpl *sliceOps) { // Collect all operands that are results of affine apply ops. SmallVector subOperands; subOperands.reserve(opInst->getNumOperands()); for (auto operand : opInst->getOperands()) if (isa_and_nonnull(operand.getDefiningOp())) subOperands.push_back(operand); // Gather sequence of AffineApplyOps reachable from 'subOperands'. SmallVector affineApplyOps; getReachableAffineApplyOps(subOperands, affineApplyOps); // Skip transforming if there are no affine maps to compose. if (affineApplyOps.empty()) return; // Check if all uses of the affine apply op's lie only in this op op, in // which case there would be nothing to do. bool localized = true; for (auto *op : affineApplyOps) { for (auto result : op->getResults()) { for (auto *user : result.getUsers()) { if (user != opInst) { localized = false; break; } } } } if (localized) return; OpBuilder builder(opInst); SmallVector composedOpOperands(subOperands); auto composedMap = builder.getMultiDimIdentityMap(composedOpOperands.size()); fullyComposeAffineMapAndOperands(&composedMap, &composedOpOperands); // Create an affine.apply for each of the map results. sliceOps->reserve(composedMap.getNumResults()); for (auto resultExpr : composedMap.getResults()) { auto singleResMap = AffineMap::get(composedMap.getNumDims(), composedMap.getNumSymbols(), resultExpr); sliceOps->push_back(builder.create( opInst->getLoc(), singleResMap, composedOpOperands)); } // Construct the new operands that include the results from the composed // affine apply op above instead of existing ones (subOperands). So, they // differ from opInst's operands only for those operands in 'subOperands', for // which they will be replaced by the corresponding one from 'sliceOps'. SmallVector newOperands(opInst->getOperands()); for (unsigned i = 0, e = newOperands.size(); i < e; i++) { // Replace the subOperands from among the new operands. unsigned j, f; for (j = 0, f = subOperands.size(); j < f; j++) { if (newOperands[i] == subOperands[j]) break; } if (j < subOperands.size()) { newOperands[i] = (*sliceOps)[j]; } } for (unsigned idx = 0, e = newOperands.size(); idx < e; idx++) { opInst->setOperand(idx, newOperands[idx]); } } // TODO: Currently works for static memrefs with a single layout map. LogicalResult mlir::normalizeMemRef(AllocOp allocOp) { MemRefType memrefType = allocOp.getType(); OpBuilder b(allocOp); // Fetch a new memref type after normalizing the old memref to have an // identity map layout. MemRefType newMemRefType = - normalizeMemRefType(memrefType, b, allocOp.getNumSymbolicOperands()); + normalizeMemRefType(memrefType, b, allocOp.symbolOperands().size()); if (newMemRefType == memrefType) // Either memrefType already had an identity map or the map couldn't be // transformed to an identity map. return failure(); Value oldMemRef = allocOp.getResult(); - SmallVector symbolOperands(allocOp.getSymbolicOperands()); + SmallVector symbolOperands(allocOp.symbolOperands()); AllocOp newAlloc = b.create(allocOp.getLoc(), newMemRefType, - llvm::None, allocOp.alignmentAttr()); + allocOp.alignmentAttr()); AffineMap layoutMap = memrefType.getAffineMaps().front(); // Replace all uses of the old memref. if (failed(replaceAllMemRefUsesWith(oldMemRef, /*newMemRef=*/newAlloc, /*extraIndices=*/{}, /*indexRemap=*/layoutMap, /*extraOperands=*/{}, /*symbolOperands=*/symbolOperands, /*domInstFilter=*/nullptr, /*postDomInstFilter=*/nullptr, /*allowDereferencingOps=*/true))) { // If it failed (due to escapes for example), bail out. newAlloc.erase(); return failure(); } // Replace any uses of the original alloc op and erase it. All remaining uses // have to be dealloc's; RAMUW above would've failed otherwise. assert(llvm::all_of(oldMemRef.getUsers(), [](Operation *op) { return isa(op); })); oldMemRef.replaceAllUsesWith(newAlloc); allocOp.erase(); return success(); } MemRefType mlir::normalizeMemRefType(MemRefType memrefType, OpBuilder b, unsigned numSymbolicOperands) { unsigned rank = memrefType.getRank(); if (rank == 0) return memrefType; ArrayRef layoutMaps = memrefType.getAffineMaps(); if (layoutMaps.empty() || layoutMaps.front() == b.getMultiDimIdentityMap(rank)) { // Either no maps is associated with this memref or this memref has // a trivial (identity) map. return memrefType; } // We don't do any checks for one-to-one'ness; we assume that it is // one-to-one. // TODO: Only for static memref's for now. if (memrefType.getNumDynamicDims() > 0) return memrefType; // We have a single map that is not an identity map. Create a new memref // with the right shape and an identity layout map. ArrayRef shape = memrefType.getShape(); // FlatAffineConstraint may later on use symbolicOperands. FlatAffineConstraints fac(rank, numSymbolicOperands); for (unsigned d = 0; d < rank; ++d) { fac.addConstantLowerBound(d, 0); fac.addConstantUpperBound(d, shape[d] - 1); } // We compose this map with the original index (logical) space to derive // the upper bounds for the new index space. AffineMap layoutMap = layoutMaps.front(); unsigned newRank = layoutMap.getNumResults(); if (failed(fac.composeMatchingMap(layoutMap))) return memrefType; // TODO: Handle semi-affine maps. // Project out the old data dimensions. fac.projectOut(newRank, fac.getNumIds() - newRank - fac.getNumLocalIds()); SmallVector newShape(newRank); for (unsigned d = 0; d < newRank; ++d) { // The lower bound for the shape is always zero. auto ubConst = fac.getConstantUpperBound(d); // For a static memref and an affine map with no symbols, this is // always bounded. assert(ubConst.hasValue() && "should always have an upper bound"); if (ubConst.getValue() < 0) // This is due to an invalid map that maps to a negative space. return memrefType; newShape[d] = ubConst.getValue() + 1; } // Create the new memref type after trivializing the old layout map. MemRefType newMemRefType = MemRefType::Builder(memrefType) .setShape(newShape) .setAffineMaps(b.getMultiDimIdentityMap(newRank)); return newMemRefType; } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir index 76aff5c6d401..eb2477438649 100644 --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -1,1307 +1,1309 @@ // RUN: mlir-opt -allow-unregistered-dialect %s -split-input-file -verify-diagnostics func @dim(%arg : tensor<1x?xf32>) { %c2 = constant 2 : index dim %arg, %c2 : tensor<1x?xf32> // expected-error {{'std.dim' op index is out of range}} return } // ----- func @rank(f32) { ^bb(%0: f32): "std.rank"(%0): (f32)->index // expected-error {{'std.rank' op operand #0 must be any tensor or memref type}} return } // ----- func @constant() { ^bb: %x = "std.constant"(){value = "xyz"} : () -> i32 // expected-error {{unsupported 'value' attribute}} return } // ----- func @constant_out_of_range() { ^bb: %x = "std.constant"(){value = 100} : () -> i1 // expected-error {{requires attribute's type ('i64') to match op's return type ('i1')}} return } // ----- func @constant_wrong_type() { ^bb: %x = "std.constant"(){value = 10.} : () -> f32 // expected-error {{requires attribute's type ('f64') to match op's return type ('f32')}} return } // ----- func @affine_apply_no_map() { ^bb0: %i = constant 0 : index %x = "affine.apply" (%i) { } : (index) -> (index) // expected-error {{requires attribute 'map'}} return } // ----- func @affine_apply_wrong_operand_count() { ^bb0: %i = constant 0 : index %x = "affine.apply" (%i) {map = affine_map<(d0, d1) -> ((d0 + 1), (d1 + 2))>} : (index) -> (index) // expected-error {{'affine.apply' op operand count and affine map dimension and symbol count must match}} return } // ----- func @affine_apply_wrong_result_count() { ^bb0: %i = constant 0 : index %j = constant 1 : index %x = "affine.apply" (%i, %j) {map = affine_map<(d0, d1) -> ((d0 + 1), (d1 + 2))>} : (index,index) -> (index) // expected-error {{'affine.apply' op mapping must produce one value}} return } // ----- func @unknown_custom_op() { ^bb0: %i = crazyThing() {value = 0} : () -> index // expected-error {{custom op 'crazyThing' is unknown}} return } // ----- func @unknown_std_op() { // expected-error@+1 {{unregistered operation 'std.foo_bar_op' found in dialect ('std') that does not allow unknown operations}} %0 = "std.foo_bar_op"() : () -> index return } // ----- func @bad_alloc_wrong_dynamic_dim_count() { ^bb0: %0 = constant 7 : index // Test alloc with wrong number of dynamic dimensions. - %1 = alloc(%0)[%1] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{op 'std.alloc' dimension operand count does not equal memref dynamic dimension count}} + // expected-error@+1 {{dimension operand count does not equal memref dynamic dimension count}} + %1 = alloc(%0)[%0] : memref<2x4xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } // ----- func @bad_alloc_wrong_symbol_count() { ^bb0: %0 = constant 7 : index // Test alloc with wrong number of symbols - %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> // expected-error {{operand count does not equal dimension plus symbol operand count}} + // expected-error@+1 {{symbol operand count does not equal memref symbol count}} + %1 = alloc(%0) : memref<2x?xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> return } // ----- func @test_store_zero_results() { ^bb0: %0 = alloc() : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> %1 = constant 0 : index %2 = constant 1 : index %3 = load %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> // Test that store returns zero results. %4 = store %3, %0[%1, %2] : memref<1024x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1> // expected-error {{cannot name an operation with no results}} return } // ----- func @test_store_zero_results2(%x: i32, %p: memref) { "std.store"(%x,%p) : (i32, memref) -> i32 // expected-error {{'std.store' op requires zero results}} return } // ----- func @test_alloc_memref_map_rank_mismatch() { ^bb0: %0 = alloc() : memref<1024x64xf32, affine_map<(d0) -> (d0)>, 1> // expected-error {{memref affine map dimension mismatch}} return } // ----- func @intlimit2() { ^bb: %0 = "std.constant"() {value = 0} : () -> i4096 %1 = "std.constant"() {value = 1} : () -> i4097 // expected-error {{integer bitwidth is limited to 4096 bits}} return } // ----- func @calls(%arg0: i32) { %x = call @calls() : () -> i32 // expected-error {{incorrect number of operands for callee}} return } // ----- func @func_with_ops(f32) { ^bb0(%a : f32): %sf = addf %a, %a, %a : f32 // expected-error {{'std.addf' op expected 2 operands}} } // ----- func @func_with_ops(f32) { ^bb0(%a : f32): %sf = addf(%a, %a) : f32 // expected-error {{expected ':'}} } // ----- func @func_with_ops(f32) { ^bb0(%a : f32): %sf = addf{%a, %a} : f32 // expected-error {{expected attribute name}} } // ----- func @func_with_ops(f32) { ^bb0(%a : f32): // expected-error@+1 {{'std.addi' op operand #0 must be signless-integer-like}} %sf = addi %a, %a : f32 } // ----- func @func_with_ops(i32) { ^bb0(%a : i32): %sf = addf %a, %a : i32 // expected-error {{'std.addf' op operand #0 must be floating-point-like}} } // ----- func @func_with_ops(i32) { ^bb0(%a : i32): // expected-error@+1 {{failed to satisfy constraint: allowed 64-bit signless integer cases: 0, 1, 2, 3, 4, 5, 6, 7, 8, 9}} %r = "std.cmpi"(%a, %a) {predicate = 42} : (i32, i32) -> i1 } // ----- // Comparison are defined for arguments of the same type. func @func_with_ops(i32, i64) { ^bb0(%a : i32, %b : i64): // expected-note {{prior use here}} %r = cmpi "eq", %a, %b : i32 // expected-error {{use of value '%b' expects different type than prior uses}} } // ----- // Comparisons must have the "predicate" attribute. func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): %r = cmpi %a, %b : i32 // expected-error {{expected non-function type}} } // ----- // Integer comparisons are not recognized for float types. func @func_with_ops(f32, f32) { ^bb0(%a : f32, %b : f32): %r = cmpi "eq", %a, %b : f32 // expected-error {{'lhs' must be signless-integer-like, but got 'f32'}} } // ----- // Result type must be boolean like. func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): %r = "std.cmpi"(%a, %b) {predicate = 0} : (i32, i32) -> i32 // expected-error {{op result #0 must be bool-like}} } // ----- func @func_with_ops(i32, i32) { ^bb0(%a : i32, %b : i32): // expected-error@+1 {{requires attribute 'predicate'}} %r = "std.cmpi"(%a, %b) {foo = 1} : (i32, i32) -> i1 } // ----- func @func_with_ops() { ^bb0: %c = constant dense<0> : vector<42 x i32> // expected-error@+1 {{op requires the same shape for all operands and results}} %r = "std.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1> } // ----- func @func_with_ops(i32, i32, i32) { ^bb0(%cond : i32, %t : i32, %f : i32): // expected-error@+2 {{different type than prior uses}} // expected-note@-2 {{prior use here}} %r = select %cond, %t, %f : i32 } // ----- func @func_with_ops(i32, i32, i32) { ^bb0(%cond : i32, %t : i32, %f : i32): // expected-error@+1 {{op operand #0 must be bool-like}} %r = "std.select"(%cond, %t, %f) : (i32, i32, i32) -> i32 } // ----- func @func_with_ops(i1, i32, i64) { ^bb0(%cond : i1, %t : i32, %f : i64): // expected-error@+1 {{all of {true_value, false_value, result} have same type}} %r = "std.select"(%cond, %t, %f) : (i1, i32, i64) -> i32 } // ----- func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<42xi32>' and 'vector<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } // ----- func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'tensor<42xi32>' and 'tensor<12xi1>'}} %r = "std.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> } // ----- func @invalid_cmp_shape(%idx : () -> ()) { // expected-error@+1 {{'lhs' must be signless-integer-like, but got '() -> ()'}} %cmp = cmpi "eq", %idx, %idx : () -> () // ----- func @dma_start_not_enough_operands() { // expected-error@+1 {{expected at least 4 operands}} "std.dma_start"() : () -> () } // ----- func @dma_no_src_memref(%m : f32, %tag : f32, %c0 : index) { // expected-error@+1 {{expected source to be of memref type}} dma_start %m[%c0], %m[%c0], %c0, %tag[%c0] : f32, f32, f32 } // ----- func @dma_start_not_enough_operands_for_src( %src: memref<2x2x2xf32>, %idx: index) { // expected-error@+1 {{expected at least 7 operands}} "std.dma_start"(%src, %idx, %idx, %idx) : (memref<2x2x2xf32>, index, index, index) -> () } // ----- func @dma_start_src_index_wrong_type( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref, %flt: f32) { // expected-error@+1 {{expected source indices to be of index type}} "std.dma_start"(%src, %idx, %flt, %dst, %idx, %tag, %idx) : (memref<2x2xf32>, index, f32, memref<2xf32,1>, index, memref, index) -> () } // ----- func @dma_no_dst_memref(%m : f32, %tag : f32, %c0 : index) { %mref = alloc() : memref<8 x f32> // expected-error@+1 {{expected destination to be of memref type}} dma_start %mref[%c0], %m[%c0], %c0, %tag[%c0] : memref<8 x f32>, f32, f32 } // ----- func @dma_start_not_enough_operands_for_dst( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref) { // expected-error@+1 {{expected at least 7 operands}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index) -> () } // ----- func @dma_start_dst_index_wrong_type( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref, %flt: f32) { // expected-error@+1 {{expected destination indices to be of index type}} "std.dma_start"(%src, %idx, %idx, %dst, %flt, %tag, %idx) : (memref<2x2xf32>, index, index, memref<2xf32,1>, f32, memref, index) -> () } // ----- func @dma_start_dst_index_wrong_type( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref, %flt: f32) { // expected-error@+1 {{expected num elements to be of index type}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %flt, %tag) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, f32, memref) -> () } // ----- func @dma_no_tag_memref(%tag : f32, %c0 : index) { %mref = alloc() : memref<8 x f32> // expected-error@+1 {{expected tag to be of memref type}} dma_start %mref[%c0], %mref[%c0], %c0, %tag[%c0] : memref<8 x f32>, memref<8 x f32>, f32 } // ----- func @dma_start_not_enough_operands_for_tag( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref<2xi32,2>) { // expected-error@+1 {{expected at least 8 operands}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>) -> () } // ----- func @dma_start_dst_index_wrong_type( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref<2xi32,2>, %flt: f32) { // expected-error@+1 {{expected tag indices to be of index type}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %flt) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref<2xi32,2>, f32) -> () } // ----- func @dma_start_same_space( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32>, %tag: memref) { // expected-error@+1 {{DMA should be between different memory spaces}} dma_start %src[%idx, %idx], %dst[%idx], %idx, %tag[] : memref<2x2xf32>, memref<2xf32>, memref } // ----- func @dma_start_too_many_operands( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref) { // expected-error@+1 {{incorrect number of operands}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %idx, %idx) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, index, index) -> () } // ----- func @dma_start_wrong_stride_type( %src: memref<2x2xf32>, %idx: index, %dst: memref<2xf32,1>, %tag: memref, %flt: f32) { // expected-error@+1 {{expected stride and num elements per stride to be of type index}} "std.dma_start"(%src, %idx, %idx, %dst, %idx, %idx, %tag, %idx, %flt) : (memref<2x2xf32>, index, index, memref<2xf32,1>, index, index, memref, index, f32) -> () } // ----- func @dma_wait_not_enough_operands() { // expected-error@+1 {{expected at least 2 operands}} "std.dma_wait"() : () -> () } // ----- func @dma_wait_no_tag_memref(%tag : f32, %c0 : index) { // expected-error@+1 {{expected tag to be of memref type}} "std.dma_wait"(%tag, %c0, %c0) : (f32, index, index) -> () } // ----- func @dma_wait_wrong_index_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { // expected-error@+1 {{expected tag indices to be of index type}} "std.dma_wait"(%tag, %flt, %idx) : (memref<2xi32>, f32, index) -> () } // ----- func @dma_wait_wrong_num_elements_type(%tag : memref<2xi32>, %idx: index, %flt: f32) { // expected-error@+1 {{expected the number of elements to be of index type}} "std.dma_wait"(%tag, %idx, %flt) : (memref<2xi32>, index, f32) -> () } // ----- func @invalid_cmp_attr(%idx : i32) { // expected-error@+1 {{invalid kind of attribute specified}} %cmp = cmpi i1, %idx, %idx : i32 // ----- func @cmpf_generic_invalid_predicate_value(%a : f32) { // expected-error@+1 {{attribute 'predicate' failed to satisfy constraint: allowed 64-bit signless integer cases}} %r = "std.cmpf"(%a, %a) {predicate = 42} : (f32, f32) -> i1 } // ----- func @cmpf_canonical_invalid_predicate_value(%a : f32) { // expected-error@+1 {{invalid predicate attribute specification: "foo"}} %r = cmpf "foo", %a, %a : f32 } // ----- func @cmpf_canonical_invalid_predicate_value_signed(%a : f32) { // expected-error@+1 {{invalid predicate attribute specification: "sge"}} %r = cmpf "sge", %a, %a : f32 } // ----- func @cmpf_canonical_invalid_predicate_value_no_order(%a : f32) { // expected-error@+1 {{invalid predicate attribute specification: "eq"}} %r = cmpf "eq", %a, %a : f32 } // ----- func @cmpf_canonical_no_predicate_attr(%a : f32, %b : f32) { %r = cmpf %a, %b : f32 // expected-error {{}} } // ----- func @cmpf_generic_no_predicate_attr(%a : f32, %b : f32) { // expected-error@+1 {{requires attribute 'predicate'}} %r = "std.cmpf"(%a, %b) {foo = 1} : (f32, f32) -> i1 } // ----- func @cmpf_wrong_type(%a : i32, %b : i32) { %r = cmpf "oeq", %a, %b : i32 // expected-error {{must be floating-point-like}} } // ----- func @cmpf_generic_wrong_result_type(%a : f32, %b : f32) { // expected-error@+1 {{result #0 must be bool-like}} %r = "std.cmpf"(%a, %b) {predicate = 0} : (f32, f32) -> f32 } // ----- func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 { %r = cmpf "oeq", %a, %b : f32 // expected-note {{prior use here}} // expected-error@+1 {{use of value '%r' expects different type than prior uses}} return %r : f32 } // ----- func @cmpf_result_shape_mismatch(%a : vector<42xf32>) { // expected-error@+1 {{op requires the same shape for all operands and results}} %r = "std.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1> } // ----- func @cmpf_operand_shape_mismatch(%a : vector<42xf32>, %b : vector<41xf32>) { // expected-error@+1 {{op requires all operands to have the same type}} %r = "std.cmpf"(%a, %b) {predicate = 0} : (vector<42 x f32>, vector<41 x f32>) -> vector<42 x i1> } // ----- func @cmpf_generic_operand_type_mismatch(%a : f32, %b : f64) { // expected-error@+1 {{op requires all operands to have the same type}} %r = "std.cmpf"(%a, %b) {predicate = 0} : (f32, f64) -> i1 } // ----- func @cmpf_canonical_type_mismatch(%a : f32, %b : f64) { // expected-note {{prior use here}} // expected-error@+1 {{use of value '%b' expects different type than prior uses}} %r = cmpf "oeq", %a, %b : f32 } // ----- func @extract_element_no_operands() { // expected-error@+1 {{op expected 1 or more operands}} %0 = "std.extract_element"() : () -> f32 return } // ----- func @extract_element_no_indices(%v : vector<3xf32>) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = "std.extract_element"(%v) : (vector<3xf32>) -> f32 return } // ----- func @extract_element_invalid_index_type(%v : vector<3xf32>, %i : i32) { // expected-error@+1 {{operand #1 must be index}} %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, i32) -> f32 return } // ----- func @extract_element_element_result_type_mismatch(%v : vector<3xf32>, %i : index) { // expected-error@+1 {{result type matches element type of aggregate}} %0 = "std.extract_element"(%v, %i) : (vector<3xf32>, index) -> f64 return } // ----- func @extract_element_vector_too_many_indices(%v : vector<3xf32>, %i : index) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = "std.extract_element"(%v, %i, %i) : (vector<3xf32>, index, index) -> f32 return } // ----- func @extract_element_tensor_too_many_indices(%t : tensor<2x3xf32>, %i : index) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = "std.extract_element"(%t, %i, %i, %i) : (tensor<2x3xf32>, index, index, index) -> f32 return } // ----- func @extract_element_tensor_too_few_indices(%t : tensor<2x3xf32>, %i : index) { // expected-error@+1 {{incorrect number of indices for extract_element}} %0 = "std.extract_element"(%t, %i) : (tensor<2x3xf32>, index) -> f32 return } // ----- func @tensor_from_elements_wrong_result_type() { // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} %c0 = constant 0 : i32 %0 = tensor_from_elements %c0 : tensor<*xi32> return } // ----- func @tensor_from_elements_wrong_elements_count() { // expected-error@+2 {{1 operands present, but expected 2}} %c0 = constant 0 : index %0 = tensor_from_elements %c0 : tensor<2xindex> return } // ----- func @index_cast_index_to_index(%arg0: index) { // expected-error@+1 {{are cast incompatible}} %0 = index_cast %arg0: index to index return } // ----- func @index_cast_float(%arg0: index, %arg1: f32) { // expected-error@+1 {{are cast incompatible}} %0 = index_cast %arg0 : index to f32 return } // ----- func @index_cast_float_to_index(%arg0: f32) { // expected-error@+1 {{are cast incompatible}} %0 = index_cast %arg0 : f32 to index return } // ----- func @sitofp_i32_to_i64(%arg0 : i32) { // expected-error@+1 {{are cast incompatible}} %0 = sitofp %arg0 : i32 to i64 return } // ----- func @sitofp_f32_to_i32(%arg0 : f32) { // expected-error@+1 {{are cast incompatible}} %0 = sitofp %arg0 : f32 to i32 return } // ----- func @fpext_f32_to_f16(%arg0 : f32) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : f32 to f16 return } // ----- func @fpext_f16_to_f16(%arg0 : f16) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : f16 to f16 return } // ----- func @fpext_i32_to_f32(%arg0 : i32) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : i32 to f32 return } // ----- func @fpext_f32_to_i32(%arg0 : f32) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : f32 to i32 return } // ----- func @fpext_vec(%arg0 : vector<2xf16>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fpext %arg0 : vector<2xf16> to vector<3xf32> return } // ----- func @fpext_vec_f32_to_f16(%arg0 : vector<2xf32>) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : vector<2xf32> to vector<2xf16> return } // ----- func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : vector<2xf16> to vector<2xf16> return } // ----- func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : vector<2xi32> to vector<2xf32> return } // ----- func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) { // expected-error@+1 {{are cast incompatible}} %0 = fpext %arg0 : vector<2xf32> to vector<2xi32> return } // ----- func @fptrunc_f16_to_f32(%arg0 : f16) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : f16 to f32 return } // ----- func @fptrunc_f32_to_f32(%arg0 : f32) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : f32 to f32 return } // ----- func @fptrunc_i32_to_f32(%arg0 : i32) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : i32 to f32 return } // ----- func @fptrunc_f32_to_i32(%arg0 : f32) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : f32 to i32 return } // ----- func @fptrunc_vec(%arg0 : vector<2xf16>) { // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type: found 'vector<3xf32>' and 'vector<2xf16>'}} %0 = fptrunc %arg0 : vector<2xf16> to vector<3xf32> return } // ----- func @fptrunc_vec_f16_to_f32(%arg0 : vector<2xf16>) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : vector<2xf16> to vector<2xf32> return } // ----- func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : vector<2xf32> to vector<2xf32> return } // ----- func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : vector<2xi32> to vector<2xf32> return } // ----- func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) { // expected-error@+1 {{are cast incompatible}} %0 = fptrunc %arg0 : vector<2xf32> to vector<2xi32> return } // ----- func @sexti_index_as_operand(%arg0 : index) { // expected-error@+1 {{'index' is not a valid operand type}} %0 = sexti %arg0 : index to i128 return } // ----- func @zexti_index_as_operand(%arg0 : index) { // expected-error@+1 {{'index' is not a valid operand type}} %0 = zexti %arg0 : index to i128 return } // ----- func @trunci_index_as_operand(%arg0 : index) { // expected-error@+1 {{'index' is not a valid operand type}} %2 = trunci %arg0 : index to i128 return } // ----- func @sexti_index_as_result(%arg0 : i1) { // expected-error@+1 {{'index' is not a valid result type}} %0 = sexti %arg0 : i1 to index return } // ----- func @zexti_index_as_operand(%arg0 : i1) { // expected-error@+1 {{'index' is not a valid result type}} %0 = zexti %arg0 : i1 to index return } // ----- func @trunci_index_as_result(%arg0 : i128) { // expected-error@+1 {{'index' is not a valid result type}} %2 = trunci %arg0 : i128 to index return } // ----- func @sexti_cast_to_narrower(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = sexti %arg0 : i16 to i15 return } // ----- func @zexti_cast_to_narrower(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = zexti %arg0 : i16 to i15 return } // ----- func @trunci_cast_to_wider(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = trunci %arg0 : i16 to i17 return } // ----- func @sexti_cast_to_same_width(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = sexti %arg0 : i16 to i16 return } // ----- func @zexti_cast_to_same_width(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = zexti %arg0 : i16 to i16 return } // ----- func @trunci_cast_to_same_width(%arg0 : i16) { // expected-error@+1 {{must be wider}} %0 = trunci %arg0 : i16 to i16 return } // ----- func @return_not_in_function() { "foo.region"() ({ // expected-error@+1 {{'std.return' op expects parent op 'func'}} return }): () -> () return } // ----- func @invalid_splat(%v : f32) { splat %v : memref<8xf32> // expected-error@-1 {{must be vector of any type values or statically shaped tensor of any type values}} return } // ----- func @invalid_splat(%v : vector<8xf32>) { %w = splat %v : tensor<8xvector<8xf32>> // expected-error@-1 {{must be integer or float type}} return } // ----- func @invalid_splat(%v : f32) { // expected-note {{prior use here}} splat %v : vector<8xf64> // expected-error@-1 {{expects different type than prior uses}} return } // ----- func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> // expected-error@+1 {{expects 1 offset operand}} %1 = view %0[][%arg0, %arg1] : memref<2048xi8> to memref return } // ----- func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> // expected-error@+1 {{unsupported map for base memref type}} %1 = view %0[%arg2][%arg0, %arg1] : memref<2048xi8, affine_map<(d0) -> (d0 floordiv 8, d0 mod 8)>> to memref (d0 * 4 + d1 + s0)>> return } // ----- func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> // expected-error@+1 {{unsupported map for result memref type}} %1 = view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref (d0, d1, s0)>> return } // ----- func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8, 2> // expected-error@+1 {{different memory spaces}} %1 = view %0[%arg2][%arg0, %arg1] : memref<2048xi8, 2> to memref return } // ----- func @invalid_view(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<2048xi8> // expected-error@+1 {{incorrect number of size operands for type}} %1 = view %0[%arg2][%arg0] : memref<2048xi8> to memref return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> // expected-error@+1 {{different memory spaces}} %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> // expected-error@+1 {{is not strided}} %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected 3 offset values}} %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1] : memref<8x16x4xf32> to memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>' or a rank-reduced version. (mismatch of result strides)}} %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] : memref<8x16x4xf32> to memref return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result element type to be 'f32'}} %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<8x16x4xi32> return } // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result rank to be smaller or equal to the source rank.}} %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<8x16x4x3xi32> return } // ----- func @invalid_rank_reducing_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> // expected-error@+1 {{expected result type to be 'memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>>' or a rank-reduced version. (mismatch of result sizes)}} %1 = subview %0[0, 0, 0][8, 16, 4][1, 1, 1] : memref<8x16x4xf32> to memref<16x4xf32> return } // ----- func @invalid_rank_reducing_subview(%arg0 : memref, %arg1 : index, %arg2 : index) { // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1)>>' or a rank-reduced version. (mismatch of result strides)}} %0 = subview %arg0[0, %arg1][%arg2, 1][1, 1] : memref to memref return } // ----- func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> return } // ----- func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2 + 16)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:16, strides:[64, 16, 1]> return } // ----- // incompatible element types func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xi32>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xi32> return } // ----- func @invalid_prefetch_rw(%i : index) { %0 = alloc() : memref<10xf32> // expected-error@+1 {{rw specifier has to be 'read' or 'write'}} prefetch %0[%i], rw, locality<0>, data : memref<10xf32> return } // ----- func @invalid_prefetch_cache_type(%i : index) { %0 = alloc() : memref<10xf32> // expected-error@+1 {{cache type has to be 'data' or 'instr'}} prefetch %0[%i], read, locality<0>, false : memref<10xf32> return } // ----- func @invalid_prefetch_locality_hint(%i : index) { %0 = alloc() : memref<10xf32> // expected-error@+1 {{32-bit signless integer attribute whose minimum value is 0 whose maximum value is 3}} prefetch %0[%i], read, locality<5>, data : memref<10xf32> return } // ----- // incompatible memory space func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> // expected-error@+1 {{operand type 'memref<2x5xf32>' and result type 'memref<*xf32, 1>' are cast incompatible}} %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 1> return } // ----- // unranked to unranked func @invalid_memref_cast() { %0 = alloc() : memref<2x5xf32, 0> %1 = memref_cast %0 : memref<2x5xf32, 0> to memref<*xf32, 0> // expected-error@+1 {{operand type 'memref<*xf32>' and result type 'memref<*xf32>' are cast incompatible}} %2 = memref_cast %1 : memref<*xf32, 0> to memref<*xf32, 0> return } // ----- func @atomic_rmw_idxs_rank_mismatch(%I: memref<16x10xf32>, %i : index, %val : f32) { // expected-error@+1 {{expects the number of subscripts to be equal to memref rank}} %x = atomic_rmw "addf" %val, %I[%i] : (f32, memref<16x10xf32>) -> f32 return } // ----- func @atomic_rmw_expects_float(%I: memref<16x10xi32>, %i : index, %val : i32) { // expected-error@+1 {{expects a floating-point type}} %x = atomic_rmw "addf" %val, %I[%i, %i] : (i32, memref<16x10xi32>) -> i32 return } // ----- func @atomic_rmw_expects_int(%I: memref<16x10xf32>, %i : index, %val : f32) { // expected-error@+1 {{expects an integer type}} %x = atomic_rmw "addi" %val, %I[%i, %i] : (f32, memref<16x10xf32>) -> f32 return } // ----- func @generic_atomic_rmw_wrong_arg_num(%I: memref<10xf32>, %i : index) { // expected-error@+1 {{expected single number of entry block arguments}} %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%arg0 : f32, %arg1 : f32): %c1 = constant 1.0 : f32 atomic_yield %c1 : f32 } return } // ----- func @generic_atomic_rmw_wrong_arg_type(%I: memref<10xf32>, %i : index) { // expected-error@+1 {{expected block argument of the same type result type}} %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%old_value : i32): %c1 = constant 1.0 : f32 atomic_yield %c1 : f32 } return } // ----- func @generic_atomic_rmw_result_type_mismatch(%I: memref<10xf32>, %i : index) { // expected-error@+1 {{failed to verify that result type matches element type of memref}} %0 = "std.generic_atomic_rmw"(%I, %i) ( { ^bb0(%old_value: f32): %c1 = constant 1.0 : f32 atomic_yield %c1 : f32 }) : (memref<10xf32>, index) -> i32 return } // ----- func @generic_atomic_rmw_has_side_effects(%I: memref<10xf32>, %i : index) { // expected-error@+4 {{should contain only operations with no side effects}} %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%old_value : f32): %c1 = constant 1.0 : f32 %buf = alloc() : memref<2048xf32> atomic_yield %c1 : f32 } } // ----- func @atomic_yield_type_mismatch(%I: memref<10xf32>, %i : index) { // expected-error@+4 {{op types mismatch between yield op: 'i32' and its parent: 'f32'}} %x = generic_atomic_rmw %I[%i] : memref<10xf32> { ^bb0(%old_value : f32): %c1 = constant 1 : i32 atomic_yield %c1 : i32 } return } // ----- // alignment is not power of 2. func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{alignment must be power of 2}} std.assume_alignment %0, 12 : memref<4x4xf16> return } // ----- // 0 alignment value. func @assume_alignment(%0: memref<4x4xf16>) { // expected-error@+1 {{attribute 'alignment' failed to satisfy constraint: 32-bit signless integer attribute whose value is positive}} std.assume_alignment %0, 0 : memref<4x4xf16> return } // ----- "alloca_without_scoped_alloc_parent"() ( { std.alloca() : memref<1xf32> // expected-error@-1 {{requires an ancestor op with AutomaticAllocationScope trait}} return }) : () -> () // ----- func @complex_number_from_non_float_operands(%real: i32, %imag: i32) { // expected-error@+1 {{'complex' must be complex type with floating-point elements, but got 'complex'}} std.create_complex %real, %imag : complex return } // ----- // expected-note@+1 {{prior use here}} func @complex_number_from_different_float_types(%real: f32, %imag: f64) { // expected-error@+1 {{expects different type than prior uses: 'f32' vs 'f64'}} std.create_complex %real, %imag : complex return } // ----- // expected-note@+1 {{prior use here}} func @complex_number_from_incompatible_float_type(%real: f32, %imag: f32) { // expected-error@+1 {{expects different type than prior uses: 'f64' vs 'f32'}} std.create_complex %real, %imag : complex return } // ----- // expected-note@+1 {{prior use here}} func @real_part_from_incompatible_complex_type(%cplx: complex) { // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} std.re %cplx : complex return } // ----- // expected-note@+1 {{prior use here}} func @imaginary_part_from_incompatible_complex_type(%cplx: complex) { // expected-error@+1 {{expects different type than prior uses: 'complex' vs 'complex'}} std.re %cplx : complex return } // ----- func @subtensor_wrong_dynamic_type(%t: tensor<8x16x4xf32>, %idx : index) { // expected-error @+1 {{expected result type to be 'tensor<4x4x4xf32>' or a rank-reduced version. (mismatch of result sizes)}} %0 = subtensor %t[0, 2, 0][4, 4, 4][1, 1, 1] : tensor<8x16x4xf32> to tensor return } // ----- func @subtensor_wrong_static_type(%t: tensor<8x16x4xf32>, %idx : index) { // expected-error @+1 {{expected result type to be 'tensor' or a rank-reduced version. (mismatch of result sizes)}} %0 = subtensor %t[0, 0, 0][%idx, 3, %idx][1, 1, 1] : tensor<8x16x4xf32> to tensor<4x4x4xf32> return }