diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -38,6 +38,19 @@ /// derive this class and implement the pure virtual functions. class TypeConverter { public: + /// Type of the hook for materializing cast operations. + using ConversionMaterializer = + std::function; + + /// Construct a TypeConverter. Derived classes must provide an implementation + /// of the ConversionMaterializer hook materializing a conversion from a set + /// of types into one result type by generating a cast operation of some kind. + /// The generated operation should produce one result, of 'resultType', with + /// the provided 'inputs' as operands. This hook must be provided when a + /// type conversion results in more than one type, or if a type conversion may + /// persist after the conversion has finished. + explicit TypeConverter(ConversionMaterializer materializer = nullptr) + : conversionMaterializer(materializer) {} virtual ~TypeConverter() = default; /// This class provides all of the information necessary to convert a type @@ -127,17 +140,8 @@ Optional convertBlockSignature(Block *block); /// This hook allows for materializing a conversion from a set of types into - /// one result type by generating a cast operation of some kind. The generated - /// operation should produce one result, of 'resultType', with the provided - /// 'inputs' as operands. This hook must be overridden when a type conversion - /// results in more than one type, or if a type conversion may persist after - /// the conversion has finished. - virtual Operation *materializeConversion(PatternRewriter &rewriter, - Type resultType, - ArrayRef inputs, - Location loc) { - llvm_unreachable("expected 'materializeConversion' to be overridden"); - } + /// one result type by generating a cast operation of some kind. + ConversionMaterializer conversionMaterializer; }; //===----------------------------------------------------------------------===// @@ -320,9 +324,6 @@ applySignatureConversion(Region *region, TypeConverter::SignatureConversion &conversion); - /// Replace all the uses of the block argument `from` with value `to`. - void replaceUsesOfBlockArgument(BlockArgument from, Value to); - /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. Value getRemappedValue(Value key); diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -652,25 +652,6 @@ rewriter.applySignatureConversion(&llvmFuncOp.getBody(), signatureConversion); - { - // For memref-typed arguments, insert the relevant loads in the beginning - // of the block to comply with the LLVM dialect calling convention. This - // needs to be done after signature conversion to get the right types. - OpBuilder::InsertionGuard guard(rewriter); - Block &block = llvmFuncOp.front(); - rewriter.setInsertionPointToStart(&block); - - for (auto en : llvm::enumerate(gpuFuncOp.getType().getInputs())) { - if (!en.value().isa() && - !en.value().isa()) - continue; - - BlockArgument arg = block.getArgument(en.index()); - Value loaded = rewriter.create(loc, arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); - } - } - rewriter.eraseOp(gpuFuncOp); return matchSuccess(); } diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -44,8 +44,19 @@ llvm::cl::desc("Replace emission of malloc/free by alloca"), llvm::cl::init(false)); +/// Materializes the "load" operation from a pointer to a memref descriptor. +/// This is used as a hook for type conversion. +Value materializeDescriptorLoad(OpBuilder &builder, Type type, + ValueRange pointer, Location loc) { + assert(pointer.size() == 1 && "only 1-1 conversion is supported"); + if (!type.isa() && !type.isa()) + return pointer[0]; + return builder.create(loc, pointer[0]); +} + LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) - : llvmDialect(ctx->getRegisteredDialect()) { + : TypeConverter(materializeDescriptorLoad), + llvmDialect(ctx->getRegisteredDialect()) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); } @@ -541,17 +552,6 @@ // Tell the rewriter to convert the region signature. rewriter.applySignatureConversion(&newFuncOp.getBody(), result); - // Insert loads from memref descriptor pointers in function bodies. - if (!newFuncOp.getBody().empty()) { - Block *firstBlock = &newFuncOp.getBody().front(); - rewriter.setInsertionPoint(firstBlock, firstBlock->begin()); - for (unsigned idx : promotedArgIndices) { - BlockArgument arg = firstBlock->getArgument(idx); - Value loaded = rewriter.create(funcOp.getLoc(), arg); - rewriter.replaceUsesOfBlockArgument(arg, loaded); - } - } - rewriter.eraseOp(op); return matchSuccess(); } diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -277,27 +277,19 @@ // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); - auto *newOp = typeConverter->materializeConversion( + assert(typeConverter->conversionMaterializer && + "ConversionMaterializer must be provided if uses of a block " + "argument persist after 1->0 conversion"); + Value newArg = typeConverter->conversionMaterializer( rewriter, origArg.getType(), llvm::None, loc); - origArg.replaceAllUsesWith(newOp->getResult(0)); + origArg.replaceAllUsesWith(newArg); } continue; } - // If mapping is 1-1, replace the remaining uses and drop the cast - // operation. - // FIXME(riverriddle) This should check that the result type and operand - // type are the same, otherwise it should force a conversion to be - // materialized. - if (argInfo->newArgSize == 1) { - origArg.replaceAllUsesWith( - mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx))); - continue; - } - - // Otherwise this is a 1->N value mapping. + // Otherwise this is a 1->1+ value mapping. Value castValue = argInfo->castValue; - assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); + assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) @@ -305,7 +297,7 @@ // If all users of the cast were removed, we can drop it. Otherwise, keep // the operation alive and let the user handle any remaining usages. - if (castValue.use_empty()) + if (castValue.use_empty() && castValue.getDefiningOp()) castValue.getDefiningOp()->erase(); } } @@ -361,23 +353,22 @@ continue; } - // If this is a 1->1 mapping, then map the argument directly. - if (inputMap->size == 1) { - mapping.map(origArg, newArgs[inputMap->inputNo]); - info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size); - continue; - } - - // Otherwise, this is a 1->N mapping. Call into the provided type converter - // to pack the new values. + // Otherwise, this is a 1->1+ mapping. Call into the provided type converter + // to cast or pack the new values. For 1->1 mappings, the converter may be + // omitted, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Operation *cast = typeConverter->materializeConversion( - rewriter, origArg.getType(), replArgs, loc); - assert(cast->getNumResults() == 1 && - cast->getNumOperands() == replArgs.size()); - mapping.map(origArg, cast->getResult(0)); + bool hasMaterializer = + typeConverter && typeConverter->conversionMaterializer; + assert((inputMap->size == 1 || hasMaterializer) && + "ConversionMaterializer must be provided for 1->N block argument " + "conversions"); + Value newArg = hasMaterializer + ? typeConverter->conversionMaterializer( + rewriter, origArg.getType(), replArgs, loc) + : replArgs.front(); + mapping.map(origArg, newArg); info.argInfo[i] = - ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); + ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } // Remove the original block from the region and return the new one. @@ -859,16 +850,6 @@ return impl->applySignatureConversion(region, conversion); } -void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from, - Value to) { - for (auto &u : from.getUses()) { - if (u.getOwner() == to.getDefiningOp()) - continue; - u.getOwner()->replaceUsesOfWith(from, to); - } - impl->mapping.map(impl->mapping.lookupOrDefault(from), to); -} - /// Return the converted value that replaces 'key'. Return 'key' if there is /// no such a converted value. Value ConversionPatternRewriter::getRemappedValue(Value key) { diff --git a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-memref-ops.mlir @@ -443,7 +443,6 @@ // CHECK-LABEL: func @static_memref_dim(%arg0: !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*">) { func @static_memref_dim(%static : memref<42x32x15x13x27xf32>) { -// CHECK-NEXT: %[[ld:.*]] = llvm.load %{{.*}} : !llvm<"{ float*, float*, i64, [5 x i64], [5 x i64] }*"> // CHECK-NEXT: llvm.mlir.constant(42 : index) : !llvm.i64 %0 = dim %static, 0 : memref<42x32x15x13x27xf32> // CHECK-NEXT: llvm.mlir.constant(32 : index) : !llvm.i64 diff --git a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir --- a/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir +++ b/mlir/test/Conversion/StandardToLLVM/standard-to-llvm.mlir @@ -2,7 +2,6 @@ // CHECK-LABEL: func @address_space( // CHECK: %{{.*}}: !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*">) -// CHECK: llvm.load %{{.*}} : !llvm<"{ float addrspace(7)*, float addrspace(7)*, i64, [1 x i64], [1 x i64] }*"> func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) { %0 = alloc() : memref<32xf32, affine_map<(d0) -> (d0)>, 5> %1 = constant 7 : index diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -310,7 +310,8 @@ namespace { struct TestTypeConverter : public TypeConverter { - using TypeConverter::TypeConverter; + TestTypeConverter() + : TypeConverter(TestTypeConverter::materializeConversion) {} LogicalResult convertType(Type t, SmallVectorImpl &results) override { // Drop I16 types. @@ -334,11 +335,12 @@ return success(); } - /// Override the hook to materialize a conversion. This is necessary because + /// Implement the hook to materialize a conversion. This is necessary because /// we generate 1->N type mappings. - Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, - Location loc) override { + static Value materializeConversion(PatternRewriter &rewriter, Type resultType, + ValueRange inputs, Location loc) { + if (inputs.size() == 1) + return inputs[0]; return rewriter.create(loc, resultType, inputs); } };