diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -33,17 +33,14 @@ class FuncOp; class OpBuilder; -/// Auxiliary range data structure to unpack the offset, size and stride -/// operands of the SubViewOp / SubTensorOp into a list of triples. -/// Such a list of triple is sometimes more convenient to manipulate. -struct Range { - Value offset; - Value size; - Value stride; -}; - raw_ostream &operator<<(raw_ostream &os, Range &range); +/// Return the list of Range (i.e. offset, size, stride). Each Range +/// entry contains either the dynamic value or a ConstantIndexOp constructed +/// with `b` at location `loc`. +SmallVector getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc); + #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -223,81 +223,14 @@ Std_Op { code extraBaseClassDeclaration = [{ - /// Returns the number of dynamic offset operands. - int64_t getNumOffsets() { return llvm::size(offsets()); } - - /// Returns the number of dynamic size operands. - int64_t getNumSizes() { return llvm::size(sizes()); } - - /// Returns the number of dynamic stride operands. - int64_t getNumStrides() { return llvm::size(strides()); } - /// Returns the dynamic sizes for this subview operation if specified. operand_range getDynamicSizes() { return sizes(); } - /// Returns in `staticStrides` the static value of the stride - /// operands. Returns failure() if the static value of the stride - /// operands could not be retrieved. - LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) - return failure(); - staticStrides.reserve(static_strides().size()); - for (auto s : static_strides().getAsValueRange()) - staticStrides.push_back(s.getZExtValue()); - return success(); - } - /// Return the list of Range (i.e. offset, size, stride). Each /// Range entry contains either the dynamic value or a ConstantIndexOp /// constructed with `b` at location `loc`. - SmallVector getOrCreateRanges(OpBuilder &b, Location loc); - - /// Return the offsets as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1; - return llvm::to_vector<4>(llvm::map_range( - static_offsets().cast(), [&](Attribute a) -> Value { - int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticOffset)); - })); - } - - /// Return the sizes as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size(); - return llvm::to_vector<4>(llvm::map_range( - static_sizes().cast(), [&](Attribute a) -> Value { - int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticSize)); - })); - } - - /// Return the strides as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed with - /// `b` at location `loc` - SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); - return llvm::to_vector<4>(llvm::map_range( - static_strides().cast(), [&](Attribute a) -> Value { - int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticStride)); - })); + SmallVector getOrCreateRanges(OpBuilder &b, Location loc) { + return mlir::getOrCreateRanges(*this, b, loc); } /// Return the rank of the source ShapedType. @@ -305,107 +238,11 @@ return source().getType().cast().getRank(); } - /// Return the rank of the result ShapedType. - unsigned getResultRank() { return getType().getRank(); } - - /// Return true if the offset `idx` is a static constant. - bool isDynamicOffset(unsigned idx) { - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - /// Return true if the size `idx` is a static constant. - bool isDynamicSize(unsigned idx) { - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - /// Return true if the stride `idx` is a static constant. - bool isDynamicStride(unsigned idx) { - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - - /// Assert the offset `idx` is a static constant and return its value. - int64_t getStaticOffset(unsigned idx) { - assert(!isDynamicOffset(idx) && "expected static offset"); - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the size `idx` is a static constant and return its value. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the stride `idx` is a static constant and return its value. - int64_t getStaticStride(unsigned idx) { - assert(!isDynamicStride(idx) && "expected static stride"); - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, - llvm::function_ref isDynamic, unsigned idx) { - return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - } - /// Assert the offset `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicOffset(unsigned idx) { - assert(isDynamicOffset(idx) && "expected static offset"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_offsets().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + numDynamic; - } - /// Assert the size `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected static size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().cast(), ShapedType::isDynamic, idx); - return 1 + offsets().size() + numDynamic; - } - /// Assert the stride `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicStride(unsigned idx) { - assert(isDynamicStride(idx) && "expected static stride"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_strides().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + offsets().size() + sizes().size() + numDynamic; - } - - /// Assert the offset `idx` is dynamic and return its value. - Value getDynamicOffset(unsigned idx) { - return getOperand(getIndexOfDynamicOffset(idx)); - } - /// Assert the size `idx` is dynamic and return its value. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - /// Assert the stride `idx` is dynamic and return its value. - Value getDynamicStride(unsigned idx) { - return getOperand(getIndexOfDynamicStride(idx)); - } - - static StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static StringRef getStaticStridesAttrName() { - return "static_strides"; - } static ArrayRef getSpecialAttrNames() { static SmallVector names{ - getStaticOffsetsAttrName(), - getStaticSizesAttrName(), - getStaticStridesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), getOperandSegmentSizeAttr()}; return names; } @@ -2338,7 +2175,7 @@ def MemRefReinterpretCastOp: BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ - NoSideEffect, ViewLikeOpInterface + NoSideEffect, ViewLikeOpInterface, OffsetSizeAndStrideOpInterface ]> { let summary = "memref reinterpret cast operation"; let description = [{ @@ -3208,7 +3045,7 @@ //===----------------------------------------------------------------------===// def SubViewOp : BaseOpWithOffsetSizesAndStrides< - "subview", [DeclareOpInterfaceMethods] > { + "subview", [DeclareOpInterfaceMethods, OffsetSizeAndStrideOpInterface] > { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -3397,7 +3234,7 @@ // SubTensorOp //===----------------------------------------------------------------------===// -def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> { +def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor operation"; let description = [{ The "subtensor" operation extract a tensor from another tensor as @@ -3487,7 +3324,7 @@ // SubTensorInsertOp //===----------------------------------------------------------------------===// -def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> { +def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor_insert operation"; let description = [{ The "subtensor_insert" operation insert a tensor `source` into another diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -13,7 +13,20 @@ #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ +#include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +/// Auxiliary range data structure to unpack the offset, size and stride +/// operands of the SubViewOp / SubTensorOp into a list of triples. +/// Such a list of triple is sometimes more convenient to manipulate. +struct Range { + Value offset; + Value size; + Value stride; +}; +} // namespace mlir /// Include the generated interface declarations. #include "mlir/Interfaces/ViewLikeInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -30,4 +30,298 @@ ]; } +def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface"> { + let description = [{ + OffsetSizeAndStrideOpInterface + }]; + let cppNamespace = "::mlir"; + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the rankof the operation. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getRank", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getSourceRank(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic offset operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic size operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.strides(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static offset operands. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static size operands. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride operands. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_strides(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Return true if the offset `idx` is a static constant. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_offsets().template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the size `idx` is a static constant. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_sizes().template getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the stride `idx` is a static constant. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_strides().template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicOffset(idx) && "expected static offset"); + APInt v = *(static_offsets().template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes().template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicStride(idx) && "expected static stride"); + APInt v = *(static_strides().template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_offsets().template cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicSize(idx) && "expected dynamic size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().template cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicStride(idx) && "expected dynamic stride"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_strides().template cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + offsets().size() + sizes().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride operands. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumDynamicEntriesUpToIdx", + /*args=*/(ins "ArrayAttr":$attr, + "llvm::function_ref":$isDynamic, + "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return std::count_if( + attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicOffset(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicSize(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicStride(idx)); + }] + >, + ]; + + let extraClassDeclaration = [{ + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + }]; +} + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -114,19 +114,18 @@ // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. - SmallVector opOffsets = subViewOp.getOrCreateOffsets(rewriter, loc); - SmallVector opStrides = subViewOp.getOrCreateStrides(rewriter, loc); - assert(opOffsets.size() == indices.size() && - "expected as many indices as rank of subview op result type"); - assert(opStrides.size() == indices.size() && + SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); + auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); + assert(opRanges.size() == indices.size() && "expected as many indices as rank of subview op result type"); // New indices for the load are the current indices * subview_stride + // subview_offset. sourceIndices.resize(indices.size()); for (auto index : llvm::enumerate(indices)) { - auto offset = opOffsets[index.index()]; - auto stride = opStrides[index.index()]; + auto offset = *(opOffsets.begin() + index.index()); + auto stride = *(opStrides.begin() + index.index()); auto mul = rewriter.create(loc, index.value(), stride); sourceIndices[index.index()] = rewriter.create(loc, offset, mul).getResult(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -348,20 +348,23 @@ } /// Verify static attributes offsets/sizes/strides. -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { - unsigned srcRank = op.getSourceRank(); +static LogicalResult +verifyOpWithOffsetSizesAndStrides(OffsetSizeAndStrideOpInterface op, + unsigned rank) { if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", srcRank, op.getStaticOffsetsAttrName(), + op, "offset", rank, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), op.static_offsets(), ShapedType::isDynamicStrideOrOffset, op.offsets()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) + op, "size", rank, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", srcRank, op.getStaticStridesAttrName(), + op, "stride", rank, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); @@ -2405,10 +2408,7 @@ ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict( op.getAttrs(), - /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - MemRefReinterpretCastOp::getStaticOffsetsAttrName(), - MemRefReinterpretCastOp::getStaticSizesAttrName(), - MemRefReinterpretCastOp::getStaticStridesAttrName()}); + /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames()); p << ": " << op.source().getType() << " to " << op.getType(); } @@ -2432,7 +2432,8 @@ if (parser.parseKeyword("to") || parser.parseKeyword("offset") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offset) || parser.parseComma()) return failure(); @@ -2441,7 +2442,8 @@ SmallVector sizes; if (parser.parseKeyword("sizes") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), ShapedType::kDynamicSize, sizes) || parser.parseComma()) return failure(); @@ -2450,7 +2452,8 @@ SmallVector strides; if (parser.parseKeyword("strides") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, strides)) return failure(); @@ -2491,16 +2494,20 @@ // Verify that dynamic and static offset/sizes/strides arguments/attributes // are consistent. if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) + op, "offset", 1, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + op.static_offsets(), ShapedType::isDynamicStrideOrOffset, + op.offsets()))) return failure(); - unsigned resultRank = op.getResultRank(); + unsigned resultRank = op.getResult().getType().cast().getRank(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", resultRank, op.getStaticSizesAttrName(), + op, "size", resultRank, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), op.static_sizes(), ShapedType::isDynamic, op.sizes()))) return failure(); if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", resultRank, op.getStaticStridesAttrName(), + op, "stride", resultRank, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), op.static_strides(), ShapedType::isDynamicStrideOrOffset, op.strides()))) return failure(); @@ -3213,13 +3220,16 @@ if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) return failure(); if (parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || - parseListOfOperandsOrIntegers(parser, result, - OpType::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizesInfo) || parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); @@ -3456,7 +3466,6 @@ llvm_unreachable("unexpected subview verification result"); } - /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { MemRefType baseType = op.getSourceType(); @@ -3472,7 +3481,7 @@ if (!isStrided(baseType)) return op.emitError("base type ") << baseType << " is not strided"; - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); // Verify result type against inferred type. @@ -3493,11 +3502,10 @@ /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. -template -static SmallVector getOrCreateRangesImpl(OpType op, OpBuilder &b, - Location loc) { +SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc) { SmallVector res; - unsigned rank = op.getSourceRank(); + unsigned rank = op.getRank(); res.reserve(rank); for (unsigned idx = 0; idx < rank; ++idx) { Value offset = @@ -3516,10 +3524,6 @@ return res; } -SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - namespace { /// Take a list of `values` with potential new constant to extract and a list @@ -3814,7 +3818,8 @@ } OpFoldResult SubViewOp::fold(ArrayRef operands) { - if (getResultRank() == 0 && getSourceRank() == 0) + if (getResult().getType().cast().getRank() == 0 && + source().getType().cast().getRank() == 0) return getViewSource(); return {}; @@ -3885,14 +3890,9 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubTensorOp. static LogicalResult verify(SubTensorOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); // Verify result type against inferred type. @@ -3963,14 +3963,9 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorInsertOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubViewOp. static LogicalResult verify(SubTensorInsertOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) + if (failed(verifyOpWithOffsetSizesAndStrides(op, op.getSourceRank()))) return failure(); if (op.getType() != op.dest().getType()) return op.emitError("expected result type to be ") << op.dest().getType();