diff --git a/mlir/docs/DialectConversion.md b/mlir/docs/DialectConversion.md --- a/mlir/docs/DialectConversion.md +++ b/mlir/docs/DialectConversion.md @@ -196,15 +196,26 @@ ```c++ class TypeConverter { public: - /// This hook allows for converting a type. This function should return - /// failure if no valid conversion exists, success otherwise. If the new set - /// of types is empty, the type is removed and any usages of the existing - /// value are expected to be removed during conversion. - virtual LogicalResult convertType(Type t, SmallVectorImpl &results); - - /// This hook simplifies defining 1-1 type conversions. This function returns - /// the type to convert to on success, and a null type on failure. - virtual Type convertType(Type t); + /// Register a conversion function. A conversion function must be convertible + /// to any of the following forms(where `T` is a class derived from `Type`: + /// * Optional(T) + /// - This form represents a 1-1 type conversion. It should return nullptr + /// or `llvm::None` to signify failure. If `llvm::None` is returned, the + /// converter is allowed to try another conversion function to perform + /// the conversion. + /// * Optional(T, SmallVectorImpl &) + /// - This form represents a 1-N type conversion. It should return + /// `failure` or `llvm::None` to signify a failed conversion. If the new + /// set of types is empty, the type is removed and any usages of the + /// existing value are expected to be removed during conversion. If + /// `llvm::None` is returned, the converter is allowed to try another + /// conversion function to perform the conversion. + /// + /// When attempting to convert a type, e.g. via `convertType`, the + /// `TypeConverter` will invoke each of the converters starting with the one + /// most recently registered. + template + void addConversion(ConversionFnT &&callback); /// This hook allows for materializing a conversion from a set of types into /// one result type by generating a cast operation of some kind. The generated diff --git a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h --- a/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h +++ b/mlir/include/mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h @@ -16,14 +16,8 @@ class ModuleOp; template class OpPassBase; -class LinalgTypeConverter : public LLVMTypeConverter { -public: - using LLVMTypeConverter::LLVMTypeConverter; - Type convertType(Type t) override; -}; - /// Populate the given list with patterns that convert from Linalg to LLVM. -void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, +void populateLinalgToLLVMConversionPatterns(LLVMTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx); diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -78,10 +78,6 @@ LLVMTypeConverter(MLIRContext *ctx, const LLVMTypeConverterCustomization &custom); - /// Convert types to LLVM IR. This calls `convertAdditionalType` to convert - /// non-standard or non-builtin types. - Type convertType(Type t) override; - /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` /// is populated with argument mapping. @@ -129,8 +125,6 @@ LLVM::LLVMDialect *llvmDialect; private: - Type convertStandardType(Type type); - // Convert a function type. The arguments and results are converted one by // one. Additionally, if the function returns more than one value, pack the // results into an LLVM IR structure type so that the converted function type diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVLowering.h @@ -27,10 +27,7 @@ /// pointers to structs. class SPIRVTypeConverter : public TypeConverter { public: - using TypeConverter::TypeConverter; - - /// Converts the given standard `type` to SPIR-V correspondence. - Type convertType(Type type) override; + SPIRVTypeConverter(); /// Gets the SPIR-V correspondence for the standard index type. static Type getIndexType(MLIRContext *context); diff --git a/mlir/include/mlir/Support/STLExtras.h b/mlir/include/mlir/Support/STLExtras.h --- a/mlir/include/mlir/Support/STLExtras.h +++ b/mlir/include/mlir/Support/STLExtras.h @@ -376,6 +376,10 @@ template using arg_t = typename std::tuple_element>::type; }; +/// Overload for non-class function type references. +template +struct FunctionTraits + : public FunctionTraits {}; } // end namespace mlir // Allow tuples to be usable as DenseMap keys. diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -91,15 +91,37 @@ SmallVector argTypes; }; - /// This hook allows for converting a type. This function should return - /// failure if no valid conversion exists, success otherwise. If the new set - /// of types is empty, the type is removed and any usages of the existing - /// value are expected to be removed during conversion. - virtual LogicalResult convertType(Type t, SmallVectorImpl &results); + /// Register a conversion function. A conversion function must be convertible + /// to any of the following forms(where `T` is a class derived from `Type`: + /// * Optional(T) + /// - This form represents a 1-1 type conversion. It should return nullptr + /// or `llvm::None` to signify failure. If `llvm::None` is returned, the + /// converter is allowed to try another conversion function to perform + /// the conversion. + /// * Optional(T, SmallVectorImpl &) + /// - This form represents a 1-N type conversion. It should return + /// `failure` or `llvm::None` to signify a failed conversion. If the new + /// set of types is empty, the type is removed and any usages of the + /// existing value are expected to be removed during conversion. If + /// `llvm::None` is returned, the converter is allowed to try another + /// conversion function to perform the conversion. + /// Note: When attempting to convert a type, e.g. via 'convertType', the + /// mostly recently added conversions will be invoked first. + template ::template arg_t<0>> + void addConversion(FnT &&callback) { + registerConversion(wrapCallback(std::forward(callback))); + } + + /// Convert the given type. This function should return failure if no valid + /// conversion exists, success otherwise. If the new set of types is empty, + /// the type is removed and any usages of the existing value are expected to + /// be removed during conversion. + LogicalResult convertType(Type t, SmallVectorImpl &results); /// This hook simplifies defining 1-1 type conversions. This function returns /// the type to convert to on success, and a null type on failure. - virtual Type convertType(Type t) { return t; } + Type convertType(Type t); /// Convert the given set of types, filling 'results' as necessary. This /// returns failure if the conversion of any of the types fails, success @@ -138,6 +160,50 @@ Location loc) { llvm_unreachable("expected 'materializeConversion' to be overridden"); } + +private: + /// The signature of the callback used to convert a type. If the new set of + /// types is empty, the type is removed and any usages of the existing value + /// are expected to be removed during conversion. + using ConversionCallbackFn = + std::function(Type, SmallVectorImpl &)>; + + /// Generate a wrapper for the given callback. This allows for accepting + /// different callback forms, that all compose into a single version. + /// With callback of form: `Optional(T)` + template + std::enable_if_t::value, ConversionCallbackFn> + wrapCallback(FnT &&callback) { + return wrapCallback([=](T type, SmallVectorImpl &results) { + if (Optional resultOpt = callback(type)) { + bool wasSuccess = static_cast(resultOpt.getValue()); + if (wasSuccess) + results.push_back(resultOpt.getValue()); + return Optional(success(wasSuccess)); + } + return Optional(); + }); + } + /// With callback of form: `Optional(T, SmallVectorImpl<> &)` + template + std::enable_if_t::value, ConversionCallbackFn> + wrapCallback(FnT &&callback) { + return [=](Type type, + SmallVectorImpl &results) -> Optional { + T derivedType = type.dyn_cast(); + if (!derivedType) + return llvm::None; + return callback(derivedType, results); + }; + } + + /// Register a type conversion. + void registerConversion(ConversionCallbackFn callback) { + conversions.emplace_back(std::move(callback)); + } + + /// The set of registered conversion functions. + SmallVector conversions; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -29,25 +29,6 @@ namespace { -/// Derived type converter for GPU to NVVM lowering. The GPU dialect uses memory -/// space 5 for private memory attributions, but NVVM represents private -/// memory allocations as local `alloca`s in the default address space. This -/// converter drops the private memory space to support the use case above. -class NVVMTypeConverter : public LLVMTypeConverter { -public: - using LLVMTypeConverter::LLVMTypeConverter; - - Type convertType(Type type) override { - auto memref = type.dyn_cast(); - if (memref && - memref.getMemorySpace() == gpu::GPUDialect::getPrivateAddressSpace()) { - type = MemRefType::Builder(memref).setMemorySpace(0); - } - - return LLVMTypeConverter::convertType(type); - } -}; - /// Converts all_reduce op to LLVM/NVVM ops. struct GPUAllReduceOpLowering : public LLVMOpLowering { using AccumulatorFactory = @@ -684,8 +665,19 @@ public: void runOnOperation() override { gpu::GPUModuleOp m = getOperation(); + + /// MemRef conversion for GPU to NVVM lowering. The GPU dialect uses memory + /// space 5 for private memory attributions, but NVVM represents private + /// memory allocations as local `alloca`s in the default address space. This + /// converter drops the private memory space to support the use case above. + LLVMTypeConverter converter(m.getContext()); + converter.addConversion([&](MemRefType type) -> Optional { + if (type.getMemorySpace() != gpu::GPUDialect::getPrivateAddressSpace()) + return llvm::None; + return converter.convertType(MemRefType::Builder(type).setMemorySpace(0)); + }); + OwningRewritePatternList patterns; - NVVMTypeConverter converter(m.getContext()); populateStdToLLVMConversionPatterns(converter, patterns); populateGpuToNVVMConversionPatterns(converter, patterns); ConversionTarget target(getContext()); diff --git a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp --- a/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp +++ b/mlir/lib/Conversion/LinalgToLLVM/LinalgToLLVM.cpp @@ -73,30 +73,19 @@ .getPointerTo(); } -// Convert the given type to the LLVM IR Dialect type. The following -// conversions are supported: -// - an Index type is converted into an LLVM integer type with pointer -// bitwidth (analogous to intptr_t in C); -// - an Integer type is converted into an LLVM integer type of the same width; -// - an F32 type is converted into an LLVM float type -// - a Buffer, Range or View is converted into an LLVM structure type -// containing the respective dynamic values. -static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { +/// Convert the given range descriptor type to the LLVMIR dialect. +/// Range descriptor contains the range bounds and the step as 64-bit integers. +/// +/// struct { +/// int64_t min; +/// int64_t max; +/// int64_t step; +/// }; +static Type convertRangeType(RangeType t, LLVMTypeConverter &lowering) { auto *context = t.getContext(); auto int64Ty = lowering.convertType(IntegerType::get(64, context)) .cast(); - - // Range descriptor contains the range bounds and the step as 64-bit integers. - // - // struct { - // int64_t min; - // int64_t max; - // int64_t step; - // }; - if (t.isa()) - return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); - - return Type(); + return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); } namespace { @@ -146,7 +135,7 @@ ConversionPatternRewriter &rewriter) const override { auto rangeOp = cast(op); auto rangeDescriptorTy = - convertLinalgType(rangeOp.getResult().getType(), lowering); + convertRangeType(rangeOp.getType().cast(), lowering); edsc::ScopedContext context(rewriter, op->getLoc()); @@ -416,12 +405,6 @@ return fnNameAttr; } -Type LinalgTypeConverter::convertType(Type t) { - if (auto result = LLVMTypeConverter::convertType(t)) - return result; - return convertLinalgType(t, *this); -} - namespace { // LinalgOpConversion creates a new call to the @@ -553,10 +536,14 @@ /// Populate the given list with patterns that convert from Linalg to LLVM. void mlir::populateLinalgToLLVMConversionPatterns( - LinalgTypeConverter &converter, OwningRewritePatternList &patterns, + LLVMTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { patterns.insert(ctx, converter); + + // Populate the type conversions for the linalg types. + converter.addConversion( + [&](RangeType type) { return convertRangeType(type, converter); }); } namespace { @@ -570,7 +557,7 @@ // Convert to the LLVM IR dialect using the converter defined above. OwningRewritePatternList patterns; - LinalgTypeConverter converter(&getContext()); + LLVMTypeConverter converter(&getContext()); populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns, /*useAlloca=*/false, diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -12,7 +12,6 @@ //===----------------------------------------------------------------------===// #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" -#include "mlir/ADT/TypeSwitch.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/StandardOps/Ops.h" @@ -130,6 +129,19 @@ customizations(customs) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); + + // Register conversions for the standard types. + addConversion([&](FloatType type) { return convertFloatType(type); }); + addConversion([&](FunctionType type) { return convertFunctionType(type); }); + addConversion([&](IndexType type) { return convertIndexType(type); }); + addConversion([&](IntegerType type) { return convertIntegerType(type); }); + addConversion([&](MemRefType type) { return convertMemRefType(type); }); + addConversion( + [&](UnrankedMemRefType type) { return convertUnrankedMemRefType(type); }); + addConversion([&](VectorType type) { return convertVectorType(type); }); + + // LLVMType is legal, so add a pass-through conversion. + addConversion([](LLVM::LLVMType type) { return type; }); } /// Get the LLVM context. @@ -359,22 +371,6 @@ return vectorType; } -// Dispatch based on the actual type. Return null type on error. -Type LLVMTypeConverter::convertStandardType(Type t) { - return TypeSwitch(t) - .Case([&](FloatType type) { return convertFloatType(type); }) - .Case([&](FunctionType type) { return convertFunctionType(type); }) - .Case([&](IndexType type) { return convertIndexType(type); }) - .Case([&](IntegerType type) { return convertIntegerType(type); }) - .Case([&](MemRefType type) { return convertMemRefType(type); }) - .Case([&](UnrankedMemRefType type) { - return convertUnrankedMemRefType(type); - }) - .Case([&](VectorType type) { return convertVectorType(type); }) - .Case([](LLVM::LLVMType type) { return type; }) - .Default([](Type) { return Type(); }); -} - LLVMOpLowering::LLVMOpLowering(StringRef rootOpName, MLIRContext *context, LLVMTypeConverter &lowering_, PatternBenefit benefit) @@ -2650,9 +2646,6 @@ populateStdToLLVMMemoryConversionPatters(converter, patterns, useAlloca); } -// Convert types using the stored LLVM IR module. -Type LLVMTypeConverter::convertType(Type t) { return convertStandardType(t); } - // Create an LLVM IR structure type if there is more than one result. Type LLVMTypeConverter::packFunctionResults(ArrayRef types) { assert(!types.empty() && "expected non-empty list of type"); diff --git a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp --- a/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVLowering.cpp @@ -98,41 +98,37 @@ return llvm::None; } -static Type convertStdType(Type type) { - // If the type is already valid in SPIR-V, directly return. - if (spirv::SPIRVDialect::isValidType(type)) { - return type; - } - - if (auto indexType = type.dyn_cast()) { - return SPIRVTypeConverter::getIndexType(type.getContext()); - } - - if (auto memRefType = type.dyn_cast()) { +SPIRVTypeConverter::SPIRVTypeConverter() { + addConversion([](Type type) -> Optional { + // If the type is already valid in SPIR-V, directly return. + return spirv::SPIRVDialect::isValidType(type) ? type : Optional(); + }); + addConversion([](IndexType indexType) { + return SPIRVTypeConverter::getIndexType(indexType.getContext()); + }); + addConversion([this](MemRefType memRefType) -> Type { // TODO(ravishankarm): For now only support default memory space. The memory // space description is not set is stone within MLIR, i.e. it depends on the // context it is being used. To map this to SPIR-V storage classes, we // should rely on the ABI attributes, and not on the memory space. This is // still evolving, and needs to be revisited when there is more clarity. - if (memRefType.getMemorySpace()) { + if (memRefType.getMemorySpace()) return Type(); - } - auto elementType = convertStdType(memRefType.getElementType()); - if (!elementType) { + auto elementType = convertType(memRefType.getElementType()); + if (!elementType) return Type(); - } auto elementSize = getTypeNumBytes(elementType); - if (!elementSize) { + if (!elementSize) return Type(); - } + // TODO(ravishankarm) : Handle dynamic shapes. if (memRefType.hasStaticShape()) { auto arraySize = getTypeNumBytes(memRefType); - if (!arraySize) { + if (!arraySize) return Type(); - } + auto arrayType = spirv::ArrayType::get( elementType, arraySize.getValue() / elementSize.getValue(), elementSize.getValue()); @@ -142,34 +138,31 @@ return spirv::PointerType::get(structType, spirv::StorageClass::StorageBuffer); } - } - - if (auto tensorType = type.dyn_cast()) { + return Type(); + }); + addConversion([this](TensorType tensorType) -> Type { // TODO(ravishankarm) : Handle dynamic shapes. - if (!tensorType.hasStaticShape()) { + if (!tensorType.hasStaticShape()) return Type(); - } - auto elementType = convertStdType(tensorType.getElementType()); - if (!elementType) { + + auto elementType = convertType(tensorType.getElementType()); + if (!elementType) return Type(); - } + auto elementSize = getTypeNumBytes(elementType); - if (!elementSize) { + if (!elementSize) return Type(); - } + auto tensorSize = getTypeNumBytes(tensorType); - if (!tensorSize) { + if (!tensorSize) return Type(); - } + return spirv::ArrayType::get(elementType, tensorSize.getValue() / elementSize.getValue(), elementSize.getValue()); - } - return Type(); + }); } -Type SPIRVTypeConverter::convertType(Type type) { return convertStdType(type); } - //===----------------------------------------------------------------------===// // FuncOp Conversion Patterns //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp --- a/mlir/lib/Transforms/DialectConversion.cpp +++ b/mlir/lib/Transforms/DialectConversion.cpp @@ -1646,13 +1646,26 @@ /// This hooks allows for converting a type. LogicalResult TypeConverter::convertType(Type t, SmallVectorImpl &results) { - if (auto newT = convertType(t)) { - results.push_back(newT); - return success(); - } + // Walk the added converters in reverse order to apply the most recently + // registered first. + for (ConversionCallbackFn &converter : llvm::reverse(conversions)) + if (Optional result = converter(t, results)) + return *result; return failure(); } +/// This hook simplifies defining 1-1 type conversions. This function returns +/// the type to convert to on success, and a null type on failure. +Type TypeConverter::convertType(Type t) { + // Use the multi-type result version to convert the type. + SmallVector results; + if (failed(convertType(t, results))) + return nullptr; + + // Check to ensure that only one type was produced. + return results.size() == 1 ? results.front() : nullptr; +} + /// Convert the given set of types, filling 'results' as necessary. This /// returns failure if the conversion of any of the types fails, success /// otherwise. diff --git a/mlir/test/lib/TestDialect/TestPatterns.cpp b/mlir/test/lib/TestDialect/TestPatterns.cpp --- a/mlir/test/lib/TestDialect/TestPatterns.cpp +++ b/mlir/test/lib/TestDialect/TestPatterns.cpp @@ -305,8 +305,9 @@ namespace { struct TestTypeConverter : public TypeConverter { using TypeConverter::TypeConverter; + TestTypeConverter() { addConversion(convertType); } - LogicalResult convertType(Type t, SmallVectorImpl &results) override { + static LogicalResult convertType(Type t, SmallVectorImpl &results) { // Drop I16 types. if (t.isInteger(16)) return success();