diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -2504,7 +2504,7 @@ //===----------------------------------------------------------------------===// def SubViewOp : Std_Op<"subview", [ - AttrSizedOperandSegments, + AttrSizedOperandSegments, DeclareOpInterfaceMethods, NoSideEffect, ]> { @@ -2516,17 +2516,14 @@ The SubView operation supports the following arguments: *) Memref: the "base" memref on which to create a "view" memref. - *) Offsets: zero or memref-rank number of dynamic offsets into the "base" - memref at which to create the "view" memref. - *) Sizes: zero or memref-rank dynamic size operands which specify the - dynamic sizes of the result "view" memref type. - *) Strides: zero or memref-rank number of dynamic strides which are applied - multiplicatively to the base memref strides in each dimension. - - Note on the number of operands for offsets, sizes and strides: For - each of these, the number of operands must either be same as the - memref-rank number or empty. For the latter, those values will be - treated as constants. + *) Offsets: memref-rank number of dynamic offsets or static integer + attributes into the "base" memref at which to create the "view" + memref. + *) Sizes: memref-rank number of dynamic sizes or static integer attributes + which specify the sizes of the result "view" memref type. + *) Strides: memref-rank number of dynamic strides or static integer + attributes multiplicatively to the base memref strides in each + dimension. Example 1: @@ -2564,9 +2561,9 @@ %0 = alloc() : memref<8x16x4xf32, (d0, d1, d1) -> (d0 * 64 + d1 * 4 + d2)> // Subview with constant offsets, sizes and strides. - %1 = subview %0[][][] + %1 = subview %0[0, 2, 0][4, 4, 4][64, 4, 1] : memref<8x16x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> to - memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> + memref<4x4x4xf32, (d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> ``` Example 4: @@ -2608,7 +2605,7 @@ // #map2 = (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0) // // where, r0 = o0 * s1 + o1 * s2 + s0, r1 = s1 * t0, r2 = s2 * t1. - %1 = subview %0[%i, %j][][%x, %y] : + %1 = subview %0[%i, %j][4, 4][%x, %y] : : memref (d0 * s1 + d1 * s2 + s0)> to memref<4x4xf32, (d0, d1)[r0, r1, r2] -> (d0 * r1 + d1 * r2 + r0)> @@ -2624,24 +2621,25 @@ AnyMemRef:$source, Variadic:$offsets, Variadic:$sizes, - Variadic:$strides + Variadic:$strides, + I64ArrayAttr:$static_offsets, + I64ArrayAttr:$static_sizes, + I64ArrayAttr:$static_strides ); let results = (outs AnyMemRef:$result); - let assemblyFormat = [{ - $source `[` $offsets `]` `[` $sizes `]` `[` $strides `]` attr-dict `:` - type($source) `to` type($result) - }]; - let builders = [ + // Build a SubViewOp with mized static and dynamic entries. OpBuilder< "OpBuilder &b, OperationState &result, Value source, " - "ValueRange offsets, ValueRange sizes, " - "ValueRange strides, Type resultType = Type(), " - "ArrayRef attrs = {}">, + "ArrayRef staticOffsets, ArrayRef staticSizes," + "ArrayRef staticStrides, ValueRange offsets, ValueRange sizes, " + "ValueRange strides, ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries. OpBuilder< - "OpBuilder &builder, OperationState &result, " - "Type resultType, Value source"> + "OpBuilder &b, OperationState &result, Value source, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}"> ]; let extraClassDeclaration = [{ @@ -2670,13 +2668,83 @@ /// operands could not be retrieved. LogicalResult getStaticStrides(SmallVectorImpl &staticStrides); - // Auxiliary range data structure and helper function that unpacks the - // offset, size and stride operands of the SubViewOp into a list of triples. - // Such a list of triple is sometimes more convenient to manipulate. + /// Auxiliary range data structure and helper function that unpacks the + /// offset, size and stride operands of the SubViewOp into a list of triples. + /// Such a list of triple is sometimes more convenient to manipulate. struct Range { Value offset, size, stride; }; + // TODO: retire `getRanges`. SmallVector getRanges(); + + /// Return the rank of the result MemRefType. + unsigned getRank() { return getType().getRank(); } + + /// Return true if the offset `idx` is a static constant. + bool isDynamicOffset(unsigned idx); + /// Return true if the size `idx` is a static constant. + bool isDynamicSize(unsigned idx); + /// Return true if the stride `idx` is a static constant. + bool isDynamicStride(unsigned idx); + + /// Assert the offset `idx` is a static constant and return its value. + int64_t getStaticOffset(unsigned idx) { + assert(!isDynamicOffset(idx) && "expected static offset"); + return + static_offsets().cast()[idx].cast().getInt(); + } + /// Assert the size `idx` is a static constant and return its value. + int64_t getStaticSize(unsigned idx) { + assert(!isDynamicSize(idx) && "expected static size"); + return static_sizes().cast()[idx].cast().getInt(); + } + /// Assert the stride `idx` is a static constant and return its value. + int64_t getStaticStride(unsigned idx) { + assert(!isDynamicStride(idx) && "expected static stride"); + return + static_strides().cast()[idx].cast().getInt(); + } + + /// Assert the offset `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicOffset(unsigned idx); + /// Assert the size `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicSize(unsigned idx); + /// Assert the stride `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicStride(unsigned idx); + + /// Assert the offset `idx` is dynamic and return its value. + Value getDynamicOffset(unsigned idx) { + return getOperand(getIndexOfDynamicOffset(idx)); + } + /// Assert the size `idx` is dynamic and return its value. + Value getDynamicSize(unsigned idx) { + return getOperand(getIndexOfDynamicSize(idx)); + } + /// Assert the stride `idx` is dynamic and return its value. + Value getDynamicStride(unsigned idx) { + return getOperand(getIndexOfDynamicStride(idx)); + } + + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + static ArrayRef getSpecialAttrNames() { + static SmallVector names{ + getStaticOffsetsAttrName(), + getStaticSizesAttrName(), + getStaticStridesAttrName(), + getOperandSegmentSizeAttr()}; + return names; + } }]; let hasCanonicalizer = 1; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2495,28 +2495,14 @@ matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); - auto viewOp = cast(op); - // TODO(b/144779634, ravishankarm) : After Tblgen is adapted to support - // having multiple variadic operands where each operand can have different - // number of entries, clean all of this up. - SmallVector dynamicOffsets( - std::next(operands.begin()), - std::next(operands.begin(), 1 + viewOp.getNumOffsets())); - SmallVector dynamicSizes( - std::next(operands.begin(), 1 + viewOp.getNumOffsets()), - std::next(operands.begin(), - 1 + viewOp.getNumOffsets() + viewOp.getNumSizes())); - SmallVector dynamicStrides( - std::next(operands.begin(), - 1 + viewOp.getNumOffsets() + viewOp.getNumSizes()), - operands.end()); - - auto sourceMemRefType = viewOp.source().getType().cast(); + auto subViewOp = cast(op); + + auto sourceMemRefType = subViewOp.source().getType().cast(); auto sourceElementTy = typeConverter.convertType(sourceMemRefType.getElementType()) .dyn_cast_or_null(); - auto viewMemRefType = viewOp.getType(); + auto viewMemRefType = subViewOp.getType(); auto targetElementTy = typeConverter.convertType(viewMemRefType.getElementType()) .dyn_cast(); @@ -2525,26 +2511,13 @@ if (!sourceElementTy || !targetDescTy) return failure(); - // Currently, only rank > 0 and full or no operands are supported. Fail to - // convert otherwise. - unsigned rank = sourceMemRefType.getRank(); - if (viewMemRefType.getRank() == 0 || - (!dynamicOffsets.empty() && rank != dynamicOffsets.size()) || - (!dynamicSizes.empty() && rank != dynamicSizes.size()) || - (!dynamicStrides.empty() && rank != dynamicStrides.size())) - return failure(); - + // Extract the offset and strides from the type. int64_t offset; SmallVector strides; auto successStrides = getStridesAndOffset(viewMemRefType, strides, offset); if (failed(successStrides)) return failure(); - // Fail to convert if neither a dynamic nor static offset is available. - if (dynamicOffsets.empty() && - offset == MemRefType::getDynamicStrideOrOffset()) - return failure(); - // Create the descriptor. if (!operands.front().getType().isa()) return failure(); @@ -2558,6 +2531,7 @@ extracted); targetMemRef.setAllocatedPtr(rewriter, loc, bitcastPtr); + // Copy the buffer pointer from the old descriptor to the new one. extracted = sourceMemRef.alignedPtr(rewriter, loc); bitcastPtr = rewriter.create( loc, targetElementTy.getPointerTo(viewMemRefType.getMemorySpace()), @@ -2570,42 +2544,48 @@ for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) strideValues.push_back(sourceMemRef.stride(rewriter, loc, i)); - // Fill in missing dynamic sizes. - auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); - if (dynamicSizes.empty()) { - dynamicSizes.reserve(viewMemRefType.getRank()); - auto shape = viewMemRefType.getShape(); - for (auto extent : shape) { - dynamicSizes.push_back(rewriter.create( - loc, llvmIndexType, rewriter.getI64IntegerAttr(extent))); - } - } - // Offset. - if (dynamicOffsets.empty()) { + auto llvmIndexType = typeConverter.convertType(rewriter.getIndexType()); + if (!ShapedType::isDynamicStrideOrOffset(offset)) { targetMemRef.setConstantOffset(rewriter, loc, offset); } else { Value baseOffset = sourceMemRef.offset(rewriter, loc); - for (int i = 0, e = viewMemRefType.getRank(); i < e; ++i) { - Value min = dynamicOffsets[i]; - baseOffset = rewriter.create( - loc, baseOffset, - rewriter.create(loc, min, strideValues[i])); + for (unsigned i = 0, e = viewMemRefType.getRank(); i < e; ++i) { + Value offset = + subViewOp.isDynamicOffset(i) + ? operands[subViewOp.getIndexOfDynamicOffset(i)] + : rewriter.create( + loc, llvmIndexType, + rewriter.getI64IntegerAttr(subViewOp.getStaticOffset(i))); + Value mul = rewriter.create(loc, offset, strideValues[i]); + baseOffset = rewriter.create(loc, baseOffset, mul); } targetMemRef.setOffset(rewriter, loc, baseOffset); } // Update sizes and strides. for (int i = viewMemRefType.getRank() - 1; i >= 0; --i) { - targetMemRef.setSize(rewriter, loc, i, dynamicSizes[i]); - Value newStride; - if (dynamicStrides.empty()) - newStride = rewriter.create( + Value size = + subViewOp.isDynamicSize(i) + ? operands[subViewOp.getIndexOfDynamicSize(i)] + : rewriter.create( + loc, llvmIndexType, + rewriter.getI64IntegerAttr(subViewOp.getStaticSize(i))); + targetMemRef.setSize(rewriter, loc, i, size); + Value stride; + if (!ShapedType::isDynamicStrideOrOffset(strides[i])) { + stride = rewriter.create( loc, llvmIndexType, rewriter.getI64IntegerAttr(strides[i])); - else - newStride = rewriter.create(loc, dynamicStrides[i], - strideValues[i]); - targetMemRef.setStride(rewriter, loc, i, newStride); + } else { + stride = + subViewOp.isDynamicStride(i) + ? operands[subViewOp.getIndexOfDynamicStride(i)] + : rewriter.create( + loc, llvmIndexType, + rewriter.getI64IntegerAttr(subViewOp.getStaticStride(i))); + stride = rewriter.create(loc, stride, strideValues[i]); + } + targetMemRef.setStride(rewriter, loc, i, stride); } rewriter.replaceOp(op, {targetMemRef}); diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1293,7 +1293,7 @@ auto indexAttr = op.getAttrOfType("index"); if (!indexAttr) return op.emitOpError("requires an integer attribute named 'index'"); - int64_t index = indexAttr.getValue().getSExtValue(); + int64_t index = indexAttr.getInt(); auto type = op.getOperand().getType(); if (auto tensorType = type.dyn_cast()) { @@ -1449,7 +1449,6 @@ return failure(); } - return success(); } @@ -2183,59 +2182,272 @@ // SubViewOp //===----------------------------------------------------------------------===// -// Returns a MemRefType with dynamic sizes and offset and the same stride as the -// `memRefType` passed as argument. -// TODO(andydavis,ntv) Evolve to a more powerful inference that can also keep -// sizes and offset static. -static Type inferSubViewResultType(MemRefType memRefType) { - auto rank = memRefType.getRank(); - int64_t offset; - SmallVector strides; - auto res = getStridesAndOffset(memRefType, strides, offset); +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void printSubViewListOfOperandsOrIntegers( + OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << "["; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << "] "; +} + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) + return success(); + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + Attribute attr; + NamedAttrList placeholder; + if (failed(parser.parseAttribute(attr, "_", placeholder)) || + !attr.isa()) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.cast().getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + else + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + +namespace { +/// Helpers to write more idiomatic operations. +namespace saturated_arith { +struct Wrapper { + explicit Wrapper(int64_t v) : v(v) {} + operator int64_t() { return v; } + int64_t v; +}; +Wrapper operator+(Wrapper a, int64_t b) { + if (ShapedType::isDynamicStrideOrOffset(a) || + ShapedType::isDynamicStrideOrOffset(b)) + return Wrapper(ShapedType::kDynamicStrideOrOffset); + return Wrapper(a.v + b); +} +Wrapper operator*(Wrapper a, int64_t b) { + if (ShapedType::isDynamicStrideOrOffset(a) || + ShapedType::isDynamicStrideOrOffset(b)) + return Wrapper(ShapedType::kDynamicStrideOrOffset); + return Wrapper(a.v * b); +} +} // end namespace saturated_arith +} // end namespace + +/// A subview result type can be fully inferred from the source type and the +/// static representation of offsets, sizes and strides. Special sentinels +/// encode the dynamic case. +static Type inferSubViewResultType(MemRefType sourceMemRefType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides) { + unsigned rank = sourceMemRefType.getRank(); + (void)rank; + assert(staticOffsets.size() == rank && + "unexpected staticOffsets size mismatch"); + assert(staticSizes.size() == rank && "unexpected staticSizes size mismatch"); + assert(staticStrides.size() == rank && + "unexpected staticStrides size mismatch"); + + // Extract source offset and strides. + int64_t sourceOffset; + SmallVector sourceStrides; + auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset); assert(succeeded(res) && "SubViewOp expected strided memref type"); (void)res; - // Assume sizes and offset are fully dynamic for now until canonicalization - // occurs on the ranges. Typed strides don't change though. - offset = MemRefType::getDynamicStrideOrOffset(); - // Overwrite strides because verifier will not pass. - // TODO(b/144419106): don't force degrade the strides to fully dynamic. - for (auto &stride : strides) - stride = MemRefType::getDynamicStrideOrOffset(); - auto stridedLayout = - makeStridedLinearLayoutMap(strides, offset, memRefType.getContext()); - SmallVector sizes(rank, ShapedType::kDynamicSize); - return MemRefType::Builder(memRefType) - .setShape(sizes) - .setAffineMaps(stridedLayout); + // Compute target offset whose value is: + // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`. + int64_t targetOffset = sourceOffset; + for (auto it : llvm::zip(staticOffsets, sourceStrides)) { + auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it); + using namespace saturated_arith; + targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride; + } + + // Compute target stride whose value is: + // `sourceStrides_i * staticStrides_i`. + SmallVector targetStrides; + targetStrides.reserve(staticOffsets.size()); + for (auto it : llvm::zip(sourceStrides, staticStrides)) { + auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it); + using namespace saturated_arith; + targetStrides.push_back(Wrapper(sourceStride) * staticStride); + } + + // The type is now known. + return MemRefType::get( + staticSizes, sourceMemRefType.getElementType(), + makeStridedLinearLayoutMap(targetStrides, targetOffset, + sourceMemRefType.getContext()), + sourceMemRefType.getMemorySpace()); +} + +/// Print SubViewOp in the form: +/// ``` +/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `to` strided-memref-type +/// ``` +static void print(OpAsmPrinter &p, SubViewOp op) { + int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; + p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; + p << op.getOperand(0); + printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); + p.printOptionalAttrDict(op.getAttrs(), + /*elidedAttrs=*/{SubViewOp::getSpecialAttrNames()}); + p << " : " << op.getOperand(0).getType() << " to " << op.getType(); +} + +/// Parse SubViewOp of the form: +/// ``` +/// subview ssa-name `[` offset-list `]` `[` size-list `]` `[` stride-list `]` +/// `:` strided-memref-type `to` strided-memref-type +/// ``` +static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { + OpAsmParser::OperandType srcInfo; + SmallVector offsetsInfo, sizesInfo, stridesInfo; + auto indexType = parser.getBuilder().getIndexType(); + Type srcType, dstType; + if (parser.parseOperand(srcInfo)) + return failure(); + if (parseListOfOperandsOrIntegers( + parser, result, SubViewOp::getStaticOffsetsAttrName(), + ShapedType::kDynamicStrideOrOffset, offsetsInfo) || + parseListOfOperandsOrIntegers(parser, result, + SubViewOp::getStaticSizesAttrName(), + ShapedType::kDynamicSize, sizesInfo) || + parseListOfOperandsOrIntegers( + parser, result, SubViewOp::getStaticStridesAttrName(), + ShapedType::kDynamicStrideOrOffset, stridesInfo)) + return failure(); + + auto b = parser.getBuilder(); + SmallVector segmentSizes{1, static_cast(offsetsInfo.size()), + static_cast(sizesInfo.size()), + static_cast(stridesInfo.size())}; + result.addAttribute(SubViewOp::getOperandSegmentSizeAttr(), + b.getI32VectorAttr(segmentSizes)); + + return failure( + parser.parseOptionalAttrDict(result.attributes) || + parser.parseColonType(srcType) || + parser.resolveOperand(srcInfo, srcType, result.operands) || + parser.resolveOperands(offsetsInfo, indexType, result.operands) || + parser.resolveOperands(sizesInfo, indexType, result.operands) || + parser.resolveOperands(stridesInfo, indexType, result.operands) || + parser.parseKeywordType("to", dstType) || + parser.addTypeToList(dstType, result.types)); } void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, - ValueRange offsets, ValueRange sizes, - ValueRange strides, Type resultType, + ArrayRef staticOffsets, + ArrayRef staticSizes, + ArrayRef staticStrides, ValueRange offsets, + ValueRange sizes, ValueRange strides, ArrayRef attrs) { - if (!resultType) - resultType = inferSubViewResultType(source.getType().cast()); - build(b, result, resultType, source, offsets, sizes, strides); + auto sourceMemRefType = source.getType().cast(); + auto resultType = inferSubViewResultType(sourceMemRefType, staticOffsets, + staticSizes, staticStrides); + build(b, result, resultType, source, offsets, sizes, strides, + b.getI64ArrayAttr(staticOffsets), b.getI64ArrayAttr(staticSizes), + b.getI64ArrayAttr(staticStrides)); result.addAttributes(attrs); } -void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, - Type resultType, Value source) { - build(b, result, source, /*offsets=*/{}, /*sizes=*/{}, /*strides=*/{}, - resultType); +/// Build a SubViewOp with all dynamic entries: `staticOffsets`, `staticSizes` +/// and `staticStrides` are automatically filled with source-memref-rank +/// sentinel values that encode dynamic entries. +void mlir::SubViewOp::build(OpBuilder &b, OperationState &result, Value source, + ValueRange offsets, ValueRange sizes, + ValueRange strides, + ArrayRef attrs) { + auto sourceMemRefType = source.getType().cast(); + unsigned rank = sourceMemRefType.getRank(); + SmallVector staticOffsetsVector; + staticOffsetsVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + SmallVector staticSizesVector; + staticSizesVector.assign(rank, ShapedType::kDynamicSize); + SmallVector staticStridesVector; + staticStridesVector.assign(rank, ShapedType::kDynamicStrideOrOffset); + build(b, result, source, staticOffsetsVector, staticSizesVector, + staticStridesVector, offsets, sizes, strides, attrs); +} + +/// Verify that a particular offset/size/stride static attribute is well-formed. +static LogicalResult +verifySubViewOpPart(SubViewOp op, StringRef name, StringRef attrName, + ArrayAttr attr, llvm::function_ref isDynamic, + ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != op.getRank()) + return op.emitError("expected ") + << op.getRank() << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); } +/// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { auto baseType = op.getBaseMemRefType().cast(); auto subViewType = op.getType(); - // The rank of the base and result subview must match. - if (baseType.getRank() != subViewType.getRank()) { - return op.emitError( - "expected rank of result type to match rank of base type "); - } - // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) return op.emitError("different memory spaces specified for base memref " @@ -2243,96 +2455,32 @@ << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. - int64_t baseOffset; - SmallVector baseStrides; - if (failed(getStridesAndOffset(baseType, baseStrides, baseOffset))) - return op.emitError("base type ") << subViewType << " is not strided"; - - // Verify that the result memref type has a strided layout map. - int64_t subViewOffset; - SmallVector subViewStrides; - if (failed(getStridesAndOffset(subViewType, subViewStrides, subViewOffset))) - return op.emitError("result type ") << subViewType << " is not strided"; - - // Num offsets should either be zero or rank of memref. - if (op.getNumOffsets() != 0 && op.getNumOffsets() != subViewType.getRank()) { - return op.emitError("expected number of dynamic offsets specified to match " - "the rank of the result type ") - << subViewType; - } - - // Num sizes should either be zero or rank of memref. - if (op.getNumSizes() != 0 && op.getNumSizes() != subViewType.getRank()) { - return op.emitError("expected number of dynamic sizes specified to match " - "the rank of the result type ") - << subViewType; - } + if (!isStrided(baseType)) + return op.emitError("base type ") << baseType << " is not strided"; - // Num strides should either be zero or rank of memref. - if (op.getNumStrides() != 0 && op.getNumStrides() != subViewType.getRank()) { - return op.emitError("expected number of dynamic strides specified to match " - "the rank of the result type ") - << subViewType; - } - - // Verify that if the shape of the subview type is static, then sizes are not - // dynamic values, and vice versa. - if ((subViewType.hasStaticShape() && op.getNumSizes() != 0) || - (op.getNumSizes() == 0 && !subViewType.hasStaticShape())) { - return op.emitError("invalid to specify dynamic sizes when subview result " - "type is statically shaped and viceversa"); - } - - // Verify that if dynamic sizes are specified, then the result memref type - // have full dynamic dimensions. - if (op.getNumSizes() > 0) { - if (llvm::any_of(subViewType.getShape(), [](int64_t dim) { - return dim != ShapedType::kDynamicSize; - })) { - // TODO: This is based on the assumption that number of size arguments are - // either 0, or the rank of the result type. It is possible to have more - // fine-grained verification where only particular dimensions are - // dynamic. That probably needs further changes to the shape op - // specification. - return op.emitError("expected shape of result type to be fully dynamic " - "when sizes are specified"); - } - } + // Verify static attributes offsets/sizes/strides. + if (failed(verifySubViewOpPart( + op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset, op.offsets()))) + return failure(); - // Verify that if dynamic offsets are specified or base memref has dynamic - // offset or base memref has dynamic strides, then the subview offset is - // dynamic. - if ((op.getNumOffsets() > 0 || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset())) && - subViewOffset != MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result memref layout map to have dynamic offset"); - } + if (failed(verifySubViewOpPart(op, "size", op.getStaticSizesAttrName(), + op.static_sizes(), ShapedType::isDynamic, + op.sizes()))) + return failure(); + if (failed(verifySubViewOpPart( + op, "stride", op.getStaticStridesAttrName(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset, op.strides()))) + return failure(); - // For now, verify that if dynamic strides are specified, then all the result - // memref type have dynamic strides. - if (op.getNumStrides() > 0) { - if (llvm::any_of(subViewStrides, [](int64_t stride) { - return stride != MemRefType::getDynamicStrideOrOffset(); - })) { - return op.emitError("expected result type to have dynamic strides"); - } - } + // Verify result type against inferred type. + auto expectedType = inferSubViewResultType( + op.getBaseMemRefType(), extractFromI64ArrayAttr(op.static_offsets()), + extractFromI64ArrayAttr(op.static_sizes()), + extractFromI64ArrayAttr(op.static_strides())); + if (op.getType() != expectedType) + return op.emitError("expected result type to be ") << expectedType; - // If any of the base memref has dynamic stride, then the corresponding - // stride of the subview must also have dynamic stride. - assert(baseStrides.size() == subViewStrides.size()); - for (auto stride : enumerate(baseStrides)) { - if (stride.value() == MemRefType::getDynamicStrideOrOffset() && - subViewStrides[stride.index()] != - MemRefType::getDynamicStrideOrOffset()) { - return op.emitError( - "expected result type to have dynamic stride along a dimension if " - "the base memref type has dynamic stride along that dimension"); - } - } return success(); } @@ -2351,39 +2499,52 @@ return res; } +static unsigned getNumDynamicEntriesUpToIdx( + ArrayAttr attr, llvm::function_ref isDynamic, unsigned idx) { + return std::count_if(attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); +} + +bool SubViewOp::isDynamicOffset(unsigned idx) { + return ShapedType::isDynamicStrideOrOffset( + extractFromI64ArrayAttr(static_offsets())[idx]); +} +bool SubViewOp::isDynamicSize(unsigned idx) { + return ShapedType::isDynamic(extractFromI64ArrayAttr(static_sizes())[idx]); +} +bool SubViewOp::isDynamicStride(unsigned idx) { + return ShapedType::isDynamicStrideOrOffset( + extractFromI64ArrayAttr(static_strides())[idx]); +} + +unsigned SubViewOp::getIndexOfDynamicOffset(unsigned idx) { + assert(isDynamicOffset(idx) && "expected static offset"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_offsets().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + numDynamic; +} +unsigned SubViewOp::getIndexOfDynamicSize(unsigned idx) { + assert(isDynamicSize(idx) && "expected static size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; +} +unsigned SubViewOp::getIndexOfDynamicStride(unsigned idx) { + assert(isDynamicStride(idx) && "expected static stride"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_strides().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + offsets().size() + sizes().size() + numDynamic; +} + LogicalResult SubViewOp::getStaticStrides(SmallVectorImpl &staticStrides) { - // If the strides are dynamic return failure. - if (getNumStrides()) - return failure(); - - // When static, the stride operands can be retrieved by taking the strides of - // the result of the subview op, and dividing the strides of the base memref. - int64_t resultOffset, baseOffset; - SmallVector resultStrides, baseStrides; - if (failed( - getStridesAndOffset(getBaseMemRefType(), baseStrides, baseOffset)) || - llvm::is_contained(baseStrides, MemRefType::getDynamicStrideOrOffset()) || - failed(getStridesAndOffset(getType(), resultStrides, resultOffset))) + if (!strides().empty()) return failure(); - - assert(static_cast(resultStrides.size()) == getType().getRank() && - baseStrides.size() == resultStrides.size() && - "base and result memrefs must have the same rank"); - assert(!llvm::is_contained(resultStrides, - MemRefType::getDynamicStrideOrOffset()) && - "strides of subview op must be static, when there are no dynamic " - "strides specified"); - staticStrides.resize(getType().getRank()); - for (auto resultStride : enumerate(resultStrides)) { - auto baseStride = baseStrides[resultStride.index()]; - // The result stride is expected to be a multiple of the base stride. Abort - // if that is not the case. - if (resultStride.value() < baseStride || - resultStride.value() % baseStride != 0) - return failure(); - staticStrides[resultStride.index()] = resultStride.value() / baseStride; - } + staticStrides = extractFromI64ArrayAttr(static_strides()); return success(); } @@ -2391,136 +2552,80 @@ namespace { -/// Pattern to rewrite a subview op with constant size arguments. -class SubViewOpShapeFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { - MemRefType subViewType = subViewOp.getType(); - // Follow all or nothing approach for shapes for now. If all the operands - // for sizes are constants then fold it into the type of the result memref. - if (subViewType.hasStaticShape() || - llvm::any_of(subViewOp.sizes(), [](Value operand) { - return !matchPattern(operand, m_ConstantIndex()); - })) { - return failure(); - } - SmallVector staticShape(subViewOp.getNumSizes()); - for (auto size : llvm::enumerate(subViewOp.sizes())) { - auto defOp = size.value().getDefiningOp(); - assert(defOp); - staticShape[size.index()] = cast(defOp).getValue(); +/// Take a list of `values` with potential new constant to extract and a list +/// of `constantValues` with`values.size()` sentinel that evaluate to true by +/// applying `isDynamic`. +/// Detects the `values` produced by a ConstantIndexOp and places the new +/// constant in place of the corresponding sentinel value. +void canonicalizeSubViewPart(SmallVectorImpl &values, + SmallVectorImpl &constantValues, + llvm::function_ref isDynamic) { + bool hasNewStaticValue = llvm::any_of( + values, [](Value val) { return matchPattern(val, m_ConstantIndex()); }); + if (hasNewStaticValue) { + for (unsigned cstIdx = 0, valIdx = 0, e = constantValues.size(); + cstIdx != e; ++cstIdx) { + // Was already static, skip. + if (!isDynamic(constantValues[cstIdx])) + continue; + // Newly static, move from Value to constant. + if (matchPattern(values[valIdx], m_ConstantIndex())) { + constantValues[cstIdx] = + cast(values[valIdx].getDefiningOp()).getValue(); + // Erase for impl. simplicity. Reverse iterator if we really must. + values.erase(std::next(values.begin(), valIdx)); + continue; + } + // Remains dynamic move to next value. + ++valIdx; } - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setShape(staticShape); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - ArrayRef(), subViewOp.strides(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); - return success(); } -}; +} -// Pattern to rewrite a subview op with constant stride arguments. -class SubViewOpStrideFolder final : public OpRewritePattern { +/// Pattern to rewrite a subview op with constant arguments. +class SubViewOpFolder final : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(SubViewOp subViewOp, PatternRewriter &rewriter) const override { - if (subViewOp.getNumStrides() == 0) { - return failure(); - } - // Follow all or nothing approach for strides for now. If all the operands - // for strides are constants then fold it into the strides of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - llvm::any_of(subViewOp.strides(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { + // No constant operand, just return; + if (llvm::none_of(subViewOp.getOperands(), [](Value operand) { + return matchPattern(operand, m_ConstantIndex()); + })) return failure(); - } - SmallVector staticStrides(subViewOp.getNumStrides()); - for (auto stride : llvm::enumerate(subViewOp.strides())) { - auto defOp = stride.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[stride.index()] > 0); - staticStrides[stride.index()] = - cast(defOp).getValue() * baseStrides[stride.index()]; - } - AffineMap layoutMap = makeStridedLinearLayoutMap( - staticStrides, resultOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); + // At least one of offsets/sizes/strides is a new constant. + // Form the new list of operands and constant attributes from the existing. + SmallVector newOffsets(subViewOp.offsets()); + SmallVector newStaticOffsets = + extractFromI64ArrayAttr(subViewOp.static_offsets()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newOffsets, newStaticOffsets, + ShapedType::isDynamicStrideOrOffset); + + SmallVector newSizes(subViewOp.sizes()); + SmallVector newStaticSizes = + extractFromI64ArrayAttr(subViewOp.static_sizes()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newSizes, newStaticSizes, ShapedType::isDynamic); + + SmallVector newStrides(subViewOp.strides()); + SmallVector newStaticStrides = + extractFromI64ArrayAttr(subViewOp.static_strides()); + assert(newStaticOffsets.size() == subViewOp.getRank()); + canonicalizeSubViewPart(newStrides, newStaticStrides, + ShapedType::isDynamicStrideOrOffset); + + // Create the new op in canonical form. auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), subViewOp.offsets(), - subViewOp.sizes(), ArrayRef(), newMemRefType); - // Insert a memref_cast for compatibility of the uses of the op. - rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, - subViewOp.getType()); - return success(); - } -}; - -// Pattern to rewrite a subview op with constant offset arguments. -class SubViewOpOffsetFolder final : public OpRewritePattern { -public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SubViewOp subViewOp, - PatternRewriter &rewriter) const override { - if (subViewOp.getNumOffsets() == 0) { - return failure(); - } - // Follow all or nothing approach for offsets for now. If all the operands - // for offsets are constants then fold it into the offset of the result - // memref. - int64_t baseOffset, resultOffset; - SmallVector baseStrides, resultStrides; - MemRefType subViewType = subViewOp.getType(); - if (failed(getStridesAndOffset(subViewOp.getBaseMemRefType(), baseStrides, - baseOffset)) || - failed(getStridesAndOffset(subViewType, resultStrides, resultOffset)) || - llvm::is_contained(baseStrides, - MemRefType::getDynamicStrideOrOffset()) || - baseOffset == MemRefType::getDynamicStrideOrOffset() || - llvm::any_of(subViewOp.offsets(), [](Value stride) { - return !matchPattern(stride, m_ConstantIndex()); - })) { - return failure(); - } - - auto staticOffset = baseOffset; - for (auto offset : llvm::enumerate(subViewOp.offsets())) { - auto defOp = offset.value().getDefiningOp(); - assert(defOp); - assert(baseStrides[offset.index()] > 0); - staticOffset += - cast(defOp).getValue() * baseStrides[offset.index()]; - } + subViewOp.getLoc(), subViewOp.source(), newStaticOffsets, + newStaticSizes, newStaticStrides, newOffsets, newSizes, newStrides); - AffineMap layoutMap = makeStridedLinearLayoutMap( - resultStrides, staticOffset, rewriter.getContext()); - MemRefType newMemRefType = - MemRefType::Builder(subViewType).setAffineMaps(layoutMap); - auto newSubViewOp = rewriter.create( - subViewOp.getLoc(), subViewOp.source(), ArrayRef(), - subViewOp.sizes(), subViewOp.strides(), newMemRefType); // Insert a memref_cast for compatibility of the uses of the op. rewriter.replaceOpWithNewOp(subViewOp, newSubViewOp, subViewOp.getType()); + return success(); } }; @@ -2633,8 +2738,7 @@ void SubViewOp::getCanonicalizationPatterns(OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); + results.insert(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -839,7 +839,7 @@ // CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG2:.*]]: !llvm.i32) -func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -883,7 +883,8 @@ // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * s1 + d1 * s2 + s0)>> + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref return } @@ -899,7 +900,7 @@ // CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, // CHECK32: %[[ARG2:.*]]: !llvm.i32) -func @subview_non_zero_addrspace(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -943,13 +944,14 @@ // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>, 3> to memref (d0 * s1 + d1 * s2 + s0)>, 3> + memref<64x4xf32, offset: 0, strides: [4, 1], 3> + to memref return } // CHECK-LABEL: func @subview_const_size( // CHECK32-LABEL: func @subview_const_size( -func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -961,17 +963,17 @@ // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BITCAST1]], %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i64 // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i64 // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -982,28 +984,29 @@ // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %[[BITCAST1]], %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i32 // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<4x2xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>> + %1 = subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref<4x2xf32, offset: ?, strides: [?, ?]> return } // CHECK-LABEL: func @subview_const_stride( // CHECK32-LABEL: func @subview_const_stride( -func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { +func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] @@ -1046,35 +1049,19 @@ // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * 4 + d1 * 2 + s0)>> + %1 = subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref return } // CHECK-LABEL: func @subview_const_stride_and_offset( // CHECK32-LABEL: func @subview_const_stride_and_offset( -func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) { +func @subview_const_stride_and_offset(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] - // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"float*"> - // CHECK: %[[DESC0:.*]] = llvm.insertvalue %[[BITCAST0]], %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[BITCAST1:.*]] = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"float*"> - // CHECK: %[[DESC1:.*]] = llvm.insertvalue %[[BITCAST1]], %[[DESC0]][1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) - // CHECK: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) - // CHECK: %[[CST8:.*]] = llvm.mlir.constant(8 : index) - // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) - // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> - // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"float*"> // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %[[BITCAST0]], %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> @@ -1082,18 +1069,64 @@ // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %[[BITCAST1]], %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) - // CHECK32: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) // CHECK32: %[[CST8:.*]] = llvm.mlir.constant(8 : index) // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> - %1 = subview %0[][][] : - memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>> + %1 = subview %0[0, 8][62, 3][1, 1] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref<62x3xf32, offset: 8, strides: [4, 1]> + return +} + +// CHECK-LABEL: func @subview_mixed_static_dynamic( +// CHECK-COUNT-2: !llvm<"float*">, +// CHECK-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i64, +// CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64, +// CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64, +// CHECK: %[[ARG2:.*]]: !llvm.i64) +// CHECK32-LABEL: func @subview_mixed_static_dynamic( +// CHECK32-COUNT-2: !llvm<"float*">, +// CHECK32-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i32, +// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, +// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, +// CHECK32: %[[ARG2:.*]]: !llvm.i32) +func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { + // The last "insertvalue" that populates the memref descriptor from the function arguments. + // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + + // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"float*"> + // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %[[BITCAST0]], %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[BITCAST1:.*]] = llvm.bitcast %{{.*}} : !llvm<"float*"> to !llvm<"float*"> + // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %[[BITCAST1]], %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFFM1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: %[[OFFA1:.*]] = llvm.add %[[OFF]], %[[OFFM1]] : !llvm.i32 + // CHECK32: %[[CST8:.*]] = llvm.mlir.constant(8 : i64) : !llvm.i32 + // CHECK32: %[[OFFM2:.*]] = llvm.mul %[[CST8]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[OFFA2:.*]] = llvm.add %[[OFFA1]], %[[OFFM2]] : !llvm.i32 + // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFFA2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) : !llvm.i32 + // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) : !llvm.i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + %1 = subview %0[%arg1, 8][62, %arg2][%arg0, 1] : + memref<64x4xf32, offset: 0, strides: [4, 1]> + to memref<62x?xf32, offset: ?, strides: [?, 1]> return } diff --git a/mlir/test/Conversion/StandardToLLVM/invalid.mlir b/mlir/test/Conversion/StandardToLLVM/invalid.mlir --- a/mlir/test/Conversion/StandardToLLVM/invalid.mlir +++ b/mlir/test/Conversion/StandardToLLVM/invalid.mlir @@ -7,7 +7,7 @@ %c0 = constant 0 : index // expected-error@+1 {{'std.memref_cast' op operand #0 must be unranked.memref of any type values or memref of any type values, but got '!llvm<"{ double*, double*, i64, [2 x i64], [2 x i64] }">'}} %5 = memref_cast %arg0 : memref to memref - %25 = std.subview %5[%c0, %c0][%c1, %c1][] : memref to memref + %25 = std.subview %5[%c0, %c0][%c1, %c1][1, 1] : memref to memref return } diff --git a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir --- a/mlir/test/Conversion/StandardToSPIRV/legalization.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/legalization.mlir @@ -11,7 +11,7 @@ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return %1 : f32 } @@ -25,7 +25,8 @@ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: load [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> + %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return %1 : f32 } @@ -41,7 +42,8 @@ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[C3]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: store [[ARG5]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> store %arg5, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> return } @@ -55,7 +57,8 @@ // CHECK: [[STRIDE2:%.*]] = muli [[ARG4]], [[ARG6]] : index // CHECK: [[INDEX2:%.*]] = addi [[ARG2]], [[STRIDE2]] : index // CHECK: store [[ARG7]], [[ARG0]]{{\[}}[[INDEX1]], [[INDEX2]]{{\]}} - %0 = subview %arg0[%arg1, %arg2][][%arg5, %arg6] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> + %0 = subview %arg0[%arg1, %arg2][4, 4][%arg5, %arg6] : + memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [?, ?]> store %arg7, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [?, ?]> return } diff --git a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/subview-to-spirv.mlir @@ -28,7 +28,7 @@ // CHECK: %[[T8:.*]] = muli %[[ARG4]], %[[C3]] // CHECK: %[[T9:.*]] = addi %[[ARG2]], %[[T8]] // CHECK store %[[STOREVAL]], %[[ARG0]][%[[T7]], %[[T9]]] - %0 = subview %arg0[%arg1, %arg2][][] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> + %0 = subview %arg0[%arg1, %arg2][4, 4][2, 3] : memref<12x32xf32> to memref<4x4xf32, offset:?, strides: [64, 3]> %1 = load %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> %2 = sqrt %1 : f32 store %2, %0[%arg3, %arg4] : memref<4x4xf32, offset:?, strides: [64, 3]> diff --git a/mlir/test/Dialect/Affine/ops.mlir b/mlir/test/Dialect/Affine/ops.mlir --- a/mlir/test/Dialect/Affine/ops.mlir +++ b/mlir/test/Dialect/Affine/ops.mlir @@ -103,8 +103,8 @@ affine.for %arg4 = 0 to %13 step 264 { %18 = dim %0, 0 : memref %20 = std.subview %0[%c0, %c0][%18,%arg4][%c1,%c1] : memref - to memref (d0 * s1 + d1 * s2 + s0)>> - %24 = dim %20, 0 : memref (d0 * s1 + d1 * s2 + s0)>> + to memref + %24 = dim %20, 0 : memref affine.for %arg5 = 0 to %24 step 768 { "foo"() : () -> () } diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -23,9 +23,9 @@ loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } @@ -88,9 +88,9 @@ loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } @@ -153,9 +153,9 @@ loop.for %arg4 = %c0 to %6 step %c2 { loop.for %arg5 = %c0 to %8 step %c3 { loop.for %arg6 = %c0 to %7 step %c4 { - %11 = std.subview %3[%arg4, %arg6][%c2, %c4][] : memref to memref - %14 = std.subview %4[%arg6, %arg5][%c4, %c3][] : memref to memref - %17 = std.subview %5[%arg4, %arg5][%c2, %c3][] : memref to memref + %11 = std.subview %3[%arg4, %arg6][%c2, %c4][1, 1] : memref to memref + %14 = std.subview %4[%arg6, %arg5][%c4, %c3][1, 1] : memref to memref + %17 = std.subview %5[%arg4, %arg5][%c2, %c3][1, 1] : memref to memref linalg.matmul(%11, %14, %17) : memref, memref, memref } } diff --git a/mlir/test/IR/core-ops.mlir b/mlir/test/IR/core-ops.mlir --- a/mlir/test/IR/core-ops.mlir +++ b/mlir/test/IR/core-ops.mlir @@ -10,15 +10,14 @@ // CHECK-DAG: #[[BASE_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)> // CHECK-DAG: #[[BASE_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> -// CHECK-DAG: #[[SUBVIEW_MAP0:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + d1 * s2 + d2 * s3 + s0)> // CHECK-DAG: #[[BASE_MAP1:map[0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)> // CHECK-DAG: #[[SUBVIEW_MAP1:map[0-9]+]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK-DAG: #[[BASE_MAP2:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 22 + d1)> -// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> -// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)> -// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[SUBVIEW_MAP2:map[0-9]+]] = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> +// CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2 + 8)> +// CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)> // CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 8 + s0 + d1 * 2)> // CHECK-LABEL: func @func_with_ops(%arg0: f32) { @@ -708,41 +707,56 @@ %c1 = constant 1 : index %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref + // CHECK: subview %0[%c0, %c0, %c0] [%arg0, %arg1, %arg2] [%c1, %c1, %c1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to + memref %2 = alloc()[%arg2] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> - // CHECK: subview %2[%c1] [%arg0] [%c1] : memref<64xf32, #[[BASE_MAP1]]> to memref + // CHECK: subview %2[%c1] [%arg0] [%c1] : + // CHECK-SAME: memref<64xf32, #[[BASE_MAP1]]> + // CHECK-SAME: to memref %3 = subview %2[%c1][%arg0][%c1] : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> to memref (d0 * s1 + s0)>> %4 = alloc() : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> - // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : memref<64x22xf32, #[[BASE_MAP2]]> to memref + // CHECK: subview %4[%c0, %c1] [%arg0, %arg1] [%c1, %c0] : + // CHECK-SAME: memref<64x22xf32, #[[BASE_MAP2]]> + // CHECK-SAME: to memref %5 = subview %4[%c0, %c1][%arg0, %arg1][%c1, %c0] - : memref<64x22xf32, affine_map<(d0, d1) -> (d0 * 22 + d1)>> to - memref (d0 * s1 + d1 * s2 + s0)>> + : memref<64x22xf32, offset:0, strides: [22, 1]> to + memref - // CHECK: subview %0[] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]> - %6 = subview %0[][][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref<4x4x4xf32, affine_map<(d0, d1, d2) -> (d0 * 16 + d1 * 4 + d2 + 8)>> + // CHECK: subview %0[0, 2, 0] [4, 4, 4] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<4x4x4xf32, #[[SUBVIEW_MAP3]]> + %6 = subview %0[0, 2, 0][4, 4, 4][1, 1, 1] + : memref<8x16x4xf32, offset:0, strides: [64, 4, 1]> to + memref<4x4x4xf32, offset:8, strides: [64, 4, 1]> %7 = alloc(%arg1, %arg2) : memref - // CHECK: subview {{%.*}}[] [] [] : memref to memref<4x4xf32, #[[SUBVIEW_MAP4]]> - %8 = subview %7[][][] - : memref to memref<4x4xf32, offset: ?, strides:[?, ?]> + // CHECK: subview {{%.*}}[0, 0] [4, 4] [1, 1] : + // CHECK-SAME: memref + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP4]]> + %8 = subview %7[0, 0][4, 4][1, 1] + : memref to memref<4x4xf32, offset: ?, strides:[?, 1]> %9 = alloc() : memref<16x4xf32> - // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [{{%.*}}, {{%.*}}] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP4]] - %10 = subview %9[%arg1, %arg1][][%arg2, %arg2] + // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [{{%.*}}, {{%.*}}] : + // CHECK-SAME: memref<16x4xf32> + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP2]] + %10 = subview %9[%arg1, %arg1][4, 4][%arg2, %arg2] : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[?, ?]> - // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [] [] : memref<16x4xf32> to memref<4x4xf32, #[[SUBVIEW_MAP5]] - %11 = subview %9[%arg1, %arg2][][] + + // CHECK: subview {{%.*}}[{{%.*}}, {{%.*}}] [4, 4] [2, 2] : + // CHECK-SAME: memref<16x4xf32> + // CHECK-SAME: to memref<4x4xf32, #[[SUBVIEW_MAP5]] + %11 = subview %9[%arg1, %arg2][4, 4][2, 2] : memref<16x4xf32> to memref<4x4xf32, offset: ?, strides:[8, 2]> + return } diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -976,10 +976,10 @@ // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> + %0 = alloc() : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> // expected-error@+1 {{different memory spaces}} - %1 = subview %0[][%arg2][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>, 2> to + %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] + : memref<8x16x4xf32, offset: 0, strides: [64, 4, 1], 2> to memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> return } @@ -987,22 +987,11 @@ // ----- func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> - // expected-error@+1 {{is not strided}} - %1 = subview %0[][%arg2][] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 + s0, d1, d2)>> - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> // expected-error@+1 {{is not strided}} - %1 = subview %0[][%arg2][] + %1 = subview %0[0, 0, 0][%arg2][1, 1, 1] : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 + d1, d1 + d2, d2)>> to - memref<8x?x4xf32, affine_map<(d0, d1, d2)[s0] -> (d0 * s0 + d1 * 4 + d2)>> + memref<8x?x4xf32, offset: 0, strides: [?, 4, 1]> return } @@ -1010,8 +999,8 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected number of dynamic offsets specified to match the rank of the result type}} - %1 = subview %0[%arg0, %arg1][%arg2][] + // expected-error@+1 {{expected 3 offset values}} + %1 = subview %0[%arg0, %arg1][%arg2][1, 1, 1] : memref<8x16x4xf32> to memref<8x?x4xf32, offset: 0, strides:[?, ?, 4]> return @@ -1021,7 +1010,7 @@ func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { %0 = alloc() : memref<8x16x4xf32> - // expected-error@+1 {{expected result type to have dynamic strides}} + // expected-error@+1 {{expected result type to be 'memref (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>'}} %1 = subview %0[%arg0, %arg1, %arg2][%arg0, %arg1, %arg2][%arg0, %arg1, %arg2] : memref<8x16x4xf32> to memref @@ -1030,106 +1019,6 @@ // ----- -func @invalid_subview(%arg0 : index, %arg1 : index, %arg2 : index) { - %0 = alloc() : memref<8x16x4xf32> - %c0 = constant 0 : index - %c1 = constant 1 : index - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %1 = subview %0[%c0, %c0, %c0][%arg0, %arg1, %arg2][%c1, %c1, %c1] - : memref<8x16x4xf32> to - memref - return -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected rank of result type to match rank of base type}} - %0 = subview %arg1[%arg0, %arg0][][%arg0, %arg0] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic offsets specified to match the rank of the result type}} - %0 = subview %arg1[%arg0][][] : memref to memref<4x4xf32, offset: ?, strides: [4, 1]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic sizes specified to match the rank of the result type}} - %0 = subview %arg1[][%arg0][] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{expected number of dynamic strides specified to match the rank of the result type}} - %0 = subview %arg1[][][%arg0] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}} - %0 = subview %arg1[][%arg0, %arg0][] : memref to memref<4x8xf32, offset: ?, strides: [?, ?]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - // expected-error@+1 {{invalid to specify dynamic sizes when subview result type is statically shaped and viceversa}} - %0 = subview %arg1[][][] : memref to memref -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[%arg0, %arg0][][] : memref<16x4xf32> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: ?, strides: [4, 1]>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: ?, strides: [4, 1]> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 8, strides:[?, 1]>) { - // expected-error@+1 {{expected result memref layout map to have dynamic offset}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: 8, strides:[?, 1]> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32>) { - // expected-error@+1 {{expected result type to have dynamic strides}} - %0 = subview %arg1[][][%arg0, %arg0] : memref<16x4xf32> to memref<4x2xf32> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref<16x4xf32, offset: 0, strides:[?, ?]>) { - // expected-error@+1 {{expected result type to have dynamic stride along a dimension if the base memref type has dynamic stride along that dimension}} - %0 = subview %arg1[][][] : memref<16x4xf32, offset: 0, strides:[?, ?]> to memref<4x2xf32, offset:?, strides:[2, 1]> -} - -// ----- - -func @invalid_subview(%arg0 : index, %arg1 : memref) { - %c0 = constant 0 : index - %c1 = constant 1 : index - // expected-error@+1 {{expected shape of result type to be fully dynamic when sizes are specified}} - %0 = subview %arg1[%c0, %c0, %c0][%c1, %arg0, %c1][%c1, %c1, %c1] : memref to memref - return -} - -// ----- - func @invalid_memref_cast(%arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]>) { // expected-error@+1{{operand type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 16 + d2)>>' and result type 'memref<12x4x16xf32, affine_map<(d0, d1, d2) -> (d0 * 128 + d1 * 32 + d2 * 2)>>' are cast incompatible}} %0 = memref_cast %arg0 : memref<12x4x16xf32, offset:0, strides:[64, 16, 1]> to memref<12x4x16xf32, offset:0, strides:[128, 32, 2]> diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir --- a/mlir/test/Transforms/canonicalize.mlir +++ b/mlir/test/Transforms/canonicalize.mlir @@ -427,7 +427,7 @@ return %c, %d : memref, memref } -#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)> +#map1 = affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)> #map2 = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s2 + d1 * s1 + d2 + s0)> // CHECK-LABEL: func @dim_op_fold(%arg0: index, %arg1: index, %arg2: index, @@ -684,106 +684,138 @@ // CHECK-DAG: #[[SUBVIEW_MAP3:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)> // CHECK-DAG: #[[SUBVIEW_MAP4:map[0-9]+]] = affine_map<(d0, d1, d2)[s0] -> (d0 * 128 + s0 + d1 * 28 + d2 * 11)> // CHECK-DAG: #[[SUBVIEW_MAP5:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2 + 79)> -// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> -// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)> +// CHECK-DAG: #[[SUBVIEW_MAP6:map[0-9]+]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2 * 2)> +// CHECK-DAG: #[[SUBVIEW_MAP7:map[0-9]+]] = affine_map<(d0, d1)[s0] -> (d0 * 4 + s0 + d1)> +// CHECK-DAG: #[[SUBVIEW_MAP8:map[0-9]+]] = affine_map<(d0, d1) -> (d0 * 4 + d1 + 12)> + // CHECK-LABEL: func @subview // CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index func @subview(%arg0 : index, %arg1 : index) -> (index, index) { // CHECK: %[[C0:.*]] = constant 0 : index %c0 = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index + // CHECK-NOT: constant 1 : index %c1 = constant 1 : index - // CHECK: %[[C2:.*]] = constant 2 : index + // CHECK-NOT: constant 2 : index %c2 = constant 2 : index + // Folded but reappears after subview folding into dim. // CHECK: %[[C7:.*]] = constant 7 : index %c7 = constant 7 : index + // Folded but reappears after subview folding into dim. // CHECK: %[[C11:.*]] = constant 11 : index %c11 = constant 11 : index + // CHECK-NOT: constant 15 : index %c15 = constant 15 : index // CHECK: %[[ALLOC0:.*]] = alloc() - %0 = alloc() : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> + %0 = alloc() : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> // Test: subview with constant base memref and constant operands is folded. // Note that the subview uses the base memrefs layout map because it used // zero offset and unit stride arguments. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[BASE_MAP0]]> + // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x2xf32, #[[BASE_MAP0]]> %1 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - %v0 = load %1[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - - // Test: subview with one dynamic operand should not be folded. - // CHECK: subview %[[ALLOC0]][%[[C0]], %[[ARG0]], %[[C0]]] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + %v0 = load %1[%c0, %c0, %c0] : memref + + // Test: subview with one dynamic operand can also be folded. + // CHECK: subview %[[ALLOC0]][0, %[[ARG0]], 0] [7, 11, 15] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x15xf32, #[[SUBVIEW_MAP0]]> %2 = subview %0[%c0, %arg0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %2[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %2[%c0, %c0, %c0] : memref // CHECK: %[[ALLOC1:.*]] = alloc(%[[ARG0]]) - %3 = alloc(%arg0) : memref (d0 * 64 + d1 * 4 + d2)>> + %3 = alloc(%arg0) : memref // Test: subview with constant operands but dynamic base memref is folded as long as the strides and offset of the base memref are static. - // CHECK: subview %[[ALLOC1]][] [] [] : memref to memref<7x11x15xf32, #[[BASE_MAP0]]> + // CHECK: subview %[[ALLOC1]][0, 0, 0] [7, 11, 15] [1, 1, 1] : + // CHECK-SAME: memref + // CHECK-SAME: to memref<7x11x15xf32, #[[BASE_MAP0]]> %4 = subview %3[%c0, %c0, %c0] [%c7, %c11, %c15] [%c1, %c1, %c1] - : memref (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %4[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref to + memref + store %v0, %4[%c0, %c0, %c0] : memref // Test: subview offset operands are folded correctly w.r.t. base strides. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP1]]> + // CHECK: subview %[[ALLOC0]][1, 2, 7] [7, 11, 2] [1, 1, 1] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP1]]> %5 = subview %0[%c1, %c2, %c7] [%c7, %c11, %c2] [%c1, %c1, %c1] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %5[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %5[%c0, %c0, %c0] : memref // Test: subview stride operands are folded correctly w.r.t. base strides. - // CHECK: subview %[[ALLOC0]][] [] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]> + // CHECK: subview %[[ALLOC0]][0, 0, 0] [7, 11, 2] [2, 7, 11] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> + // CHECK-SAME: to memref<7x11x2xf32, #[[SUBVIEW_MAP2]]> %6 = subview %0[%c0, %c0, %c0] [%c7, %c11, %c2] [%c2, %c7, %c11] - : memref<8x16x4xf32, affine_map<(d0, d1, d2) -> (d0 * 64 + d1 * 4 + d2)>> to - memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - store %v0, %6[%c0, %c0, %c0] : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + : memref<8x16x4xf32, offset : 0, strides : [64, 4, 1]> to + memref + store %v0, %6[%c0, %c0, %c0] : memref // Test: subview shape are folded, but offsets and strides are not even if base memref is static - // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> - %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to memref - store %v0, %10[%arg1, %arg1, %arg1] : memref + // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> + %10 = subview %0[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : + memref<8x16x4xf32, offset:0, strides:[64, 4, 1]> to + memref + store %v0, %10[%arg1, %arg1, %arg1] : + memref // Test: subview strides are folded, but offsets and shape are not even if base memref is static - // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref - store %v0, %11[%arg0, %arg0, %arg0] : memref + // CHECK: subview %[[ALLOC0]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 7, 11] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref to + memref + store %v0, %11[%arg0, %arg0, %arg0] : + memref // Test: subview offsets are folded, but strides and shape are not even if base memref is static - // CHECK: subview %[[ALLOC0]][] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : memref<8x16x4xf32, #[[BASE_MAP0]]> to memref to memref - store %v0, %13[%arg1, %arg1, %arg1] : memref + // CHECK: subview %[[ALLOC0]][1, 2, 7] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] : + // CHECK-SAME: memref<8x16x4xf32, #[[BASE_MAP0]]> to + // CHECK-SAME: memref to + memref + store %v0, %13[%arg1, %arg1, %arg1] : + memref // CHECK: %[[ALLOC2:.*]] = alloc(%[[ARG0]], %[[ARG0]], %[[ARG1]]) %14 = alloc(%arg0, %arg0, %arg1) : memref // Test: subview shape are folded, even if base memref is not static - // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> - %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : memref to memref + // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [7, 11, 2] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref to + // CHECK-SAME: memref<7x11x2xf32, #[[SUBVIEW_MAP3]]> + %15 = subview %14[%arg0, %arg0, %arg0] [%c7, %c11, %c2] [%arg1, %arg1, %arg1] : + memref to + memref store %v0, %15[%arg1, %arg1, %arg1] : memref - // TEST: subview strides are not folded when the base memref is not static - // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [%[[C2]], %[[C2]], %[[C2]]] : memref to memref to memref + // TEST: subview strides are folded, in the type only the most minor stride is folded. + // CHECK: subview %[[ALLOC2]][%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] [2, 2, 2] : + // CHECK-SAME: memref to + // CHECK-SAME: memref to + memref store %v0, %16[%arg0, %arg0, %arg0] : memref - // TEST: subview offsets are not folded when the base memref is not static - // CHECK: subview %[[ALLOC2]][%[[C1]], %[[C1]], %[[C1]]] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : memref to memref to memref + // TEST: subview offsets are folded but the type offset remains dynamic, when the base memref is not static + // CHECK: subview %[[ALLOC2]][1, 1, 1] [%[[ARG0]], %[[ARG0]], %[[ARG0]]] [%[[ARG1]], %[[ARG1]], %[[ARG1]]] : + // CHECK-SAME: memref to + // CHECK-SAME: memref to + memref store %v0, %17[%arg0, %arg0, %arg0] : memref // CHECK: %[[ALLOC3:.*]] = alloc() : memref<12x4xf32> @@ -791,20 +823,26 @@ %c4 = constant 4 : index // TEST: subview strides are maintained when sizes are folded - // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [] [] : memref<12x4xf32> to memref<2x4xf32, #[[SUBVIEW_MAP6]]> - %19 = subview %18[%arg1, %arg1] [%c2, %c4] [] : memref<12x4xf32> to memref + // CHECK: subview %[[ALLOC3]][%arg1, %arg1] [2, 4] [1, 1] : + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<2x4xf32, #[[SUBVIEW_MAP7]]> + %19 = subview %18[%arg1, %arg1] [%c2, %c4] [1, 1] : + memref<12x4xf32> to + memref store %v0, %19[%arg1, %arg1] : memref // TEST: subview strides and sizes are maintained when offsets are folded - // CHECK: subview %[[ALLOC3]][] [] [] : memref<12x4xf32> to memref<12x4xf32, #[[SUBVIEW_MAP7]]> - %20 = subview %18[%c2, %c4] [] [] : memref<12x4xf32> to memref<12x4xf32, offset: ?, strides:[4, 1]> + // CHECK: subview %[[ALLOC3]][2, 4] [12, 4] [1, 1] : + // CHECK-SAME: memref<12x4xf32> to + // CHECK-SAME: memref<12x4xf32, #[[SUBVIEW_MAP8]]> + %20 = subview %18[%c2, %c4] [12, 4] [1, 1] : + memref<12x4xf32> to + memref<12x4xf32, offset: ?, strides:[4, 1]> store %v0, %20[%arg1, %arg1] : memref<12x4xf32, offset: ?, strides:[4, 1]> // Test: dim on subview is rewritten to size operand. - %7 = dim %4, 0 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> - %8 = dim %4, 1 : memref (d0 * s1 + d1 * s2 + d2 * s3 + s0)>> + %7 = dim %4, 0 : memref + %8 = dim %4, 1 : memref // CHECK: return %[[C7]], %[[C11]] return %7, %8 : index, index @@ -898,7 +936,7 @@ func @memref_cast_folding_subview(%arg0: memref<4x5xf32>, %i: index) -> (memref) { %0 = memref_cast %arg0 : memref<4x5xf32> to memref // CHECK-NEXT: subview %{{.*}}: memref<4x5xf32> - %1 = subview %0[][%i,%i][]: memref to memref + %1 = subview %0[%i, %i][%i, %i][%i, %i]: memref to memref // CHECK-NEXT: return %{{.*}} return %1: memref }