diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -201,202 +201,6 @@ let hasCanonicalizer = 1; } -// Base class for ops with static/dynamic offset, sizes and strides -// attributes/arguments. -class BaseOpWithOffsetSizesAndStrides traits = []> : - Std_Op { - code extraBaseClassDeclaration = [{ - /// Returns the number of dynamic offset operands. - int64_t getNumOffsets() { return llvm::size(offsets()); } - - /// Returns the number of dynamic size operands. - int64_t getNumSizes() { return llvm::size(sizes()); } - - /// Returns the number of dynamic stride operands. - int64_t getNumStrides() { return llvm::size(strides()); } - - /// Returns the dynamic sizes for this subview operation if specified. - operand_range getDynamicSizes() { return sizes(); } - - /// Returns in `staticStrides` the static value of the stride - /// operands. Returns failure() if the static value of the stride - /// operands could not be retrieved. - LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { - if (!strides().empty()) - return failure(); - staticStrides.reserve(static_strides().size()); - for (auto s : static_strides().getAsValueRange()) - staticStrides.push_back(s.getZExtValue()); - return success(); - } - - /// Return the list of Range (i.e. offset, size, stride). Each - /// Range entry contains either the dynamic value or a ConstantIndexOp - /// constructed with `b` at location `loc`. - SmallVector getOrCreateRanges(OpBuilder &b, Location loc); - - /// Return the offsets as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1; - return llvm::to_vector<4>(llvm::map_range( - static_offsets().cast(), [&](Attribute a) -> Value { - int64_t staticOffset = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticOffset)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticOffset)); - })); - } - - /// Return the sizes as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed - /// with `b` at location `loc` - SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size(); - return llvm::to_vector<4>(llvm::map_range( - static_sizes().cast(), [&](Attribute a) -> Value { - int64_t staticSize = a.cast().getInt(); - if (ShapedType::isDynamic(staticSize)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticSize)); - })); - } - - /// Return the strides as Values. Each Value is either the dynamic - /// value specified in the op or a ConstantIndexOp constructed with - /// `b` at location `loc` - SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { - unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); - return llvm::to_vector<4>(llvm::map_range( - static_strides().cast(), [&](Attribute a) -> Value { - int64_t staticStride = a.cast().getInt(); - if (ShapedType::isDynamicStrideOrOffset(staticStride)) - return getOperand(dynamicIdx++); - else - return b.create( - loc, b.getIndexType(), b.getIndexAttr(staticStride)); - })); - } - - /// Return the rank of the source ShapedType. - unsigned getSourceRank() { - return source().getType().cast().getRank(); - } - - /// Return the rank of the result ShapedType. - unsigned getResultRank() { return getType().getRank(); } - - /// Return true if the offset `idx` is a static constant. - bool isDynamicOffset(unsigned idx) { - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - /// Return true if the size `idx` is a static constant. - bool isDynamicSize(unsigned idx) { - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return ShapedType::isDynamic(v.getSExtValue()); - } - - /// Return true if the stride `idx` is a static constant. - bool isDynamicStride(unsigned idx) { - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); - } - - /// Assert the offset `idx` is a static constant and return its value. - int64_t getStaticOffset(unsigned idx) { - assert(!isDynamicOffset(idx) && "expected static offset"); - APInt v = *(static_offsets().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the size `idx` is a static constant and return its value. - int64_t getStaticSize(unsigned idx) { - assert(!isDynamicSize(idx) && "expected static size"); - APInt v = *(static_sizes().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - /// Assert the stride `idx` is a static constant and return its value. - int64_t getStaticStride(unsigned idx) { - assert(!isDynamicStride(idx) && "expected static stride"); - APInt v = *(static_strides().getAsValueRange().begin() + idx); - return v.getSExtValue(); - } - - unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, - llvm::function_ref isDynamic, unsigned idx) { - return std::count_if( - attr.getValue().begin(), attr.getValue().begin() + idx, - [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - } - /// Assert the offset `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicOffset(unsigned idx) { - assert(isDynamicOffset(idx) && "expected static offset"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_offsets().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + numDynamic; - } - /// Assert the size `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicSize(unsigned idx) { - assert(isDynamicSize(idx) && "expected static size"); - auto numDynamic = getNumDynamicEntriesUpToIdx( - static_sizes().cast(), ShapedType::isDynamic, idx); - return 1 + offsets().size() + numDynamic; - } - /// Assert the stride `idx` is dynamic and return the position of the - /// corresponding operand. - unsigned getIndexOfDynamicStride(unsigned idx) { - assert(isDynamicStride(idx) && "expected static stride"); - auto numDynamic = - getNumDynamicEntriesUpToIdx(static_strides().cast(), - ShapedType::isDynamicStrideOrOffset, idx); - return 1 + offsets().size() + sizes().size() + numDynamic; - } - - /// Assert the offset `idx` is dynamic and return its value. - Value getDynamicOffset(unsigned idx) { - return getOperand(getIndexOfDynamicOffset(idx)); - } - /// Assert the size `idx` is dynamic and return its value. - Value getDynamicSize(unsigned idx) { - return getOperand(getIndexOfDynamicSize(idx)); - } - /// Assert the stride `idx` is dynamic and return its value. - Value getDynamicStride(unsigned idx) { - return getOperand(getIndexOfDynamicStride(idx)); - } - - static StringRef getStaticOffsetsAttrName() { - return "static_offsets"; - } - static StringRef getStaticSizesAttrName() { - return "static_sizes"; - } - static StringRef getStaticStridesAttrName() { - return "static_strides"; - } - static ArrayRef getSpecialAttrNames() { - static SmallVector names{ - getStaticOffsetsAttrName(), - getStaticSizesAttrName(), - getStaticStridesAttrName(), - getOperandSegmentSizeAttr()}; - return names; - } - }]; -} - - //===----------------------------------------------------------------------===// // AbsFOp //===----------------------------------------------------------------------===// @@ -2216,51 +2020,6 @@ }]; } -//===----------------------------------------------------------------------===// -// MemRefReinterpretCastOp -//===----------------------------------------------------------------------===// - -def MemRefReinterpretCastOp: - BaseOpWithOffsetSizesAndStrides<"memref_reinterpret_cast", [ - NoSideEffect, ViewLikeOpInterface - ]> { - let summary = "memref reinterpret cast operation"; - let description = [{ - Modify offset, sizes and strides of an unranked/ranked memref. - - Example: - ```mlir - memref_reinterpret_cast %ranked to - offset: [0], - sizes: [%size0, 10], - strides: [1, %stride1] - : memref to memref - - memref_reinterpret_cast %unranked to - offset: [%offset], - sizes: [%size0, %size1], - strides: [%stride0, %stride1] - : memref<*xf32> to memref - ``` - }]; - - let arguments = (ins - Arg:$source, - Variadic:$offsets, - Variadic:$sizes, - Variadic:$strides, - I64ArrayAttr:$static_offsets, - I64ArrayAttr:$static_sizes, - I64ArrayAttr:$static_strides - ); - let results = (outs AnyMemRef:$result); - let extraClassDeclaration = extraBaseClassDeclaration # [{ - // The result of the op is always a ranked memref. - MemRefType getType() { return getResult().getType().cast(); } - Value getViewSource() { return source(); } - }]; -} - //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2951,6 +2710,212 @@ // SubViewOp //===----------------------------------------------------------------------===// +class BaseOpWithOffsetSizesAndStrides traits = []> : + Std_Op { + let builders = [ + // Build a SubViewOp with mixed static and dynamic entries. + OpBuilder< + "Value source, ArrayRef staticOffsets, " + "ArrayRef staticSizes, ArrayRef staticStrides, " + "ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}">, + // Build a SubViewOp with all dynamic entries. + OpBuilder< + "Value source, ValueRange offsets, ValueRange sizes, ValueRange strides, " + "ArrayRef attrs = {}"> + ]; + + code extraBaseClassDeclaration = [{ + /// Returns the number of dynamic offset operands. + int64_t getNumOffsets() { return llvm::size(offsets()); } + + /// Returns the number of dynamic size operands. + int64_t getNumSizes() { return llvm::size(sizes()); } + + /// Returns the number of dynamic stride operands. + int64_t getNumStrides() { return llvm::size(strides()); } + + /// Returns the dynamic sizes for this subview operation if specified. + operand_range getDynamicSizes() { return sizes(); } + + /// Returns in `staticStrides` the static value of the stride + /// operands. Returns failure() if the static value of the stride + /// operands could not be retrieved. + LogicalResult getStaticStrides(SmallVectorImpl &staticStrides) { + if (!strides().empty()) + return failure(); + staticStrides.reserve(static_strides().size()); + for (auto s : static_strides().getAsValueRange()) + staticStrides.push_back(s.getZExtValue()); + return success(); + } + + /// Return the list of Range (i.e. offset, size, stride). Each + /// Range entry contains either the dynamic value or a ConstantIndexOp + /// constructed with `b` at location `loc`. + SmallVector getOrCreateRanges(OpBuilder &b, Location loc); + + /// Return the offsets as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed + /// with `b` at location `loc` + SmallVector getOrCreateOffsets(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1; + return llvm::to_vector<4>(llvm::map_range( + static_offsets().cast(), [&](Attribute a) -> Value { + int64_t staticOffset = a.cast().getInt(); + if (ShapedType::isDynamicStrideOrOffset(staticOffset)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticOffset)); + })); + } + + /// Return the sizes as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed + /// with `b` at location `loc` + SmallVector getOrCreateSizes(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1 + offsets().size(); + return llvm::to_vector<4>(llvm::map_range( + static_sizes().cast(), [&](Attribute a) -> Value { + int64_t staticSize = a.cast().getInt(); + if (ShapedType::isDynamic(staticSize)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticSize)); + })); + } + + /// Return the strides as Values. Each Value is either the dynamic + /// value specified in the op or a ConstantIndexOp constructed with + /// `b` at location `loc` + SmallVector getOrCreateStrides(OpBuilder &b, Location loc) { + unsigned dynamicIdx = 1 + offsets().size() + sizes().size(); + return llvm::to_vector<4>(llvm::map_range( + static_strides().cast(), [&](Attribute a) -> Value { + int64_t staticStride = a.cast().getInt(); + if (ShapedType::isDynamicStrideOrOffset(staticStride)) + return getOperand(dynamicIdx++); + else + return b.create( + loc, b.getIndexType(), b.getIndexAttr(staticStride)); + })); + } + + /// Return the rank of the source ShapedType. + unsigned getSourceRank() { + return source().getType().cast().getRank(); + } + + /// Return the rank of the result ShapedType. + unsigned getResultRank() { return getType().getRank(); } + + /// Return true if the offset `idx` is a static constant. + bool isDynamicOffset(unsigned idx) { + APInt v = *(static_offsets().getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + } + /// Return true if the size `idx` is a static constant. + bool isDynamicSize(unsigned idx) { + APInt v = *(static_sizes().getAsValueRange().begin() + idx); + return ShapedType::isDynamic(v.getSExtValue()); + } + + /// Return true if the stride `idx` is a static constant. + bool isDynamicStride(unsigned idx) { + APInt v = *(static_strides().getAsValueRange().begin() + idx); + return ShapedType::isDynamicStrideOrOffset(v.getSExtValue()); + } + + /// Assert the offset `idx` is a static constant and return its value. + int64_t getStaticOffset(unsigned idx) { + assert(!isDynamicOffset(idx) && "expected static offset"); + APInt v = *(static_offsets().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + /// Assert the size `idx` is a static constant and return its value. + int64_t getStaticSize(unsigned idx) { + assert(!isDynamicSize(idx) && "expected static size"); + APInt v = *(static_sizes().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + /// Assert the stride `idx` is a static constant and return its value. + int64_t getStaticStride(unsigned idx) { + assert(!isDynamicStride(idx) && "expected static stride"); + APInt v = *(static_strides().getAsValueRange().begin() + idx); + return v.getSExtValue(); + } + + unsigned getNumDynamicEntriesUpToIdx(ArrayAttr attr, + llvm::function_ref isDynamic, unsigned idx) { + return std::count_if( + attr.getValue().begin(), attr.getValue().begin() + idx, + [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + } + /// Assert the offset `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicOffset(unsigned idx) { + assert(isDynamicOffset(idx) && "expected static offset"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_offsets().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + numDynamic; + } + /// Assert the size `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicSize(unsigned idx) { + assert(isDynamicSize(idx) && "expected static size"); + auto numDynamic = getNumDynamicEntriesUpToIdx( + static_sizes().cast(), ShapedType::isDynamic, idx); + return 1 + offsets().size() + numDynamic; + } + /// Assert the stride `idx` is dynamic and return the position of the + /// corresponding operand. + unsigned getIndexOfDynamicStride(unsigned idx) { + assert(isDynamicStride(idx) && "expected static stride"); + auto numDynamic = + getNumDynamicEntriesUpToIdx(static_strides().cast(), + ShapedType::isDynamicStrideOrOffset, idx); + return 1 + offsets().size() + sizes().size() + numDynamic; + } + + /// Assert the offset `idx` is dynamic and return its value. + Value getDynamicOffset(unsigned idx) { + return getOperand(getIndexOfDynamicOffset(idx)); + } + /// Assert the size `idx` is dynamic and return its value. + Value getDynamicSize(unsigned idx) { + return getOperand(getIndexOfDynamicSize(idx)); + } + /// Assert the stride `idx` is dynamic and return its value. + Value getDynamicStride(unsigned idx) { + return getOperand(getIndexOfDynamicStride(idx)); + } + + static StringRef getStaticOffsetsAttrName() { + return "static_offsets"; + } + static StringRef getStaticSizesAttrName() { + return "static_sizes"; + } + static StringRef getStaticStridesAttrName() { + return "static_strides"; + } + static ArrayRef getSpecialAttrNames() { + static SmallVector names{ + getStaticOffsetsAttrName(), + getStaticSizesAttrName(), + getStaticStridesAttrName(), + getOperandSegmentSizeAttr()}; + return names; + } + }]; +} + def SubViewOp : BaseOpWithOffsetSizesAndStrides< "subview", [DeclareOpInterfaceMethods] > { let summary = "memref subview operation"; diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -2416,6 +2416,122 @@ } }; +struct MemRefReinterpretCastOpLowering + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto castOp = cast(op); + MemRefReinterpretCastOp::Adaptor adaptor(operands, op->getAttrDictionary()); + Type srcType = castOp.source().getType(); + + Value descriptor; + if (failed(ConvertSourceMemRefToDescriptor(rewriter, srcType, castOp, + adaptor, &descriptor))) + return failure(); + rewriter.replaceOp(op, {descriptor}); + return success(); + } + +private: + LogicalResult + ConvertSourceMemRefToDescriptor(ConversionPatternRewriter &rewriter, + Type srcType, MemRefReinterpretCastOp castOp, + MemRefReinterpretCastOp::Adaptor adaptor, + Value *descriptor) const { + MemRefType targetMemRefType = + castOp.getResult().getType().cast(); + auto llvmTargetDescriptorTy = typeConverter.convertType(targetMemRefType) + .dyn_cast_or_null(); + if (!llvmTargetDescriptorTy || !llvmTargetDescriptorTy.isStructTy()) + return failure(); + + // Create descriptor. + Location loc = castOp.getLoc(); + auto desc = MemRefDescriptor::undef(rewriter, loc, llvmTargetDescriptorTy); + + // Set allocated and aligned pointers. + Value allocatedPtr, alignedPtr; + ExtractPointers(loc, rewriter, castOp.source(), adaptor.source(), + &allocatedPtr, &alignedPtr); + desc.setAllocatedPtr(rewriter, loc, allocatedPtr); + desc.setAlignedPtr(rewriter, loc, alignedPtr); + + // Set offset. + if (castOp.isDynamicOffset(0)) { + desc.setOffset(rewriter, loc, adaptor.offsets()[0]); + } else { + desc.setConstantOffset(rewriter, loc, castOp.getStaticOffset(0)); + } + + // Set sizes and strides. + unsigned dynSizeId = 0; + unsigned dynStrideId = 0; + for (unsigned i = 0, e = targetMemRefType.getRank(); i < e; ++i) { + if (castOp.isDynamicSize(i)) { + desc.setSize(rewriter, loc, i, adaptor.sizes()[dynSizeId++]); + } else { + desc.setConstantSize(rewriter, loc, i, castOp.getStaticSize(i)); + } + if (castOp.isDynamicStride(i)) { + desc.setStride(rewriter, loc, i, adaptor.strides()[dynStrideId++]); + } else { + desc.setConstantStride(rewriter, loc, i, castOp.getStaticStride(i)); + } + } + *descriptor = desc; + return success(); + } + + void ExtractPointers(Location loc, ConversionPatternRewriter &rewriter, + Value originalOperand, Value convertedOperand, + Value *allocatedPtr, Value *alignedPtr) const { + Type operandType = originalOperand.getType(); + if (operandType.isa()) { + MemRefDescriptor desc(convertedOperand); + *allocatedPtr = desc.allocatedPtr(rewriter, loc); + *alignedPtr = desc.alignedPtr(rewriter, loc); + return; + } + + unsigned memorySpace = + operandType.cast().getMemorySpace(); + LLVM::LLVMType elementType = + typeConverter + .convertType( + operandType.cast().getElementType()) + .cast(); + LLVM::LLVMType elementPtrPtrType = + elementType.getPointerTo(memorySpace).getPointerTo(memorySpace); + + // Extract pointer to the underlying ranked memref descriptor and cast it to + // ElemType**. + UnrankedMemRefDescriptor unrankedDesc(convertedOperand); + Value underlyingDescPtr = unrankedDesc.memRefDescPtr(rewriter, loc); + Value elementPtrPtr = rewriter.create( + loc, elementPtrPtrType, underlyingDescPtr); + + LLVM::LLVMType int32Type = + typeConverter.convertType(rewriter.getI32Type()).cast(); + + // Extract and set allocated pointer. + Value zero = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(memorySpace)); + Value base_gep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({zero})); + *allocatedPtr = rewriter.create(loc, base_gep); + + // Extract and set aligned pointer. + Value one = rewriter.create( + loc, int32Type, rewriter.getI32IntegerAttr(1)); + Value aligned_gep = rewriter.create( + loc, elementPtrPtrType, elementPtrPtr, ValueRange({one})); + *alignedPtr = rewriter.create(loc, aligned_gep); + } +}; + struct DialectCastOpLowering : public ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; @@ -3532,6 +3648,7 @@ DimOpLowering, LoadOpLowering, MemRefCastOpLowering, + MemRefReinterpretCastOpLowering, RankOpLowering, StoreOpLowering, SubViewOpLowering, diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -261,126 +261,6 @@ [](APInt a, APInt b) { return a + b; }); } -//===----------------------------------------------------------------------===// -// BaseOpWithOffsetSizesAndStridesOp -//===----------------------------------------------------------------------===// - -/// Print a list with either (1) the static integer value in `arrayAttr` if -/// `isDynamic` evaluates to false or (2) the next value otherwise. -/// This allows idiomatic printing of mixed value and integer attributes in a -/// list. E.g. `[%arg0, 7, 42, %arg42]`. -static void -printListOfOperandsOrIntegers(OpAsmPrinter &p, ValueRange values, - ArrayAttr arrayAttr, - llvm::function_ref isDynamic) { - p << '['; - unsigned idx = 0; - llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { - int64_t val = a.cast().getInt(); - if (isDynamic(val)) - p << values[idx++]; - else - p << val; - }); - p << ']'; -} - -/// Parse a mixed list with either (1) static integer values or (2) SSA values. -/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` -/// encode the position of SSA values. Add the parsed SSA values to `ssa` -/// in-order. -// -/// E.g. after parsing "[%arg0, 7, 42, %arg42]": -/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" -/// 2. `ssa` is filled with "[%arg0, %arg1]". -static ParseResult -parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, - StringRef attrName, int64_t dynVal, - SmallVectorImpl &ssa) { - if (failed(parser.parseLSquare())) - return failure(); - // 0-D. - if (succeeded(parser.parseOptionalRSquare())) { - result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); - return success(); - } - - SmallVector attrVals; - while (true) { - OpAsmParser::OperandType operand; - auto res = parser.parseOptionalOperand(operand); - if (res.hasValue() && succeeded(res.getValue())) { - ssa.push_back(operand); - attrVals.push_back(dynVal); - } else { - IntegerAttr attr; - if (failed(parser.parseAttribute(attr))) - return parser.emitError(parser.getNameLoc()) - << "expected SSA value or integer"; - attrVals.push_back(attr.getInt()); - } - - if (succeeded(parser.parseOptionalComma())) - continue; - if (failed(parser.parseRSquare())) - return failure(); - break; - } - - auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); - result.addAttribute(attrName, arrayAttr); - return success(); -} - -/// Verify that a particular offset/size/stride static attribute is well-formed. -template -static LogicalResult verifyOpWithOffsetSizesAndStridesPart( - OpType op, StringRef name, unsigned expectedNumElements, StringRef attrName, - ArrayAttr attr, llvm::function_ref isDynamic, - ValueRange values) { - /// Check static and dynamic offsets/sizes/strides breakdown. - if (attr.size() != expectedNumElements) - return op.emitError("expected ") - << expectedNumElements << " " << name << " values"; - unsigned expectedNumDynamicEntries = - llvm::count_if(attr.getValue(), [&](Attribute attr) { - return isDynamic(attr.cast().getInt()); - }); - if (values.size() != expectedNumDynamicEntries) - return op.emitError("expected ") - << expectedNumDynamicEntries << " dynamic " << name << " values"; - return success(); -} - -/// Extract int64_t values from the assumed ArrayAttr of IntegerAttr. -static SmallVector extractFromI64ArrayAttr(Attribute attr) { - return llvm::to_vector<4>( - llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { - return a.cast().getInt(); - })); -} - -/// Verify static attributes offsets/sizes/strides. -template -static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { - unsigned srcRank = op.getSourceRank(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", srcRank, op.getStaticOffsetsAttrName(), - op.static_offsets(), ShapedType::isDynamicStrideOrOffset, - op.offsets()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", srcRank, op.getStaticSizesAttrName(), op.static_sizes(), - ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", srcRank, op.getStaticStridesAttrName(), - op.static_strides(), ShapedType::isDynamicStrideOrOffset, - op.strides()))) - return failure(); - return success(); -} - //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -2265,169 +2145,6 @@ return impl::foldCastOp(*this); } -//===----------------------------------------------------------------------===// -// MemRefReinterpretCastOp -//===----------------------------------------------------------------------===// - -/// Print of the form: -/// ``` -/// `name` ssa-name to -/// offset: `[` offset `]` -/// sizes: `[` size-list `]` -/// strides:`[` stride-list `]` -/// `:` any-memref-type to strided-memref-type -/// ``` -static void print(OpAsmPrinter &p, MemRefReinterpretCastOp op) { - int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1; - p << op.getOperationName().drop_front(stdDotLen) << " " << op.source() - << " to offset: "; - printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - p << ", sizes: "; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p << ", strides: "; - printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); - p.printOptionalAttrDict( - op.getAttrs(), - /*elidedAttrs=*/{MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - MemRefReinterpretCastOp::getStaticOffsetsAttrName(), - MemRefReinterpretCastOp::getStaticSizesAttrName(), - MemRefReinterpretCastOp::getStaticStridesAttrName()}); - p << ": " << op.source().getType() << " to " << op.getType(); -} - -/// Parse of the form: -/// ``` -/// `name` ssa-name to -/// offset: `[` offset `]` -/// sizes: `[` size-list `]` -/// strides:`[` stride-list `]` -/// `:` any-memref-type to strided-memref-type -/// ``` -static ParseResult parseMemRefReinterpretCastOp(OpAsmParser &parser, - OperationState &result) { - // Parse `operand` and `offset`. - OpAsmParser::OperandType operand; - if (parser.parseOperand(operand)) - return failure(); - - // Parse offset. - SmallVector offset; - if (parser.parseKeyword("to") || parser.parseKeyword("offset") || - parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticOffsetsAttrName(), - ShapedType::kDynamicStrideOrOffset, offset) || - parser.parseComma()) - return failure(); - - // Parse `sizes`. - SmallVector sizes; - if (parser.parseKeyword("sizes") || parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticSizesAttrName(), - ShapedType::kDynamicSize, sizes) || - parser.parseComma()) - return failure(); - - // Parse `strides`. - SmallVector strides; - if (parser.parseKeyword("strides") || parser.parseColon() || - parseListOfOperandsOrIntegers( - parser, result, MemRefReinterpretCastOp::getStaticStridesAttrName(), - ShapedType::kDynamicStrideOrOffset, strides)) - return failure(); - - // Handle segment sizes. - auto b = parser.getBuilder(); - SmallVector segmentSizes = {1, static_cast(offset.size()), - static_cast(sizes.size()), - static_cast(strides.size())}; - result.addAttribute(MemRefReinterpretCastOp::getOperandSegmentSizeAttr(), - - b.getI32VectorAttr(segmentSizes)); - - // Parse types and resolve. - Type indexType = b.getIndexType(); - Type operandType, resultType; - return failure( - (parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(operandType) || parser.parseKeyword("to") || - parser.parseType(resultType) || - parser.resolveOperand(operand, operandType, result.operands) || - parser.resolveOperands(offset, indexType, result.operands) || - parser.resolveOperands(sizes, indexType, result.operands) || - parser.resolveOperands(strides, indexType, result.operands) || - parser.addTypeToList(resultType, result.types))); -} - -static LogicalResult verify(MemRefReinterpretCastOp op) { - // The source and result memrefs should be in the same memory space. - auto srcType = op.source().getType().cast(); - auto resultType = op.getType().cast(); - if (srcType.getMemorySpace() != resultType.getMemorySpace()) - return op.emitError("different memory spaces specified for source type ") - << srcType << " and result memref type " << resultType; - if (srcType.getElementType() != resultType.getElementType()) - return op.emitError("different element types specified for source type ") - << srcType << " and result memref type " << resultType; - - // Verify that dynamic and static offset/sizes/strides arguments/attributes - // are consistent. - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "offset", 1, op.getStaticOffsetsAttrName(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset, op.offsets()))) - return failure(); - unsigned resultRank = op.getResultRank(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "size", resultRank, op.getStaticSizesAttrName(), - op.static_sizes(), ShapedType::isDynamic, op.sizes()))) - return failure(); - if (failed(verifyOpWithOffsetSizesAndStridesPart( - op, "stride", resultRank, op.getStaticStridesAttrName(), - op.static_strides(), ShapedType::isDynamicStrideOrOffset, - op.strides()))) - return failure(); - - // Extract source offset and strides. - int64_t resultOffset; - SmallVector resultStrides; - if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return failure(); - - // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); - if (resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") - << resultOffset << " instead of " << expectedOffset; - - // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultType.getShape(), - extractFromI64ArrayAttr(op.static_sizes())))) { - int64_t resultSize = std::get<0>(en.value()); - int64_t expectedSize = std::get<1>(en.value()); - if (resultSize != expectedSize) - return op.emitError("expected result type with size = ") - << expectedSize << " instead of " << resultSize - << " in dim = " << en.index(); - } - - // Match strides in result memref type and in static_strides attribute. - for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { - int64_t resultStride = std::get<0>(en.value()); - int64_t expectedStride = std::get<1>(en.value()); - if (resultStride != expectedStride) - return op.emitError("expected result type with stride = ") - << expectedStride << " instead of " << resultStride - << " in dim = " << en.index(); - } - return success(); -} - //===----------------------------------------------------------------------===// // MulFOp //===----------------------------------------------------------------------===// @@ -2825,6 +2542,75 @@ // SubViewOp //===----------------------------------------------------------------------===// +/// Print a list with either (1) the static integer value in `arrayAttr` if +/// `isDynamic` evaluates to false or (2) the next value otherwise. +/// This allows idiomatic printing of mixed value and integer attributes in a +/// list. E.g. `[%arg0, 7, 42, %arg42]`. +static void printSubViewListOfOperandsOrIntegers( + OpAsmPrinter &p, ValueRange values, ArrayAttr arrayAttr, + llvm::function_ref isDynamic) { + p << "["; + unsigned idx = 0; + llvm::interleaveComma(arrayAttr, p, [&](Attribute a) { + int64_t val = a.cast().getInt(); + if (isDynamic(val)) + p << values[idx++]; + else + p << val; + }); + p << "] "; +} + +/// Parse a mixed list with either (1) static integer values or (2) SSA values. +/// Fill `result` with the integer ArrayAttr named `attrName` where `dynVal` +/// encode the position of SSA values. Add the parsed SSA values to `ssa` +/// in-order. +// +/// E.g. after parsing "[%arg0, 7, 42, %arg42]": +/// 1. `result` is filled with the i64 ArrayAttr "[`dynVal`, 7, 42, `dynVal`]" +/// 2. `ssa` is filled with "[%arg0, %arg1]". +static ParseResult +parseListOfOperandsOrIntegers(OpAsmParser &parser, OperationState &result, + StringRef attrName, int64_t dynVal, + SmallVectorImpl &ssa) { + if (failed(parser.parseLSquare())) + return failure(); + // 0-D. + if (succeeded(parser.parseOptionalRSquare())) { + result.addAttribute(attrName, parser.getBuilder().getArrayAttr({})); + return success(); + } + + SmallVector attrVals; + while (true) { + OpAsmParser::OperandType operand; + auto res = parser.parseOptionalOperand(operand); + if (res.hasValue() && succeeded(res.getValue())) { + ssa.push_back(operand); + attrVals.push_back(dynVal); + } else { + Attribute attr; + NamedAttrList placeholder; + if (failed(parser.parseAttribute(attr, "_", placeholder)) || + !attr.isa()) + return parser.emitError(parser.getNameLoc()) + << "expected SSA value or integer"; + attrVals.push_back(attr.cast().getInt()); + } + + if (succeeded(parser.parseOptionalComma())) + continue; + if (failed(parser.parseRSquare())) + return failure(); + else + break; + } + + auto arrayAttr = parser.getBuilder().getI64ArrayAttr(attrVals); + result.addAttribute(attrName, arrayAttr); + return success(); +} + namespace { /// Helpers to write more idiomatic operations. namespace saturated_arith { @@ -2912,15 +2698,12 @@ p << op.getOperation()->getName().getStringRef().drop_front(stdDotLen) << ' '; p << op.source(); printExtraOperands(p, op); - printListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), - ShapedType::isDynamicStrideOrOffset); - p << ' '; - printListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), - ShapedType::isDynamic); - p << ' '; - printListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), - ShapedType::isDynamicStrideOrOffset); - p << ' '; + printSubViewListOfOperandsOrIntegers(p, op.offsets(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset); + printSubViewListOfOperandsOrIntegers(p, op.sizes(), op.static_sizes(), + ShapedType::isDynamic); + printSubViewListOfOperandsOrIntegers(p, op.strides(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset); p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{OpType::getSpecialAttrNames()}); p << " : " << op.getSourceType() << " " << resultTypeKeyword << " " @@ -3060,6 +2843,33 @@ /// For ViewLikeOpInterface. Value SubViewOp::getViewSource() { return source(); } +/// Verify that a particular offset/size/stride static attribute is well-formed. +template +static LogicalResult verifyOpWithOffsetSizesAndStridesPart( + OpType op, StringRef name, StringRef attrName, ArrayAttr attr, + llvm::function_ref isDynamic, ValueRange values) { + /// Check static and dynamic offsets/sizes/strides breakdown. + if (attr.size() != op.getSourceRank()) + return op.emitError("expected ") + << op.getSourceRank() << " " << name << " values"; + unsigned expectedNumDynamicEntries = + llvm::count_if(attr.getValue(), [&](Attribute attr) { + return isDynamic(attr.cast().getInt()); + }); + if (values.size() != expectedNumDynamicEntries) + return op.emitError("expected ") + << expectedNumDynamicEntries << " dynamic " << name << " values"; + return success(); +} + +/// Helper function extracts int64_t from the assumedArrayAttr of IntegerAttr. +static SmallVector extractFromI64ArrayAttr(Attribute attr) { + return llvm::to_vector<4>( + llvm::map_range(attr.cast(), [](Attribute a) -> int64_t { + return a.cast().getInt(); + })); +} + llvm::Optional> mlir::computeRankReductionMask(ArrayRef originalShape, ArrayRef reducedShape) { @@ -3195,6 +3005,24 @@ } } +template +static LogicalResult verifyOpWithOffsetSizesAndStrides(OpType op) { + // Verify static attributes offsets/sizes/strides. + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "offset", op.getStaticOffsetsAttrName(), op.static_offsets(), + ShapedType::isDynamicStrideOrOffset, op.offsets()))) + return failure(); + + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "size", op.getStaticSizesAttrName(), op.static_sizes(), + ShapedType::isDynamic, op.sizes()))) + return failure(); + if (failed(verifyOpWithOffsetSizesAndStridesPart( + op, "stride", op.getStaticStridesAttrName(), op.static_strides(), + ShapedType::isDynamicStrideOrOffset, op.strides()))) + return failure(); + return success(); +} /// Verifier for SubViewOp. static LogicalResult verify(SubViewOp op) { diff --git a/mlir/test/Dialect/Standard/invalid.mlir b/mlir/test/Dialect/Standard/invalid.mlir --- a/mlir/test/Dialect/Standard/invalid.mlir +++ b/mlir/test/Dialect/Standard/invalid.mlir @@ -102,82 +102,3 @@ // expected-error @+1 {{output type 'memref (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref (d0 * s1 + s0 + d1)>>'}} transpose %v (i, j) -> (j, i) : memref(off + M * i + j)>> to memref(off + M * i + j)>> } - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_too_many_offsets -func @memref_reinterpret_cast_too_many_offsets(%in: memref) { - // expected-error @+1 {{expected 1 offset values}} - %out = memref_reinterpret_cast %in to - offset: [0, 0], sizes: [10, 10], strides: [10, 1] - : memref to memref<10x10xf32, offset: 0, strides: [10, 1]> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_element_types -func @memref_reinterpret_cast_incompatible_element_types(%in: memref<*xf32>) { - // expected-error @+1 {{different element types specified}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<10xi32, offset: 0, strides: [1]> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_incompatible_memory_space -func @memref_reinterpret_cast_incompatible_memory_space(%in: memref<*xf32>) { - // expected-error @+1 {{different memory spaces specified}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<10xi32, offset: 0, strides: [1], 2> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_offset_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - // expected-error @+1 {{expected result type with offset = 0 instead of 1}} - %out = memref_reinterpret_cast %in to - offset: [1], sizes: [10], strides: [1] - : memref to memref<10xf32> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_size_mismatch -func @memref_reinterpret_cast_size_mismatch(%in: memref<*xf32>) { - // expected-error @+1 {{expected result type with size = 10 instead of 1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [1] - : memref<*xf32> to memref<1xf32, offset: 0, strides: [1]> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_stride_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - // expected-error @+1 {{expected result type with stride = 2 instead of 1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [0], sizes: [10], strides: [2] - : memref to memref<10xf32> - return -} - -// ----- - -// CHECK-LABEL: func @memref_reinterpret_cast_dynamic_size_mismatch -func @memref_reinterpret_cast_offset_mismatch(%in: memref) { - %c0 = constant 0 : index - %c10 = constant 10 : index - // expected-error @+1 {{expected result type with size = 10 instead of -1 in dim = 0}} - %out = memref_reinterpret_cast %in to - offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] - : memref to memref - return -} diff --git a/mlir/test/Dialect/Standard/ops.mlir b/mlir/test/Dialect/Standard/ops.mlir --- a/mlir/test/Dialect/Standard/ops.mlir +++ b/mlir/test/Dialect/Standard/ops.mlir @@ -54,14 +54,3 @@ %result = atan2 %arg0, %arg1 : f32 return %result : f32 } - -// CHECK-LABEL: func @memref_reinterpret_cast -func @memref_reinterpret_cast(%in: memref) - -> memref<10x?xf32, offset: ?, strides: [?, 1]> { - %c0 = constant 0 : index - %c10 = constant 10 : index - %out = memref_reinterpret_cast %in to - offset: [%c0], sizes: [10, %c10], strides: [%c10, 1] - : memref to memref<10x?xf32, offset: ?, strides: [?, 1]> - return %out : memref<10x?xf32, offset: ?, strides: [?, 1]> -} diff --git a/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/mlir-cpu-runner/memref_reinterpret_cast.mlir @@ -0,0 +1,103 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-std-to-llvm | mlir-cpu-runner -e main -entry-point-result=void -shared-libs=%mlir_runner_utils_dir/libmlir_runner_utils%shlibext,%mlir_runner_utils_dir/libmlir_c_runner_utils%shlibext | FileCheck %s + +func @print_memref_f32(memref<*xf32>) attributes { llvm.emit_c_interface } + +func @main() -> () { + %c0 = constant 0 : index + %c1 = constant 1 : index + + // Initialize input. + %input = alloc() : memref<2x3xf32> + %dim_x = dim %input, %c0 : memref<2x3xf32> + %dim_y = dim %input, %c1 : memref<2x3xf32> + scf.parallel (%i, %j) = (%c0, %c0) to (%dim_x, %dim_y) step (%c1, %c1) { + %prod = muli %i, %dim_y : index + %val = addi %prod, %j : index + %val_i64 = index_cast %val : index to i64 + %val_f32 = sitofp %val_i64 : i64 to f32 + store %val_f32, %input[%i, %j] : memref<2x3xf32> + } + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + call @print_memref_f32(%unranked_input) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [2, 3] strides = [3, 1] + // CHECK-NEXT: [0, 1, 2] + // CHECK-NEXT: [3, 4, 5] + + // Test cases. + call @cast_ranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_ranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_static_shape(%input) : (memref<2x3xf32>) -> () + call @cast_unranked_memref_to_dynamic_shape(%input) : (memref<2x3xf32>) -> () + return +} + +func @cast_ranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %output = memref_reinterpret_cast %input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<2x3xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_ranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<2x3xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +} + +func @cast_unranked_memref_to_static_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %output = memref_reinterpret_cast %unranked_input to + offset: [0], sizes: [6, 1], strides: [1, 1] + : memref<*xf32> to memref<6x1xf32> + + %unranked_output = memref_cast %output + : memref<6x1xf32> to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [6, 1] strides = [1, 1] data = + // CHECK-NEXT: [0], + // CHECK-NEXT: [1], + // CHECK-NEXT: [2], + // CHECK-NEXT: [3], + // CHECK-NEXT: [4], + // CHECK-NEXT: [5] + return +} + +func @cast_unranked_memref_to_dynamic_shape(%input : memref<2x3xf32>) { + %unranked_input = memref_cast %input : memref<2x3xf32> to memref<*xf32> + %c0 = constant 0 : index + %c1 = constant 1 : index + %c6 = constant 6 : index + %output = memref_reinterpret_cast %unranked_input to + offset: [%c0], sizes: [%c1, %c6], strides: [%c6, %c1] + : memref<*xf32> to memref + + %unranked_output = memref_cast %output + : memref to memref<*xf32> + call @print_memref_f32(%unranked_output) : (memref<*xf32>) -> () + // CHECK: rank = 2 offset = 0 sizes = [1, 6] strides = [6, 1] data = + // CHECK-NEXT: [0, 1, 2, 3, 4, 5] + return +} +