diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.h @@ -33,17 +33,14 @@ class FuncOp; class OpBuilder; -/// Auxiliary range data structure to unpack the offset, size and stride -/// operands of the SubViewOp / SubTensorOp into a list of triples. -/// Such a list of triple is sometimes more convenient to manipulate. -struct Range { - Value offset; - Value size; - Value stride; -}; - raw_ostream &operator<<(raw_ostream &os, Range &range); +/// Return the list of Range (i.e. offset, size, stride). Each Range +/// entry contains either the dynamic value or a ConstantIndexOp constructed +/// with `b` at location `loc`. +SmallVector getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc); + #define GET_OP_CLASSES #include "mlir/Dialect/StandardOps/IR/Ops.h.inc" diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -223,189 +223,21 @@ Std_Op { code extraBaseClassDeclaration = [{ - /// Returns the number of dynamic offset operands. - int64_t getNumOffsets() { return llvm::size(offsets()); } - - /// Returns the number of dynamic size operands. - int64_t getNumSizes() { return llvm::size(sizes()); } - - /// Returns the number of dynamic stride operands. - int64_t getNumStrides() { return llvm::size(strides()); } - /// Returns the dynamic sizes for this subview operation if specified. operand_range getDynamicSizes() { return sizes(); } - /// Returns in `staticStrides` the static value of the stride - /// operands. Returns failure() if the static value of the stride - /// operands could not be retrieved. - LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) - return failure(); - staticStrides.reserve(static_strides().size()); - for (auto s : static_strides().getAsValueRange()) - staticStrides.push_back(s.getZExtValue()); - return success(); - } - /// Return the list of Range (i.e. offset, size, stride). Each /// Range entry contains either the dynamic value or a ConstantIndexOp /// constructed with `b` at location `loc`. - SmallVector getOrCreateRanges(OpBuilder &b, Location loc); - - /// Return the offsets as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1; - return llvm::to_vector<4>(llvm::map_range( - static_offsets().cast(), [&](Attribute a) -> Value { - int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticOffset)); - })); - } - - /// Return the sizes as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size(); - return llvm::to_vector<4>(llvm::map_range( - static_sizes().cast(), [&](Attribute a) -> Value { - int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticSize)); - })); - } - - /// Return the strides as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed with - /// `b` at location `loc` - SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); - return llvm::to_vector<4>(llvm::map_range( - static_strides().cast(), [&](Attribute a) -> Value { - int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticStride)); - })); - } - - /// Return the rank of the source ShapedType. - unsigned getSourceRank() { - return source().getType().cast().getRank(); + SmallVector getOrCreateRanges(OpBuilder &b, Location loc) { + return mlir::getOrCreateRanges(*this, b, loc); } - /// Return the rank of the result ShapedType. - unsigned getResultRank() { return getType().getRank(); } - - /// Return true if the offset `idx` is a static constant. - bool isDynamicOffset(unsigned idx) { - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - /// Return true if the size `idx` is a static constant. - bool isDynamicSize(unsigned idx) { - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - /// Return true if the stride `idx` is a static constant. - bool isDynamicStride(unsigned idx) { - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - - /// Assert the offset `idx` is a static constant and return its value. - int64_t getStaticOffset(unsigned idx) { - assert(!isDynamicOffset(idx) && "expected static offset"); - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the size `idx` is a static constant and return its value. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the stride `idx` is a static constant and return its value. - int64_t getStaticStride(unsigned idx) { - assert(!isDynamicStride(idx) && "expected static stride"); - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, - llvm::function_ref isDynamic, unsigned idx) { - return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - } - /// Assert the offset `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicOffset(unsigned idx) { - assert(isDynamicOffset(idx) && "expected static offset"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_offsets().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + numDynamic; - } - /// Assert the size `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected static size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().cast(), ShapedType::isDynamic, idx); - return 1 + offsets().size() + numDynamic; - } - /// Assert the stride `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicStride(unsigned idx) { - assert(isDynamicStride(idx) && "expected static stride"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_strides().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + offsets().size() + sizes().size() + numDynamic; - } - - /// Assert the offset `idx` is dynamic and return its value. - Value getDynamicOffset(unsigned idx) { - return getOperand(getIndexOfDynamicOffset(idx)); - } - /// Assert the size `idx` is dynamic and return its value. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - /// Assert the stride `idx` is dynamic and return its value. - Value getDynamicStride(unsigned idx) { - return getOperand(getIndexOfDynamicStride(idx)); - } - - static StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static StringRef getStaticStridesAttrName() { - return "static_strides"; - } static ArrayRef getSpecialAttrNames() { static SmallVector names{ - getStaticOffsetsAttrName(), - getStaticSizesAttrName(), - getStaticStridesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), getOperandSegmentSizeAttr()}; return names; } @@ -2340,7 +2172,7 @@ def MemRefReinterpretCastOp: BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ - NoSideEffect, ViewLikeOpInterface + NoSideEffect, ViewLikeOpInterface, OffsetSizeAndStrideOpInterface ]> { let summary = "memref reinterpret cast operation"; let description = [{ @@ -2390,6 +2222,18 @@ // The result of the op is always a ranked memref. MemRefType getType() { return getResult().getType().cast(); } Value getViewSource() { return source(); } + + /// Return the rank of the source ShapedType. + unsigned getResultRank() { + return getResult().getType().cast().getRank(); + } + + /// Return the expected rank of each of the`static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrRanks() { + unsigned resultRank = getResult().getType().cast().getRank(); + return {1, resultRank, resultRank}; + } }]; } @@ -3210,7 +3054,7 @@ //===----------------------------------------------------------------------===// def SubViewOp : BaseOpWithOffsetSizesAndStrides< - "subview", [DeclareOpInterfaceMethods] > { + "subview", [DeclareOpInterfaceMethods, OffsetSizeAndStrideOpInterface] > { let summary = "memref subview operation"; let description = [{ The "subview" operation converts a memref type to another memref type @@ -3389,6 +3233,13 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + + /// Return the expected rank of each of the`static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } }]; let hasCanonicalizer = 1; @@ -3399,7 +3250,7 @@ // SubTensorOp //===----------------------------------------------------------------------===// -def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor"> { +def SubTensorOp : BaseOpWithOffsetSizesAndStrides<"subtensor", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor operation"; let description = [{ The "subtensor" operation extract a tensor from another tensor as @@ -3480,6 +3331,13 @@ ArrayRef staticOffsets, ArrayRef staticSizes, ArrayRef staticStrides); + + /// Return the expected rank of each of the`static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } }]; let hasCanonicalizer = 1; @@ -3489,7 +3347,7 @@ // SubTensorInsertOp //===----------------------------------------------------------------------===// -def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert"> { +def SubTensorInsertOp : BaseOpWithOffsetSizesAndStrides<"subtensor_insert", [OffsetSizeAndStrideOpInterface]> { let summary = "subtensor_insert operation"; let description = [{ The "subtensor_insert" operation insert a tensor `source` into another @@ -3556,6 +3414,13 @@ RankedTensorType getType() { return getResult().getType().cast(); } + + /// Return the expected rank of each of the`static_offsets`, `static_sizes` + /// and `static_strides` attributes. + std::array getArrayAttrRanks() { + unsigned rank = getSourceType().getRank(); + return {rank, rank, rank}; + } }]; } diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.h b/mlir/include/mlir/Interfaces/ViewLikeInterface.h --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.h +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.h @@ -13,7 +13,23 @@ #ifndef MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ #define MLIR_INTERFACES_VIEWLIKEINTERFACE_H_ +#include "mlir/IR/Builders.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" + +namespace mlir { +/// Auxiliary range data structure to unpack the offset, size and stride +/// operands into a list of triples. Such a list can be more convenient to +/// manipulate. +struct Range { + Value offset; + Value size; + Value stride; +}; + +class OffsetSizeAndStrideOpInterface; +LogicalResult verify(OffsetSizeAndStrideOpInterface op); +} // namespace mlir /// Include the generated interface declarations. #include "mlir/Interfaces/ViewLikeInterface.h.inc" diff --git a/mlir/include/mlir/Interfaces/ViewLikeInterface.td b/mlir/include/mlir/Interfaces/ViewLikeInterface.td --- a/mlir/include/mlir/Interfaces/ViewLikeInterface.td +++ b/mlir/include/mlir/Interfaces/ViewLikeInterface.td @@ -30,4 +30,338 @@ ]; } +def OffsetSizeAndStrideOpInterface : OpInterface<"OffsetSizeAndStrideOpInterface"> { + let description = [{ + Common interface for ops that allow specifying mixed dynamic and static + offsets, sizes and strides variadic operands. + Ops that implement this interface need to expose the following methods: + 1. `getArrayAttrRanks` to specify the length of static integer + attributes. + 2. `offsets`, `sizes` and `strides` variadic operands. + 3. `static_offsets`, resp. `static_sizes` and `static_strides` integer + array attributes. + + The invariants of this interface are: + 1. `static_offsets`, `static_sizes` and `static_strides` have length + exactly `getArrayAttrRanks()`[0] (resp. [1], [2]). + 2. `offsets`, `sizes` and `strides` have each length at most + `getArrayAttrRanks()`[0] (resp. [1], [2]). + 3. if an entry of `static_offsets` (resp. `static_sizes`, + `static_strides`) is equal to a special sentinel value, namely + `ShapedType::kDynamicStrideOrOffset` (resp. `ShapedType::kDynamicSize`, + `ShapedType::kDynamicStrideOrOffset`), then the corresponding entry is + a dynamic offset (resp. size, stride). + 4. a variadic `offset` (resp. `sizes`, `strides`) operand must be present + for each dynamic offset (resp. size, stride). + + This interface is useful to factor out common behavior and provide support + for carrying or injecting static behavior through the use of the static + attributes. + }]; + + let cppNamespace = "::mlir"; + + let methods = [ + InterfaceMethod< + /*desc=*/[{ + Return the expected rank of each of the`static_offsets`, `static_sizes` + and `static_strides` attributes. + }], + /*retTy=*/"std::array", + /*methodName=*/"getArrayAttrRanks", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic offset operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic size operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride operands. + }], + /*retTy=*/"OperandRange", + /*methodName=*/"strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.strides(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static offset attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_offsets", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_offsets(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the static size attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_sizes", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_sizes(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return the dynamic stride attributes. + }], + /*retTy=*/"ArrayAttr", + /*methodName=*/"static_strides", + /*args=*/(ins), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.static_strides(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Return true if the offset `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_offsets() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the size `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_sizes() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Return true if the stride `idx` is dynamic. + }], + /*retTy=*/"bool", + /*methodName=*/"isDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + APInt v = *(static_strides() + .template getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicOffset(idx) && "expected static offset"); + APInt v = *(static_offsets(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is a static constant and return its value. + }], + /*retTy=*/"int64_t", + /*methodName=*/"getStaticStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert(!$_op.isDynamicStride(idx) && "expected static stride"); + APInt v = *(static_strides(). + template getAsValueRange().begin() + idx); + return v.getSExtValue(); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicOffset(idx) && "expected dynamic offset"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_offsets().template cast(), + ShapedType::isDynamicStrideOrOffset, + idx); + return 1 + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicSize(idx) && "expected dynamic size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().template cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return the position of the + corresponding operand. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getIndexOfDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + assert($_op.isDynamicStride(idx) && "expected dynamic stride"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_strides().template cast(), + ShapedType::isDynamicStrideOrOffset, + idx); + return 1 + offsets().size() + sizes().size() + numDynamic; + }] + >, + InterfaceMethod< + /*desc=*/[{ + Helper method to compute the number of dynamic entries of `attr`, up to + `idx` using `isDynamic` to determine whether an entry is dynamic. + }], + /*retTy=*/"unsigned", + /*methodName=*/"getNumDynamicEntriesUpToIdx", + /*args=*/(ins "ArrayAttr":$attr, + "llvm::function_ref":$isDynamic, + "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return std::count_if( + attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + }] + >, + + InterfaceMethod< + /*desc=*/[{ + Assert the offset `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicOffset", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicOffset(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the size `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicSize", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicSize(idx)); + }] + >, + InterfaceMethod< + /*desc=*/[{ + Assert the stride `idx` is dynamic and return its value. + }], + /*retTy=*/"Value", + /*methodName=*/"getDynamicStride", + /*args=*/(ins "unsigned":$idx), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + return $_op.getOperand(getIndexOfDynamicStride(idx)); + }] + >, + ]; + + let extraClassDeclaration = [{ + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + }]; + + let verify = [{ + return mlir::verify(cast($_op)); + }]; +} + #endif // MLIR_INTERFACES_VIEWLIKEINTERFACE diff --git a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp --- a/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp +++ b/mlir/lib/Conversion/StandardToSPIRV/LegalizeStandardForSPIRV.cpp @@ -114,19 +114,18 @@ // TODO: Aborting when the offsets are static. There might be a way to fold // the subview op with load even if the offsets have been canonicalized // away. - SmallVector opOffsets = subViewOp.getOrCreateOffsets(rewriter, loc); - SmallVector opStrides = subViewOp.getOrCreateStrides(rewriter, loc); - assert(opOffsets.size() == indices.size() && - "expected as many indices as rank of subview op result type"); - assert(opStrides.size() == indices.size() && + SmallVector opRanges = subViewOp.getOrCreateRanges(rewriter, loc); + auto opOffsets = llvm::map_range(opRanges, [](Range r) { return r.offset; }); + auto opStrides = llvm::map_range(opRanges, [](Range r) { return r.stride; }); + assert(opRanges.size() == indices.size() && "expected as many indices as rank of subview op result type"); // New indices for the load are the current indices * subview_stride + // subview_offset. sourceIndices.resize(indices.size()); for (auto index : llvm::enumerate(indices)) { - auto offset = opOffsets[index.index()]; - auto stride = opStrides[index.index()]; + auto offset = *(opOffsets.begin() + index.index()); + auto stride = *(opStrides.begin() + index.index()); auto mul = rewriter.create(loc, index.value(), stride); sourceIndices[index.index()] = rewriter.create(loc, offset, mul).getResult(); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -320,11 +320,10 @@ } /// Verify that a particular offset/size/stride static attribute is well-formed. -template static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, - ArrayAttr attr, llvm::function_ref isDynamic, - ValueRange values) { + OffsetSizeAndStrideOpInterface op, StringRef name, + unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, + llvm::function_ref isDynamic, ValueRange values) { /// Check static and dynamic offsets/sizes/strides breakdown. if (attr.size() != expectedNumElements) return op.emitError("expected ") @@ -347,27 +346,6 @@ })); } -/// Verify static attributes offsets/sizes/strides. -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { - unsigned srcRank = op.getSourceRank(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", srcRank, op.getStaticOffsetsAttrName(), - op.static_offsets(), ShapedType::isDynamicStrideOrOffset, - op.offsets()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", srcRank, op.getStaticStridesAttrName(), - op.static_strides(), ShapedType::isDynamicStrideOrOffset, - op.strides()))) - return failure(); - return success(); -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -2481,10 +2459,7 @@ ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict( op.getAttrs(), - /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - MemRefReinterpretCastOp::getStaticOffsetsAttrName(), - MemRefReinterpretCastOp::getStaticSizesAttrName(), - MemRefReinterpretCastOp::getStaticStridesAttrName()}); + /*elidedAttrs=*/MemRefReinterpretCastOp::getSpecialAttrNames()); p << ": " << op.source().getType() << " to " << op.getType(); } @@ -2508,7 +2483,8 @@ if (parser.parseKeyword("to") || parser.parseKeyword("offset") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offset) || parser.parseComma()) return failure(); @@ -2517,7 +2493,8 @@ SmallVector sizes; if (parser.parseKeyword("sizes") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), ShapedType::kDynamicSize, sizes) || parser.parseComma()) return failure(); @@ -2526,7 +2503,8 @@ SmallVector strides; if (parser.parseKeyword("strides") || parser.parseColon() || parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, strides)) return failure(); @@ -2564,23 +2542,6 @@ return op.emitError("different element types specified for source type ") << srcType << " and result memref type " << resultType; - // Verify that dynamic and static offset/sizes/strides arguments/attributes - // are consistent. - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) - return failure(); - unsigned resultRank = op.getResultRank(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", resultRank, op.getStaticSizesAttrName(), - op.static_sizes(), ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", resultRank, op.getStaticStridesAttrName(), - op.static_strides(), ShapedType::isDynamicStrideOrOffset, - op.strides()))) - return failure(); - // Match sizes in result memref type and in static_sizes attribute. for (auto &en : llvm::enumerate(llvm::zip(resultType.getShape(), @@ -3289,13 +3250,16 @@ if (parseExtraOperand && parseExtraOperand(parser, dstInfo)) return failure(); if (parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticOffsetsAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), ShapedType::kDynamicStrideOrOffset, offsetsInfo) || - parseListOfOperandsOrIntegers(parser, result, - OpType::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizesInfo) || parseListOfOperandsOrIntegers( - parser, result, OpType::getStaticStridesAttrName(), + parser, result, + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + parseListOfOperandsOrIntegers( + parser, result, + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), ShapedType::kDynamicStrideOrOffset, stridesInfo)) return failure(); @@ -3532,7 +3496,6 @@ llvm_unreachable("unexpected subview verification result"); } - /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { MemRefType baseType = op.getSourceType(); @@ -3548,9 +3511,6 @@ if (!isStrided(baseType)) return op.emitError("base type ") << baseType << " is not strided"; - if (failed(verifyOpWithOffsetSizesAndStrides(op))) - return failure(); - // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( baseType, extractFromI64ArrayAttr(op.static_offsets()), @@ -3569,11 +3529,13 @@ /// Return the list of Range (i.e. offset, size, stride). Each Range /// entry contains either the dynamic value or a ConstantIndexOp constructed /// with `b` at location `loc`. -template -static SmallVector getOrCreateRangesImpl(OpType op, OpBuilder &b, - Location loc) { +SmallVector mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op, + OpBuilder &b, Location loc) { + std::array ranks = op.getArrayAttrRanks(); + assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks"); + assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks"); SmallVector res; - unsigned rank = op.getSourceRank(); + unsigned rank = ranks[0]; res.reserve(rank); for (unsigned idx = 0; idx < rank; ++idx) { Value offset = @@ -3592,10 +3554,6 @@ return res; } -SmallVector SubViewOp::getOrCreateRanges(OpBuilder &b, Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - namespace { /// Take a list of `values` with potential new constant to extract and a list @@ -3658,20 +3616,22 @@ SmallVector newOffsets(op.offsets()); SmallVector newStaticOffsets = extractFromI64ArrayAttr(op.static_offsets()); - assert(newStaticOffsets.size() == op.getSourceRank()); + std::array ranks = op.getArrayAttrRanks(); + (void)ranks; + assert(newStaticOffsets.size() == ranks[0]); canonicalizeSubViewPart(newOffsets, newStaticOffsets, ShapedType::isDynamicStrideOrOffset); SmallVector newSizes(op.sizes()); SmallVector newStaticSizes = extractFromI64ArrayAttr(op.static_sizes()); - assert(newStaticOffsets.size() == op.getSourceRank()); + assert(newStaticSizes.size() == ranks[1]); canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); SmallVector newStrides(op.strides()); SmallVector newStaticStrides = extractFromI64ArrayAttr(op.static_strides()); - assert(newStaticOffsets.size() == op.getSourceRank()); + assert(newStaticStrides.size() == ranks[2]); canonicalizeSubViewPart(newStrides, newStaticStrides, ShapedType::isDynamicStrideOrOffset); @@ -3890,7 +3850,8 @@ } OpFoldResult SubViewOp::fold(ArrayRef operands) { - if (getResultRank() == 0 && getSourceRank() == 0) + if (getResult().getType().cast().getRank() == 0 && + source().getType().cast().getRank() == 0) return getViewSource(); return {}; @@ -3961,16 +3922,8 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubTensorOp. static LogicalResult verify(SubTensorOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) - return failure(); - // Verify result type against inferred type. auto expectedType = SubTensorOp::inferResultType( op.getSourceType(), extractFromI64ArrayAttr(op.static_offsets()), @@ -4039,15 +3992,8 @@ staticStridesVector, offsets, sizes, strides, attrs); } -SmallVector SubTensorInsertOp::getOrCreateRanges(OpBuilder &b, - Location loc) { - return ::getOrCreateRangesImpl(*this, b, loc); -} - /// Verifier for SubViewOp. static LogicalResult verify(SubTensorInsertOp op) { - if (failed(verifyOpWithOffsetSizesAndStrides(op))) - return failure(); if (op.getType() != op.dest().getType()) return op.emitError("expected result type to be ") << op.dest().getType(); return success(); diff --git a/mlir/lib/Interfaces/ViewLikeInterface.cpp b/mlir/lib/Interfaces/ViewLikeInterface.cpp --- a/mlir/lib/Interfaces/ViewLikeInterface.cpp +++ b/mlir/lib/Interfaces/ViewLikeInterface.cpp @@ -8,6 +8,8 @@ #include "mlir/Interfaces/ViewLikeInterface.h" +#include "mlir/IR/StandardTypes.h" + using namespace mlir; //===----------------------------------------------------------------------===// @@ -16,3 +18,43 @@ /// Include the definitions of the loop-like interfaces. #include "mlir/Interfaces/ViewLikeInterface.cpp.inc" + +static LogicalResult verifyOpWithOffsetSizesAndStridesPart( + OffsetSizeAndStrideOpInterface op, StringRef name, + unsigned expectedNumElements, StringRef attrName, ArrayAttr attr, + llvm::function_ref isDynamic, ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != expectedNumElements) + return op.emitError("expected ") + << expectedNumElements << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +LogicalResult mlir::verify(OffsetSizeAndStrideOpInterface op) { + std::array ranks = op.getArrayAttrRanks(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", ranks[0], + OffsetSizeAndStrideOpInterface::getStaticOffsetsAttrName(), + op.static_offsets(), ShapedType::isDynamicStrideOrOffset, + op.offsets()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", ranks[1], + OffsetSizeAndStrideOpInterface::getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", ranks[2], + OffsetSizeAndStrideOpInterface::getStaticStridesAttrName(), + op.static_strides(), ShapedType::isDynamicStrideOrOffset, + op.strides()))) + return failure(); + return success(); +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -981,7 +981,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> // expected-error@+1 {{different memory spaces}} - %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] + %1 = subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> return @@ -992,7 +992,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> // expected-error@+1 {{is not strided}} - %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] + %1 = subview %0[0, 0, 0][%arg2, %arg2, %arg2][1, 1, 1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> return