diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -33,17 +33,14 @@ class FuncOp; class OpBuilder; -/// Auxiliary range data structure to unpack the offset, size and stride -/// operands of the SubViewOp / SubTensorOp into a list of triples. -/// Such a list of triple is sometimes more convenient to manipulate. -struct Range { - Value offset; - Value size; - Value stride; -}; - raw_ostream &operator<<(raw_ostream &os, Range &range); +/// Return the list of Range (i.e. offset, size, stride). Each Range +/// entry contains either the dynamic value or a ConstantIndexOp constructed +/// with `b` at location `loc`. +SmallVector getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc); + #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -223,81 +223,14 @@ Std_Op { code extraBaseClassDeclaration = [{ - /// Returns the number of dynamic offset operands. - int64_t getNumOffsets() { return llvm::size(offsets()); } - - /// Returns the number of dynamic size operands. - int64_t getNumSizes() { return llvm::size(sizes()); } - - /// Returns the number of dynamic stride operands. - int64_t getNumStrides() { return llvm::size(strides()); } - /// Returns the dynamic sizes for this subview operation if specified. operand_range getDynamicSizes() { return sizes(); } - /// Returns in `staticStrides` the static value of the stride - /// operands. Returns failure() if the static value of the stride - /// operands could not be retrieved. - LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) - return failure(); - staticStrides.reserve(static_strides().size()); - for (auto s : static_strides().getAsValueRange()) - staticStrides.push_back(s.getZExtValue()); - return success(); - } - /// Return the list of Range (i.e. offset, size, stride). Each /// Range entry contains either the dynamic value or a ConstantIndexOp /// constructed with `b` at location `loc`. - SmallVector getOrCreateRanges(OpBuilder &b, Location loc); - - /// Return the offsets as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1; - return llvm::to_vector<4>(llvm::map_range( - static_offsets().cast(), [&](Attribute a) -> Value { - int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticOffset)); - })); - } - - /// Return the sizes as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size(); - return llvm::to_vector<4>(llvm::map_range( - static_sizes().cast(), [&](Attribute a) -> Value { - int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticSize)); - })); - } - - /// Return the strides as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed with - /// `b` at location `loc` - SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); - return llvm::to_vector<4>(llvm::map_range( - static_strides().cast(), [&](Attribute a) -> Value { - int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticStride)); - })); + SmallVector getOrCreateRanges(OpBuilder &b, Location loc) { + return mlir::getOrCreateRanges(*this, b, loc); } /// Return the rank of the source ShapedType. @@ -305,107 +238,11 @@ return source().getType().cast().getRank(); } - /// Return the rank of the result ShapedType. - unsigned getResultRank() { return getType().getRank(); } - - /// Return true if the offset `idx` is a static constant. - bool isDynamicOffset(unsigned idx) { - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - /// Return true if the size `idx` is a static constant. - bool isDynamicSize(unsigned idx) { - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - /// Return true if the stride `idx` is a static constant. - bool isDynamicStride(unsigned idx) { - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - - /// Assert the offset `idx` is a static constant and return its value. - int64_t getStaticOffset(unsigned idx) { - assert(!isDynamicOffset(idx) && "expected static offset"); - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the size `idx` is a static constant and return its value. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the stride `idx` is a static constant and return its value. - int64_t getStaticStride(unsigned idx) { - assert(!isDynamicStride(idx) && "expected static stride"); - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, - llvm::function_ref isDynamic, unsigned idx) { - return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - } - /// Assert the offset `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicOffset(unsigned idx) { - assert(isDynamicOffset(idx) && "expected static offset"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_offsets().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + numDynamic; - } - /// Assert the size `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected static size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().cast(), ShapedType::isDynamic, idx); - return 1 + offsets().size() + numDynamic; - } - /// Assert the stride `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicStride(unsigned idx) { - assert(isDynamicStride(idx) && "expected static stride"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_strides().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + offsets().size() + sizes().size() + numDynamic; - } - - /// Assert the offset `idx` is dynamic and return its value. - Value getDynamicOffset(unsigned idx) { - return getOperand(getIndexOfDynamicOffset(idx)); - } - /// Assert the size `idx` is dynamic and return its value. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - /// Assert the stride `idx` is dynamic and return its value. - Value getDynamicStride(unsigned idx) { - return getOperand(getIndexOfDynamicStride(idx)); - } - - static StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static StringRef getStaticStridesAttrName() { - return "static_strides"; - } static ArrayRef getSpecialAttrNames() { static SmallVector names{ - getStaticOffsetsAttrName(), - getStaticSizesAttrName(), - getStaticStridesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), getOperandSegmentSizeAttr()}; return names; } @@ -1753,7 +1590,6 @@ Optional getConstantIndex(); }]; - let hasCanonicalizer = 1; let hasFolder = 1; } @@ -2235,7 +2071,6 @@ operand_range getIndices() { return {operand_begin() + 1, operand_end()}; } }]; - let hasCanonicalizer = 1; let hasFolder = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; @@ -2340,7 +2175,7 @@ def MemRefReinterpretCastOp: BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ - NoSideEffect, ViewLikeOpInterface + NoSideEffect, ViewLikeOpInterface, OffsetSizeAndStrideOpInterface ]> { let summary = "memref reinterpret cast operation"; let description = [{ @@ -3210,7 +3045,7 @@ //===----------------------------------------------------------------------===// def SubViewOp : BaseOpWithOffsetSizesAndStrides< - "subview", [DeclareOpInterfaceMethods] > { + "subview", [DeclareOpInterfaceMethods, OffsetSizeAndStrideOpInterface] > { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -3399,7 +3234,7 @@ // SubTensorOp //===----------------------------------------------------------------------===// -def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> { +def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor operation"; let description = [{ The "subtensor" operation extract a tensor from another tensor as @@ -3489,7 +3324,7 @@ // SubTensorInsertOp //===----------------------------------------------------------------------===// -def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> { +def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor_insert operation"; let description = [{ The "subtensor_insert" operation insert a tensor `source` into another diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -13,7 +13,20 @@ #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ +#include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +/// Auxiliary range data structure to unpack the offset, size and stride +/// operands into a list of triples. Such a list can be more convenient to +/// manipulate. +struct Range { + Value offset; + Value size; + Value stride; +}; +} // namespace mlir /// Include the generated interface declarations. #include "mlir/Interfaces/ViewLikeInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -30,4 +30,334 @@ ]; } +def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface"> { + let description = [{ + Common interface for ops that allow specifying mixed dynamic and static + offsets, sizes and strides variadic operands. + Ops that implement this interface need to expose the following methods: + 1. `getSourceRank` to specify the common length of static integer + attributes. + 2. `offsets`, `sizes` and `strides` variadic operands of length each at + most `getSourceRank()`. + 3. `static_offsets`, `static_sizes` and `static_strides` integer array + attributes of length exactly `getSourceRank()`. + + The invariants of this interface are: + 1. `static_offsets`, `static_sizes` and `static_strides` have length + exactly `getSourceRank()`. + 2. `offsets`, `sizes` and `strides` have each length + at most `getSourceRank()`. + 3. if an entry of `static_offsets` (resp. `static_sizes`, + `static_strides`) is equal to a special sentinel value, namely + `ShapedType::kDynamicStrideOrOffset` (resp. `ShapedType::kDynamicSize`, + `ShapedType::kDynamicStrideOrOffset`), then the corresponding entry is + a dynamic offset (resp. size, stride). + 4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present + for each dynamic offset (resp. size, stride). + + This interface is useful to factor out common behavior and provide support + for carrying or injecting static behavior through the use of the static + attributes. + }]; + + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the rank of the operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getSourceRank(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic offset operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic size operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.strides(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static offset attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static size attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_strides(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Return true if the offset `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_offsets() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the size `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_sizes() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the stride `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_strides() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicOffset(idx) && "expected static offset"); + APInt v = *(static_offsets(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicStride(idx) && "expected static stride"); + APInt v = *(static_strides(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_offsets().template cast(), + ShapedType::isDynamicStrideOrOffset, + idx); + return 1 + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicSize(idx) && "expected dynamic size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().template cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicStride(idx) && "expected dynamic stride"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_strides().template cast(), + ShapedType::isDynamicStrideOrOffset, + idx); + return 1 + offsets().size() + sizes().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Helper method to compute the number of dynamic entries of `attr`, up to + `idx` using `isDynamic` to determine whether an entry is dynamic. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumDynamicEntriesUpToIdx", + /*args=*/(ins "ArrayAttr":$attr, + "llvm::function_ref":$isDynamic, + "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return std::count_if( + attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicOffset(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicSize(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicStride(idx)); + }] + >, + ]; + + let extraClassDeclaration = [{ + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + }]; +} + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -114,19 +114,18 @@ // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. - SmallVector opOffsets = subViewOp.getOrCreateOffsets(rewriter, loc); - SmallVector opStrides = subViewOp.getOrCreateStrides(rewriter, loc); - assert(opOffsets.size() == indices.size() && - "expected as many indices as rank of subview op result type"); - assert(opStrides.size() == indices.size() && + SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); + auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); + assert(opRanges.size() == indices.size() && "expected as many indices as rank of subview op result type"); // New indices for the load are the current indices * subview_stride + // subview_offset. sourceIndices.resize(indices.size()); for (auto index : llvm::enumerate(indices)) { - auto offset = opOffsets[index.index()]; - auto stride = opStrides[index.index()]; + auto offset = *(opOffsets.begin() + index.index()); + auto stride = *(opStrides.begin() + index.index()); auto mul = rewriter.create(loc, index.value(), stride); sourceIndices[index.index()] = rewriter.create(loc, offset, mul).getResult(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -320,11 +320,10 @@ } /// Verify that a particular offset/size/stride static attribute is well-formed. -template static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, - ArrayAttr attr, llvm::function_ref isDynamic, - ValueRange values) { + OffsetSizeAndStrideOpInterface op, StringRef name, + unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, + llvm::function_ref isDynamic, ValueRange values) { /// Check static and dynamic offsets/sizes/strides breakdown. if (attr.size() != expectedNumElements) return op.emitError("expected ") @@ -348,20 +347,23 @@ } /// Verify static attributes offsets/sizes/strides. -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { +static LogicalResult +verifyOpWithOffsetSizesAndStrides(OffsetSizeAndStrideOpInterface op) { unsigned srcRank = op.getSourceRank(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", srcRank, op.getStaticOffsetsAttrName(), + op, "offset", srcRank, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset, op.offsets()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) + op, "size", srcRank, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", srcRank, op.getStaticStridesAttrName(), + op, "stride", srcRank, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); @@ -2481,10 +2483,7 @@ ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict( op.getAttrs(), - /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - MemRefReinterpretCastOp::getStaticOffsetsAttrName(), - MemRefReinterpretCastOp::getStaticSizesAttrName(), - MemRefReinterpretCastOp::getStaticStridesAttrName()}); + /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames()); p << ": " << op.source().getType() << " to " << op.getType(); } @@ -2508,7 +2507,8 @@ if (parser.parseKeyword("to") || parser.parseKeyword("offset") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offset) || parser.parseComma()) return failure(); @@ -2517,7 +2517,8 @@ SmallVector sizes; if (parser.parseKeyword("sizes") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), ShapedType::kDynamicSize, sizes) || parser.parseComma()) return failure(); @@ -2526,7 +2527,8 @@ SmallVector strides; if (parser.parseKeyword("strides") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, strides)) return failure(); @@ -2567,16 +2569,20 @@ // Verify that dynamic and static offset/sizes/strides arguments/attributes // are consistent. if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) + op, "offset", 1, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + op.static_offsets(), ShapedType::isDynamicStrideOrOffset, + op.offsets()))) return failure(); - unsigned resultRank = op.getResultRank(); + unsigned resultRank = op.getResult().getType().cast().getRank(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", resultRank, op.getStaticSizesAttrName(), + op, "size", resultRank, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", resultRank, op.getStaticStridesAttrName(), + op, "stride", resultRank, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); @@ -3289,13 +3295,16 @@ if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) return failure(); if (parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || - parseListOfOperandsOrIntegers(parser, result, - OpType::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizesInfo) || parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); @@ -3532,7 +3541,6 @@ llvm_unreachable("unexpected subview verification result"); } - /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { MemRefType baseType = op.getSourceType(); @@ -3548,7 +3556,7 @@ if (!isStrided(baseType)) return op.emitError("base type ") << baseType << " is not strided"; - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); // Verify result type against inferred type. @@ -3569,11 +3577,10 @@ /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. -template -static SmallVector getOrCreateRangesImpl(OpType op, OpBuilder &b, - Location loc) { +SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc) { SmallVector res; - unsigned rank = op.getSourceRank(); + unsigned rank = op.getRank(); res.reserve(rank); for (unsigned idx = 0; idx < rank; ++idx) { Value offset = @@ -3592,10 +3599,6 @@ return res; } -SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - namespace { /// Take a list of `values` with potential new constant to extract and a list @@ -3890,7 +3893,8 @@ } OpFoldResult SubViewOp::fold(ArrayRef operands) { - if (getResultRank() == 0 && getSourceRank() == 0) + if (getResult().getType().cast().getRank() == 0 && + source().getType().cast().getRank() == 0) return getViewSource(); return {}; @@ -3961,14 +3965,9 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubTensorOp. static LogicalResult verify(SubTensorOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); // Verify result type against inferred type. @@ -4039,14 +4038,9 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorInsertOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubViewOp. static LogicalResult verify(SubTensorInsertOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); if (op.getType() != op.dest().getType()) return op.emitError("expected result type to be ") << op.dest().getType();