diff --git a/mlir/docs/Tutorials/Toy/Ch-5.md b/mlir/docs/Tutorials/Toy/Ch-5.md --- a/mlir/docs/Tutorials/Toy/Ch-5.md +++ b/mlir/docs/Tutorials/Toy/Ch-5.md @@ -70,9 +70,14 @@ // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as *legal*. + // to lower, `toy.print`, as *legal*. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. target.addIllegalDialect(); - target.addLegalOp(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); + }); ... } ``` diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp @@ -197,6 +197,24 @@ } }; +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Return operations //===----------------------------------------------------------------------===// @@ -294,15 +312,21 @@ // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. target.addIllegalDialect(); - target.addLegalOp(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); + }); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the Toy operations. RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + PrintOpLowering, ReturnOpLowering, TransposeOpLowering>( + &getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp @@ -57,9 +57,9 @@ /// induction variables for the iteration. It returns a value to store at the /// current index of the iteration. using LoopIterationFn = function_ref; + OpBuilder &rewriter, ValueRange memRefOperands, ValueRange loopIvs)>; -static void lowerOpToLoops(Operation *op, ArrayRef operands, +static void lowerOpToLoops(Operation *op, ValueRange operands, PatternRewriter &rewriter, LoopIterationFn processIteration) { auto tensorType = (*op->result_type_begin()).cast(); @@ -162,6 +162,7 @@ constantIndices.push_back( rewriter.create(loc, 0)); } + // The constant operation represents a multi-dimensional constant, so we // will need to generate a store for each of the elements. The following // functor recursively walks the dimensions of the constant shape, @@ -196,6 +197,24 @@ } }; +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Return operations //===----------------------------------------------------------------------===// @@ -293,15 +312,21 @@ // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. target.addIllegalDialect(); - target.addLegalOp(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); + }); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the Toy operations. RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + PrintOpLowering, ReturnOpLowering, TransposeOpLowering>( + &getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp --- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp +++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp @@ -197,6 +197,24 @@ } }; +//===----------------------------------------------------------------------===// +// ToyToAffine RewritePatterns: Print operations +//===----------------------------------------------------------------------===// + +struct PrintOpLowering : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(toy::PrintOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + // We don't lower "toy.print" in this pass, but we need to update its + // operands. + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; + //===----------------------------------------------------------------------===// // ToyToAffine RewritePatterns: Return operations //===----------------------------------------------------------------------===// @@ -294,15 +312,21 @@ // We also define the Toy dialect as Illegal so that the conversion will fail // if any of these operations are *not* converted. Given that we actually want // a partial lowering, we explicitly mark the Toy operations that don't want - // to lower, `toy.print`, as `legal`. + // to lower, `toy.print`, as `legal`. `toy.print` will still need its operands + // to be updated though (as we convert from TensorType to MemRefType), so we + // only treat it as `legal` if its operands are legal. target.addIllegalDialect(); - target.addLegalOp(); + target.addDynamicallyLegalOp([](toy::PrintOp op) { + return llvm::none_of(op->getOperandTypes(), + [](Type type) { return type.isa(); }); + }); // Now that the conversion target has been defined, we just need to provide // the set of patterns that will lower the Toy operations. RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + PrintOpLowering, ReturnOpLowering, TransposeOpLowering>( + &getContext()); // With the target and rewrite patterns defined, we can now attempt the // conversion. The conversion will signal failure if any of our `illegal` diff --git a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h --- a/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h +++ b/mlir/include/mlir/Conversion/SPIRVToLLVM/SPIRVToLLVM.h @@ -25,7 +25,7 @@ public: SPIRVToLLVMConversion(MLIRContext *context, LLVMTypeConverter &typeConverter, PatternBenefit benefit = 1) - : OpConversionPattern(context, benefit), + : OpConversionPattern(typeConverter, context, benefit), typeConverter(typeConverter) {} protected: diff --git a/mlir/include/mlir/IR/BlockAndValueMapping.h b/mlir/include/mlir/IR/BlockAndValueMapping.h --- a/mlir/include/mlir/IR/BlockAndValueMapping.h +++ b/mlir/include/mlir/IR/BlockAndValueMapping.h @@ -27,10 +27,8 @@ public: /// Inserts a new mapping for 'from' to 'to'. If there is an existing mapping, /// it is overwritten. - void map(Block *from, Block *to) { valueMap[from] = to; } - void map(Value from, Value to) { - valueMap[from.getAsOpaquePointer()] = to.getAsOpaquePointer(); - } + void map(Block *from, Block *to) { blockMap[from] = to; } + void map(Value from, Value to) { valueMap[from] = to; } template < typename S, typename T, @@ -42,14 +40,12 @@ } /// Erases a mapping for 'from'. - void erase(Block *from) { valueMap.erase(from); } - void erase(Value from) { valueMap.erase(from.getAsOpaquePointer()); } + void erase(Block *from) { blockMap.erase(from); } + void erase(Value from) { valueMap.erase(from); } /// Checks to see if a mapping for 'from' exists. - bool contains(Block *from) const { return valueMap.count(from); } - bool contains(Value from) const { - return valueMap.count(from.getAsOpaquePointer()); - } + bool contains(Block *from) const { return blockMap.count(from); } + bool contains(Value from) const { return valueMap.count(from); } /// Lookup a mapped value within the map. If a mapping for the provided value /// does not exist then return nullptr. @@ -76,28 +72,26 @@ /// Clears all mappings held by the mapper. void clear() { valueMap.clear(); } - /// Returns a new mapper containing the inverse mapping. - BlockAndValueMapping getInverse() const { - BlockAndValueMapping result; - for (const auto &pair : valueMap) - result.valueMap.try_emplace(pair.second, pair.first); - return result; - } + /// Return the held value mapping. + const DenseMap &getValueMap() const { return valueMap; } + + /// Return the held block mapping. + const DenseMap &getBlockMap() const { return blockMap; } private: /// Utility lookupOrValue that looks up an existing key or returns the /// provided value. Block *lookupOrValue(Block *from, Block *value) const { - auto it = valueMap.find(from); - return it != valueMap.end() ? reinterpret_cast(it->second) : value; + auto it = blockMap.find(from); + return it != blockMap.end() ? it->second : value; } Value lookupOrValue(Value from, Value value) const { - auto it = valueMap.find(from.getAsOpaquePointer()); - return it != valueMap.end() ? Value::getFromOpaquePointer(it->second) - : value; + auto it = valueMap.find(from); + return it != valueMap.end() ? it->second : value; } - DenseMap valueMap; + DenseMap valueMap; + DenseMap blockMap; }; } // end namespace mlir diff --git a/mlir/include/mlir/IR/Operation.h b/mlir/include/mlir/IR/Operation.h --- a/mlir/include/mlir/IR/Operation.h +++ b/mlir/include/mlir/IR/Operation.h @@ -151,21 +151,8 @@ /// Replace all uses of results of this operation with the provided 'values'. template - std::enable_if_t::value> - replaceAllUsesWith(ValuesT &&values) { - assert(std::distance(values.begin(), values.end()) == getNumResults() && - "expected 'values' to correspond 1-1 with the number of results"); - - auto valueIt = values.begin(); - for (unsigned i = 0, e = getNumResults(); i != e; ++i) - getResult(i).replaceAllUsesWith(*(valueIt++)); - } - - /// Replace all uses of results of this operation with results of 'op'. - void replaceAllUsesWith(Operation *op) { - assert(getNumResults() == op->getNumResults()); - for (unsigned i = 0, e = getNumResults(); i != e; ++i) - getResult(i).replaceAllUsesWith(op->getResult(i)); + void replaceAllUsesWith(ValuesT &&values) { + getResults().replaceAllUsesWith(std::forward(values)); } /// Destroys this operation and its subclass data. diff --git a/mlir/include/mlir/IR/OperationSupport.h b/mlir/include/mlir/IR/OperationSupport.h --- a/mlir/include/mlir/IR/OperationSupport.h +++ b/mlir/include/mlir/IR/OperationSupport.h @@ -903,6 +903,7 @@ ResultRange, detail::OpResultImpl *, OpResult, OpResult, OpResult> { public: using RangeBaseT::RangeBaseT; + ResultRange(OpResult result); //===--------------------------------------------------------------------===// // Types @@ -934,6 +935,22 @@ [](OpResult result) { return result.use_empty(); }); } + /// Replace all uses of results of this range with the provided 'values'. The + /// size of `values` must match the size of this range. + template + std::enable_if_t::value> + replaceAllUsesWith(ValuesT &&values) { + assert(static_cast(std::distance(values.begin(), values.end())) == + size() && + "expected 'values' to correspond 1-1 with the number of results"); + + for (auto it : llvm::zip(*this, values)) + std::get<0>(it).replaceAllUsesWith(std::get<1>(it)); + } + + /// Replace all uses of results of this range with results of 'op'. + void replaceAllUsesWith(Operation *op); + //===--------------------------------------------------------------------===// // Users //===--------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -118,9 +118,8 @@ /// must return a Value of the converted type on success, an `llvm::None` if /// it failed but other materialization can be attempted, and `nullptr` on /// unrecoverable failure. It will only be called for (sub)types of `T`. - /// Materialization functions must be provided when a type conversion - /// results in more than one type, or if a type conversion may persist after - /// the conversion has finished. + /// Materialization functions must be provided when a type conversion may + /// persist after the conversion has finished. /// /// This method registers a materialization that will be called when /// converting an illegal block argument type, to a legal type. @@ -551,10 +550,17 @@ /// Replace all the uses of the block argument `from` with value `to`. void replaceUsesOfBlockArgument(BlockArgument from, Value to); - /// Return the converted value that replaces 'key'. Return 'key' if there is - /// no such a converted value. + /// Return the converted value of 'key' with a type defined by the type + /// converter of the currently executing pattern. Return nullptr in the case + /// of failure, the remapped value otherwise. Value getRemappedValue(Value key); + /// Return the converted values that replace 'keys' with types defined by the + /// type converter of the currently executing pattern. Returns failure if the + /// remap failed, success otherwise. + LogicalResult getRemappedValues(ValueRange keys, + SmallVectorImpl &results); + //===--------------------------------------------------------------------===// // PatternRewriter Hooks //===--------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp --- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp +++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp @@ -56,7 +56,7 @@ public: SCFToSPIRVPattern(MLIRContext *context, SPIRVTypeConverter &converter, ScfToSPIRVContextImpl *scfToSPIRVContext) - : OpConversionPattern::OpConversionPattern(context), + : OpConversionPattern::OpConversionPattern(converter, context), scfToSPIRVContext(scfToSPIRVContext), typeConverter(converter) {} protected: diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp --- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp +++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp @@ -218,7 +218,7 @@ } /// Utility for `spv.Load` and `spv.Store` conversion. -static LogicalResult replaceWithLoadOrStore(Operation *op, +static LogicalResult replaceWithLoadOrStore(Operation *op, ValueRange operands, ConversionPatternRewriter &rewriter, LLVMTypeConverter &typeConverter, unsigned alignment, bool isVolatile, @@ -228,12 +228,14 @@ if (!dstType) return failure(); rewriter.replaceOpWithNewOp( - loadOp, dstType, loadOp.ptr(), alignment, isVolatile, isNonTemporal); + loadOp, dstType, spirv::LoadOpAdaptor(operands).ptr(), alignment, + isVolatile, isNonTemporal); return success(); } auto storeOp = cast(op); - rewriter.replaceOpWithNewOp(storeOp, storeOp.value(), - storeOp.ptr(), alignment, + spirv::StoreOpAdaptor adaptor(operands); + rewriter.replaceOpWithNewOp(storeOp, adaptor.value(), + adaptor.ptr(), alignment, isVolatile, isNonTemporal); return success(); } @@ -308,7 +310,7 @@ if (!dstType) return failure(); // To use GEP we need to add a first 0 index to go through the pointer. - auto indices = llvm::to_vector<4>(op.indices()); + auto indices = llvm::to_vector<4>(adaptor.indices()); Type indexType = op.indices().front().getType(); auto llvmIndexType = typeConverter.convertType(indexType); if (!llvmIndexType) @@ -316,7 +318,7 @@ Value zero = rewriter.create( op.getLoc(), llvmIndexType, rewriter.getIntegerAttr(indexType, 0)); indices.insert(indices.begin(), zero); - rewriter.replaceOpWithNewOp(op, dstType, op.base_ptr(), + rewriter.replaceOpWithNewOp(op, dstType, adaptor.base_ptr(), indices); return success(); } @@ -572,11 +574,11 @@ IntegerAttr value = op.indices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( - op, dstType, op.composite(), index); + op, dstType, adaptor.composite(), index); return success(); } rewriter.replaceOpWithNewOp( - op, dstType, op.composite(), op.indices()); + op, dstType, adaptor.composite(), op.indices()); return success(); } }; @@ -602,11 +604,11 @@ IntegerAttr value = op.indices()[0].cast(); Value index = createI32ConstantOf(loc, rewriter, value.getInt()); rewriter.replaceOpWithNewOp( - op, dstType, op.composite(), op.object(), index); + op, dstType, adaptor.composite(), adaptor.object(), index); return success(); } rewriter.replaceOpWithNewOp( - op, dstType, op.composite(), op.object(), op.indices()); + op, dstType, adaptor.composite(), adaptor.object(), op.indices()); return success(); } }; @@ -897,9 +899,10 @@ matchAndRewrite(SPIRVOp op, typename SPIRVOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (!op.memory_access().hasValue()) { - return replaceWithLoadOrStore( - op, rewriter, this->typeConverter, /*alignment=*/0, - /*isVolatile=*/false, /*isNonTemporal=*/false); + return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, + this->typeConverter, /*alignment=*/0, + /*isVolatile=*/false, + /*isNonTemporal=*/false); } auto memoryAccess = op.memory_access().getValue(); switch (memoryAccess) { @@ -911,8 +914,9 @@ memoryAccess == spirv::MemoryAccess::Aligned ? *op.alignment() : 0; bool isNonTemporal = memoryAccess == spirv::MemoryAccess::Nontemporal; bool isVolatile = memoryAccess == spirv::MemoryAccess::Volatile; - return replaceWithLoadOrStore(op, rewriter, this->typeConverter, - alignment, isVolatile, isNonTemporal); + return replaceWithLoadOrStore(op, adaptor.getOperands(), rewriter, + this->typeConverter, alignment, isVolatile, + isNonTemporal); } default: // There is no support of other memory access attributes. @@ -1178,13 +1182,13 @@ Value extended; if (isUnsignedIntegerOrVector(op2Type)) { extended = rewriter.template create(loc, dstType, - operation.operand2()); + adaptor.operand2()); } else { extended = rewriter.template create(loc, dstType, - operation.operand2()); + adaptor.operand2()); } Value result = rewriter.template create( - loc, dstType, operation.operand1(), extended); + loc, dstType, adaptor.operand1(), extended); rewriter.replaceOp(operation, result); return success(); } @@ -1268,7 +1272,7 @@ return success(); } Value allocated = rewriter.create(loc, dstType, size); - rewriter.create(loc, init, allocated); + rewriter.create(loc, adaptor.initializer(), allocated); rewriter.replaceOp(varOp, allocated); return success(); } diff --git a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp --- a/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp +++ b/mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp @@ -503,7 +503,6 @@ // For all other cases, insert the individual values individually. Type eltType; - llvm::errs() << llvmType << "\n"; if (auto arrayType = llvmType.dyn_cast()) eltType = arrayType.getElementType(); else diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -24,6 +24,9 @@ static Value sourceMaterializationCallback(OpBuilder &builder, Type type, ValueRange inputs, Location loc) { assert(inputs.size() == 1); + if (inputs[0].getType().isa()) + return nullptr; + // A detensored value is converted back by creating a new tensor from its // element(s). auto createNewTensorOp = builder.create( diff --git a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp --- a/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp +++ b/mlir/lib/Dialect/SPIRV/Transforms/DecorateCompositeTypeLayoutPass.cpp @@ -72,11 +72,29 @@ return success(); } }; + +template +class SPIRVPassThroughConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(OpT op, typename OpT::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.updateRootInPlace(op, + [&] { op->setOperands(adaptor.getOperands()); }); + return success(); + } +}; } // namespace static void populateSPIRVLayoutInfoPatterns(RewritePatternSet &patterns) { patterns.add(patterns.getContext()); + SPIRVAddressOfOpLayoutInfoDecoration, + SPIRVPassThroughConversion, + SPIRVPassThroughConversion, + SPIRVPassThroughConversion>( + patterns.getContext()); } namespace { @@ -104,8 +122,17 @@ return VulkanLayoutUtils::isLegalType(op.pointer().getType()); }); - // TODO: Change the type for the indirect users such as spv.Load, spv.Store, - // spv.FunctionCall and so on. + // Change the type for the indirect users. + target.addDynamicallyLegalOp([&](Operation *op) { + for (Value operand : op->getOperands()) { + auto addrOp = operand.getDefiningOp(); + if (addrOp && !VulkanLayoutUtils::isLegalType(addrOp.pointer().getType())) + return false; + } + return true; + }); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); for (auto spirvModule : module.getOps()) if (failed(applyFullConversion(spirvModule, target, frozenPatterns))) diff --git a/mlir/lib/IR/OperationSupport.cpp b/mlir/lib/IR/OperationSupport.cpp --- a/mlir/lib/IR/OperationSupport.cpp +++ b/mlir/lib/IR/OperationSupport.cpp @@ -555,6 +555,10 @@ //===----------------------------------------------------------------------===// // ResultRange +ResultRange::ResultRange(OpResult result) + : ResultRange(static_cast(Value(result).getImpl()), + 1) {} + ResultRange::use_range ResultRange::getUses() const { return {use_begin(), use_end()}; } @@ -605,6 +609,10 @@ use = (*it).use_begin(); } +void ResultRange::replaceAllUsesWith(Operation *op) { + replaceAllUsesWith(op->getResults()); +} + //===----------------------------------------------------------------------===// // ValueRange diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -31,7 +31,8 @@ /// regions pre-filtered to avoid considering them for legalization. static LogicalResult computeConversionSet(iterator_range region, - Location regionLoc, std::vector &toConvert, + Location regionLoc, + SmallVectorImpl &toConvert, ConversionTarget *target = nullptr) { if (llvm::empty(region)) return success(); @@ -114,16 +115,32 @@ /// Lookup a mapped value within the map, or return null if a mapping does not /// exist. If a mapping exists, this follows the same behavior of /// `lookupOrDefault`. - Value lookupOrNull(Value from) const; + Value lookupOrNull(Value from, Type desiredType = nullptr) const; /// Map a value to the one provided. - void map(Value oldVal, Value newVal) { mapping.map(oldVal, newVal); } + void map(Value oldVal, Value newVal) { + LLVM_DEBUG({ + for (Value it = newVal; it; it = mapping.lookupOrNull(it)) + assert(it != oldVal && "inserting cyclic mapping"); + }); + mapping.map(oldVal, newVal); + } + + /// Try to map a value to the one provided. Returns false if a transitive + /// mapping from the new value to the old value already exists, true if the + /// map was updated. + bool tryMap(Value oldVal, Value newVal); /// Drop the last mapping for the given value. void erase(Value value) { mapping.erase(value); } /// Returns the inverse raw value mapping (without recursive query support). - BlockAndValueMapping getInverse() const { return mapping.getInverse(); } + DenseMap> getInverse() const { + DenseMap> inverse; + for (auto &it : mapping.getValueMap()) + inverse[it.second].push_back(it.first); + return inverse; + } private: /// Current value mappings. @@ -158,9 +175,19 @@ return desiredValue ? desiredValue : from; } -Value ConversionValueMapping::lookupOrNull(Value from) const { - Value result = lookupOrDefault(from); - return result == from ? nullptr : result; +Value ConversionValueMapping::lookupOrNull(Value from, Type desiredType) const { + Value result = lookupOrDefault(from, desiredType); + if (result == from || (desiredType && result.getType() != desiredType)) + return nullptr; + return result; +} + +bool ConversionValueMapping::tryMap(Value oldVal, Value newVal) { + for (Value it = newVal; it; it = mapping.lookupOrNull(it)) + if (it == oldVal) + return false; + map(oldVal, newVal); + return true; } //===----------------------------------------------------------------------===// @@ -170,10 +197,13 @@ /// This class contains a snapshot of the current conversion rewriter state. /// This is useful when saving and undoing a set of rewrites. struct RewriterState { - RewriterState(unsigned numCreatedOps, unsigned numReplacements, - unsigned numArgReplacements, unsigned numBlockActions, - unsigned numIgnoredOperations, unsigned numRootUpdates) - : numCreatedOps(numCreatedOps), numReplacements(numReplacements), + RewriterState(unsigned numCreatedOps, unsigned numUnresolvedMaterializations, + unsigned numReplacements, unsigned numArgReplacements, + unsigned numBlockActions, unsigned numIgnoredOperations, + unsigned numRootUpdates) + : numCreatedOps(numCreatedOps), + numUnresolvedMaterializations(numUnresolvedMaterializations), + numReplacements(numReplacements), numArgReplacements(numArgReplacements), numBlockActions(numBlockActions), numIgnoredOperations(numIgnoredOperations), @@ -182,6 +212,9 @@ /// The current number of created operations. unsigned numCreatedOps; + /// The current number of unresolved materializations. + unsigned numUnresolvedMaterializations; + /// The current number of replacements queued. unsigned numReplacements; @@ -321,8 +354,103 @@ MergeInfo mergeInfo; }; }; + +//===----------------------------------------------------------------------===// +// UnresolvedMaterialization + +/// This class represents an unresolved materialization, i.e. a materialization +/// that was inserted during conversion that needs to be legalized at the end of +/// the conversion process. +class UnresolvedMaterialization { +public: + /// The type of materialization. + enum Kind { + /// This materialization materializes a conversion for an illegal block + /// argument type, to a legal one. + Argument, + + /// This materialization materializes a conversion from an illegal type to a + /// legal one. + Target + }; + + UnresolvedMaterialization(UnrealizedConversionCastOp op = nullptr, + TypeConverter *converter = nullptr, + Kind kind = Target, Type origOutputType = nullptr) + : op(op), converterAndKind(converter, kind), + origOutputType(origOutputType) {} + + /// Return the temporary conversion operation inserted for this + /// materialization. + UnrealizedConversionCastOp getOp() const { return op; } + + /// Return the type converter of this materialization (which may be null). + TypeConverter *getConverter() const { return converterAndKind.getPointer(); } + + /// Return the kind of this materialization. + Kind getKind() const { return converterAndKind.getInt(); } + + /// Set the kind of this materialization. + void setKind(Kind kind) { converterAndKind.setInt(kind); } + + /// Return the original illegal output type of the input values. + Type getOrigOutputType() const { return origOutputType; } + +private: + /// The unresolved materialization operation created during conversion. + UnrealizedConversionCastOp op; + + /// The corresponding type converter to use when resolving this + /// materialization, and the kind of this materialization. + llvm::PointerIntPair converterAndKind; + + /// The original output type. This is only used for argument conversions. + Type origOutputType; +}; } // end anonymous namespace +/// Build an unresolved materialization operation given an output type and set +/// of input operands. +static Value buildUnresolvedMaterialization( + UnresolvedMaterialization::Kind kind, Block *insertBlock, + Block::iterator insertPt, Location loc, ValueRange inputs, Type outputType, + Type origOutputType, TypeConverter *converter, + SmallVectorImpl &unresolvedMaterializations) { + // Avoid materializing an unnecessary cast. + if (inputs.size() == 1 && inputs.front().getType() == outputType) + return inputs.front(); + + // Create an unresolved materialization. We use a new OpBuilder to avoid + // tracking the materialization like we do for other operations. + OpBuilder builder(insertBlock, insertPt); + auto convertOp = + builder.create(loc, outputType, inputs); + unresolvedMaterializations.emplace_back(convertOp, converter, kind, + origOutputType); + return convertOp.getResult(0); +} +static Value buildUnresolvedArgumentMaterialization( + PatternRewriter &rewriter, Location loc, ValueRange inputs, + Type origOutputType, Type outputType, TypeConverter *converter, + SmallVectorImpl &unresolvedMaterializations) { + return buildUnresolvedMaterialization( + UnresolvedMaterialization::Argument, rewriter.getInsertionBlock(), + rewriter.getInsertionPoint(), loc, inputs, outputType, origOutputType, + converter, unresolvedMaterializations); +} +static Value buildUnresolvedTargetMaterialization( + Location loc, Value input, Type outputType, TypeConverter *converter, + SmallVectorImpl &unresolvedMaterializations) { + Block *insertBlock = input.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = input.dyn_cast()) + insertPt = ++inputRes.getOwner()->getIterator(); + + return buildUnresolvedMaterialization( + UnresolvedMaterialization::Target, insertBlock, insertPt, loc, input, + outputType, outputType, converter, unresolvedMaterializations); +} + //===----------------------------------------------------------------------===// // ArgConverter //===----------------------------------------------------------------------===// @@ -332,7 +460,11 @@ /// types and extracting the block that contains the old illegal types to allow /// for undoing pending rewrites in the case of failure. struct ArgConverter { - ArgConverter(PatternRewriter &rewriter) : rewriter(rewriter) {} + ArgConverter( + PatternRewriter &rewriter, + SmallVectorImpl &unresolvedMaterializations) + : rewriter(rewriter), + unresolvedMaterializations(unresolvedMaterializations) {} /// This structure contains the information pertaining to an argument that has /// been converted. @@ -356,8 +488,8 @@ /// This structure contains information pertaining to a block that has had its /// signature converted. struct ConvertedBlockInfo { - ConvertedBlockInfo(Block *origBlock, TypeConverter &converter) - : origBlock(origBlock), converter(&converter) {} + ConvertedBlockInfo(Block *origBlock, TypeConverter *converter) + : origBlock(origBlock), converter(converter) {} /// The original block that was requested to have its signature converted. Block *origBlock; @@ -420,7 +552,7 @@ /// block is returned containing the new arguments. Returns `block` if it did /// not require conversion. FailureOr - convertSignature(Block *block, TypeConverter &converter, + convertSignature(Block *block, TypeConverter *converter, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements); @@ -431,7 +563,7 @@ /// translate between the origin argument types and those specified in the /// signature conversion. Block *applySignatureConversion( - Block *block, TypeConverter &converter, + Block *block, TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements); @@ -456,6 +588,9 @@ /// The pattern rewriter to use when materializing conversions. PatternRewriter &rewriter; + + /// An ordered set of unresolved materializations during conversion. + SmallVectorImpl &unresolvedMaterializations; }; } // end anonymous namespace @@ -519,7 +654,7 @@ // Handle the case of a 1->0 value mapping. if (!argInfo) { - if (Value newArg = mapping.lookupOrNull(origArg)) + if (Value newArg = mapping.lookupOrNull(origArg, origArg.getType())) origArg.replaceAllUsesWith(newArg); continue; } @@ -529,8 +664,10 @@ assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); // If the argument is still used, replace it with the generated cast. - if (!origArg.use_empty()) - origArg.replaceAllUsesWith(mapping.lookupOrDefault(castValue)); + if (!origArg.use_empty()) { + origArg.replaceAllUsesWith( + mapping.lookupOrDefault(castValue, origArg.getType())); + } } } } @@ -545,31 +682,38 @@ // Process the remapping for each of the original arguments. for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) { - // FIXME: We should run the below checks even if the type conversion was - // 1->N, but a lot of existing lowering rely on the block argument being - // blindly replaced. Those usages should be updated, and this if should be - // removed. - if (blockInfo.argInfo[i]) + // FIXME: We should run the below checks even if a type converter wasn't + // provided, but a lot of existing lowering rely on the block argument + // being blindly replaced. We should rework argument materialization to be + // more robust for temporary source materializations, update existing + // patterns, and remove these checks. + if (!blockInfo.converter && blockInfo.argInfo[i]) continue; // If the type of this argument changed and the argument is still live, we // need to materialize a conversion. BlockArgument origArg = origBlock->getArgument(i); - auto argReplacementValue = mapping.lookupOrDefault(origArg); - bool isDroppedArg = argReplacementValue == origArg; - if (argReplacementValue.getType() == origArg.getType() && !isDroppedArg) + if (mapping.lookupOrNull(origArg, origArg.getType())) continue; Operation *liveUser = findLiveUser(origArg); if (!liveUser) continue; - if (OpResult result = argReplacementValue.dyn_cast()) - rewriter.setInsertionPointAfter(result.getOwner()); - else + Value replacementValue = mapping.lookupOrDefault(origArg); + bool isDroppedArg = replacementValue == origArg; + if (isDroppedArg) rewriter.setInsertionPointToStart(newBlock); - Value newArg = blockInfo.converter->materializeSourceConversion( - rewriter, origArg.getLoc(), origArg.getType(), - isDroppedArg ? ValueRange() : ValueRange(argReplacementValue)); + else + rewriter.setInsertionPointAfterValue(replacementValue); + Value newArg; + if (blockInfo.converter) { + newArg = blockInfo.converter->materializeSourceConversion( + rewriter, origArg.getLoc(), origArg.getType(), + isDroppedArg ? ValueRange() : ValueRange(replacementValue)); + assert((!newArg || newArg.getType() == origArg.getType()) && + "materialization hook did not provide a value of the expected " + "type"); + } if (!newArg) { InFlightDiagnostic diag = emitError(origArg.getLoc()) @@ -577,7 +721,7 @@ << " that remained live after conversion, type was " << origArg.getType(); if (!isDroppedArg) - diag << ", with target type " << argReplacementValue.getType(); + diag << ", with target type " << replacementValue.getType(); diag.attachNote(liveUser->getLoc()) << "see existing live user here: " << *liveUser; return failure(); @@ -592,22 +736,26 @@ // Conversion FailureOr ArgConverter::convertSignature( - Block *block, TypeConverter &converter, ConversionValueMapping &mapping, + Block *block, TypeConverter *converter, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements) { // Check if the block was already converted. If the block is detached, // conservatively assume it is going to be deleted. if (hasBeenConverted(block) || !block->getParent()) return block; + // If a converter wasn't provided, and the block wasn't already converted, + // there is nothing we can do. + if (!converter) + return failure(); // Try to convert the signature for the block with the provided converter. - if (auto conversion = converter.convertBlockSignature(block)) + if (auto conversion = converter->convertBlockSignature(block)) return applySignatureConversion(block, converter, *conversion, mapping, argReplacements); return failure(); } Block *ArgConverter::applySignatureConversion( - Block *block, TypeConverter &converter, + Block *block, TypeConverter *converter, TypeConverter::SignatureConversion &signatureConversion, ConversionValueMapping &mapping, SmallVectorImpl &argReplacements) { @@ -649,26 +797,35 @@ continue; } - // Otherwise, this is a 1->1+ mapping. Call into the provided type converter - // to pack the new values. For 1->1 mappings, if there is no materialization - // provided, use the argument directly instead. + // Otherwise, this is a 1->1+ mapping. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); Value newArg; // If this is a 1->1 mapping and the types of new and replacement arguments // match (i.e. it's an identity map), then the argument is mapped to its // original type. - if (replArgs.size() == 1 && replArgs[0].getType() == origArg.getType()) + // FIXME: We simply pass through the replacement argument if there wasn't a + // converter, which isn't great as it allows implicit type conversions to + // appear. We should properly restructure this code to handle cases where a + // converter isn't provided and also to properly handle the case where an + // argument materialization is actually a temporary source materialization + // (e.g. in the case of 1->N). + if (replArgs.size() == 1 && + (!converter || replArgs[0].getType() == origArg.getType())) { newArg = replArgs.front(); - else - newArg = converter.materializeArgumentConversion( - rewriter, origArg.getLoc(), origArg.getType(), replArgs); + } else { + Type origOutputType = origArg.getType(); - if (!newArg) { - assert(replArgs.size() == 1 && - "couldn't materialize the result of 1->N conversion"); - newArg = replArgs.front(); + // Legalize the argument output type. + Type outputType = origOutputType; + if (Type legalOutputType = converter->convertType(outputType)) + outputType = legalOutputType; + + newArg = buildUnresolvedArgumentMaterialization( + rewriter, origArg.getLoc(), replArgs, origOutputType, outputType, + converter, unresolvedMaterializations); } + mapping.map(origArg, newArg); argReplacements.push_back(origArg); info.argInfo[i] = @@ -702,7 +859,7 @@ namespace detail { struct ConversionPatternRewriterImpl { ConversionPatternRewriterImpl(PatternRewriter &rewriter) - : argConverter(rewriter) {} + : argConverter(rewriter, unresolvedMaterializations) {} /// Cleanup and destroy any generated rewrite operations. This method is /// invoked when the conversion process fails. @@ -730,13 +887,12 @@ /// "numActionsToKeep" actions remains. void undoBlockActions(unsigned numActionsToKeep = 0); - /// Remap the given operands to those with potentially different types. The - /// provided type converter is used to ensure that the remapped types are - /// legal. Returns success if the operands could be remapped, failure - /// otherwise. - LogicalResult remapValues(Location loc, PatternRewriter &rewriter, - TypeConverter *converter, - Operation::operand_range operands, + /// Remap the given values to those with potentially different types. Returns + /// success if the values could be remapped, failure otherwise. `valueDiagTag` + /// is the tag used when describing a value within a diagnostic, e.g. + /// "operand". + LogicalResult remapValues(StringRef valueDiagTag, Optional inputLoc, + PatternRewriter &rewriter, ValueRange values, SmallVectorImpl &remapped); /// Returns true if the given operation is ignored, and does not need to be @@ -753,7 +909,7 @@ /// Convert the signature of the given block. FailureOr convertBlockSignature( - Block *block, TypeConverter &converter, + Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion = nullptr); /// Apply a signature conversion on the given region, using `converter` for @@ -817,7 +973,11 @@ ArgConverter argConverter; /// Ordered vector of all of the newly created operations during conversion. - std::vector createdOps; + SmallVector createdOps; + + /// Ordered vector of all unresolved type conversion materializations during + /// conversion. + SmallVector unresolvedMaterializations; /// Ordered map of requested operation replacements. llvm::MapVector replacements; @@ -847,10 +1007,6 @@ /// 1->N conversion of some kind. SmallVector operationsWithChangedResults; - /// A default type converter, used when block conversions do not have one - /// explicitly provided. - TypeConverter defaultTypeConverter; - /// The current type converter, or nullptr if no type converter is currently /// active. TypeConverter *currentTypeConverter = nullptr; @@ -896,6 +1052,8 @@ undoBlockActions(); // Remove any newly created ops. + for (UnresolvedMaterialization &materialization : unresolvedMaterializations) + detachNestedAndErase(materialization.getOp()); for (auto *op : llvm::reverse(createdOps)) detachNestedAndErase(op); } @@ -904,7 +1062,7 @@ // Apply all of the rewrites replacements requested during conversion. for (auto &repl : replacements) { for (OpResult result : repl.first->getResults()) - if (Value newValue = mapping.lookupOrNull(result)) + if (Value newValue = mapping.lookupOrNull(result, result.getType())) result.replaceAllUsesWith(newValue); // If this operation defines any regions, drop any pending argument @@ -915,7 +1073,10 @@ // Apply all of the requested argument replacements. for (BlockArgument arg : argReplacements) { - Value repl = mapping.lookupOrDefault(arg); + Value repl = mapping.lookupOrNull(arg, arg.getType()); + if (!repl) + continue; + if (repl.isa()) { arg.replaceAllUsesWith(repl); continue; @@ -932,6 +1093,13 @@ }); } + // Drop all of the unresolved materialization operations created during + // conversion. + for (auto &mat : unresolvedMaterializations) { + mat.getOp()->dropAllUses(); + mat.getOp()->erase(); + } + // In a second pass, erase all of the replaced operations in reverse. This // allows processing nested operations before their parent region is // destroyed. Because we process in reverse order, producers may be deleted @@ -952,9 +1120,10 @@ // State Management RewriterState ConversionPatternRewriterImpl::getCurrentState() { - return RewriterState(createdOps.size(), replacements.size(), - argReplacements.size(), blockActions.size(), - ignoredOps.size(), rootUpdates.size()); + return RewriterState(createdOps.size(), unresolvedMaterializations.size(), + replacements.size(), argReplacements.size(), + blockActions.size(), ignoredOps.size(), + rootUpdates.size()); } void ConversionPatternRewriterImpl::resetState(RewriterState state) { @@ -979,6 +1148,20 @@ while (replacements.size() != state.numReplacements) replacements.pop_back(); + // Pop all of the newly inserted materializations. + while (unresolvedMaterializations.size() != + state.numUnresolvedMaterializations) { + UnresolvedMaterialization mat = unresolvedMaterializations.pop_back_val(); + UnrealizedConversionCastOp op = mat.getOp(); + + // If this was a target materialization, drop the mapping that was inserted. + if (mat.getKind() == UnresolvedMaterialization::Target) { + for (Value input : op->getOperands()) + mapping.erase(input); + } + detachNestedAndErase(op); + } + // Pop all of the newly created operations. while (createdOps.size() != state.numCreatedOps) { detachNestedAndErase(createdOps.back()); @@ -1070,25 +1253,27 @@ } LogicalResult ConversionPatternRewriterImpl::remapValues( - Location loc, PatternRewriter &rewriter, TypeConverter *converter, - Operation::operand_range operands, SmallVectorImpl &remapped) { - remapped.reserve(llvm::size(operands)); + StringRef valueDiagTag, Optional inputLoc, + PatternRewriter &rewriter, ValueRange values, + SmallVectorImpl &remapped) { + remapped.reserve(llvm::size(values)); SmallVector legalTypes; - for (auto it : llvm::enumerate(operands)) { + for (auto it : llvm::enumerate(values)) { Value operand = it.value(); Type origType = operand.getType(); // If a converter was provided, get the desired legal types for this // operand. Type desiredType; - if (converter) { + if (currentTypeConverter) { // If there is no legal conversion, fail to match this pattern. legalTypes.clear(); - if (failed(converter->convertType(origType, legalTypes))) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to convert type for operand #" << it.index() - << ", type was " << origType; + if (failed(currentTypeConverter->convertType(origType, legalTypes))) { + Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); + return notifyMatchFailure(operandLoc, [=](Diagnostic &diag) { + diag << "unable to convert type for " << valueDiagTag << " #" + << it.index() << ", type was " << origType; }); } // TODO: There currently isn't any mechanism to do 1->N type conversion @@ -1108,18 +1293,13 @@ // Handle the case where the conversion was 1->1 and the new operand type // isn't legal. Type newOperandType = newOperand.getType(); - if (converter && desiredType && newOperandType != desiredType) { - // Attempt to materialize a conversion for this new value. - newOperand = converter->materializeTargetConversion( - rewriter, loc, desiredType, newOperand); - if (!newOperand) { - return notifyMatchFailure(loc, [=](Diagnostic &diag) { - diag << "unable to materialize a conversion for " - "operand #" - << it.index() << ", from " << newOperandType << " to " - << desiredType; - }); - } + if (currentTypeConverter && desiredType && newOperandType != desiredType) { + Location operandLoc = inputLoc ? *inputLoc : operand.getLoc(); + Value castValue = buildUnresolvedTargetMaterialization( + operandLoc, newOperand, desiredType, currentTypeConverter, + unresolvedMaterializations); + mapping.map(mapping.lookupOrDefault(newOperand), castValue); + newOperand = castValue; } remapped.push_back(newOperand); } @@ -1148,7 +1328,7 @@ // Type Conversion FailureOr ConversionPatternRewriterImpl::convertBlockSignature( - Block *block, TypeConverter &converter, + Block *block, TypeConverter *converter, TypeConverter::SignatureConversion *conversion) { FailureOr result = conversion ? argConverter.applySignatureConversion( @@ -1167,11 +1347,8 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion( Region *region, TypeConverter::SignatureConversion &conversion, TypeConverter *converter) { - if (!region->empty()) { - return *convertBlockSignature(®ion->front(), - converter ? *converter : defaultTypeConverter, - &conversion); - } + if (!region->empty()) + return *convertBlockSignature(®ion->front(), converter, &conversion); return nullptr; } @@ -1186,7 +1363,7 @@ return failure(); FailureOr newEntry = - convertBlockSignature(®ion->front(), converter, entryConversion); + convertBlockSignature(®ion->front(), &converter, entryConversion); return newEntry; } @@ -1212,7 +1389,7 @@ : const_cast( &blockConversions[blockIdx++]); - if (failed(convertBlockSignature(&block, converter, blockConversion))) + if (failed(convertBlockSignature(&block, &converter, blockConversion))) return failure(); } return success(); @@ -1393,7 +1570,20 @@ } Value ConversionPatternRewriter::getRemappedValue(Value key) { - return impl->mapping.lookupOrDefault(key); + SmallVector remappedValues; + if (failed(impl->remapValues("value", /*inputLoc=*/llvm::None, *this, key, + remappedValues))) + return nullptr; + return remappedValues.front(); +} + +LogicalResult +ConversionPatternRewriter::getRemappedValues(ValueRange keys, + SmallVectorImpl &results) { + if (keys.empty()) + return success(); + return impl->remapValues("value", /*inputLoc=*/llvm::None, *this, keys, + results); } void ConversionPatternRewriter::notifyBlockCreated(Block *block) { @@ -1505,9 +1695,8 @@ // Remap the operands of the operation. SmallVector operands; - if (failed(rewriterImpl.remapValues(op->getLoc(), rewriter, - getTypeConverter(), op->getOperands(), - operands))) { + if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter, + op->getOperands(), operands))) { return failure(); } return matchAndRewrite(op, operands, dialectRewriter); @@ -1800,7 +1989,7 @@ auto &os = rewriter.getImpl().logger; os.getOStream() << "\n"; os.startLine() << "* Pattern : '" << op->getName() << " -> ("; - llvm::interleaveComma(pattern.getGeneratedOps(), llvm::dbgs()); + llvm::interleaveComma(pattern.getGeneratedOps(), os.getOStream()); os.getOStream() << ")' {\n"; os.indent(); }); @@ -1879,7 +2068,7 @@ // directly. if (auto *converter = impl.argConverter.getConverter(action.block->getParent())) { - if (failed(impl.convertBlockSignature(action.block, *converter))) { + if (failed(impl.convertBlockSignature(action.block, converter))) { LLVM_DEBUG(logFailure(impl.logger, "failed to convert types of moved " "block")); return failure(); @@ -2088,7 +2277,7 @@ SmallVector, 4> patternsByDepth; patternsByDepth.reserve(patterns.size()); for (const Pattern *pattern : patterns) { - unsigned depth = 0; + unsigned depth = 1; for (auto generatedOp : pattern->getGeneratedOps()) { unsigned generatedOpDepth = computeOpLegalizationDepth( generatedOp, minOpPatternDepth, legalizerPatterns); @@ -2173,6 +2362,12 @@ legalizeConvertedArgumentTypes(ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl); + /// Legalize any unresolved type materializations. + LogicalResult legalizeUnresolvedMaterializations( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + Optional>> &inverseMapping); + /// Legalize an operation result that was marked as "erased". LogicalResult legalizeErasedResult(Operation *op, OpResult result, @@ -2180,12 +2375,11 @@ /// Legalize an operation result that was replaced with a value of a different /// type. - LogicalResult - legalizeChangedResultType(Operation *op, OpResult result, Value newValue, - TypeConverter *replConverter, - ConversionPatternRewriter &rewriter, - ConversionPatternRewriterImpl &rewriterImpl, - const BlockAndValueMapping &inverseMapping); + LogicalResult legalizeChangedResultType( + Operation *op, OpResult result, Value newValue, + TypeConverter *replConverter, ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + const DenseMap> &inverseMapping); /// The legalizer to use when converting operations. OperationLegalizer opLegalizer; @@ -2236,7 +2430,7 @@ ConversionTarget &target = opLegalizer.getTarget(); // Compute the set of operations and blocks to convert. - std::vector toConvert; + SmallVector toConvert; for (auto *op : ops) { toConvert.emplace_back(op); for (auto ®ion : op->getRegions()) @@ -2277,17 +2471,16 @@ LogicalResult OperationConverter::finalize(ConversionPatternRewriter &rewriter) { + Optional>> inverseMapping; ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl(); - - // Legalize converted block arguments. - if (failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) + if (failed(legalizeUnresolvedMaterializations(rewriter, rewriterImpl, + inverseMapping)) || + failed(legalizeConvertedArgumentTypes(rewriter, rewriterImpl))) return failure(); if (rewriterImpl.operationsWithChangedResults.empty()) return success(); - Optional inverseMapping; - // Process requested operation replacements. for (unsigned i = 0, e = rewriterImpl.operationsWithChangedResults.size(); i != e; ++i) { @@ -2338,22 +2531,290 @@ }); return liveUserIt == val.user_end() ? nullptr : *liveUserIt; }; + return rewriterImpl.argConverter.materializeLiveConversions( + rewriterImpl.mapping, rewriter, findLiveUser); +} + +/// Replace the results of a materialization operation with the given values. +static void +replaceMaterialization(ConversionPatternRewriterImpl &rewriterImpl, + ResultRange matResults, ValueRange values, + DenseMap> &inverseMapping) { + matResults.replaceAllUsesWith(values); + + // For each of the materialization results, update the inverse mappings to + // point to the replacement values. + for (auto it : llvm::zip(matResults, values)) { + Value matResult, newValue; + std::tie(matResult, newValue) = it; + auto inverseMapIt = inverseMapping.find(matResult); + if (inverseMapIt == inverseMapping.end()) + continue; - // Materialize any necessary conversions for converted block arguments that - // are still live. - size_t numCreatedOps = rewriterImpl.createdOps.size(); - if (failed(rewriterImpl.argConverter.materializeLiveConversions( - rewriterImpl.mapping, rewriter, findLiveUser))) - return failure(); + // Update the reverse mapping, or remove the mapping if we couldn't update + // it. Not being able to update signals that the mapping would have become + // circular (i.e. %foo -> newValue -> %foo), which may occur as values are + // propagated through temporary materializations. We simply drop the + // mapping, and let the post-conversion replacement logic handle updating + // uses. + for (Value inverseMapVal : inverseMapIt->second) + if (!rewriterImpl.mapping.tryMap(inverseMapVal, newValue)) + rewriterImpl.mapping.erase(inverseMapVal); + } +} - // Legalize any newly created operations during argument materialization. - for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { - if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { - return rewriterImpl.createdOps[i]->emitError() - << "failed to legalize conversion operation generated for block " - "argument that remained live after conversion"; +/// Compute all of the unresolved materializations that will persist beyond the +/// conversion process, and require inserting a proper user materialization for. +static void computeNecessaryMaterializations( + DenseMap &materializationOps, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + DenseMap> &inverseMapping, + SetVector &necessaryMaterializations) { + auto isLive = [&](Value value) { + auto findFn = [&](Operation *user) { + auto matIt = materializationOps.find(user); + if (matIt != materializationOps.end()) + return !necessaryMaterializations.count(matIt->second); + return rewriterImpl.isOpIgnored(user); + }; + return llvm::find_if_not(value.getUsers(), findFn) != value.user_end(); + }; + + llvm::unique_function lookupRemappedValue = + [&](Value invalidRoot, Value value, Type type) { + // Check to see if the input operation was remapped to a variant of the + // output. + Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); + if (remappedValue.getType() == type && remappedValue != invalidRoot) + return remappedValue; + + // Check to see if the input is a materialization operation that + // provides an inverse conversion. We just check blindly for + // UnrealizedConversionCastOp here, but it has no effect on correctness. + auto inputCastOp = value.getDefiningOp(); + if (inputCastOp && inputCastOp->getNumOperands() == 1) + return lookupRemappedValue(invalidRoot, inputCastOp->getOperand(0), + type); + + return Value(); + }; + + SetVector worklist; + for (auto &mat : rewriterImpl.unresolvedMaterializations) { + materializationOps.try_emplace(mat.getOp(), &mat); + worklist.insert(&mat); + } + while (!worklist.empty()) { + UnresolvedMaterialization *mat = worklist.pop_back_val(); + UnrealizedConversionCastOp op = mat->getOp(); + + // We currently only handle target materializations here. + assert(op->getNumResults() == 1 && "unexpected materialization type"); + OpResult opResult = op->getOpResult(0); + Type outputType = opResult.getType(); + Operation::operand_range inputOperands = op.getOperands(); + + // Try to forward propagate operands for user conversion casts that result + // in the input types of the current cast. + for (Operation *user : llvm::make_early_inc_range(opResult.getUsers())) { + auto castOp = dyn_cast(user); + if (!castOp) + continue; + if (castOp->getResultTypes() == inputOperands.getTypes()) { + replaceMaterialization(rewriterImpl, opResult, inputOperands, + inverseMapping); + necessaryMaterializations.remove(materializationOps.lookup(user)); + } + } + + // Try to avoid materializing a resolved materialization if possible. + // Handle the case of a 1-1 materialization. + if (inputOperands.size() == 1) { + // Check to see if the input operation was remapped to a variant of the + // output. + Value remappedValue = + lookupRemappedValue(opResult, inputOperands[0], outputType); + if (remappedValue && remappedValue != opResult) { + replaceMaterialization(rewriterImpl, opResult, remappedValue, + inverseMapping); + necessaryMaterializations.remove(mat); + continue; + } + } else { + // TODO: Avoid materializing other types of conversions here. + } + + // Check to see if this is an argument materialization. + auto isBlockArg = [](Value v) { return v.isa(); }; + if (llvm::any_of(op->getOperands(), isBlockArg) || + llvm::any_of(inverseMapping[op->getResult(0)], isBlockArg)) { + mat->setKind(UnresolvedMaterialization::Argument); + } + + // If the materialization does not have any live users, we don't need to + // generate a user materialization for it. + // FIXME: For argument materializations, we currently need to check if any + // of the inverse mapped values are used because some patterns expect blind + // value replacement even if the types differ in some cases. When those + // patterns are fixed, we can drop the argument special case here. + bool isMaterializationLive = isLive(opResult); + if (mat->getKind() == UnresolvedMaterialization::Argument) + isMaterializationLive |= llvm::any_of(inverseMapping[opResult], isLive); + if (!isMaterializationLive) + continue; + if (!necessaryMaterializations.insert(mat)) + continue; + + // Reprocess input materializations to see if they have an updated status. + for (Value input : inputOperands) { + if (auto parentOp = input.getDefiningOp()) { + if (auto *mat = materializationOps.lookup(parentOp)) + worklist.insert(mat); + } + } + } +} + +/// Legalize the given unresolved materialization. Returns success if the +/// materialization was legalized, failure otherise. +static LogicalResult legalizeUnresolvedMaterialization( + UnresolvedMaterialization &mat, + DenseMap &materializationOps, + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + DenseMap> &inverseMapping) { + auto findLiveUser = [&](auto &&users) { + auto liveUserIt = llvm::find_if_not( + users, [&](Operation *user) { return rewriterImpl.isOpIgnored(user); }); + return liveUserIt == users.end() ? nullptr : *liveUserIt; + }; + + llvm::unique_function lookupRemappedValue = + [&](Value value, Type type) { + // Check to see if the input operation was remapped to a variant of the + // output. + Value remappedValue = rewriterImpl.mapping.lookupOrDefault(value, type); + if (remappedValue.getType() == type) + return remappedValue; + return Value(); + }; + + UnrealizedConversionCastOp op = mat.getOp(); + if (!rewriterImpl.ignoredOps.insert(op)) + return success(); + + // We currently only handle target materializations here. + OpResult opResult = op->getOpResult(0); + Operation::operand_range inputOperands = op.getOperands(); + Type outputType = opResult.getType(); + + // If any input to this materialization is another materialization, resolve + // the input first. + for (Value value : op->getOperands()) { + auto valueCast = value.getDefiningOp(); + if (!valueCast) + continue; + + auto matIt = materializationOps.find(valueCast); + if (matIt != materializationOps.end()) + if (failed(legalizeUnresolvedMaterialization( + *matIt->second, materializationOps, rewriter, rewriterImpl, + inverseMapping))) + return failure(); + } + + // Perform a last ditch attempt to avoid materializing a resolved + // materialization if possible. + // Handle the case of a 1-1 materialization. + if (inputOperands.size() == 1) { + // Check to see if the input operation was remapped to a variant of the + // output. + Value remappedValue = lookupRemappedValue(inputOperands[0], outputType); + if (remappedValue && remappedValue != opResult) { + replaceMaterialization(rewriterImpl, opResult, remappedValue, + inverseMapping); + return success(); + } + } else { + // TODO: Avoid materializing other types of conversions here. + } + + // Try to materialize the conversion. + if (TypeConverter *converter = mat.getConverter()) { + // FIXME: Determine a suitable insertion location when there are multiple + // inputs. + if (inputOperands.size() == 1) + rewriter.setInsertionPointAfterValue(inputOperands.front()); + else + rewriter.setInsertionPoint(op); + + Value newMaterialization; + switch (mat.getKind()) { + case UnresolvedMaterialization::Argument: + // Try to materialize an argument conversion. + // FIXME: The current argument materialization hook expects the original + // output type, even though it doesn't use that as the actual output type + // of the generated IR. The output type is just used as an indicator of + // the type of materialization to do. This behavior is really awkward in + // that it diverges from the behavior of the other hooks, and can be + // easily misunderstood. We should clean up the argument hooks to better + // represent the desired invariants we actually care about. + newMaterialization = converter->materializeArgumentConversion( + rewriter, op->getLoc(), mat.getOrigOutputType(), inputOperands); + if (newMaterialization) + break; + + // If an argument materialization failed, fallback to trying a target + // materialization. + LLVM_FALLTHROUGH; + case UnresolvedMaterialization::Target: + newMaterialization = converter->materializeTargetConversion( + rewriter, op->getLoc(), outputType, inputOperands); + break; + default: + llvm_unreachable("unknown materialization kind"); + } + if (newMaterialization) { + replaceMaterialization(rewriterImpl, opResult, newMaterialization, + inverseMapping); + return success(); } } + + InFlightDiagnostic diag = op->emitError() + << "failed to legalize unresolved materialization " + "from " + << inputOperands.getTypes() << " to " << outputType + << " that remained live after conversion"; + if (Operation *liveUser = findLiveUser(op->getUsers())) { + diag.attachNote(liveUser->getLoc()) + << "see existing live user here: " << *liveUser; + } + return failure(); +} + +LogicalResult OperationConverter::legalizeUnresolvedMaterializations( + ConversionPatternRewriter &rewriter, + ConversionPatternRewriterImpl &rewriterImpl, + Optional>> &inverseMapping) { + if (rewriterImpl.unresolvedMaterializations.empty()) + return success(); + inverseMapping = rewriterImpl.mapping.getInverse(); + + // As an initial step, compute all of the inserted materializations that we + // expect to persist beyond the conversion process. + DenseMap materializationOps; + SetVector necessaryMaterializations; + computeNecessaryMaterializations(materializationOps, rewriter, rewriterImpl, + *inverseMapping, necessaryMaterializations); + + // Once computed, legalize any necessary materializations. + for (auto *mat : necessaryMaterializations) { + if (failed(legalizeUnresolvedMaterialization( + *mat, materializationOps, rewriter, rewriterImpl, *inverseMapping))) + return failure(); + } return success(); } @@ -2378,10 +2839,13 @@ /// Finds a user of the given value, or of any other value that the given value /// replaced, that was not replaced in the conversion process. -static Operation * -findLiveUserOfReplaced(Value value, ConversionPatternRewriterImpl &rewriterImpl, - const BlockAndValueMapping &inverseMapping) { - do { +static Operation *findLiveUserOfReplaced( + Value initialValue, ConversionPatternRewriterImpl &rewriterImpl, + const DenseMap> &inverseMapping) { + SmallVector worklist(1, initialValue); + while (!worklist.empty()) { + Value value = worklist.pop_back_val(); + // Walk the users of this value to see if there are any live users that // weren't replaced during conversion. auto liveUserIt = llvm::find_if_not(value.getUsers(), [&](Operation *user) { @@ -2389,8 +2853,10 @@ }); if (liveUserIt != value.user_end()) return *liveUserIt; - value = inverseMapping.lookupOrNull(value); - } while (value != nullptr); + auto mapIt = inverseMapping.find(value); + if (mapIt != inverseMapping.end()) + worklist.append(mapIt->second); + } return nullptr; } @@ -2398,30 +2864,14 @@ Operation *op, OpResult result, Value newValue, TypeConverter *replConverter, ConversionPatternRewriter &rewriter, ConversionPatternRewriterImpl &rewriterImpl, - const BlockAndValueMapping &inverseMapping) { + const DenseMap> &inverseMapping) { Operation *liveUser = findLiveUserOfReplaced(result, rewriterImpl, inverseMapping); if (!liveUser) return success(); - // If the replacement has a type converter, attempt to materialize a - // conversion back to the original type. - if (!replConverter) { - // TODO: We should emit an error here, similarly to the case where the - // result is replaced with null. Unfortunately a lot of existing - // patterns rely on this behavior, so until those patterns are updated - // we keep the legacy behavior here of just forwarding the new value. - return success(); - } - - // Track the number of created operations so that new ones can be legalized. - size_t numCreatedOps = rewriterImpl.createdOps.size(); - - // Materialize a conversion for this live result value. - Type resultType = result.getType(); - Value convertedValue = replConverter->materializeSourceConversion( - rewriter, op->getLoc(), resultType, newValue); - if (!convertedValue) { + // Functor used to emit a conversion error for a failed materialization. + auto emitConversionError = [&] { InFlightDiagnostic diag = op->emitError() << "failed to materialize conversion for result #" << result.getResultNumber() << " of operation '" @@ -2430,16 +2880,19 @@ diag.attachNote(liveUser->getLoc()) << "see existing live user here: " << *liveUser; return failure(); - } + }; - // Legalize all of the newly created conversion operations. - for (int i : llvm::seq(numCreatedOps, rewriterImpl.createdOps.size())) { - if (failed(opLegalizer.legalize(rewriterImpl.createdOps[i], rewriter))) { - return op->emitError("failed to legalize conversion operation generated ") - << "for result #" << result.getResultNumber() << " of operation '" - << op->getName() << "' that remained live after conversion"; - } - } + // If the replacement has a type converter, attempt to materialize a + // conversion back to the original type. + if (!replConverter) + return emitConversionError(); + + // Materialize a conversion for this live result value. + Type resultType = result.getType(); + Value convertedValue = replConverter->materializeSourceConversion( + rewriter, op->getLoc(), resultType, newValue); + if (!convertedValue) + return emitConversionError(); rewriterImpl.mapping.map(result, convertedValue); return success(); diff --git a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir --- a/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir +++ b/mlir/test/Conversion/ArithmeticToLLVM/arith-to-llvm.mlir @@ -322,7 +322,7 @@ func @index_vector(%arg0: vector<4xindex>) { // CHECK: %[[CST:.*]] = llvm.mlir.constant(dense<[0, 1, 2, 3]> : vector<4xindex>) : vector<4xi64> %0 = arith.constant dense<[0, 1, 2, 3]> : vector<4xindex> - // CHECK: %[[V:.*]] = llvm.add %1, %[[CST]] : vector<4xi64> + // CHECK: %[[V:.*]] = llvm.add %{{.*}}, %[[CST]] : vector<4xi64> %1 = arith.addi %arg0, %0 : vector<4xindex> std.return } diff --git a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir --- a/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir +++ b/mlir/test/Conversion/ArithmeticToSPIRV/arithmetic-to-spirv.mlir @@ -161,9 +161,10 @@ spv.target_env = #spv.target_env<#spv.vce, {}> } { +// expected-error @+1 {{failed to materialize conversion for block argument #0 that remained live after conversion, type was 'vector<4xi64>', with target type 'vector<4xi32>'}} func @int_vector4_invalid(%arg0: vector<4xi64>) { // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}} - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-note @+1 {{see existing live user here}} %0 = arith.divui %arg0, %arg0: vector<4xi64> return } diff --git a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir --- a/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/ComplexToLLVM/convert-to-llvm.mlir @@ -14,8 +14,7 @@ // CHECK-SAME: (%[[CPLX:.*]]: complex) // CHECK-NEXT: %[[CAST0:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> // CHECK-NEXT: %[[REAL:.*]] = llvm.extractvalue %[[CAST0]][0] : !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[CAST1:.*]] = builtin.unrealized_conversion_cast %[[CPLX]] : complex to !llvm.struct<(f32, f32)> -// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST1]][1] : !llvm.struct<(f32, f32)> +// CHECK-NEXT: %[[IMAG:.*]] = llvm.extractvalue %[[CAST0]][1] : !llvm.struct<(f32, f32)> func @complex_extract(%cplx: complex) { %real1 = complex.re %cplx : complex %imag1 = complex.im %cplx : complex @@ -70,8 +69,8 @@ %div = complex.div %lhs, %rhs : complex return %div : complex } -// CHECK: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] -// CHECK: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] // CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] // CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] @@ -106,8 +105,8 @@ %mul = complex.mul %lhs, %rhs : complex return %mul : complex } -// CHECK: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] -// CHECK: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] +// CHECK-DAG: %[[CASTED_LHS:.*]] = builtin.unrealized_conversion_cast %[[LHS]] : complex to ![[C_TY:.*>]] +// CHECK-DAG: %[[CASTED_RHS:.*]] = builtin.unrealized_conversion_cast %[[RHS]] : complex to ![[C_TY]] // CHECK: %[[LHS_RE:.*]] = llvm.extractvalue %[[CASTED_LHS]][0] : ![[C_TY]] // CHECK: %[[LHS_IM:.*]] = llvm.extractvalue %[[CASTED_LHS]][1] : ![[C_TY]] diff --git a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/convert-dynamic-memref-ops.mlir @@ -4,8 +4,8 @@ // CHECK-LABEL: func @mixed_alloc( // CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index) func @mixed_alloc(%arg0: index, %arg1: index) -> memref { -// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] -// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] +// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] +// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] // CHECK: %[[c42:.*]] = llvm.mlir.constant(42 : index) : i64 // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[st0:.*]] = llvm.mul %[[N]], %[[c42]] : i64 @@ -46,8 +46,8 @@ // CHECK-LABEL: func @dynamic_alloc( // CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index) func @dynamic_alloc(%arg0: index, %arg1: index) -> memref { -// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] -// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] +// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] +// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] // CHECK-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[sz:.*]] = llvm.mul %[[N]], %[[M]] : i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr @@ -73,8 +73,8 @@ // CHECK-LABEL: func @dynamic_alloca // CHECK: %[[Marg:.*]]: index, %[[Narg:.*]]: index) func @dynamic_alloca(%arg0: index, %arg1: index) -> memref { -// CHECK: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] -// CHECK: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] +// CHECK-DAG: %[[M:.*]] = builtin.unrealized_conversion_cast %[[Marg]] +// CHECK-DAG: %[[N:.*]] = builtin.unrealized_conversion_cast %[[Narg]] // CHECK-NEXT: %[[st1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-NEXT: %[[num_elems:.*]] = llvm.mul %[[N]], %[[M]] : i64 // CHECK-NEXT: %[[null:.*]] = llvm.mlir.null : !llvm.ptr @@ -119,7 +119,7 @@ // CHECK-LABEL: func @stdlib_aligned_alloc({{.*}}) // ALIGNED-ALLOC-LABEL: func @stdlib_aligned_alloc({{.*}}) func @stdlib_aligned_alloc(%N : index) -> memref<32x18xf32> { -// ALIGNED-ALLOC-NEXT: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : i64 +// ALIGNED-ALLOC: %[[sz1:.*]] = llvm.mlir.constant(32 : index) : i64 // ALIGNED-ALLOC-NEXT: %[[sz2:.*]] = llvm.mlir.constant(18 : index) : i64 // ALIGNED-ALLOC-NEXT: %[[one:.*]] = llvm.mlir.constant(1 : index) : i64 // ALIGNED-ALLOC-NEXT: %[[num_elems:.*]] = llvm.mlir.constant(576 : index) : i64 @@ -148,7 +148,7 @@ %4 = memref.alloc() {alignment = 8} : memref<1024xvector<4xf32>> // Bump the memref allocation size if its size is not a multiple of alignment. // ALIGNED-ALLOC: %[[c32:.*]] = llvm.mlir.constant(32 : index) : i64 - // ALIGNED-ALLOC-NEXT: llvm.mlir.constant(1 : index) : i64 + // ALIGNED-ALLOC: llvm.mlir.constant(1 : index) : i64 // ALIGNED-ALLOC-NEXT: llvm.sub // ALIGNED-ALLOC-NEXT: llvm.add // ALIGNED-ALLOC-NEXT: llvm.urem @@ -167,8 +167,8 @@ // CHECK-LABEL: func @mixed_load( // CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index) func @mixed_load(%mixed : memref<42x?xf32>, %i : index, %j : index) { -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] -// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] +// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] +// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 @@ -184,8 +184,8 @@ // CHECK-LABEL: func @dynamic_load( // CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index) func @dynamic_load(%dynamic : memref, %i : index, %j : index) { -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] -// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] +// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] +// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 @@ -201,8 +201,8 @@ // CHECK-LABEL: func @prefetch // CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index) func @prefetch(%A : memref, %i : index, %j : index) { -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] -// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] +// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] +// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 @@ -231,8 +231,8 @@ // CHECK-LABEL: func @dynamic_store // CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index func @dynamic_store(%dynamic : memref, %i : index, %j : index, %val : f32) { -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] -// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] +// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] +// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 @@ -248,8 +248,8 @@ // CHECK-LABEL: func @mixed_store // CHECK: %{{.*}}, %[[Iarg:.*]]: index, %[[Jarg:.*]]: index func @mixed_store(%mixed : memref<42x?xf32>, %i : index, %j : index, %val : f32) { -// CHECK: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] -// CHECK: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] +// CHECK-DAG: %[[I:.*]] = builtin.unrealized_conversion_cast %[[Iarg]] +// CHECK-DAG: %[[J:.*]] = builtin.unrealized_conversion_cast %[[Jarg]] // CHECK: %[[ptr:.*]] = llvm.extractvalue %[[ld:.*]][1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[st0:.*]] = llvm.extractvalue %[[ld]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK-NEXT: %[[offI:.*]] = llvm.mul %[[I]], %[[st0]] : i64 @@ -376,12 +376,12 @@ // CHECK-LABEL: @memref_dim_with_dyn_index // CHECK: %{{.*}}, %[[IDXarg:.*]]: index func @memref_dim_with_dyn_index(%arg : memref<3x?xf32>, %idx : index) -> index { + // CHECK-DAG: %[[IDX:.*]] = builtin.unrealized_conversion_cast %[[IDXarg]] // CHECK-DAG: %[[C0:.*]] = llvm.mlir.constant(0 : index) : i64 // CHECK-DAG: %[[C1:.*]] = llvm.mlir.constant(1 : index) : i64 // CHECK-DAG: %[[SIZES:.*]] = llvm.extractvalue %{{.*}}[3] : ![[DESCR_TY:.*]] // CHECK-DAG: %[[SIZES_PTR:.*]] = llvm.alloca %[[C1]] x !llvm.array<2 x i64> : (i64) -> !llvm.ptr> // CHECK-DAG: llvm.store %[[SIZES]], %[[SIZES_PTR]] : !llvm.ptr> - // CHECK-DAG: %[[IDX:.*]] = builtin.unrealized_conversion_cast %[[IDXarg]] // CHECK-DAG: %[[RESULT_PTR:.*]] = llvm.getelementptr %[[SIZES_PTR]][%[[C0]], %[[IDX]]] : (!llvm.ptr>, i64, i64) -> !llvm.ptr // CHECK-DAG: %[[RESULT:.*]] = llvm.load %[[RESULT_PTR]] : !llvm.ptr %result = memref.dim %arg, %idx : memref<3x?xf32> @@ -433,12 +433,12 @@ // CHECK-SAME: ([[OFFSETarg:%[a-z,0-9]+]]: index, // CHECK-SAME: [[SIZE_0arg:%[a-z,0-9]+]]: index, [[SIZE_1arg:%[a-z,0-9]+]]: index, // CHECK-SAME: [[STRIDE_0arg:%[a-z,0-9]+]]: index, [[STRIDE_1arg:%[a-z,0-9]+]]: index, -// CHECK: [[INPUT:%.*]] = builtin.unrealized_conversion_cast -// CHECK: [[OFFSET:%.*]] = builtin.unrealized_conversion_cast [[OFFSETarg]] -// CHECK: [[SIZE_0:%.*]] = builtin.unrealized_conversion_cast [[SIZE_0arg]] -// CHECK: [[SIZE_1:%.*]] = builtin.unrealized_conversion_cast [[SIZE_1arg]] -// CHECK: [[STRIDE_0:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_0arg]] -// CHECK: [[STRIDE_1:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_1arg]] +// CHECK-DAG: [[OFFSET:%.*]] = builtin.unrealized_conversion_cast [[OFFSETarg]] +// CHECK-DAG: [[SIZE_0:%.*]] = builtin.unrealized_conversion_cast [[SIZE_0arg]] +// CHECK-DAG: [[SIZE_1:%.*]] = builtin.unrealized_conversion_cast [[SIZE_1arg]] +// CHECK-DAG: [[STRIDE_0:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_0arg]] +// CHECK-DAG: [[STRIDE_1:%.*]] = builtin.unrealized_conversion_cast [[STRIDE_1arg]] +// CHECK-DAG: [[INPUT:%.*]] = builtin.unrealized_conversion_cast // CHECK: [[OUT_0:%.*]] = llvm.mlir.undef : [[TY:!.*]] // CHECK: [[DESCRIPTOR:%.*]] = llvm.extractvalue [[INPUT]][1] : !llvm.struct<(i64, ptr)> // CHECK: [[BASE_PTR_PTR:%.*]] = llvm.bitcast [[DESCRIPTOR]] : !llvm.ptr to !llvm.ptr> diff --git a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir --- a/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir +++ b/mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir @@ -5,14 +5,14 @@ // CHECK-LABEL: func @view( // CHECK: %[[ARG0F:.*]]: index, %[[ARG1F:.*]]: index, %[[ARG2F:.*]]: index func @view(%arg0 : index, %arg1 : index, %arg2 : index) { + // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]] + // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F:.*]] + // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]] // CHECK: llvm.mlir.constant(2048 : index) : i64 // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %0 = memref.alloc() : memref<2048xi8> // Test two dynamic sizes. - // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]] - // CHECK: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0F:.*]] - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]] // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_PTR:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[SHIFTED_BASE_PTR:.*]] = llvm.getelementptr %[[BASE_PTR]][%[[ARG2]]] : (!llvm.ptr, i64) -> !llvm.ptr @@ -29,8 +29,6 @@ %1 = memref.view %0[%arg2][%arg0, %arg1] : memref<2048xi8> to memref // Test one dynamic size. - // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]] - // CHECK: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1F:.*]] // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_PTR_2:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[SHIFTED_BASE_PTR_2:.*]] = llvm.getelementptr %[[BASE_PTR_2]][%[[ARG2]]] : (!llvm.ptr, i64) -> !llvm.ptr @@ -48,7 +46,6 @@ %3 = memref.view %0[%arg2][%arg1] : memref<2048xi8> to memref<4x?xf32> // Test static sizes. - // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]] // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_PTR_3:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[SHIFTED_BASE_PTR_3:.*]] = llvm.getelementptr %[[BASE_PTR_3]][%[[ARG2]]] : (!llvm.ptr, i64) -> !llvm.ptr @@ -71,7 +68,6 @@ // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> %6 = memref.alloc() : memref<2048xi8, 4> - // CHECK: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2F:.*]] // CHECK: llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BASE_PTR_4:.*]] = llvm.extractvalue %{{.*}}[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> // CHECK: %[[SHIFTED_BASE_PTR_4:.*]] = llvm.getelementptr %[[BASE_PTR_4]][%[[ARG2]]] : (!llvm.ptr, i64) -> !llvm.ptr @@ -105,21 +101,13 @@ // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) func @subview(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - - // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + + // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -129,16 +117,16 @@ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64 + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64 - // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i64 - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i64 - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -148,16 +136,16 @@ // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32 + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32 - // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i32 - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i32 - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -178,21 +166,12 @@ // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) func @subview_non_zero_addrspace(%0 : memref<64x4xf32, offset: 0, strides: [4, 1], 3>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - - // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0c:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1c:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -202,16 +181,16 @@ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64 + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64 - // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i64 - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i64 - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -221,16 +200,16 @@ // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32 + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32 - // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1c]], %[[STRIDE1]] : i32 - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0c]], %[[STRIDE0]] : i32 - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : memref<64x4xf32, offset: 0, strides: [4, 1], 3> @@ -251,17 +230,12 @@ // CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index func @subview_const_size(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - - // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -271,17 +245,17 @@ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64 + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64 - // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1b]], %[[STRIDE1]] : i64 + // CHECK: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0b]], %[[STRIDE0]] : i64 + // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> @@ -292,17 +266,17 @@ // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32 + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32 - // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1b]], %[[STRIDE1]] : i32 + // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0b]], %[[STRIDE0]] : i32 + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][4, 2][%arg0, %arg1] : @@ -324,17 +298,12 @@ // CHECK32-SAME: %[[ARG1f:[a-zA-Z0-9]*]]: index // CHECK32-SAME: %[[ARG2f:[a-zA-Z0-9]*]]: index func @subview_const_stride(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - - // CHECK: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0a:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1a:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG0b:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] - // CHECK32: %[[ARG1b:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -344,16 +313,16 @@ // CHECK: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> - // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i64 + // CHECK: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i64 // CHECK: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i64 - // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i64 + // CHECK: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i64 // CHECK: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i64 // CHECK: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> + // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)> // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr @@ -363,16 +332,16 @@ // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> - // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0a]], %[[STRIDE0]] : i32 + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : i32 // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : i32 - // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1a]], %[[STRIDE1]] : i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : i32 // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : i32 // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) - // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1b]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) - // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0b]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> %1 = memref.subview %0[%arg0, %arg1][%arg0, %arg1][1, 2] : memref<64x4xf32, offset: 0, strides: [4, 1]> @@ -425,10 +394,10 @@ // CHECK32: %[[ARG1f:[a-zA-Z0-9]*]]: index, // CHECK32: %[[ARG2f:.*]]: index) func @subview_mixed_static_dynamic(%0 : memref<64x4xf32, offset: 0, strides: [4, 1]>, %arg0 : index, %arg1 : index, %arg2 : index) { - // CHECK32: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] - // CHECK32: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] - // CHECK32: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2f]] - // CHECK32: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] + // CHECK32-DAG: %[[MEMREF:.*]] = builtin.unrealized_conversion_cast %[[MEM]] + // CHECK32-DAG: %[[ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1f]] + // CHECK32-DAG: %[[ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2f]] + // CHECK32-DAG: %[[ARG0:.*]] = builtin.unrealized_conversion_cast %[[ARG0f]] // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm.struct<(ptr, ptr, i32, array<2 x i32>, array<2 x i32>)> // CHECK32: %[[BITCAST0:.*]] = llvm.bitcast %{{.*}} : !llvm.ptr to !llvm.ptr diff --git a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir --- a/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir +++ b/mlir/test/Conversion/MemRefToSPIRV/memref-to-spirv.mlir @@ -17,13 +17,13 @@ // CHECK-LABEL: @load_store_zero_rank_float func @load_store_zero_rank_float(%arg0: memref, %arg1: memref) { // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]], [[ZERO1]] // CHECK-SAME: ] : // CHECK: spv.Load "StorageBuffer" %{{.*}} : f32 %0 = memref.load %arg0[] : memref - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]], [[ZERO2]] @@ -36,13 +36,13 @@ // CHECK-LABEL: @load_store_zero_rank_int func @load_store_zero_rank_int(%arg0: memref, %arg1: memref) { // CHECK: [[ARG0:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO1:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG0]][ // CHECK-SAME: [[ZERO1]], [[ZERO1]] // CHECK-SAME: ] : // CHECK: spv.Load "StorageBuffer" %{{.*}} : i32 %0 = memref.load %arg0[] : memref - // CHECK: [[ARG1:%.*]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: [[ZERO2:%.*]] = spv.Constant 0 : i32 // CHECK: spv.AccessChain [[ARG1]][ // CHECK-SAME: [[ZERO2]], [[ZERO2]] @@ -55,10 +55,10 @@ // CHECK-LABEL: func @load_store_unknown_dim func @load_store_unknown_dim(%i: index, %source: memref, %dest: memref) { // CHECK: %[[SRC:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> + // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: %[[AC0:.+]] = spv.AccessChain %[[SRC]] // CHECK: spv.Load "StorageBuffer" %[[AC0]] %0 = memref.load %source[%i] : memref - // CHECK: %[[DST:.+]] = builtin.unrealized_conversion_cast {{.+}} : memref to !spv.ptr [0])>, StorageBuffer> // CHECK: %[[AC1:.+]] = spv.AccessChain %[[DST]] // CHECK: spv.Store "StorageBuffer" %[[AC1]] memref.store %0, %dest[%i]: memref @@ -68,8 +68,8 @@ // CHECK-LABEL: func @load_i1 // CHECK-SAME: (%[[SRC:.+]]: memref<4xi1>, %[[IDX:.+]]: index) func @load_i1(%src: memref<4xi1>, %i : index) -> i1 { - // CHECK: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> - // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK-DAG: %[[SRC_CAST:.+]] = builtin.unrealized_conversion_cast %[[SRC]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 @@ -89,8 +89,8 @@ // CHECK-SAME: %[[IDX:.+]]: index func @store_i1(%dst: memref<4xi1>, %i: index) { %true = arith.constant true - // CHECK: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> - // CHECK: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] + // CHECK-DAG: %[[DST_CAST:.+]] = builtin.unrealized_conversion_cast %[[DST]] : memref<4xi1> to !spv.ptr [0])>, StorageBuffer> + // CHECK-DAG: %[[IDX_CAST:.+]] = builtin.unrealized_conversion_cast %[[IDX]] // CHECK: %[[ZERO_0:.+]] = spv.Constant 0 : i32 // CHECK: %[[ZERO_1:.+]] = spv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 @@ -237,8 +237,8 @@ // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) func @store_i8(%arg0: memref, %value: i8) { - // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 - // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 + // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 @@ -261,9 +261,9 @@ // CHECK-LABEL: @store_i16 // CHECK: (%[[ARG0:.+]]: memref<10xi16>, %[[ARG1:.+]]: index, %[[ARG2:.+]]: i16) func @store_i16(%arg0: memref<10xi16>, %index: index, %value: i16) { - // CHECK: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32 - // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] - // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 + // CHECK-DAG: %[[ARG2_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG2]] : i16 to i32 + // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : index to i32 // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[OFFSET:.+]] = spv.Constant 0 : i32 // CHECK: %[[ONE:.+]] = spv.Constant 1 : i32 @@ -350,8 +350,8 @@ // CHECK-LABEL: @store_i8 // CHECK: (%[[ARG0:.+]]: {{.*}}, %[[ARG1:.+]]: i8) func @store_i8(%arg0: memref, %value: i8) { - // CHECK: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 - // CHECK: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK-DAG: %[[ARG1_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : i8 to i32 + // CHECK-DAG: %[[ARG0_CAST:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] // CHECK: %[[ZERO:.+]] = spv.Constant 0 : i32 // CHECK: %[[FOUR:.+]] = spv.Constant 4 : i32 // CHECK: %[[EIGHT:.+]] = spv.Constant 8 : i32 diff --git a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir @@ -6,7 +6,9 @@ omp.master { // CHECK-NEXT: ^[[BB0:.*]](%[[ARG1:.*]]: i64, %[[ARG2:.*]]: i64): ^bb0(%arg1: index, %arg2: index): - // CHECK-NEXT: "test.payload"(%[[ARG1]], %[[ARG2]]) : (i64, i64) -> () + // CHECK-DAG: %[[CAST_ARG1:.*]] = builtin.unrealized_conversion_cast %[[ARG1]] : i64 to index + // CHECK-DAG: %[[CAST_ARG2:.*]] = builtin.unrealized_conversion_cast %[[ARG2]] : i64 to index + // CHECK-NEXT: "test.payload"(%[[CAST_ARG1]], %[[CAST_ARG2]]) : (index, index) -> () "test.payload"(%arg1, %arg2) : (index, index) -> () omp.terminator } @@ -50,7 +52,9 @@ // CHECK: omp.wsloop (%[[ARG6:.*]], %[[ARG7:.*]]) : i64 = (%[[ARG0]], %[[ARG1]]) to (%[[ARG2]], %[[ARG3]]) step (%[[ARG4]], %[[ARG5]]) { "omp.wsloop"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) ( { ^bb0(%arg6: index, %arg7: index): // no predecessors - // CHECK: "test.payload"(%[[ARG6]], %[[ARG7]]) : (i64, i64) -> () + // CHECK-DAG: %[[CAST_ARG6:.*]] = builtin.unrealized_conversion_cast %[[ARG6]] : i64 to index + // CHECK-DAG: %[[CAST_ARG7:.*]] = builtin.unrealized_conversion_cast %[[ARG7]] : i64 to index + // CHECK: "test.payload"(%[[CAST_ARG6]], %[[CAST_ARG7]]) : (index, index) -> () "test.payload"(%arg6, %arg7) : (index, index) -> () omp.yield }) {operand_segment_sizes = dense<[2, 2, 2, 0, 0, 0, 0, 0, 0, 0]> : vector<10xi32>} : (index, index, index, index, index, index) -> () diff --git a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir --- a/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir +++ b/mlir/test/Conversion/StandardToLLVM/calling-convention.mlir @@ -148,11 +148,13 @@ // Match the construction of the unranked descriptor. // CHECK: %[[ALLOCA:.*]] = llvm.alloca // CHECK: %[[MEMORY:.*]] = llvm.bitcast %[[ALLOCA]] + // CHECK: %[[RANK:.*]] = llvm.mlir.constant(2 : i64) // CHECK: %[[DESC_0:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> - // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC_0]][0] + // CHECK: %[[DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[DESC_0]][0] // CHECK: %[[DESC_2:.*]] = llvm.insertvalue %[[MEMORY]], %[[DESC_1]][1] %0 = memref.cast %arg0: memref<4x3xf32> to memref<*xf32> + // CHECK: %[[ONE:.*]] = llvm.mlir.constant(1 : index) // CHECK: %[[TWO:.*]] = llvm.mlir.constant(2 : index) // These sizes may depend on the data layout, not matching specific values. @@ -160,17 +162,14 @@ // CHECK: %[[IDX_SIZE:.*]] = llvm.mlir.constant // CHECK: %[[DOUBLE_PTR_SIZE:.*]] = llvm.mul %[[TWO]], %[[PTR_SIZE]] - // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm.struct<(i64, ptr)> // CHECK: %[[DOUBLE_RANK:.*]] = llvm.mul %[[TWO]], %[[RANK]] // CHECK: %[[DOUBLE_RANK_INC:.*]] = llvm.add %[[DOUBLE_RANK]], %[[ONE]] // CHECK: %[[TABLES_SIZE:.*]] = llvm.mul %[[DOUBLE_RANK_INC]], %[[IDX_SIZE]] // CHECK: %[[ALLOC_SIZE:.*]] = llvm.add %[[DOUBLE_PTR_SIZE]], %[[TABLES_SIZE]] // CHECK: %[[FALSE:.*]] = llvm.mlir.constant(false) // CHECK: %[[ALLOCATED:.*]] = llvm.call @malloc(%[[ALLOC_SIZE]]) - // CHECK: %[[SOURCE:.*]] = llvm.extractvalue %[[DESC_2]][1] - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[SOURCE]], %[[ALLOC_SIZE]], %[[FALSE]]) + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED]], %[[MEMORY]], %[[ALLOC_SIZE]], %[[FALSE]]) // CHECK: %[[NEW_DESC:.*]] = llvm.mlir.undef : !llvm.struct<(i64, ptr)> - // CHECK: %[[RANK:.*]] = llvm.extractvalue %[[DESC_2]][0] : !llvm.struct<(i64, ptr)> // CHECK: %[[NEW_DESC_1:.*]] = llvm.insertvalue %[[RANK]], %[[NEW_DESC]][0] // CHECK: %[[NEW_DESC_2:.*]] = llvm.insertvalue %[[ALLOCATED]], %[[NEW_DESC_1]][1] // CHECK: llvm.return %[[NEW_DESC_2]] @@ -224,15 +223,13 @@ // convention requires the caller to free them and the caller cannot know // whether they are the same value or not. // CHECK: %[[ALLOCATED_1:.*]] = llvm.call @malloc(%{{.*}}) - // CHECK: %[[SOURCE_1:.*]] = llvm.extractvalue %[[DESC_2]][1] - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[SOURCE_1]], %{{.*}}, %[[FALSE:.*]]) + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_1]], %[[MEMORY]], %{{.*}}, %[[FALSE:.*]]) // CHECK: %[[RES_1:.*]] = llvm.mlir.undef // CHECK: %[[RES_11:.*]] = llvm.insertvalue %{{.*}}, %[[RES_1]][0] // CHECK: %[[RES_12:.*]] = llvm.insertvalue %[[ALLOCATED_1]], %[[RES_11]][1] // CHECK: %[[ALLOCATED_2:.*]] = llvm.call @malloc(%{{.*}}) - // CHECK: %[[SOURCE_2:.*]] = llvm.extractvalue %[[DESC_2]][1] - // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[SOURCE_2]], %{{.*}}, %[[FALSE]]) + // CHECK: "llvm.intr.memcpy"(%[[ALLOCATED_2]], %[[MEMORY]], %{{.*}}, %[[FALSE]]) // CHECK: %[[RES_2:.*]] = llvm.mlir.undef // CHECK: %[[RES_21:.*]] = llvm.insertvalue %{{.*}}, %[[RES_2]][0] // CHECK: %[[RES_22:.*]] = llvm.insertvalue %[[ALLOCATED_2]], %[[RES_21]][1] diff --git a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir --- a/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir +++ b/mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir @@ -173,9 +173,10 @@ spv.target_env = #spv.target_env<#spv.vce, {}> } { +// expected-error@below {{failed to materialize conversion for block argument #0 that remained live after conversion}} func @int_vector4_invalid(%arg0: vector<4xi64>) { - // expected-error @+2 {{bitwidth emulation is not implemented yet on unsigned op}} - // expected-error @+1 {{op requires the same type for all operands and results}} + // expected-error@below {{bitwidth emulation is not implemented yet on unsigned op}} + // expected-note@below {{see existing live user here}} %0 = arith.divui %arg0, %arg0: vector<4xi64> return } diff --git a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir --- a/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir +++ b/mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir @@ -111,8 +111,8 @@ } // CHECK-LABEL: @broadcast_vec2d_from_index_vec1d( // CHECK-SAME: %[[A:.*]]: vector<2xindex>) -// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex> // CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<3x2xindex> // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xindex> to !llvm.array<3 x vector<2xi64>> // CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<3 x vector<2xi64>> @@ -128,14 +128,14 @@ // CHECK-LABEL: @broadcast_vec3d_from_vec1d( // CHECK-SAME: %[[A:.*]]: vector<2xf32>) // CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> +// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> // CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> // CHECK: %[[T3:.*]] = llvm.insertvalue %[[A]], %[[T2]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T4:.*]] = llvm.insertvalue %[[A]], %[[T3]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T5:.*]] = llvm.insertvalue %[[A]], %[[T4]][2] : !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T5]], %[[T6]][0] : !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T5]], %[[T7]][1] : !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T9:.*]] = llvm.insertvalue %[[T5]], %[[T8]][2] : !llvm.array<4 x array<3 x vector<2xf32>>> @@ -152,16 +152,13 @@ } // CHECK-LABEL: @broadcast_vec3d_from_vec2d( // CHECK-SAME: %[[A:.*]]: vector<3x2xf32>) -// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> // CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T0:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T3:.*]] = llvm.insertvalue %[[T1]], %[[T2]][0] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T4]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T6:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T8]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<2xf32>>> +// CHECK: %[[T5:.*]] = llvm.insertvalue %[[T1]], %[[T3]][1] : !llvm.array<4 x array<3 x vector<2xf32>>> +// CHECK: %[[T7:.*]] = llvm.insertvalue %[[T1]], %[[T5]][2] : !llvm.array<4 x array<3 x vector<2xf32>>> +// CHECK: %[[T9:.*]] = llvm.insertvalue %[[T1]], %[[T7]][3] : !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T9]] : !llvm.array<4 x array<3 x vector<2xf32>>> to vector<4x3x2xf32> // CHECK: return %[[T10]] : vector<4x3x2xf32> @@ -187,10 +184,10 @@ } // CHECK-LABEL: @broadcast_stretch_at_start( // CHECK-SAME: %[[A:.*]]: vector<1x4xf32>) -// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>> -// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>> +// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<3x4xf32> // CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<3x4xf32> to !llvm.array<3 x vector<4xf32>> +// CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<1 x vector<4xf32>> // CHECK: %[[T5:.*]] = llvm.insertvalue %[[T3]], %[[T4]][0] : !llvm.array<3 x vector<4xf32>> // CHECK: %[[T6:.*]] = llvm.insertvalue %[[T3]], %[[T5]][1] : !llvm.array<3 x vector<4xf32>> // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T3]], %[[T6]][2] : !llvm.array<3 x vector<4xf32>> @@ -205,28 +202,25 @@ } // CHECK-LABEL: @broadcast_stretch_at_end( // CHECK-SAME: %[[A:.*]]: vector<4x1xf32>) -// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> // CHECK: %[[T2:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>> +// CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3xf32> +// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>> // CHECK: %[[T3:.*]] = llvm.extractvalue %[[T2]][0] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[T3]]{{\[}}%[[T4]] : i64] : vector<1xf32> // CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3xf32> to !llvm.array<4 x vector<3xf32>> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<4 x vector<3xf32>> -// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>> -// CHECK: %[[T10:.*]] = llvm.extractvalue %[[T9]][1] : !llvm.array<4 x vector<1xf32>> +// CHECK: %[[T10:.*]] = llvm.extractvalue %[[T2]][1] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T11:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T12:.*]] = llvm.extractelement %[[T10]]{{\[}}%[[T11]] : i64] : vector<1xf32> // CHECK: %[[T13:.*]] = splat %[[T12]] : vector<3xf32> // CHECK: %[[T14:.*]] = llvm.insertvalue %[[T13]], %[[T8]][1] : !llvm.array<4 x vector<3xf32>> -// CHECK: %[[T15:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>> -// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T15]][2] : !llvm.array<4 x vector<1xf32>> +// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T2]][2] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T17:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T18:.*]] = llvm.extractelement %[[T16]]{{\[}}%[[T17]] : i64] : vector<1xf32> // CHECK: %[[T19:.*]] = splat %[[T18]] : vector<3xf32> // CHECK: %[[T20:.*]] = llvm.insertvalue %[[T19]], %[[T14]][2] : !llvm.array<4 x vector<3xf32>> -// CHECK: %[[T21:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1xf32> to !llvm.array<4 x vector<1xf32>> -// CHECK: %[[T22:.*]] = llvm.extractvalue %[[T21]][3] : !llvm.array<4 x vector<1xf32>> +// CHECK: %[[T22:.*]] = llvm.extractvalue %[[T2]][3] : !llvm.array<4 x vector<1xf32>> // CHECK: %[[T23:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T24:.*]] = llvm.extractelement %[[T22]]{{\[}}%[[T23]] : i64] : vector<1xf32> // CHECK: %[[T25:.*]] = splat %[[T24]] : vector<3xf32> @@ -242,34 +236,28 @@ } // CHECK-LABEL: @broadcast_stretch_in_middle( // CHECK-SAME: %[[A:.*]]: vector<4x1x2xf32>) -> vector<4x3x2xf32> { +// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>> // CHECK: %[[T1:.*]] = arith.constant dense<0.000000e+00> : vector<4x3x2xf32> +// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<3x2xf32> -// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> // CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T4:.*]] = llvm.extractvalue %[[T3]][0, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> // CHECK: %[[T6:.*]] = llvm.insertvalue %[[T4]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T4]], %[[T6]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T4]], %[[T7]][2] : !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T9:.*]] = builtin.unrealized_conversion_cast %[[T1]] : vector<4x3x2xf32> to !llvm.array<4 x array<3 x vector<2xf32>>> // CHECK: %[[T10:.*]] = llvm.insertvalue %[[T8]], %[[T9]][0] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T11:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T12:.*]] = llvm.extractvalue %[[T11]][1, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T13:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T12]], %[[T13]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T12:.*]] = llvm.extractvalue %[[T3]][1, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> +// CHECK: %[[T14:.*]] = llvm.insertvalue %[[T12]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T15:.*]] = llvm.insertvalue %[[T12]], %[[T14]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T16:.*]] = llvm.insertvalue %[[T12]], %[[T15]][2] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T17:.*]] = llvm.insertvalue %[[T16]], %[[T10]][1] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T18:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T19:.*]] = llvm.extractvalue %[[T18]][2, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T20:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T21:.*]] = llvm.insertvalue %[[T19]], %[[T20]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T19:.*]] = llvm.extractvalue %[[T3]][2, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> +// CHECK: %[[T21:.*]] = llvm.insertvalue %[[T19]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T22:.*]] = llvm.insertvalue %[[T19]], %[[T21]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T23:.*]] = llvm.insertvalue %[[T19]], %[[T22]][2] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T24:.*]] = llvm.insertvalue %[[T23]], %[[T17]][2] : !llvm.array<4 x array<3 x vector<2xf32>>> -// CHECK: %[[T25:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4x1x2xf32> to !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T26:.*]] = llvm.extractvalue %[[T25]][3, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> -// CHECK: %[[T27:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<3x2xf32> to !llvm.array<3 x vector<2xf32>> -// CHECK: %[[T28:.*]] = llvm.insertvalue %[[T26]], %[[T27]][0] : !llvm.array<3 x vector<2xf32>> +// CHECK: %[[T26:.*]] = llvm.extractvalue %[[T3]][3, 0] : !llvm.array<4 x array<1 x vector<2xf32>>> +// CHECK: %[[T28:.*]] = llvm.insertvalue %[[T26]], %[[T5]][0] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T29:.*]] = llvm.insertvalue %[[T26]], %[[T28]][1] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T30:.*]] = llvm.insertvalue %[[T26]], %[[T29]][2] : !llvm.array<3 x vector<2xf32>> // CHECK: %[[T31:.*]] = llvm.insertvalue %[[T30]], %[[T24]][3] : !llvm.array<4 x array<3 x vector<2xf32>>> @@ -286,11 +274,11 @@ // CHECK-SAME: %[[A:.*]]: vector<2xf32>, // CHECK-SAME: %[[B:.*]]: vector<3xf32>) // CHECK: %[[T2:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T3:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T4:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T3]] : i64] : vector<2xf32> // CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xf32> // CHECK: %[[T6:.*]] = arith.mulf %[[T5]], %[[B]] : vector<3xf32> -// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T8:.*]] = llvm.insertvalue %[[T6]], %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T9:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[T10:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T9]] : i64] : vector<2xf32> @@ -309,15 +297,15 @@ // CHECK-LABEL: @outerproduct_index( // CHECK-SAME: %[[A:.*]]: vector<2xindex>, // CHECK-SAME: %[[B:.*]]: vector<3xindex>) -// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<2x3xindex> // CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK: %[[T0:.*]] = arith.constant dense<0> : vector<2x3xindex> +// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>> // CHECK: %[[T2:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T3:.*]] = llvm.extractelement %[[T1]]{{\[}}%[[T2]] : i64] : vector<2xi64> // CHECK: %[[T4:.*]] = builtin.unrealized_conversion_cast %[[T3]] : i64 to index // CHECK: %[[T5:.*]] = splat %[[T4]] : vector<3xindex> // CHECK: %[[T6:.*]] = arith.muli %[[T5]], %[[B]] : vector<3xindex> // CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[T6]] : vector<3xindex> to vector<3xi64> -// CHECK: %[[T8:.*]] = builtin.unrealized_conversion_cast %[[T0]] : vector<2x3xindex> to !llvm.array<2 x vector<3xi64>> // CHECK: %{{.*}} = llvm.insertvalue %[[T7]], %[[T8]][0] : !llvm.array<2 x vector<3xi64>> // ----- @@ -330,20 +318,19 @@ // CHECK-SAME: %[[A:.*]]: vector<2xf32>, // CHECK-SAME: %[[B:.*]]: vector<3xf32>, // CHECK-SAME: %[[C:.*]]: vector<2x3xf32>) -> vector<2x3xf32> +// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T3:.*]] = arith.constant dense<0.000000e+00> : vector<2x3xf32> +// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T4:.*]] = llvm.mlir.constant(0 : i64) : i64 // CHECK: %[[T5:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T4]] : i64] : vector<2xf32> // CHECK: %[[T6:.*]] = splat %[[T5]] : vector<3xf32> -// CHECK: %[[T7:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T8:.*]] = llvm.extractvalue %[[T7]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T9:.*]] = "llvm.intr.fmuladd"(%[[T6]], %[[B]], %[[T8]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> -// CHECK: %[[T10:.*]] = builtin.unrealized_conversion_cast %[[T3]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> // CHECK: %[[T11:.*]] = llvm.insertvalue %[[T9]], %[[T10]][0] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T12:.*]] = llvm.mlir.constant(1 : i64) : i64 // CHECK: %[[T13:.*]] = llvm.extractelement %[[A]]{{\[}}%[[T12]] : i64] : vector<2xf32> // CHECK: %[[T14:.*]] = splat %[[T13]] : vector<3xf32> -// CHECK: %[[T15:.*]] = builtin.unrealized_conversion_cast %[[C]] : vector<2x3xf32> to !llvm.array<2 x vector<3xf32>> -// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T15]][1] : !llvm.array<2 x vector<3xf32>> +// CHECK: %[[T16:.*]] = llvm.extractvalue %[[T7]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T17:.*]] = "llvm.intr.fmuladd"(%[[T14]], %[[B]], %[[T16]]) : (vector<3xf32>, vector<3xf32>, vector<3xf32>) -> vector<3xf32> // CHECK: %[[T18:.*]] = llvm.insertvalue %[[T17]], %[[T11]][1] : !llvm.array<2 x vector<3xf32>> // CHECK: %[[T19:.*]] = builtin.unrealized_conversion_cast %[[T18]] : !llvm.array<2 x vector<3xf32>> to vector<2x3xf32> @@ -370,8 +357,8 @@ // CHECK-LABEL: @shuffle_1D_index_direct( // CHECK-SAME: %[[A:.*]]: vector<2xindex>, // CHECK-SAME: %[[B:.*]]: vector<2xindex>) -// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64> -// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2xindex> to vector<2xi64> +// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2xindex> to vector<2xi64> +// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2xindex> to vector<2xi64> // CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [0, 1] : vector<2xi64>, vector<2xi64> // CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex> // CHECK: return %[[T3]] : vector<2xindex> @@ -417,8 +404,8 @@ // CHECK-LABEL: @shuffle_2D( // CHECK-SAME: %[[A:.*]]: vector<1x4xf32>, // CHECK-SAME: %[[B:.*]]: vector<2x4xf32>) -// CHECK: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>> -// CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> +// CHECK-DAG: %[[VAL_0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<1x4xf32> to !llvm.array<1 x vector<4xf32>> +// CHECK-DAG: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[u0:.*]] = llvm.mlir.undef : !llvm.array<3 x vector<4xf32>> // CHECK: %[[e1:.*]] = llvm.extractvalue %[[VAL_1]][0] : !llvm.array<2 x vector<4xf32>> // CHECK: %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm.array<3 x vector<4xf32>> @@ -533,8 +520,8 @@ // CHECK-LABEL: @insert_index_element_into_vec_1d( // CHECK-SAME: %[[A:.*]]: index, // CHECK-SAME: %[[B:.*]]: vector<4xindex>) -// CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64 -// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<4xindex> to vector<4xi64> +// CHECK-DAG: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : index to i64 +// CHECK-DAG: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<4xindex> to vector<4xi64> // CHECK: %[[T3:.*]] = llvm.mlir.constant(3 : i64) : i64 // CHECK: %[[T4:.*]] = llvm.insertelement %[[T0]], %[[T1]][%[[T3]] : i64] : vector<4xi64> // CHECK: %[[T5:.*]] = builtin.unrealized_conversion_cast %[[T4]] : vector<4xi64> to vector<4xindex> @@ -845,8 +832,7 @@ // CHECK-LABEL: @extract_strided_index_slice1( // CHECK-SAME: %[[A:.*]]: vector<4xindex>) // CHECK: %[[T0:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64> -// CHECK: %[[T1:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<4xindex> to vector<4xi64> -// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T1]] [2, 3] : vector<4xi64>, vector<4xi64> +// CHECK: %[[T2:.*]] = llvm.shufflevector %[[T0]], %[[T0]] [2, 3] : vector<4xi64>, vector<4xi64> // CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : vector<2xi64> to vector<2xindex> // CHECK: return %[[T3]] : vector<2xindex> @@ -875,14 +861,13 @@ } // CHECK-LABEL: @extract_strided_slice3( // CHECK-SAME: %[[ARG:.*]]: vector<4x8xf32>) +// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>> // CHECK: %[[VAL_1:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[VAL_2:.*]] = splat %[[VAL_1]] : vector<2x2xf32> -// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>> +// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> // CHECK: %[[T2:.*]] = llvm.extractvalue %[[A]][2] : !llvm.array<4 x vector<8xf32>> // CHECK: %[[T3:.*]] = llvm.shufflevector %[[T2]], %[[T2]] [2, 3] : vector<8xf32>, vector<8xf32> -// CHECK: %[[VAL_6:.*]] = builtin.unrealized_conversion_cast %[[VAL_2]] : vector<2x2xf32> to !llvm.array<2 x vector<2xf32>> // CHECK: %[[T4:.*]] = llvm.insertvalue %[[T3]], %[[VAL_6]][0] : !llvm.array<2 x vector<2xf32>> -// CHECK: %[[A:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x8xf32> to !llvm.array<4 x vector<8xf32>> // CHECK: %[[T5:.*]] = llvm.extractvalue %[[A]][3] : !llvm.array<4 x vector<8xf32>> // CHECK: %[[T6:.*]] = llvm.shufflevector %[[T5]], %[[T5]] [2, 3] : vector<8xf32>, vector<8xf32> // CHECK: %[[T7:.*]] = llvm.insertvalue %[[T6]], %[[T4]][1] : !llvm.array<2 x vector<2xf32>> @@ -918,8 +903,8 @@ // CHECK-LABEL: @insert_strided_slice2 // // Subvector vector<2xf32> @0 into vector<4xf32> @2 +// CHECK: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>> // CHECK: llvm.extractvalue {{.*}}[0] : !llvm.array<2 x vector<2xf32>> -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>> // CHECK-NEXT: llvm.extractvalue {{.*}}[2] : !llvm.array<4 x vector<4xf32>> // Element @0 -> element @2 // CHECK-NEXT: arith.constant 0 : index @@ -935,12 +920,10 @@ // CHECK-NEXT: arith.constant 3 : index // CHECK-NEXT: unrealized_conversion_cast %{{.*}} : index to i64 // CHECK-NEXT: llvm.insertelement {{.*}}, {{.*}}[{{.*}} : i64] : vector<4xf32> -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>> // CHECK-NEXT: llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm.array<4 x vector<4xf32>> // // Subvector vector<2xf32> @1 into vector<4xf32> @3 // CHECK: llvm.extractvalue {{.*}}[1] : !llvm.array<2 x vector<2xf32>> -// CHECK-NEXT: unrealized_conversion_cast %{{.*}} : vector<4x4xf32> to !llvm.array<4 x vector<4xf32>> // CHECK-NEXT: llvm.extractvalue {{.*}}[3] : !llvm.array<4 x vector<4xf32>> // Element @0 -> element @2 // CHECK-NEXT: arith.constant 0 : index @@ -968,12 +951,11 @@ // CHECK-LABEL: @insert_strided_slice3( // CHECK-SAME: %[[A:.*]]: vector<2x4xf32>, // CHECK-SAME: %[[B:.*]]: vector<16x4x8xf32>) -// CHECK: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK-DAG: %[[s2:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK-DAG: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[s3:.*]] = llvm.extractvalue %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s4:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[s5:.*]] = llvm.extractvalue %[[s4]][0] : !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s6:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s6]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK: %[[s7:.*]] = llvm.extractvalue %[[s2]][0, 0] : !llvm.array<16 x array<4 x vector<8xf32>>> // CHECK: %[[s8:.*]] = arith.constant 0 : index // CHECK: %[[s9:.*]] = builtin.unrealized_conversion_cast %[[s8]] : index to i64 // CHECK: %[[s10:.*]] = llvm.extractelement %[[s5]]{{\[}}%[[s9]] : i64] : vector<4xf32> @@ -999,10 +981,8 @@ // CHECK: %[[s30:.*]] = builtin.unrealized_conversion_cast %[[s29]] : index to i64 // CHECK: %[[s31:.*]] = llvm.insertelement %[[s28]], %[[s25]]{{\[}}%[[s30]] : i64] : vector<8xf32> // CHECK: %[[s32:.*]] = llvm.insertvalue %[[s31]], %[[s3]][0] : !llvm.array<4 x vector<8xf32>> -// CHECK: %[[s33:.*]] = builtin.unrealized_conversion_cast %[[A]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s33]][1] : !llvm.array<2 x vector<4xf32>> -// CHECK: %[[s35:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s35]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK: %[[s34:.*]] = llvm.extractvalue %[[s4]][1] : !llvm.array<2 x vector<4xf32>> +// CHECK: %[[s36:.*]] = llvm.extractvalue %[[s2]][0, 1] : !llvm.array<16 x array<4 x vector<8xf32>>> // CHECK: %[[s37:.*]] = arith.constant 0 : index // CHECK: %[[s38:.*]] = builtin.unrealized_conversion_cast %[[s37]] : index to i64 // CHECK: %[[s39:.*]] = llvm.extractelement %[[s34]]{{\[}}%[[s38]] : i64] : vector<4xf32> @@ -1028,8 +1008,7 @@ // CHECK: %[[s59:.*]] = builtin.unrealized_conversion_cast %[[s58]] : index to i64 // CHECK: %[[s60:.*]] = llvm.insertelement %[[s57]], %[[s54]]{{\[}}%[[s59]] : i64] : vector<8xf32> // CHECK: %[[s61:.*]] = llvm.insertvalue %[[s60]], %[[s32]][1] : !llvm.array<4 x vector<8xf32>> -// CHECK: %[[s62:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<16x4x8xf32> to !llvm.array<16 x array<4 x vector<8xf32>>> -// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s62]][0] : !llvm.array<16 x array<4 x vector<8xf32>>> +// CHECK: %[[s63:.*]] = llvm.insertvalue %[[s61]], %[[s2]][0] : !llvm.array<16 x array<4 x vector<8xf32>>> // CHECK: %[[s64:.*]] = builtin.unrealized_conversion_cast %[[s63]] : !llvm.array<16 x array<4 x vector<8xf32>>> to vector<16x4x8xf32> // CHECK: return %[[s64]] : vector<16x4x8xf32> @@ -1039,24 +1018,19 @@ // CHECK-LABEL: @vector_fma // CHECK-SAME: %[[A:.*]]: vector<8xf32> // CHECK-SAME: %[[B:.*]]: vector<2x4xf32> + // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: "llvm.intr.fmuladd" // CHECK-SAME: (vector<8xf32>, vector<8xf32>, vector<8xf32>) -> vector<8xf32> %0 = vector.fma %a, %a, %a : vector<8xf32> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b00:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b01:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b02:.*]] = llvm.extractvalue %[[BL]][0] : !llvm.array<2 x vector<4xf32>> // CHECK: %[[B0:.*]] = "llvm.intr.fmuladd"(%[[b00]], %[[b01]], %[[b02]]) : // CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> // CHECK: llvm.insertvalue %[[B0]], {{.*}}[0] : !llvm.array<2 x vector<4xf32>> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b10:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b11:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>> - // CHECK: %[[BL:.*]] = builtin.unrealized_conversion_cast %[[B]] : vector<2x4xf32> to !llvm.array<2 x vector<4xf32>> // CHECK: %[[b12:.*]] = llvm.extractvalue %[[BL]][1] : !llvm.array<2 x vector<4xf32>> // CHECK: %[[B1:.*]] = "llvm.intr.fmuladd"(%[[b10]], %[[b11]], %[[b12]]) : // CHECK-SAME: (vector<4xf32>, vector<4xf32>, vector<4xf32>) -> vector<4xf32> diff --git a/mlir/test/Dialect/ArmSVE/memcpy.mlir b/mlir/test/Dialect/ArmSVE/memcpy.mlir --- a/mlir/test/Dialect/ArmSVE/memcpy.mlir +++ b/mlir/test/Dialect/ArmSVE/memcpy.mlir @@ -7,19 +7,18 @@ %vs = arm_sve.vector_scale : index %step = arith.muli %c4, %vs : index + // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr + // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr // CHECK: scf.for [[LOOPIDX:%arg[0-9]+]] = {{.*}} scf.for %i0 = %c0 to %size step %step { - // CHECK: [[SRCMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[SRC]] : memref to !llvm.struct<(ptr // CHECK: [[SRCIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 // CHECK: [[SRCMEM:%[0-9]+]] = llvm.extractvalue [[SRCMRS]][1] : !llvm.struct<(ptr // CHECK-NEXT: [[SRCPTR:%[0-9]+]] = llvm.getelementptr [[SRCMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-NEXT: [[SRCVPTR:%[0-9]+]] = llvm.bitcast [[SRCPTR]] : !llvm.ptr to !llvm.ptr> // CHECK-NEXT: [[LDVAL:%[0-9]+]] = llvm.load [[SRCVPTR]] : !llvm.ptr> %0 = arm_sve.load %src[%i0] : !arm_sve.vector<4xf32> from memref - // CHECK: [[DSTMRS:%[0-9]+]] = builtin.unrealized_conversion_cast [[DST]] : memref to !llvm.struct<(ptr - // CHECK: [[DSTIDX:%[0-9]+]] = builtin.unrealized_conversion_cast [[LOOPIDX]] : index to i64 // CHECK: [[DSTMEM:%[0-9]+]] = llvm.extractvalue [[DSTMRS]][1] : !llvm.struct<(ptr - // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[DSTIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr + // CHECK-NEXT: [[DSTPTR:%[0-9]+]] = llvm.getelementptr [[DSTMEM]]{{.}}[[SRCIDX]]{{.}} : (!llvm.ptr, i64) -> !llvm.ptr // CHECK-NEXT: [[DSTVPTR:%[0-9]+]] = llvm.bitcast [[DSTPTR]] : !llvm.ptr to !llvm.ptr> // CHECK-NEXT: llvm.store [[LDVAL]], [[DSTVPTR]] : !llvm.ptr> arm_sve.store %0, %dst[%i0] : !arm_sve.vector<4xf32> to memref diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -139,8 +139,8 @@ // CHECK-LABEL: func @generic_with_init_tensor( // CHECK-SAME: %[[ARG0_TENSOR:.*]]: tensor<2x3x4xvector<3x4xi4>>, // CHECK-SAME: %[[ARG1_TENSOR:.*]]: tensor<3x2xf32>) -> tensor<3x2xf32> { -// CHECK: %[[ARG0_MEMREF:.*]] = memref.buffer_cast %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>> -// CHECK: %[[ARG1_MEMREF:.*]] = memref.buffer_cast %[[ARG1_TENSOR]] : memref<3x2xf32> +// CHECK-DAG: %[[ARG0_MEMREF:.*]] = memref.buffer_cast %[[ARG0_TENSOR]] : memref<2x3x4xvector<3x4xi4>> +// CHECK-DAG: %[[ARG1_MEMREF:.*]] = memref.buffer_cast %[[ARG1_TENSOR]] : memref<3x2xf32> // CHECK: %[[INIT_BUFFER:.*]] = memref.alloc() : memref<3x2xf32> // CHECK: linalg.copy(%[[ARG1_MEMREF]], %[[INIT_BUFFER]]) : memref<3x2xf32>, memref<3x2xf32> // CHECK: linalg.generic @@ -169,10 +169,11 @@ // CHECK-LABEL: func @bufferize_slice( // CHECK-SAME: %[[T:[0-9a-z]*]]: tensor func @bufferize_slice(%t : tensor) -> (tensor<2x3xf32>, tensor<2x?xf32>) { + // CHECK: %[[M:.*]] = memref.buffer_cast %[[T]] : memref + // CHECK: %[[IDX:.*]] = call @make_index() : () -> index %i0 = call @make_index() : () -> index - // CHECK: %[[M:.*]] = memref.buffer_cast %[[T]] : memref // CHECK-NEXT: %[[A0:.*]] = memref.alloc() : memref<2x3xf32> // CHECK-NEXT: %[[SM0:.*]] = memref.subview %[[M]][0, 0] [2, 3] [1, 1] // CHECK-SAME: memref to memref<2x3xf32, #[[$MAP0]]> @@ -204,6 +205,10 @@ // CHECK-SAME: %[[ST1:[0-9a-z]*]]: tensor<2x?xf32> func @bufferize_insert_slice(%t : tensor, %st0 : tensor<2x3xf32>, %st1 : tensor<2x?xf32>) -> (tensor, tensor) { + // CHECK-DAG: %[[M:.*]] = memref.buffer_cast %[[T]] : memref + // CHECK-DAG: %[[SM0:.*]] = memref.buffer_cast %[[ST0]] : memref<2x3xf32> + // CHECK-DAG: %[[SM1:.*]] = memref.buffer_cast %[[ST1]] : memref<2x?xf32> + %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -212,8 +217,6 @@ // CHECK: %[[IDX:.*]] = call @make_index() : () -> index - // CHECK-DAG: %[[M:.*]] = memref.buffer_cast %[[T]] : memref - // CHECK-DAG: %[[SM0:.*]] = memref.buffer_cast %[[ST0]] : memref<2x3xf32> // CHECK-NEXT: %[[DIM0:.*]] = tensor.dim %[[T]], %[[C0]] : tensor // CHECK-NEXT: %[[DIM1:.*]] = tensor.dim %[[T]], %[[C1]] : tensor // CHECK-NEXT: %[[M_COPY0:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref @@ -224,7 +227,6 @@ // CHECK-NEXT: %[[RT0:.*]] = memref.tensor_load %[[M_COPY0]] : memref %t0 = tensor.insert_slice %st0 into %t[0, 0][2, 3][1, 1] : tensor<2x3xf32> into tensor - // CHECK-DAG: %[[SM1:.*]] = memref.buffer_cast %[[ST1]] : memref<2x?xf32> // CHECK-NEXT: %[[M_COPY1:.*]] = memref.alloc(%[[DIM0]], %[[DIM1]]) : memref // CHECK-NEXT: linalg.copy(%[[M]], %[[M_COPY1]]) : memref, memref // CHECK-NEXT: %[[SUBVIEW1:.*]] = memref.subview %[[M_COPY1]][0, %[[IDX]]] [2, %[[IDX]]] [1, 2] @@ -285,13 +287,13 @@ // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32> // CHECK: %[[DIM1:.*]] = tensor.dim %[[IN]], %[[C1]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM2:.*]] = arith.addi %[[OFFSET]], %[[C2]] : index // CHECK: %[[DIM3:.*]] = tensor.dim %[[IN]], %[[C3]] : tensor<4x?x2x?xf32> // CHECK: %[[OUT_DIM3:.*]] = arith.addi %[[DIM3]], %[[OFFSET]] : index // CHECK: %[[FILLED:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32> // CHECK: linalg.fill(%[[CST]], %[[FILLED]]) : f32, memref<4x?x?x?xf32> -// CHECK: %[[IN_MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref<4x?x2x?xf32> // CHECK: %[[OUT:.*]] = memref.alloc(%[[DIM1]], %[[OUT_DIM2]], %[[OUT_DIM3]]) : memref<4x?x?x?xf32> // CHECK: linalg.copy(%[[FILLED]], %[[OUT]]) : memref<4x?x?x?xf32>, memref<4x?x?x?xf32> // CHECK: %[[INTERIOR:.*]] = memref.subview %[[OUT]][0, 0, %[[OFFSET]], 0] [4, %[[DIM1]], 2, %[[DIM3]]] [1, 1, 1, 1] : memref<4x?x?x?xf32> to memref<4x?x2x?xf32, #map> diff --git a/mlir/test/Dialect/Linalg/detensorize_0d.mlir b/mlir/test/Dialect/Linalg/detensorize_0d.mlir --- a/mlir/test/Dialect/Linalg/detensorize_0d.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_0d.mlir @@ -57,8 +57,7 @@ // CHECK-DAG: %[[arg1_val:.*]] = tensor.extract %[[arg1]] // CHECK-DAG: %[[arg2_val:.*]] = tensor.extract %[[arg2]] // CHECK: %[[detensored_res:.*]] = arith.addf %[[arg1_val]], %[[arg2_val]] -// CHECK-DAG: %[[arg1_val2:.*]] = tensor.extract %[[arg1]] -// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val2]], %[[detensored_res]] +// CHECK: %[[detensored_res2:.*]] = arith.mulf %[[arg1_val]], %[[detensored_res]] // CHECK: %[[detensored_res3:.*]] = arith.divf %[[detensored_res]], %[[detensored_res2]] // CHECK: %[[new_tensor_res:.*]] = tensor.from_elements %[[detensored_res3]] // CHECK: %[[reshaped_tensor_res:.*]] = linalg.tensor_collapse_shape %[[new_tensor_res]] diff --git a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir --- a/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir +++ b/mlir/test/Dialect/Linalg/detensorize_while_impure_cf.mlir @@ -73,10 +73,7 @@ // DET-ALL: linalg.yield %{{.*}} : i32 // DET-ALL: } -> tensor // DET-ALL: tensor.extract %{{.*}}[] : tensor -// DET-ALL: tensor.extract %{{.*}}[] : tensor -// DET-ALL: arith.cmpi slt, %{{.*}}, %{{.*}} : i32 -// DET-ALL: tensor.extract %{{.*}}[] : tensor -// DET-ALL: tensor.extract %{{.*}}[] : tensor +// DET-ALL: cmpi slt, %{{.*}}, %{{.*}} : i32 // DET-ALL: cond_br %{{.*}}, ^[[bb2:.*]](%{{.*}} : i32), ^[[bb3:.*]](%{{.*}} : i32) // DET-ALL: ^[[bb2]](%{{.*}}: i32) // DET-ALL: tensor.from_elements %{{.*}} : tensor<1xi32> @@ -99,8 +96,7 @@ // DET-CF: ^bb1(%{{.*}}: tensor<10xi32>) // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor<10xi32>) outs(%{{.*}} : tensor) { // DET-CF: tensor.extract %{{.*}}[] : tensor -// DET-CF: tensor.extract %{{.*}}[] : tensor -// DET-CF: arith.cmpi slt, %{{.*}}, %{{.*}} : i32 +// DET-CF: cmpi slt, %{{.*}}, %{{.*}} : i32 // DET-CF: cond_br %{{.*}}, ^bb2(%{{.*}} : tensor), ^bb3(%{{.*}} : tensor) // DET-CF: ^bb2(%{{.*}}: tensor) // DET-CF: %{{.*}} = linalg.generic {{{.*}}} ins(%{{.*}} : tensor) outs(%{{.*}} : tensor<10xi32>) { diff --git a/mlir/test/Dialect/SCF/bufferize.mlir b/mlir/test/Dialect/SCF/bufferize.mlir --- a/mlir/test/Dialect/SCF/bufferize.mlir +++ b/mlir/test/Dialect/SCF/bufferize.mlir @@ -4,11 +4,11 @@ // CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[TRUE_TENSOR:.*]]: tensor, // CHECK-SAME: %[[FALSE_TENSOR:.*]]: tensor) -> tensor { +// CHECK: %[[TRUE_MEMREF:.*]] = memref.buffer_cast %[[TRUE_TENSOR]] : memref +// CHECK: %[[FALSE_MEMREF:.*]] = memref.buffer_cast %[[FALSE_TENSOR]] : memref // CHECK: %[[RESULT_MEMREF:.*]] = scf.if %[[PRED]] -> (memref) { -// CHECK: %[[TRUE_MEMREF:.*]] = memref.buffer_cast %[[TRUE_TENSOR]] : memref // CHECK: scf.yield %[[TRUE_MEMREF]] : memref // CHECK: } else { -// CHECK: %[[FALSE_MEMREF:.*]] = memref.buffer_cast %[[FALSE_TENSOR]] : memref // CHECK: scf.yield %[[FALSE_MEMREF]] : memref // CHECK: } // CHECK: %[[RESULT_TENSOR:.*]] = memref.tensor_load %[[RESULT_MEMREF:.*]] : memref @@ -29,9 +29,7 @@ // CHECK-SAME: %[[STEP:.*]]: index) -> tensor { // CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref // CHECK: %[[RESULT_MEMREF:.*]] = scf.for %[[VAL_6:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] iter_args(%[[ITER:.*]] = %[[MEMREF]]) -> (memref) { -// CHECK: %[[TENSOR_ITER:.*]] = memref.tensor_load %[[ITER]] : memref -// CHECK: %[[MEMREF_YIELDED:.*]] = memref.buffer_cast %[[TENSOR_ITER]] : memref -// CHECK: scf.yield %[[MEMREF_YIELDED]] : memref +// CHECK: scf.yield %[[ITER]] : memref // CHECK: } // CHECK: %[[VAL_8:.*]] = memref.tensor_load %[[VAL_9:.*]] : memref // CHECK: return %[[VAL_8]] : tensor diff --git a/mlir/test/Dialect/Standard/bufferize.mlir b/mlir/test/Dialect/Standard/bufferize.mlir --- a/mlir/test/Dialect/Standard/bufferize.mlir +++ b/mlir/test/Dialect/Standard/bufferize.mlir @@ -4,8 +4,8 @@ // CHECK-SAME: %[[PRED:.*]]: i1, // CHECK-SAME: %[[TRUE_VAL:.*]]: tensor, // CHECK-SAME: %[[FALSE_VAL:.*]]: tensor) -> tensor { -// CHECK: %[[TRUE_VAL_MEMREF:.*]] = memref.buffer_cast %[[TRUE_VAL]] : memref -// CHECK: %[[FALSE_VAL_MEMREF:.*]] = memref.buffer_cast %[[FALSE_VAL]] : memref +// CHECK-DAG: %[[TRUE_VAL_MEMREF:.*]] = memref.buffer_cast %[[TRUE_VAL]] : memref +// CHECK-DAG: %[[FALSE_VAL_MEMREF:.*]] = memref.buffer_cast %[[FALSE_VAL]] : memref // CHECK: %[[RET_MEMREF:.*]] = select %[[PRED]], %[[TRUE_VAL_MEMREF]], %[[FALSE_VAL_MEMREF]] : memref // CHECK: %[[RET:.*]] = memref.tensor_load %[[RET_MEMREF]] : memref // CHECK: return %[[RET]] : tensor diff --git a/mlir/test/Dialect/Standard/func-bufferize.mlir b/mlir/test/Dialect/Standard/func-bufferize.mlir --- a/mlir/test/Dialect/Standard/func-bufferize.mlir +++ b/mlir/test/Dialect/Standard/func-bufferize.mlir @@ -2,22 +2,16 @@ // CHECK-LABEL: func @identity( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[ARG]] : memref -// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref -// CHECK: return %[[MEMREF]] : memref +// CHECK: return %[[ARG]] : memref func @identity(%arg0: tensor) -> tensor { return %arg0 : tensor } // CHECK-LABEL: func @block_arguments( // CHECK-SAME: %[[ARG:.*]]: memref) -> memref { -// CHECK: %[[T1:.*]] = memref.tensor_load %[[ARG]] : memref -// CHECK: %[[M1:.*]] = memref.buffer_cast %[[T1]] : memref -// CHECK: br ^bb1(%[[M1]] : memref) +// CHECK: br ^bb1(%[[ARG]] : memref) // CHECK: ^bb1(%[[BBARG:.*]]: memref): -// CHECK: %[[T2:.*]] = memref.tensor_load %[[BBARG]] : memref -// CHECK: %[[M2:.*]] = memref.buffer_cast %[[T2]] : memref -// CHECK: return %[[M2]] : memref +// CHECK: return %[[BBARG]] : memref func @block_arguments(%arg0: tensor) -> tensor { br ^bb1(%arg0: tensor) ^bb1(%bbarg: tensor): @@ -35,9 +29,7 @@ } // CHECK-LABEL: func @call_sink( // CHECK-SAME: %[[ARG:.*]]: memref) { -// CHECK: %[[TENSOR:.*]] = memref.tensor_load %[[ARG]] : memref -// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[TENSOR]] : memref -// CHECK: call @sink(%[[MEMREF]]) : (memref) -> () +// CHECK: call @sink(%[[ARG]]) : (memref) -> () // CHECK: return func private @sink(tensor) func @call_sink(%arg0: tensor) { diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -74,11 +74,11 @@ // CHECK-LABEL: func @tensor.generate( // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { +// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32> // CHECK: %[[MEMREF:.*]] = memref.alloc(%[[DYNAMIC_EXTENT]]) : memref // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[C1:.*]] = arith.constant 1 : index // CHECK: scf.parallel (%[[I:.*]]) = (%[[C0]]) to (%[[DYNAMIC_EXTENT]]) step (%[[C1]]) { -// CHECK: %[[CASTED:.*]] = memref.buffer_cast %[[ARG]] : memref<*xf32> // CHECK: %[[ELEM:.*]] = memref.dim %[[CASTED]], %[[I]] : memref<*xf32> // CHECK: store %[[ELEM]], %[[MEMREF]][%[[I]]] : memref // CHECK: scf.yield diff --git a/mlir/test/Transforms/test-legalize-remapped-value.mlir b/mlir/test/Transforms/test-legalize-remapped-value.mlir --- a/mlir/test/Transforms/test-legalize-remapped-value.mlir +++ b/mlir/test/Transforms/test-legalize-remapped-value.mlir @@ -1,13 +1,28 @@ // RUN: mlir-opt %s -test-remapped-value | FileCheck %s // Simple test that exercises ConvertPatternRewriter::getRemappedValue. + +// CHECK-LABEL: func @remap_input_1_to_1 +// CHECK-SAME: (%[[ARG:.*]]: i32) +// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]]) +// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]]) + func @remap_input_1_to_1(%arg0: i32) { %0 = "test.one_variadic_out_one_variadic_in1"(%arg0) : (i32) -> i32 %1 = "test.one_variadic_out_one_variadic_in1"(%0) : (i32) -> i32 "test.return"() : () -> () } -// CHECK-LABEL: func @remap_input_1_to_1 -// CHECK-SAME: (%[[ARG:.*]]: i32) -// CHECK-NEXT: %[[VAL:.*]] = "test.one_variadic_out_one_variadic_in1"(%[[ARG]], %[[ARG]]) -// CHECK-NEXT: "test.one_variadic_out_one_variadic_in1"(%[[VAL]], %[[VAL]]) +// Test the case where an operation is converted before its operands are. + +// CHECK-LABEL: func @remap_unconverted +// CHECK-NEXT: %[[VAL:.*]] = "test.type_producer"() : () -> f64 +// CHECK-NEXT: "test.type_consumer"(%[[VAL]]) : (f64) +func @remap_unconverted() { + %region_result = "test.remapped_value_region"() ({ + %result = "test.type_producer"() : () -> f32 + "test.return"(%result) : (f32) -> () + }) : () -> (f32) + "test.type_consumer"(%region_result) : (f32) -> () + "test.return"() : () -> () +} diff --git a/mlir/test/Transforms/test-legalize-type-conversion.mlir b/mlir/test/Transforms/test-legalize-type-conversion.mlir --- a/mlir/test/Transforms/test-legalize-type-conversion.mlir +++ b/mlir/test/Transforms/test-legalize-type-conversion.mlir @@ -10,13 +10,6 @@ // ----- -// expected-error@below {{failed to legalize conversion operation generated for block argument}} -func @test_invalid_arg_illegal_materialization(%arg0: i32) { - "foo.return"(%arg0) : (i32) -> () -} - -// ----- - // CHECK-LABEL: func @test_valid_arg_materialization func @test_valid_arg_materialization(%arg0: i64) { // CHECK: %[[ARG:.*]] = "test.type_producer" @@ -67,14 +60,6 @@ // ----- -func @test_invalid_result_legalization() { - // expected-error@below {{failed to legalize conversion operation generated for result #0 of operation 'test.type_producer' that remained live after conversion}} - %result = "test.type_producer"() : () -> i16 - "foo.return"(%result) : (i16) -> () -} - -// ----- - // CHECK-LABEL: func @test_valid_result_legalization func @test_valid_result_legalization() { // CHECK: %[[RESULT:.*]] = "test.type_producer"() : () -> f64 diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -28,7 +28,6 @@ // CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64) func @remap_call_1_to_1(%arg0: i64) { // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> () - // expected-remark@+1 {{op 'std.call' is not legalizable}} call @remap_input_1_to_1(%arg0) : (i64) -> () // expected-remark@+1 {{op 'std.return' is not legalizable}} return @@ -36,7 +35,6 @@ // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16) func @remap_input_1_to_N(%arg0: f32) -> f32 { - // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32 // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> () "test.return"(%arg0) : (f32) -> () } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1497,13 +1497,25 @@ def TestMergeBlocksOp : TEST_Op<"merge_blocks"> { let summary = "merge_blocks operation"; let description = [{ - Test op with multiple blocks that are merged with Dialect Conversion" + Test op with multiple blocks that are merged with Dialect Conversion }]; let regions = (region AnyRegion:$body); let results = (outs Variadic:$result); } +def TestRemappedValueRegionOp : TEST_Op<"remapped_value_region", + [SingleBlock]> { + let summary = "remapped_value_region operation"; + let description = [{ + Test op that remaps values that haven't yet been converted in Dialect + Conversion. + }]; + + let regions = (region SizedRegion<1>:$body); + let results = (outs Variadic:$result); +} + def TestSignatureConversionUndoOp : TEST_Op<"signature_conversion_undo"> { let regions = (region AnyRegion); } diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -429,7 +429,8 @@ // Check if the first operation is a cast operation, if it is we use the // results directly. auto *defOp = operands[0].getDefiningOp(); - if (auto packerOp = llvm::dyn_cast_or_null(defOp)) { + if (auto packerOp = + llvm::dyn_cast_or_null(defOp)) { rewriter.replaceOpWithNewOp(op, packerOp.getOperands()); return success(); } @@ -586,16 +587,6 @@ addConversion(convertType); addArgumentMaterialization(materializeCast); addSourceMaterialization(materializeCast); - - /// Materialize the cast for one-to-one conversion from i64 to f64. - const auto materializeOneToOneCast = - [](OpBuilder &builder, IntegerType resultType, ValueRange inputs, - Location loc) -> Optional { - if (resultType.getWidth() == 42 && inputs.size() == 1) - return builder.create(loc, resultType, inputs).getResult(); - return llvm::None; - }; - addArgumentMaterialization(materializeOneToOneCast); } static LogicalResult convertType(Type t, SmallVectorImpl &results) { @@ -630,8 +621,6 @@ /// 1->N type mappings. static Optional materializeCast(OpBuilder &builder, Type resultType, ValueRange inputs, Location loc) { - if (inputs.size() == 1) - return inputs[0]; return builder.create(loc, resultType, inputs).getResult(); } }; @@ -684,6 +673,8 @@ return converter.isSignatureLegal(op.getType()) && converter.isLegal(&op.getBody()); }); + target.addDynamicallyLegalOp( + [&](CallOp op) { return converter.isLegal(op); }); // TestCreateUnregisteredOp creates `arith.constant` operation, // which was not added to target intentionally to test @@ -771,6 +762,16 @@ // to get the remapped value of an original value that was replaced using // ConversionPatternRewriter. namespace { +struct TestRemapValueTypeConverter : public TypeConverter { + using TypeConverter::TypeConverter; + + TestRemapValueTypeConverter() { + addConversion( + [](Float32Type type) { return Float64Type::get(type.getContext()); }); + addConversion([](Type type) { return type; }); + } +}; + /// Converter that replaces a one-result one-operand OneVResOneVOperandOp1 with /// a one-operand two-result OneVResOneVOperandOp1 by replicating its original /// operand twice. @@ -802,6 +803,36 @@ } }; +/// A rewriter pattern that tests that blocks can be merged. +struct TestRemapValueInRegion + : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(TestRemappedValueRegionOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Block &block = op.getBody().front(); + Operation *terminator = block.getTerminator(); + + // Merge the block into the parent region. + Block *parentBlock = op->getBlock(); + Block *finalBlock = rewriter.splitBlock(parentBlock, op->getIterator()); + rewriter.mergeBlocks(&block, parentBlock, ValueRange()); + rewriter.mergeBlocks(finalBlock, parentBlock, ValueRange()); + + // Replace the results of this operation with the remapped terminator + // values. + SmallVector terminatorOperands; + if (failed(rewriter.getRemappedValues(terminator->getOperands(), + terminatorOperands))) + return failure(); + + rewriter.eraseOp(terminator); + rewriter.replaceOp(op, terminatorOperands); + return success(); + } +}; + struct TestRemappedValue : public mlir::PassWrapper { StringRef getArgument() const final { return "test-remapped-value"; } @@ -809,18 +840,29 @@ return "Test public remapped value mechanism in ConversionPatternRewriter"; } void runOnFunction() override { + TestRemapValueTypeConverter typeConverter; + mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); + patterns.add( + &getContext()); + patterns.add(typeConverter, &getContext()); mlir::ConversionTarget target(getContext()); target.addLegalOp(); + + // Expect the type_producer/type_consumer operations to only operate on f64. + target.addDynamicallyLegalOp( + [](TestTypeProducerOp op) { return op.getType().isF64(); }); + target.addDynamicallyLegalOp([](TestTypeConsumerOp op) { + return op.getOperand().getType().isF64(); + }); + // We make OneVResOneVOperandOp1 legal only when it has more that one // operand. This will trigger the conversion that will replace one-operand // OneVResOneVOperandOp1 with two-operand OneVResOneVOperandOp1. target.addDynamicallyLegalOp( - [](Operation *op) -> bool { - return std::distance(op->operand_begin(), op->operand_end()) > 1; - }); + [](Operation *op) { return op->getNumOperands() > 1; }); if (failed(mlir::applyFullConversion(getFunction(), target, std::move(patterns)))) {