diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h @@ -70,13 +70,23 @@ public: using TypeConverter::convertType; - /// Create an LLVMTypeConverter using the default - /// LLVMTypeConverterCustomization. - LLVMTypeConverter(MLIRContext *ctx); + /// Can be passed when the type converter is expected to derive the bitwidth + /// of the index type from data layout. + static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; - /// Create an LLVMTypeConverter using 'custom' customizations. - LLVMTypeConverter(MLIRContext *ctx, - const LLVMTypeConverterCustomization &custom); + /// Create an LLVMTypeConverter using the default + /// LLVMTypeConverterCustomization. `indexBitwidth` specifies the size of the + /// index type, which defaults to the pointer type of the used LLVM data + /// layout. + LLVMTypeConverter(MLIRContext *ctx, unsigned indexBitwidth = + kDeriveIndexBitwidthFromDataLayout); + + /// Create an LLVMTypeConverter using 'custom' customizations. `indexBitwidth` + /// specifies the size of the index type, which defaults to the pointer type + /// of the used LLVM data layout. + LLVMTypeConverter( + MLIRContext *ctx, const LLVMTypeConverterCustomization &custom, + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); /// Convert a function type. The arguments and results are converted one by /// one and results are packed into a wrapped LLVM IR structure type. `result` @@ -119,6 +129,13 @@ ArrayRef values, Location loc) override; + /// Gets the LLVM representation of the index type. The returned type is an + /// integer type with the size confgured for this type converter. + LLVM::LLVMType getIndexType(); + + /// Gets the bitwidth of the index type when converted to LLVM. + unsigned getIndexTypeBitwidth() { return indexBitwidth; } + protected: /// LLVM IR module used to parse/create types. llvm::Module *module; @@ -178,12 +195,11 @@ // Convert a 1D vector type into an LLVM vector type. Type convertVectorType(VectorType type); - // Get the LLVM representation of the index type based on the bitwidth of the - // pointer as defined by the data layout of the module. - LLVM::LLVMType getIndexType(); - /// Callbacks for customizing the type conversion. LLVMTypeConverterCustomization customizations; + + /// Bitwidth of the index type. + unsigned indexBitwidth; }; /// Helper class to produce LLVM dialect operations extracting or inserting diff --git a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h --- a/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h +++ b/mlir/include/mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h @@ -13,6 +13,7 @@ namespace mlir { class LLVMTypeConverter; + class ModuleOp; template class OpPassBase; class OwningRewritePatternList; @@ -53,13 +54,16 @@ LLVMTypeConverter &converter, OwningRewritePatternList &patterns, bool useAlloca = false); +static constexpr unsigned kDeriveIndexBitwidthFromDataLayout = 0; + /// Creates a pass to convert the Standard dialect into the LLVMIR dialect. /// By default stdlib malloc/free are used for allocating MemRef payloads. /// Specifying `useAlloca-true` emits stack allocations instead. In the future /// this may become an enum when we have concrete uses for other options. -std::unique_ptr> -createLowerToLLVMPass(bool useAlloca = false, bool useBarePtrCallConv = false, - bool emitCWrappers = false); +std::unique_ptr> createLowerToLLVMPass( + bool useAlloca = false, bool useBarePtrCallConv = false, + bool emitCWrappers = false, + unsigned indexBitwidth = kDeriveIndexBitwidthFromDataLayout); namespace LLVM { /// Make argument-taking successors of each block distinct. PHI nodes in LLVM diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -34,16 +34,12 @@ .Default(invalid); } - static unsigned getIndexBitWidth(LLVMTypeConverter &type_converter) { - auto dialect = type_converter.getDialect(); - return dialect->getLLVMModule().getDataLayout().getPointerSizeInBits(); - } - public: - explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &lowering_) + explicit GPUIndexIntrinsicOpLowering(LLVMTypeConverter &typeConverter) : ConvertToLLVMPattern(Op::getOperationName(), - lowering_.getDialect()->getContext(), lowering_), - indexBitwidth(getIndexBitWidth(lowering_)) {} + typeConverter.getDialect()->getContext(), + typeConverter), + indexBitwidth(typeConverter.getIndexTypeBitwidth()) {} // Convert the kernel arguments to an LLVM type, preserve the rest. PatternMatchResult diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -121,16 +121,19 @@ } /// Create an LLVMTypeConverter using default LLVMTypeConverterCustomization. -LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx) - : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization()) {} +LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx, unsigned bitwidth) + : LLVMTypeConverter(ctx, LLVMTypeConverterCustomization(), bitwidth) {} /// Create an LLVMTypeConverter using 'custom' customizations. LLVMTypeConverter::LLVMTypeConverter( - MLIRContext *ctx, const LLVMTypeConverterCustomization &customs) + MLIRContext *ctx, const LLVMTypeConverterCustomization &customs, + unsigned bitwidth) : llvmDialect(ctx->getRegisteredDialect()), - customizations(customs) { + customizations(customs), indexBitwidth(bitwidth) { assert(llvmDialect && "LLVM IR dialect is not registered"); module = &llvmDialect->getLLVMModule(); + if (indexBitwidth == kDeriveIndexBitwidthFromDataLayout) + indexBitwidth = module->getDataLayout().getPointerSizeInBits(); // Register conversions for the standard types. addConversion([&](FloatType type) { return convertFloatType(type); }); @@ -146,14 +149,15 @@ addConversion([](LLVM::LLVMType type) { return type; }); } +constexpr unsigned LLVMTypeConverter::kDeriveIndexBitwidthFromDataLayout; + /// Get the LLVM context. llvm::LLVMContext &LLVMTypeConverter::getLLVMContext() { return module->getContext(); } LLVM::LLVMType LLVMTypeConverter::getIndexType() { - return LLVM::LLVMType::getIntNTy( - llvmDialect, module->getDataLayout().getPointerSizeInBits()); + return LLVM::LLVMType::getIntNTy(llvmDialect, indexBitwidth); } Type LLVMTypeConverter::convertIndexType(IndexType type) { @@ -720,10 +724,7 @@ // Get the MLIR type wrapping the LLVM integer type whose bit width is defined // by the pointer size used in the LLVM module. - LLVM::LLVMType getIndexType() const { - return LLVM::LLVMType::getIntNTy( - &dialect, getModule().getDataLayout().getPointerSizeInBits()); - } + LLVM::LLVMType getIndexType() const { return typeConverter.getIndexType(); } LLVM::LLVMType getVoidType() const { return LLVM::LLVMType::getVoidTy(&dialect); @@ -3014,11 +3015,12 @@ /// A pass converting MLIR operations into the LLVM IR dialect. struct LLVMLoweringPass : public ModulePass { /// Creates an LLVM lowering pass. - explicit LLVMLoweringPass(bool useAlloca, bool useBarePtrCallConv, - bool emitCWrappers) { + LLVMLoweringPass(bool useAlloca, bool useBarePtrCallConv, bool emitCWrappers, + unsigned indexBitwidth) { this->useAlloca = useAlloca; this->useBarePtrCallConv = useBarePtrCallConv; this->emitCWrappers = emitCWrappers; + this->indexBitwidth = indexBitwidth; } explicit LLVMLoweringPass() {} LLVMLoweringPass(const LLVMLoweringPass &pass) {} @@ -3039,7 +3041,7 @@ LLVMTypeConverterCustomization customs; customs.funcArgConverter = useBarePtrCallConv ? barePtrFuncArgTypeConverter : structFuncArgTypeConverter; - LLVMTypeConverter typeConverter(&getContext(), customs); + LLVMTypeConverter typeConverter(&getContext(), customs, indexBitwidth); OwningRewritePatternList patterns; if (useBarePtrCallConv) @@ -3072,6 +3074,13 @@ *this, "emit-c-wrappers", llvm::cl::desc("Emit C-compatible wrapper functions"), llvm::cl::init(false)}; + + /// Configure the bitwidth of the index type when lowered to LLVM. + Option indexBitwidth{ + *this, "index-bitwidth", + llvm::cl::desc( + "Bitwidth of the index type, 0 to use size of machine word"), + llvm::cl::init(LLVMTypeConverter::kDeriveIndexBitwidthFromDataLayout)}; }; } // end namespace @@ -3083,9 +3092,9 @@ std::unique_ptr> mlir::createLowerToLLVMPass(bool useAlloca, bool useBarePtrCallConv, - bool emitCWrappers) { + bool emitCWrappers, unsigned indexBitwidth) { return std::make_unique(useAlloca, useBarePtrCallConv, - emitCWrappers); + emitCWrappers, indexBitwidth); } static PassRegistration diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -92,15 +92,19 @@ emitError(loc, "struct types are not supported in constants"); return nullptr; } + // For integer types, we allow a mismatch in sizes as the index type in + // MLIR might have a different size than the index type in the LLVM module. if (auto intAttr = attr.dyn_cast()) - return llvm::ConstantInt::get(llvmType, intAttr.getValue()); + return llvm::ConstantInt::get( + llvmType, + intAttr.getValue().sextOrTrunc(llvmType->getIntegerBitWidth())); if (auto floatAttr = attr.dyn_cast()) return llvm::ConstantFP::get(llvmType, floatAttr.getValue()); if (auto funcAttr = attr.dyn_cast()) return functionMapping.lookup(funcAttr.getValue()); if (auto splatAttr = attr.dyn_cast()) { auto *sequentialType = cast(llvmType); - auto elementType = sequentialType->getElementType(); + auto *elementType = sequentialType->getElementType(); uint64_t numElements = sequentialType->getNumElements(); // Splat value is a scalar. Extract it only if the element type is not // another sequence type. The recursion terminates because each step removes @@ -115,7 +119,7 @@ if (llvmType->isVectorTy()) return llvm::ConstantVector::getSplat(numElements, child); if (llvmType->isArrayTy()) { - auto arrayType = llvm::ArrayType::get(elementType, numElements); + auto *arrayType = llvm::ArrayType::get(elementType, numElements); SmallVector constants(numElements, child); return llvm::ConstantArray::get(arrayType, constants); } diff --git a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-dynamic-memref-ops.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s | FileCheck %s +// RIN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s | FileCheck %s // CHECK-LABEL: func @check_strided_memref_arguments( // CHECK-COUNT-2: !llvm<"float*"> diff --git a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir --- a/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir +++ b/mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir @@ -1,4 +1,5 @@ // RUN: mlir-opt -convert-std-to-llvm %s -split-input-file | FileCheck %s +// RUN: mlir-opt -convert-std-to-llvm='index-bitwidth=32' %s -split-input-file | FileCheck --check-prefix=CHECK32 %s // CHECK-LABEL: func @empty() { // CHECK-NEXT: llvm.return @@ -12,15 +13,21 @@ func @body(index) // CHECK-LABEL: func @simple_loop() { +// CHECK32-LABEL: func @simple_loop() { func @simple_loop() { ^bb0: // CHECK-NEXT: llvm.br ^bb1 +// CHECK32-NEXT: llvm.br ^bb1 br ^bb1 // CHECK-NEXT: ^bb1: // pred: ^bb0 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: llvm.br ^bb2({{.*}} : !llvm.i64) +// CHECK32-NEXT: ^bb1: // pred: ^bb0 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(42 : index) : !llvm.i32 +// CHECK32-NEXT: llvm.br ^bb2({{.*}} : !llvm.i32) ^bb1: // pred: ^bb0 %c1 = constant 1 : index %c42 = constant 42 : index @@ -29,6 +36,9 @@ // CHECK: ^bb2({{.*}}: !llvm.i64): // 2 preds: ^bb1, ^bb3 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : !llvm.i64 // CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 +// CHECK32: ^bb2({{.*}}: !llvm.i32): // 2 preds: ^bb1, ^bb3 +// CHECK32-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : !llvm.i32 +// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 ^bb2(%0: index): // 2 preds: ^bb1, ^bb3 %1 = cmpi "slt", %0, %c42 : index cond_br %1, ^bb3, ^bb4 @@ -38,6 +48,11 @@ // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64 // CHECK-NEXT: llvm.br ^bb2({{.*}} : !llvm.i64) +// CHECK32: ^bb3: // pred: ^bb2 +// CHECK32-NEXT: llvm.call @body({{.*}}) : (!llvm.i32) -> () +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i32 +// CHECK32-NEXT: llvm.br ^bb2({{.*}} : !llvm.i32) ^bb3: // pred: ^bb2 call @body(%0) : (index) -> () %c1_0 = constant 1 : index @@ -81,13 +96,18 @@ } // CHECK-LABEL: func @body_args(!llvm.i64) -> !llvm.i64 +// CHECK32-LABEL: func @body_args(!llvm.i32) -> !llvm.i32 func @body_args(index) -> index // CHECK-LABEL: func @other(!llvm.i64, !llvm.i32) -> !llvm.i32 +// CHECK32-LABEL: func @other(!llvm.i32, !llvm.i32) -> !llvm.i32 func @other(index, i32) -> i32 // CHECK-LABEL: func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 { // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : i32) : !llvm.i32 // CHECK-NEXT: llvm.br ^bb1 +// CHECK32-LABEL: func @func_args(%arg0: !llvm.i32, %arg1: !llvm.i32) -> !llvm.i32 { +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(0 : i32) : !llvm.i32 +// CHECK32-NEXT: llvm.br ^bb1 func @func_args(i32, i32) -> i32 { ^bb0(%arg0: i32, %arg1: i32): %c0_i32 = constant 0 : i32 @@ -97,6 +117,10 @@ // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.mlir.constant(42 : index) : !llvm.i64 // CHECK-NEXT: llvm.br ^bb2({{.*}} : !llvm.i64) +// CHECK32-NEXT: ^bb1: // pred: ^bb0 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(42 : index) : !llvm.i32 +// CHECK32-NEXT: llvm.br ^bb2({{.*}} : !llvm.i32) ^bb1: // pred: ^bb0 %c0 = constant 0 : index %c42 = constant 42 : index @@ -105,6 +129,9 @@ // CHECK-NEXT: ^bb2({{.*}}: !llvm.i64): // 2 preds: ^bb1, ^bb3 // CHECK-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : !llvm.i64 // CHECK-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 +// CHECK32-NEXT: ^bb2({{.*}}: !llvm.i32): // 2 preds: ^bb1, ^bb3 +// CHECK32-NEXT: {{.*}} = llvm.icmp "slt" {{.*}}, {{.*}} : !llvm.i32 +// CHECK32-NEXT: llvm.cond_br {{.*}}, ^bb3, ^bb4 ^bb2(%0: index): // 2 preds: ^bb1, ^bb3 %1 = cmpi "slt", %0, %c42 : index cond_br %1, ^bb3, ^bb4 @@ -117,6 +144,14 @@ // CHECK-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64 // CHECK-NEXT: llvm.br ^bb2({{.*}} : !llvm.i64) +// CHECK32-NEXT: ^bb3: // pred: ^bb2 +// CHECK32-NEXT: {{.*}} = llvm.call @body_args({{.*}}) : (!llvm.i32) -> !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.call @other({{.*}}, %arg0) : (!llvm.i32, !llvm.i32) -> !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.call @other({{.*}}, {{.*}}) : (!llvm.i32, !llvm.i32) -> !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.call @other({{.*}}, %arg1) : (!llvm.i32, !llvm.i32) -> !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(1 : index) : !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i32 +// CHECK32-NEXT: llvm.br ^bb2({{.*}} : !llvm.i32) ^bb3: // pred: ^bb2 %2 = call @body_args(%0) : (index) -> index %3 = call @other(%2, %arg0) : (index, i32) -> i32 @@ -130,6 +165,10 @@ // CHECK-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.call @other({{.*}}, {{.*}}) : (!llvm.i64, !llvm.i32) -> !llvm.i32 // CHECK-NEXT: llvm.return {{.*}} : !llvm.i32 +// CHECK32-NEXT: ^bb4: // pred: ^bb2 +// CHECK32-NEXT: {{.*}} = llvm.mlir.constant(0 : index) : !llvm.i32 +// CHECK32-NEXT: {{.*}} = llvm.call @other({{.*}}, {{.*}}) : (!llvm.i32, !llvm.i32) -> !llvm.i32 +// CHECK32-NEXT: llvm.return {{.*}} : !llvm.i32 ^bb4: // pred: ^bb2 %c0_0 = constant 0 : index %7 = call @other(%c0_0, %c0_i32) : (index, i32) -> i32 @@ -137,12 +176,15 @@ } // CHECK-LABEL: func @pre(!llvm.i64) +// CHECK32-LABEL: func @pre(!llvm.i32) func @pre(index) // CHECK-LABEL: func @body2(!llvm.i64, !llvm.i64) +// CHECK32-LABEL: func @body2(!llvm.i32, !llvm.i32) func @body2(index, index) // CHECK-LABEL: func @post(!llvm.i64) +// CHECK32-LABEL: func @post(!llvm.i32) func @post(index) // CHECK-LABEL: func @imperfectly_nested_loops() { @@ -326,14 +368,19 @@ // CHECK-LABEL: func @get_f32() -> !llvm.float func @get_f32() -> (f32) // CHECK-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> +// CHECK32-LABEL: func @get_memref() -> !llvm<"{ float*, float*, i32, [4 x i32], [4 x i32] }"> func @get_memref() -> (memref<42x?x10x?xf32>) // CHECK-LABEL: func @multireturn() -> !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> { +// CHECK32-LABEL: func @multireturn() -> !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> { func @multireturn() -> (i64, f32, memref<42x?x10x?xf32>) { ^bb0: // CHECK-NEXT: {{.*}} = llvm.call @get_i64() : () -> !llvm.i64 // CHECK-NEXT: {{.*}} = llvm.call @get_f32() : () -> !llvm.float // CHECK-NEXT: {{.*}} = llvm.call @get_memref() : () -> !llvm<"{ float*, float*, i64, [4 x i64], [4 x i64] }"> +// CHECK32-NEXT: {{.*}} = llvm.call @get_i64() : () -> !llvm.i64 +// CHECK32-NEXT: {{.*}} = llvm.call @get_f32() : () -> !llvm.float +// CHECK32-NEXT: {{.*}} = llvm.call @get_memref() : () -> !llvm<"{ float*, float*, i32, [4 x i32], [4 x i32] }"> %0 = call @get_i64() : () -> (i64) %1 = call @get_f32() : () -> (f32) %2 = call @get_memref() : () -> (memref<42x?x10x?xf32>) @@ -342,17 +389,27 @@ // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> // CHECK-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> // CHECK-NEXT: llvm.return {{.*}} : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> +// CHECK32-NEXT: {{.*}} = llvm.mlir.undef : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[0] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[1] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.insertvalue {{.*}}, {{.*}}[2] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: llvm.return {{.*}} : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> return %0, %1, %2 : i64, f32, memref<42x?x10x?xf32> } // CHECK-LABEL: func @multireturn_caller() { +// CHECK32-LABEL: func @multireturn_caller() { func @multireturn_caller() { ^bb0: // CHECK-NEXT: {{.*}} = llvm.call @multireturn() : () -> !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> // CHECK-NEXT: {{.*}} = llvm.extractvalue {{.*}}[2] : !llvm<"{ i64, float, { float*, float*, i64, [4 x i64], [4 x i64] } }"> +// CHECK32-NEXT: {{.*}} = llvm.call @multireturn() : () -> !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.extractvalue {{.*}}[0] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.extractvalue {{.*}}[1] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> +// CHECK32-NEXT: {{.*}} = llvm.extractvalue {{.*}}[2] : !llvm<"{ i64, float, { float*, float*, i32, [4 x i32], [4 x i32] } }"> %0:3 = call @multireturn() : () -> (i64, f32, memref<42x?x10x?xf32>) %1 = constant 42 : i64 // CHECK: {{.*}} = llvm.add {{.*}}, {{.*}} : !llvm.i64 @@ -757,9 +814,16 @@ // CHECK: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i64, // CHECK: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i64, // CHECK: %[[ARG2:.*]]: !llvm.i64) +// CHECK32-LABEL: func @subview( +// CHECK32-COUNT-2: !llvm<"float*">, +// CHECK32-COUNT-5: {{%[a-zA-Z0-9]*}}: !llvm.i32, +// CHECK32: %[[ARG0:[a-zA-Z0-9]*]]: !llvm.i32, +// CHECK32: %[[ARG1:[a-zA-Z0-9]*]]: !llvm.i32, +// CHECK32: %[[ARG2:.*]]: !llvm.i32) func @subview(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -778,15 +842,34 @@ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i32 + // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + %1 = subview %0[%arg0, %arg1][%arg0, %arg1][%arg0, %arg1] : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * s1 + d1 * s2 + s0)>> return } // CHECK-LABEL: func @subview_const_size( +// CHECK32-LABEL: func @subview_const_size( func @subview_const_size(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -807,15 +890,36 @@ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i64 // CHECK: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i32 + // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST2]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESCSTRIDE1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[DESCSTRIDE1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST4]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESCSTRIDE0:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: llvm.insertvalue %[[DESCSTRIDE0]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> %1 = subview %0[%arg0, %arg1][][%arg0, %arg1] : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<4x2xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + d1 * s2 + s0)>> return } // CHECK-LABEL: func @subview_const_stride( +// CHECK32-LABEL: func @subview_const_stride( func @subview_const_stride(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>, %arg0 : index, %arg1 : index, %arg2 : index) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -834,15 +938,34 @@ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFF:.*]] = llvm.extractvalue %[[MEMREF]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[OFFINC:.*]] = llvm.mul %[[ARG0]], %[[STRIDE0]] : !llvm.i32 + // CHECK32: %[[OFF1:.*]] = llvm.add %[[OFF]], %[[OFFINC]] : !llvm.i32 + // CHECK32: %[[OFFINC1:.*]] = llvm.mul %[[ARG1]], %[[STRIDE1]] : !llvm.i32 + // CHECK32: %[[OFF2:.*]] = llvm.add %[[OFF1]], %[[OFFINC1]] : !llvm.i32 + // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[OFF2]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[ARG1]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST2:.*]] = llvm.mlir.constant(2 : i64) + // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST2]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[ARG0]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> %1 = subview %0[%arg0, %arg1][%arg0, %arg1][] : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref (d0 * 4 + d1 * 2 + s0)>> return } // CHECK-LABEL: func @subview_const_stride_and_offset( +// CHECK32-LABEL: func @subview_const_stride_and_offset( func @subview_const_stride_and_offset(%0 : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>>) { // The last "insertvalue" that populates the memref descriptor from the function arguments. // CHECK: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] + // CHECK32: %[[MEMREF:.*]] = llvm.insertvalue %{{.*}}, %{{.*}}[4, 1] // CHECK: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> @@ -859,6 +982,21 @@ // CHECK: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> // CHECK: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) // CHECK: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i64, [2 x i64], [2 x i64] }"> + // CHECK32: %[[DESC:.*]] = llvm.mlir.undef : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC0:.*]] = llvm.insertvalue %{{.*}}, %[[DESC]][0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC1:.*]] = llvm.insertvalue %{{.*}}, %[[DESC0]][1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE0:.*]] = llvm.extractvalue %[[MEMREF]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[STRIDE1:.*]] = llvm.extractvalue %[[MEMREF]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST62:.*]] = llvm.mlir.constant(62 : i64) + // CHECK32: %[[CST3:.*]] = llvm.mlir.constant(3 : i64) + // CHECK32: %[[CST8:.*]] = llvm.mlir.constant(8 : index) + // CHECK32: %[[DESC2:.*]] = llvm.insertvalue %[[CST8]], %[[DESC1]][2] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC3:.*]] = llvm.insertvalue %[[CST3]], %[[DESC2]][3, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST1:.*]] = llvm.mlir.constant(1 : i64) + // CHECK32: %[[DESC4:.*]] = llvm.insertvalue %[[CST1]], %[[DESC3]][4, 1] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[DESC5:.*]] = llvm.insertvalue %[[CST62]], %[[DESC4]][3, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> + // CHECK32: %[[CST4:.*]] = llvm.mlir.constant(4 : i64) + // CHECK32: llvm.insertvalue %[[CST4]], %[[DESC5]][4, 0] : !llvm<"{ float*, float*, i32, [2 x i32], [2 x i32] }"> %1 = subview %0[][][] : memref<64x4xf32, affine_map<(d0, d1) -> (d0 * 4 + d1)>> to memref<62x3xf32, affine_map<(d0, d1) -> (d0 * 4 + d1 + 8)>> return diff --git a/mlir/test/Target/llvmir.mlir b/mlir/test/Target/llvmir.mlir --- a/mlir/test/Target/llvmir.mlir +++ b/mlir/test/Target/llvmir.mlir @@ -1180,4 +1180,4 @@ // CHECK: freeze i32 undef %2 = llvm.freeze %1 : !llvm.i32 llvm.return -} \ No newline at end of file +}