diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -217,16 +217,20 @@ template void addConversion(ConversionFnT &&callback); - /// This hook allows for materializing a conversion from a set of types into - /// one result type by generating a cast operation of some kind. The generated - /// operation should produce one result, of 'resultType', with the provided - /// 'inputs' as operands. This hook must be overridden when a type conversion + /// Register a materialization function, which must be convertibe to the + /// following form + /// `Optional(PatternRewriter &, T, ValueRange, Location)`, + /// where `T` is any subclass of `Type`. This function is responsible for + /// creating an operation, using the PatternRewriter and Location provided, + /// that "casts" a range of values into a single value of the given type `T`. + /// It must return a Value of the converted type on success, an `llvm::None` + /// if it failed but other materialization can be attempted, and `nullptr` on + /// unrecoverable failure. It will only be called for (sub)types of `T`. + /// Materialization functions must be provided when a type conversion /// results in more than one type, or if a type conversion may persist after /// the conversion has finished. - virtual Operation *materializeConversion(PatternRewriter &rewriter, - Type resultType, - ArrayRef inputs, - Location loc); + template + void addMaterialization(FnT &&callback); }; ``` diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -122,11 +122,6 @@ /// pointers to memref descriptors for arguments. LLVM::LLVMType convertFunctionTypeCWrapper(FunctionType type); - /// Creates descriptor structs from individual values constituting them. - Operation *materializeConversion(PatternRewriter &rewriter, Type type, - ArrayRef values, - Location loc) override; - /// Gets the LLVM representation of the index type. The returned type is an /// integer type with the size configured for this type converter. LLVM::LLVMType getIndexType(); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -113,6 +113,25 @@ registerConversion(wrapCallback(std::forward(callback))); } + /// Register a materialization function, which must be convertibe to the + /// following form + /// `Optional(PatternRewriter &, T, ValueRange, Location)`, + /// where `T` is any subclass of `Type`. This function is responsible for + /// creating an operation, using the PatternRewriter and Location provided, + /// that "casts" a range of values into a single value of the given type `T`. + /// It must return a Value of the converted type on success, an `llvm::None` + /// if it failed but other materialization can be attempted, and `nullptr` on + /// unrecoverable failure. It will only be called for (sub)types of `T`. + /// Materialization functions must be provided when a type conversion + /// results in more than one type, or if a type conversion may persist after + /// the conversion has finished. + template ::template arg_t<1>> + void addMaterialization(FnT &&callback) { + registerMaterialization( + wrapMaterialization(std::forward(callback))); + } + /// Convert the given type. This function should return failure if no valid /// conversion exists, success otherwise. If the new set of types is empty, /// the type is removed and any usages of the existing value are expected to @@ -148,18 +167,10 @@ /// valid conversion for the signature on success, None otherwise. Optional convertBlockSignature(Block *block); - /// This hook allows for materializing a conversion from a set of types into - /// one result type by generating a cast operation of some kind. The generated - /// operation should produce one result, of 'resultType', with the provided - /// 'inputs' as operands. This hook must be overridden when a type conversion - /// results in more than one type, or if a type conversion may persist after - /// the conversion has finished. - virtual Operation *materializeConversion(PatternRewriter &rewriter, - Type resultType, - ArrayRef inputs, - Location loc) { - llvm_unreachable("expected 'materializeConversion' to be overridden"); - } + /// Materialize a conversion from a set of types into one result type by + /// generating a cast operation of some kind. + Value materializeConversion(PatternRewriter &rewriter, Location loc, + Type resultType, ValueRange inputs); private: /// The signature of the callback used to convert a type. If the new set of @@ -168,6 +179,9 @@ using ConversionCallbackFn = std::function(Type, SmallVectorImpl &)>; + using MaterializationCallbackFn = std::function( + PatternRewriter &, Type, ValueRange, Location)>; + /// Generate a wrapper for the given callback. This allows for accepting /// different callback forms, that all compose into a single version. /// With callback of form: `Optional(T)` @@ -204,8 +218,30 @@ conversions.emplace_back(std::move(callback)); } + /// Generate a wrapper for the given materialization callback. The callback + /// may take any subclass of `Type` and the wrapper will check for the target + /// type to be of the expected class before calling the callback. + template + MaterializationCallbackFn wrapMaterialization(FnT &&callback) { + return [callback = std::forward(callback)]( + PatternRewriter &rewriter, Type resultType, ValueRange inputs, + Location loc) -> Optional { + if (T derivedType = resultType.dyn_cast()) + return callback(rewriter, derivedType, inputs, loc); + return llvm::None; + }; + } + + /// Register a materialization. + void registerMaterialization(MaterializationCallbackFn &&callback) { + materializations.emplace_back(std::move(callback)); + } + /// The set of registered conversion functions. SmallVector conversions; + + /// The list of registered materialization functions. + SmallVector materializations; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/StandardToLLVM.cpp @@ -150,6 +150,24 @@ // LLVMType is legal, so add a pass-through conversion. addConversion([](LLVM::LLVMType type) { return type; }); + + // Materialization for memrefs creates descriptor structs from individual + // values constituting them, when descriptors are used, i.e. more than one + // value represents a memref. + addMaterialization([&](PatternRewriter &rewriter, + UnrankedMemRefType resultType, ValueRange inputs, + Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, resultType, + inputs); + }); + addMaterialization([&](PatternRewriter &rewriter, MemRefType resultType, + ValueRange inputs, Location loc) -> Optional { + if (inputs.size() == 1) + return llvm::None; + return MemRefDescriptor::pack(rewriter, loc, *this, resultType, inputs); + }); } /// Returns the MLIR context. @@ -297,22 +315,6 @@ return LLVM::LLVMType::getFunctionTy(resultType, inputs, false); } -/// Creates descriptor structs from individual values constituting them. -Operation *LLVMTypeConverter::materializeConversion(PatternRewriter &rewriter, - Type type, - ArrayRef values, - Location loc) { - if (auto unrankedMemRefType = type.dyn_cast()) - return UnrankedMemRefDescriptor::pack(rewriter, loc, *this, - unrankedMemRefType, values) - .getDefiningOp(); - - auto memRefType = type.dyn_cast(); - assert(memRefType && "1->N conversion is only supported for memrefs"); - return MemRefDescriptor::pack(rewriter, loc, *this, memRefType, values) - .getDefiningOp(); -} - // Convert a MemRef to an LLVM type. The result is a MemRef descriptor which // contains: // 1. the pointer to the data buffer, followed by diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -305,27 +305,20 @@ // persist in the IR after conversion. if (!origArg.use_empty()) { rewriter.setInsertionPointToStart(newBlock); - auto *newOp = typeConverter->materializeConversion( - rewriter, origArg.getType(), llvm::None, loc); - origArg.replaceAllUsesWith(newOp->getResult(0)); + Value newArg = typeConverter->materializeConversion( + rewriter, loc, origArg.getType(), llvm::None); + assert(newArg && + "Couldn't materialize a block argument after 1->0 conversion"); + origArg.replaceAllUsesWith(newArg); } continue; } - // If mapping is 1-1, replace the remaining uses and drop the cast - // operation. - // FIXME(riverriddle) This should check that the result type and operand - // type are the same, otherwise it should force a conversion to be - // materialized. - if (argInfo->newArgSize == 1) { - origArg.replaceAllUsesWith( - mapping.lookupOrDefault(newBlock->getArgument(argInfo->newArgIdx))); - continue; - } - - // Otherwise this is a 1->N value mapping. + // Otherwise this is a 1->1+ value mapping. Value castValue = argInfo->castValue; - assert(argInfo->newArgSize > 1 && castValue && "expected 1->N mapping"); + assert(argInfo->newArgSize >= 1); + assert(castValue); + assert(argInfo->newArgSize >= 1 && castValue && "expected 1->1+ mapping"); // If the argument is still used, replace it with the generated cast. if (!origArg.use_empty()) @@ -333,7 +326,7 @@ // If all users of the cast were removed, we can drop it. Otherwise, keep // the operation alive and let the user handle any remaining usages. - if (castValue.use_empty()) + if (castValue.use_empty() && castValue.getDefiningOp()) castValue.getDefiningOp()->erase(); } } @@ -389,22 +382,22 @@ continue; } - // If this is a 1->1 mapping, then map the argument directly. - if (inputMap->size == 1) { - mapping.map(origArg, newArgs[inputMap->inputNo]); - info.argInfo[i] = ConvertedArgInfo(inputMap->inputNo, inputMap->size); - continue; - } - - // Otherwise, this is a 1->N mapping. Call into the provided type converter - // to pack the new values. + // Otherwise, this is a 1->1+ mapping. Call into the provided type converter + // to pack the new values. For 1->1 mappings, if there is no materialization + // provided, use the argument directly instead. auto replArgs = newArgs.slice(inputMap->inputNo, inputMap->size); - Operation *cast = typeConverter->materializeConversion( - rewriter, origArg.getType(), replArgs, loc); - assert(cast->getNumResults() == 1); - mapping.map(origArg, cast->getResult(0)); + Value newArg = typeConverter + ? typeConverter->materializeConversion( + rewriter, loc, origArg.getType(), replArgs) + : Value(); + if (!newArg) { + assert(replArgs.size() == 1 && + "couldn't materialize the result of 1->N conversion"); + newArg = replArgs.front(); + } + mapping.map(origArg, newArg); info.argInfo[i] = - ConvertedArgInfo(inputMap->inputNo, inputMap->size, cast->getResult(0)); + ConvertedArgInfo(inputMap->inputNo, inputMap->size, newArg); } // Remove the original block from the region and return the new one. @@ -1755,6 +1748,15 @@ return success(); } +Value TypeConverter::materializeConversion(PatternRewriter &rewriter, + Location loc, Type resultType, + ValueRange inputs) { + for (MaterializationCallbackFn &fn : llvm::reverse(materializations)) + if (Optional result = fn(rewriter, resultType, inputs, loc)) + return result.getValue(); + return nullptr; +} + /// Create a default conversion pattern that rewrites the type signature of a /// FuncOp. namespace { diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir --- a/mlir/test/Transforms/test-legalizer.mlir +++ b/mlir/test/Transforms/test-legalizer.mlir @@ -48,6 +48,13 @@ "work"(%arg0) : (f32) -> () } +// CHECK-LABEL: func @remap_materialize_1_to_1(%{{.*}}: i43) +func @remap_materialize_1_to_1(%arg0: i42) { + // CHECK: %[[V:.*]] = "test.cast"(%arg0) : (i43) -> i42 + // CHECK: "test.return"(%[[V]]) + "test.return"(%arg0) : (i42) -> () +} + // CHECK-LABEL: func @remap_input_to_self func @remap_input_to_self(%arg0: index) { // CHECK-NOT: test.cast diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp --- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp +++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp @@ -448,7 +448,11 @@ namespace { struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; - TestTypeConverter() { addConversion(convertType); } + TestTypeConverter() { + addConversion(convertType); + addMaterialization(materializeCast); + addMaterialization(materializeOneToOneCast); + } static LogicalResult convertType(Type t, SmallVectorImpl &results) { // Drop I16 types. @@ -461,6 +465,12 @@ return success(); } + // Convert I42 to I43. + if (t.isInteger(42)) { + results.push_back(IntegerType::get(43, t.getContext())); + return success(); + } + // Split F32 into F16,F16. if (t.isF32()) { results.assign(2, FloatType::getF16(t.getContext())); @@ -472,12 +482,24 @@ return success(); } - /// Override the hook to materialize a conversion. This is necessary because - /// we generate 1->N type mappings. - Operation *materializeConversion(PatternRewriter &rewriter, Type resultType, - ArrayRef inputs, - Location loc) override { - return rewriter.create(loc, resultType, inputs); + /// Hook for materializing a conversion. This is necessary because we generate + /// 1->N type mappings. + static Optional materializeCast(PatternRewriter &rewriter, + Type resultType, ValueRange inputs, + Location loc) { + if (inputs.size() == 1) + return inputs[0]; + return rewriter.create(loc, resultType, inputs).getResult(); + } + + /// Materialize the cast for one-to-one conversion from i64 to f64. + static Optional materializeOneToOneCast(PatternRewriter &rewriter, + IntegerType resultType, + ValueRange inputs, + Location loc) { + if (resultType.getWidth() == 42 && inputs.size() == 1) + return rewriter.create(loc, resultType, inputs).getResult(); + return llvm::None; } };